# Shedding Hub Data Analysis

This notebook provides two ways to work with the shedding data:
1. **Launch the Interactive Dashboard** - Run the full Dash web application
2. **Standalone Analysis** - Explore data interactively in the notebook

Choose which mode you want to use below.

## Setup and Data Loading

In [None]:
# Import required packages
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import yaml
import fsspec
import os
import glob
import re
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

print("Packages imported successfully!")

In [2]:
from shedding_hub.viz import (
        plot_time_course,
        plot_time_courses,
        plot_shedding_heatmap,
        plot_mean_trajectory,
        plot_value_distribution_by_time,
        plot_detection_probability,
        plot_clearance_curve
    )

In [None]:
# Set up GitHub credentials (optional - can work without if repo is public)
GITHUB_USERNAME = os.getenv("GITHUB_USERNAME")
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN")

print(f"GitHub Username: {GITHUB_USERNAME if GITHUB_USERNAME else 'Not set (using public access)'}")

In [None]:
# Helper functions
def key_missing(x, key):
    """Return unknown for variables not in yaml"""
    try:
        return x[key]
    except KeyError:
        return 'unknown'
    except TypeError:
        return 'unknown'

def list_join(x):
    """Return the joined list if multiple specimens combined"""
    if isinstance(x, list):
        return '; '.join(x)
    else:
        return x

print("Helper functions defined!")

## Download and Load Data

In [None]:
# Download YAML files from shedding-hub repository
destination = Path.cwd() / "data"
destination.mkdir(exist_ok=True, parents=True)

fs = fsspec.filesystem(
    "github",
    org="shedding-hub",
    repo="shedding-hub",
    username=GITHUB_USERNAME,
    token=GITHUB_TOKEN
)

fs.get(fs.glob("data/**/*.yaml"), destination.as_posix(), recursive=True)
print(f"Data downloaded to {destination}")

In [None]:
# Load all YAML files
list_file = glob.glob("data/*.yaml")
list_yaml = []

for file in list_file:
    list_yaml.append(yaml.safe_load(Path(file).read_text()))

print(f"Loaded {len(list_yaml)} YAML files")

## Extract Analyte Data

In [None]:
# Extract analyte information
list_analyte = []

for df in range(len(list_yaml)):
    for analyte in list_yaml[df]['analytes']:
        list_analyte.append([
            re.split(r'[\\/\.]+', list_file[df])[1],
            analyte,
            list_yaml[df]['analytes'][analyte]['biomarker'],
            key_missing(list_yaml[df]['analytes'][analyte], 'gene_target'),
            list_join(list_yaml[df]['analytes'][analyte]['specimen']),
            list_yaml[df]['analytes'][analyte]['unit'],
            list_yaml[df]['analytes'][analyte]['limit_of_detection'],
            list_yaml[df]['analytes'][analyte]['limit_of_quantification'],
            list_yaml[df]['analytes'][analyte]['reference_event']
        ])

df_analyte = pd.DataFrame(
    list_analyte,
    columns=['ID', 'analyte', 'biomarker', 'gene_target', 'specimen', 'unit', 'LOD', 'LOQ', 'reference_event']
)
df_analyte[' index'] = range(1, len(df_analyte) + 1)

print(f"Extracted {len(df_analyte)} analyte records")
df_analyte.head()

## Extract Participant Data

In [None]:
# Extract participant information
list_participant = []

for df in range(len(list_yaml)):
    for participant_id in range(len(list_yaml[df]['participants'])):
        list_participant.append([
            re.split(r'[\\/\.]+', list_file[df])[1],
            participant_id + 1,
            key_missing(key_missing(list_yaml[df]['participants'][participant_id], 'attributes'), 'age'),
            key_missing(key_missing(list_yaml[df]['participants'][participant_id], 'attributes'), 'sex'),
            key_missing(key_missing(list_yaml[df]['participants'][participant_id], 'attributes'), 'race'),
            key_missing(key_missing(list_yaml[df]['participants'][participant_id], 'attributes'), 'ethnicity'),
            key_missing(key_missing(list_yaml[df]['participants'][participant_id], 'attributes'), 'vaccinated')
        ])

df_participant = pd.DataFrame(
    list_participant,
    columns=['ID', 'participant_ID', 'age', 'sex', 'race', 'ethnicity', 'vaccinated']
)
df_participant[' index'] = range(1, len(df_participant) + 1)

print(f"Extracted {len(df_participant)} participant records")
df_participant.head()

## Extract Measurement Data

In [None]:
# Extract measurement data
list_measurement = []

for df in range(len(list_yaml)):
    for participant_id in range(len(list_yaml[df]['participants'])):
        for measurement_id in range(len(list_yaml[df]['participants'][participant_id]['measurements'])):
            list_measurement.append([
                re.split(r'[\\/\.]+', list_file[df])[1],
                participant_id + 1,
                measurement_id + 1,
                key_missing(list_yaml[df]['participants'][participant_id]['measurements'][measurement_id], 'analyte'),
                key_missing(list_yaml[df]['participants'][participant_id]['measurements'][measurement_id], 'time'),
                key_missing(list_yaml[df]['participants'][participant_id]['measurements'][measurement_id], 'value')
            ])

df_measurement = pd.DataFrame(
    list_measurement,
    columns=['ID', 'participant_ID', 'measurement_ID', 'analyte', 'time', 'value']
)
df_measurement[' index'] = range(1, len(df_measurement) + 1)

print(f"Extracted {len(df_measurement)} measurement records")
df_measurement.head()

## Data Summary

In [None]:
# Display summary statistics
print("Available Biomarkers:")
print(df_analyte['biomarker'].unique())

print("\nAvailable Specimens:")
print(df_analyte['specimen'].unique())

print("\nAvailable Gene Targets:")
print(df_analyte['gene_target'].unique())

---

# Choose Your Analysis Mode

You now have two options:

**Option 1**: Run the full interactive Dash dashboard (see next cell)  
**Option 2**: Do standalone analysis in the notebook (skip to later cells)

---

# Option 1: Launch Interactive Dashboard

Run the cell below to launch the full Dash dashboard in your browser. The dashboard will open at http://localhost:8050/

**Note:** The dashboard will run until you stop the cell (use the stop button in Jupyter or Ctrl+C in terminal).

In [None]:
# Run this cell to launch the Dash dashboard
# The dashboard will open at http://localhost:8050/ or http://127.0.0.1:8050/
# To stop the server, click the stop button or press Ctrl+C

from dash import Dash, html, dash_table, dcc, callback, Output, Input

# Get unique values for dropdowns
list_biomarker = df_analyte["biomarker"].unique()
list_specimen = df_analyte["specimen"].unique()

# Functions to generate dashboard components
def description_card():
    return html.Div(
        id="description-card",
        children=[
            html.H5("Shedding Information Analytics"),
            html.H3("Welcome to the Shedding Hub Dashboard"),
            html.Div(
                id="intro",
                children="Explore the pathogen/biomarker shedding in different human specimens.",
            ),
        ],
    )

def generate_control_card():
    return html.Div(
        id="control-card",
        children=[
            html.P("Select Biomarker"),
            dcc.Dropdown(
                id="biomarker-select",
                options=[{"label": i, "value": i} for i in list_biomarker],
                value="SARS-CoV-2",
            ),
            html.Br(),
            html.P("Select Specimen"),
            dcc.Dropdown(
                id="specimen-select",
            ),
            html.Br(),
            html.Br(),
            html.P("Select Gene Targets"),
            dcc.Dropdown(
                id="gene-select",
                multi=True,
            ),
            html.Br(),
            html.Div(
                id="reset-btn-outer",
                children=html.Button(id="reset-btn", children="Reset", n_clicks=0),
            ),
        ],
    )

# Initialize the app
app = Dash(
    __name__,
    meta_tags=[
        {"name": "viewport", "content": "width=device-width, initial-scale=1"}
    ]
)
app.title = "Shedding Hub Dashboard"

# App layout
app.layout = html.Div(
    id="app-container",
    children=[
        # Banner
        html.Div(
            id="banner",
            className="banner",
            children=[html.H1("Shedding Hub Dashboard")],
        ),
        # Left column
        html.Div(
            id="left-column",
            className="three columns",
            children=[description_card(), generate_control_card()],
            style={'width': '25%', 'display': 'inline-block', 'vertical-align': 'top'}
        ),
        # Right column
        html.Div(
            id="right-column",
            className="nine columns",
            children=[
                # Tables
                html.Div(
                    id="tables",
                    children=[
                        html.B("Data Tables"),
                        html.Hr(),
                        dcc.Tabs([
                            dcc.Tab(label='Participant Information', children=[
                                dash_table.DataTable(
                                    id='table-participant-paging-and-sorting',
                                    columns=[
                                        {'name': i, 'id': i, 'deletable': True} for i in sorted(df_participant.columns)
                                    ],
                                    page_current=0,
                                    page_size=10,
                                    page_action='custom',
                                    sort_action='custom',
                                    sort_mode='single',
                                    sort_by=[]
                                ),
                            ]),
                            dcc.Tab(label='Analyte Information', children=[
                                dash_table.DataTable(
                                    id='table-analyte-paging-and-sorting',
                                    columns=[
                                        {'name': i, 'id': i, 'deletable': True} for i in sorted(df_analyte.columns)
                                    ],
                                    page_current=0,
                                    page_size=10,
                                    page_action='custom',
                                    sort_action='custom',
                                    sort_mode='single',
                                    sort_by=[]
                                ),
                            ]),
                            dcc.Tab(label='Measurement', children=[
                                dash_table.DataTable(
                                    id='table-measurement-paging-and-sorting',
                                    columns=[
                                        {'name': i, 'id': i, 'deletable': True} for i in sorted(df_measurement.columns)
                                    ],
                                    page_current=0,
                                    page_size=10,
                                    page_action='custom',
                                    sort_action='custom',
                                    sort_mode='single',
                                    sort_by=[]
                                ),
                            ]),
                        ]),
                    ],
                ),
                # Plots
                html.Div(
                    id="plots",
                    children=[
                        html.Hr(),
                        html.Div([
                                html.Div(dcc.Graph(id='scatter_plot_symptom_onset'), style={'width': '48%', 'display': 'inline-block'}),
                                html.Div(dcc.Graph(id='scatter_plot_symptom_onset_ct'), style={'width': '48%', 'display': 'inline-block'}),
                            ]
                        ),
                        html.Hr(),
                        html.Div([
                                html.Div(dcc.Graph(id='scatter_plot_confirmation'), style={'width': '48%', 'display': 'inline-block'}),
                                html.Div(dcc.Graph(id='scatter_plot_confirmation_ct'), style={'width': '48%', 'display': 'inline-block'}),
                            ]
                        ),
                        html.Hr(),
                        html.Div([
                                html.Div(dcc.Graph(id='scatter_plot_enrollment'), style={'width': '48%', 'display': 'inline-block'}),
                                html.Div(dcc.Graph(id='scatter_plot_enrollment_ct'), style={'width': '48%', 'display': 'inline-block'}),
                            ]
                        ),
                    ],
                ),
            ],
            style={'width': '70%', 'display': 'inline-block', 'vertical-align': 'top'}
        ),
    ],
)

# Callbacks for tables
@callback(
    Output('table-participant-paging-and-sorting', 'data'),
    Input('table-participant-paging-and-sorting', "page_current"),
    Input('table-participant-paging-and-sorting', "page_size"),
    Input('table-participant-paging-and-sorting', 'sort_by'),
)
def update_participant_table(page_current, page_size, sort_by):
    if len(sort_by):
        dff_participant = df_participant.sort_values(
            sort_by[0]['column_id'],
            ascending=sort_by[0]['direction'] == 'asc',
            inplace=False
        )
    else:
        dff_participant = df_participant
    return dff_participant.iloc[
        page_current*page_size:(page_current+ 1)*page_size
    ].to_dict('records')

@callback(
    Output('table-analyte-paging-and-sorting', 'data'),
    Input('table-analyte-paging-and-sorting', "page_current"),
    Input('table-analyte-paging-and-sorting', "page_size"),
    Input('table-analyte-paging-and-sorting', 'sort_by'),
)
def update_analyte_table(page_current, page_size, sort_by):
    if len(sort_by):
        dff_analyte = df_analyte.sort_values(
            sort_by[0]['column_id'],
            ascending=sort_by[0]['direction'] == 'asc',
            inplace=False
        )
    else:
        dff_analyte = df_analyte
    return dff_analyte.iloc[
        page_current*page_size:(page_current+ 1)*page_size
    ].to_dict('records')

@callback(
    Output('table-measurement-paging-and-sorting', 'data'),
    Input('table-measurement-paging-and-sorting', "page_current"),
    Input('table-measurement-paging-and-sorting', "page_size"),
    Input('table-measurement-paging-and-sorting', 'sort_by'),
)
def update_measurement_table(page_current, page_size, sort_by):
    if len(sort_by):
        dff_measurement = df_measurement.sort_values(
            sort_by[0]['column_id'],
            ascending=sort_by[0]['direction'] == 'asc',
            inplace=False
        )
    else:
        dff_measurement = df_measurement
    return dff_measurement.iloc[
        page_current*page_size:(page_current+ 1)*page_size
    ].to_dict('records')

# Callbacks for dropdowns
@app.callback(
    [Output('specimen-select', 'options'),
     Output('specimen-select', 'value')],
    Input('biomarker-select', 'value')
)
def update_specimen_dropdown(selected_biomarker):
    if selected_biomarker:
        items = df_analyte["specimen"][df_analyte["biomarker"]==selected_biomarker].unique()
        return [{'label': item, 'value': item} for item in items], items[0]
    return [], []

@app.callback(
    [Output('gene-select', 'options'),
     Output('gene-select', 'value')],
    Input('biomarker-select', 'value')
)
def update_gene_dropdown(selected_biomarker):
    if selected_biomarker:
        items = df_analyte["gene_target"][df_analyte["biomarker"]==selected_biomarker].unique()
        return [{'label': item, 'value': item} for item in items], list(items)
    return [], []

# Callback for plots (using the functions we already defined)
@callback(
    Output('scatter_plot_symptom_onset', 'figure'),
    Input('biomarker-select', 'value'),
    Input('specimen-select', 'value'),
    Input('gene-select', 'value'),
)
def update_symptom_onset_plot(selected_biomarker, selected_specimen, selected_gene):
    if not selected_gene:
        return {}
    fig = create_viral_load_plot(selected_biomarker, selected_specimen, selected_gene, 'symptom onset')
    return fig if fig else {}

@callback(
    Output('scatter_plot_confirmation', 'figure'),
    Input('biomarker-select', 'value'),
    Input('specimen-select', 'value'),
    Input('gene-select', 'value'),
)
def update_confirmation_plot(selected_biomarker, selected_specimen, selected_gene):
    if not selected_gene:
        return {}
    fig = create_viral_load_plot(selected_biomarker, selected_specimen, selected_gene, 'confirmation date')
    return fig if fig else {}

@callback(
    Output('scatter_plot_enrollment', 'figure'),
    Input('biomarker-select', 'value'),
    Input('specimen-select', 'value'),
    Input('gene-select', 'value'),
)
def update_enrollment_plot(selected_biomarker, selected_specimen, selected_gene):
    if not selected_gene:
        return {}
    fig = create_viral_load_plot(selected_biomarker, selected_specimen, selected_gene, 'enrollment')
    return fig if fig else {}

@callback(
    Output('scatter_plot_symptom_onset_ct', 'figure'),
    Input('biomarker-select', 'value'),
    Input('specimen-select', 'value'),
    Input('gene-select', 'value'),
)
def update_symptom_onset_ct_plot(selected_biomarker, selected_specimen, selected_gene):
    if not selected_gene:
        return {}
    fig = create_ct_plot(selected_biomarker, selected_specimen, selected_gene, 'symptom onset')
    return fig if fig else {}

@callback(
    Output('scatter_plot_confirmation_ct', 'figure'),
    Input('biomarker-select', 'value'),
    Input('specimen-select', 'value'),
    Input('gene-select', 'value'),
)
def update_confirmation_ct_plot(selected_biomarker, selected_specimen, selected_gene):
    if not selected_gene:
        return {}
    fig = create_ct_plot(selected_biomarker, selected_specimen, selected_gene, 'confirmation date')
    return fig if fig else {}

@callback(
    Output('scatter_plot_enrollment_ct', 'figure'),
    Input('biomarker-select', 'value'),
    Input('specimen-select', 'value'),
    Input('gene-select', 'value'),
)
def update_enrollment_ct_plot(selected_biomarker, selected_specimen, selected_gene):
    if not selected_gene:
        return {}
    fig = create_ct_plot(selected_biomarker, selected_specimen, selected_gene, 'enrollment')
    return fig if fig else {}

# Run the app
print("Starting Dash server...")
print("Dashboard will be available at: http://127.0.0.1:8050/")
print("Press the Stop button to stop the server")
app.run(debug=True, port=8050)

# Option 2: Standalone Interactive Analysis

Use the cells below to explore the data directly in the notebook without running the dashboard.

## Interactive Visualization Functions

In [None]:
def create_viral_load_plot(selected_biomarker, selected_specimen, selected_genes, reference_event='symptom onset'):
    """
    Create a viral load scatter plot for the selected parameters.
    
    Parameters:
    - selected_biomarker: str, e.g., 'SARS-CoV-2'
    - selected_specimen: str, e.g., 'stool'
    - selected_genes: list of str, e.g., ['N1', 'N2']
    - reference_event: str, 'symptom onset', 'confirmation date', or 'enrollment'
    """
    # Filter data
    filtered_df_analyte = df_analyte.loc[
        (df_analyte["biomarker"] == selected_biomarker) & 
        (df_analyte["specimen"] == selected_specimen) & 
        (df_analyte["gene_target"].isin(selected_genes)) & 
        (df_analyte["reference_event"] == reference_event) & 
        (df_analyte["unit"] != "cycle threshold")
    ]
    
    if len(filtered_df_analyte) == 0:
        print(f"No data found for {selected_biomarker} in {selected_specimen} with genes {selected_genes}")
        return None
    
    filtered_df_measurement = df_measurement.loc[
        (df_measurement['ID'] + df_measurement['analyte']).isin(
            filtered_df_analyte['ID'] + filtered_df_analyte['analyte']
        )
    ]
    
    # Calculate minimum values for negative replacement
    min_measurement = filtered_df_measurement.loc[
        (filtered_df_measurement["value"] != "negative") & 
        (filtered_df_measurement["value"] != "positive")
    ].groupby(["ID", "analyte"])["value"].min()
    
    filtered_df_analyte = filtered_df_analyte.merge(min_measurement, on=["ID", "analyte"])
    filtered_df_analyte = filtered_df_analyte.rename(columns={'value': 'min_value'})
    
    # Replace negative values
    def find_negative_replacement(row):
        if row['LOQ'] != 'unknown':
            return row['LOQ']
        elif row['LOD'] != 'unknown':
            return row['LOD']
        else:
            return row['min_value']
    
    filtered_df_analyte['neg_replacement_value'] = filtered_df_analyte.apply(find_negative_replacement, axis=1)
    filtered_df_measurement = filtered_df_measurement.merge(
        filtered_df_analyte[["ID", "analyte", 'neg_replacement_value']], 
        on=["ID", "analyte"]
    )
    
    filtered_df_measurement["value_w_replacement"] = filtered_df_measurement["value"]
    filtered_df_measurement.loc[
        filtered_df_measurement["value"] == "negative", 
        "value_w_replacement"
    ] = filtered_df_measurement.loc[
        filtered_df_measurement["value"] == "negative", 
        "neg_replacement_value"
    ]
    filtered_df_measurement.loc[
        filtered_df_measurement["value"] == "positive", 
        "value_w_replacement"
    ] = filtered_df_measurement.loc[
        filtered_df_measurement["value"] == "positive", 
        "neg_replacement_value"
    ]
    
    filtered_df_measurement["pos"] = "positive"
    filtered_df_measurement.loc[filtered_df_measurement["value"] == "negative", "pos"] = "negative"
    filtered_df_measurement.loc[filtered_df_measurement["value"] == "positive", "pos"] = "negative"
    
    # Create plot
    fig = px.scatter(
        filtered_df_measurement, 
        x='time', 
        y='value_w_replacement', 
        log_y=True,
        color="ID", 
        symbol="pos",
        labels={
            "time": f"Days after {reference_event.capitalize()}", 
            "value_w_replacement": "Viral Load (gc/mL or gc/gram or gc/swab)", 
            "ID": "Study"
        },
        title=f"{selected_biomarker} Shedding in {selected_specimen.capitalize()} - {reference_event.capitalize()}"
    )
    
    # Style traces
    def style_trace(trace):
        if "positive" in trace.name:
            trace.showlegend = True
            trace.marker.symbol = "circle"
        if "negative" in trace.name:
            trace.showlegend = False
            trace.marker.symbol = "x"
        trace.name = trace.name.split(",")[-2]
        trace.legendgroup = trace.legendgroup.split(",")[-2]
    
    fig.for_each_trace(style_trace)
    fig.update_layout(
        legend_title_text='Study',
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    return fig

In [None]:
def create_ct_plot(selected_biomarker, selected_specimen, selected_genes, reference_event='symptom onset'):
    """
    Create a CT value scatter plot for the selected parameters.
    
    Parameters:
    - selected_biomarker: str, e.g., 'SARS-CoV-2'
    - selected_specimen: str, e.g., 'stool'
    - selected_genes: list of str, e.g., ['N1', 'N2']
    - reference_event: str, 'symptom onset', 'confirmation date', or 'enrollment'
    """
    # Filter data for CT values
    filtered_df_analyte = df_analyte.loc[
        (df_analyte["biomarker"] == selected_biomarker) & 
        (df_analyte["specimen"] == selected_specimen) & 
        (df_analyte["gene_target"].isin(selected_genes)) & 
        (df_analyte["reference_event"] == reference_event) & 
        (df_analyte["unit"] == "cycle threshold")
    ]
    
    if len(filtered_df_analyte) == 0:
        print(f"No CT data found for {selected_biomarker} in {selected_specimen} with genes {selected_genes}")
        return None
    
    filtered_df_measurement = df_measurement.loc[
        (df_measurement['ID'] + df_measurement['analyte']).isin(
            filtered_df_analyte['ID'] + filtered_df_analyte['analyte']
        )
    ]
    
    # Calculate maximum values for negative replacement
    max_measurement = filtered_df_measurement.loc[
        (filtered_df_measurement["value"] != "negative") & 
        (filtered_df_measurement["value"] != "positive")
    ].groupby(["ID", "analyte"])["value"].max()
    
    filtered_df_analyte = filtered_df_analyte.merge(max_measurement, on=["ID", "analyte"])
    filtered_df_analyte = filtered_df_analyte.rename(columns={'value': 'max_value'})
    
    # Replace negative values
    def find_negative_replacement(row):
        if row['LOQ'] != 'unknown':
            return row['LOQ']
        elif row['LOD'] != 'unknown':
            return row['LOD']
        else:
            return row['max_value']
    
    filtered_df_analyte['neg_replacement_value'] = filtered_df_analyte.apply(find_negative_replacement, axis=1)
    filtered_df_measurement = filtered_df_measurement.merge(
        filtered_df_analyte[["ID", "analyte", 'neg_replacement_value']], 
        on=["ID", "analyte"]
    )
    
    filtered_df_measurement["value_w_replacement"] = filtered_df_measurement["value"]
    filtered_df_measurement.loc[
        filtered_df_measurement["value"] == "negative", 
        "value_w_replacement"
    ] = filtered_df_measurement.loc[
        filtered_df_measurement["value"] == "negative", 
        "neg_replacement_value"
    ]
    filtered_df_measurement.loc[
        filtered_df_measurement["value"] == "positive", 
        "value_w_replacement"
    ] = filtered_df_measurement.loc[
        filtered_df_measurement["value"] == "positive", 
        "neg_replacement_value"
    ]
    
    filtered_df_measurement["pos"] = "positive"
    filtered_df_measurement.loc[filtered_df_measurement["value"] == "negative", "pos"] = "negative"
    filtered_df_measurement.loc[filtered_df_measurement["value"] == "positive", "pos"] = "negative"
    
    # Create plot
    fig = px.scatter(
        filtered_df_measurement, 
        x='time', 
        y='value_w_replacement', 
        log_y=True,
        color="ID", 
        symbol="pos",
        labels={
            "time": f"Days after {reference_event.capitalize()}", 
            "value_w_replacement": "Ct value", 
            "ID": "Study"
        },
        title=f"{selected_biomarker} CT Values in {selected_specimen.capitalize()} - {reference_event.capitalize()}"
    )
    
    # Style traces
    def style_trace(trace):
        if "positive" in trace.name:
            trace.showlegend = True
            trace.marker.symbol = "circle"
        if "negative" in trace.name:
            trace.showlegend = False
            trace.marker.symbol = "x"
        trace.name = trace.name.split(",")[-2]
        trace.legendgroup = trace.legendgroup.split(",")[-2]
    
    fig.for_each_trace(style_trace)
    fig.update_layout(
        legend_title_text='Study',
        yaxis_autorange="reversed",
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    return fig

## Example Visualizations

### Modify the parameters below to explore different biomarkers, specimens, and gene targets

In [None]:
# Set parameters for visualization
biomarker = "SARS-CoV-2"
specimen = "stool"

# Get available gene targets for this biomarker
available_genes = df_analyte[df_analyte["biomarker"] == biomarker]["gene_target"].unique()
print(f"Available gene targets for {biomarker}: {available_genes}")

# Select gene targets (modify as needed)
gene_targets = [g for g in available_genes if g != 'unknown']

### Viral Load Plots

In [None]:
# Symptom Onset - Viral Load
fig1 = create_viral_load_plot(biomarker, specimen, gene_targets, 'symptom onset')
if fig1:
    fig1.show()

In [None]:
# Confirmation Date - Viral Load
fig2 = create_viral_load_plot(biomarker, specimen, gene_targets, 'confirmation date')
if fig2:
    fig2.show()

In [None]:
# Enrollment - Viral Load
fig3 = create_viral_load_plot(biomarker, specimen, gene_targets, 'enrollment')
if fig3:
    fig3.show()

### CT Value Plots

In [None]:
# Symptom Onset - CT Values
fig4 = create_ct_plot(biomarker, specimen, gene_targets, 'symptom onset')
if fig4:
    fig4.show()

In [None]:
# Confirmation Date - CT Values
fig5 = create_ct_plot(biomarker, specimen, gene_targets, 'confirmation date')
if fig5:
    fig5.show()

In [None]:
# Enrollment - CT Values
fig6 = create_ct_plot(biomarker, specimen, gene_targets, 'enrollment')
if fig6:
    fig6.show()

## Data Export

Export the processed dataframes to CSV files for further analysis

In [None]:
# Export to CSV
df_analyte.to_csv('analyte_data.csv', index=False)
df_participant.to_csv('participant_data.csv', index=False)
df_measurement.to_csv('measurement_data.csv', index=False)

print("Data exported to CSV files:")
print("- analyte_data.csv")
print("- participant_data.csv")
print("- measurement_data.csv")