In [145]:
#| label: fig4cell

import plotly.express as px
from plotly.offline import plot
from IPython.core.display import HTML
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import statsmodels.formula.api as smf
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm

# Initialize the figure with subplots
fig = make_subplots(
    rows=6, cols=4,  # Adjust number of rows and columns as needed
    vertical_spacing=0.05,  # Adjust vertical spacing (between rows)
    horizontal_spacing=0.09,  # Adjust horizontal spacing (between columns)
)

# Color palettes
palette_boxplots = ['steelblue', '#F0B0B0', 'lightcoral', '#B4464F']
palette_points = ['#00517F', '#B4464F', '#B4464F', '#5E000E']


# Add traces with color palette in each subplot

# Initialize dictionnaries to store dataframes for white matter, gray matter and WM/GM ratio

df_WM = {}
df_GM = {}
df_WMGM = {}

# Dataframes for metrics in white matter : 
metrics_in_WM = ['FA', 'MD', 'AD', 'RD', 'ODI', 'FISO', 'FICVF', 'MTR']
for metric in metrics_in_WM:
    df = pd.read_csv(f'../data/parkinsons-spinalcord-mri-metrics/data/{metric}.csv')
    df_WM[metric] = df[(df['Label'] == 'white matter') & (df['VertLevel'] == '2:05')]

# Dataframes for metrics in gray matter : 
metrics_in_GM = ['ODI', 'FISO', 'FICVF']
for metric in metrics_in_GM:
    df = pd.read_csv(f'../data/parkinsons-spinalcord-mri-metrics/data/{metric}.csv')
    df_GM[metric] = df[(df['Label'] == 'gray matter') & (df['VertLevel'] == '2:05')]

# Dataframes for metrics in WM/GM ratio: 
df = pd.read_csv(f'../data/parkinsons-spinalcord-mri-metrics/data/T2star.csv')
df_WM['T2star'] = df[(df['Label'] == 'white matter') & (df['VertLevel'] == '2:05')]
df_GM['T2star'] = df[(df['Label'] == 'gray matter') & (df['VertLevel'] == '2:05')]
df_WMGM['T2star'] = df_WM['T2star'].copy()
df_WMGM['T2star']['WA'] = df_WM['T2star']['WA']/df_GM['T2star']['WA'].values

def add_boxplot_for_subplot(data, row, col):
    groups = ['CTRL', 'low', 'med', 'adv']

    # OLS analysis   
    ols_model = smf.ols(formula='WA ~ C(UPDRS_class_bis) + Age', data=data) # The "C()" here means that UPDRS_class_bis is a categorical variable
    ols_results = ols_model.fit()
    print(f'OLS results for WM MTR in C2-C5: {ols_results.summary()}')

    # Perform ANOVA test
    anova_results = anova_lm(ols_results, typ=2)  # Type II ANOVA
    print(f'ANOVA results : {anova_results}')

    # Adjusted R2
    adjR2 = ols_results.rsquared_adj
    formatted_adjR2 = f"{adjR2:.4f}"

    # UPDRSIII_class_bis p-value
    pvalue_UPDRS_class_bis = anova_results.loc['C(UPDRS_class_bis)', 'PR(>F)']
    formatted_pvalue_UPDRS_class_bis = f"{pvalue_UPDRS_class_bis:.4f}"

    # Age p-value
    pvalue_age = anova_results.loc['Age', 'PR(>F)']
    formatted_pvalue_age = f"{pvalue_age:.4f}"
    
    for i, group in enumerate(groups):
        box_color = palette_boxplots[i % len(palette_boxplots)]  # Ensure we don't run out of colors
        points_color = palette_points[i % len(palette_points)]  # Cycle through jitter colors
        
        fig.add_trace(go.Box(
            # Filter data for each group
            x=data['UPDRS_class_bis'][data['UPDRS_class_bis'] == group], 
            y=data['WA'][data['UPDRS_class_bis'] == group],  

            # Show all points
            boxpoints='all',
            jitter=0.7,  # Jitter the points for better visibility
            whiskerwidth=0.8,  # Width of the whiskers
            fillcolor=box_color,  # Color of the box
            marker_size=2.5,  # Marker size for points
            marker_color=points_color, # Color of the points
            marker_opacity=0.8,  # Opacity of the points
            line_width=1,  # Border width of the box
            line_color="black" ,  # Border color of the box
            pointpos=0, # Center the points with the box
        ), row=row, col=col)

        # Determine if the p-value is significant (whether to add an asterisk)
        red_asterisk = '<span style="color:red; font-size:14">*</span>' if pvalue_UPDRS_class_bis < 0.05 else ""
        black_asterisk = '<span style="color:black; font-size:14">*</span>' if pvalue_age < 0.05 else ""

        # Add annotations for p-values
        fig.add_annotation(
        x=0.66,  # Position of the annotation in x
        y=0.99,  # Position of the annotation in y 
        text=f"p-Group: {pvalue_UPDRS_class_bis:.4f}{red_asterisk}<br>p-Age: {pvalue_age:.4f}{black_asterisk}", 
        showarrow=False,
        font=dict(size=11),
        align="right",
        row=row,
        col=col,
        xref="x domain",  # Use the plot's domain (relative x-axis)
        yref="y domain"   # Use the plot's domain (relative y-axis)
    )


# Add boxplots for each trace (WM and GM data)

# White matter plots
add_boxplot_for_subplot(df_WM['FA'], 2, 3)
add_boxplot_for_subplot(df_WM['MD'], 3, 3)
add_boxplot_for_subplot(df_WM['AD'], 4, 3)
add_boxplot_for_subplot(df_WM['RD'], 5, 3)
add_boxplot_for_subplot(df_WM['ODI'], 2, 4)
add_boxplot_for_subplot(df_WM['FISO'], 3, 4)
add_boxplot_for_subplot(df_WM['FICVF'], 4, 4)
add_boxplot_for_subplot(df_WM['MTR'], 5, 4)

# Gray matter plots
add_boxplot_for_subplot(df_GM['ODI'], 4, 1)
add_boxplot_for_subplot(df_GM['FISO'], 5, 1)
add_boxplot_for_subplot(df_GM['FICVF'], 6, 1)

# WM/GM plots
add_boxplot_for_subplot(df_WMGM['T2star'], 6, 4)

# Update layout
fig.update_layout(
    margin=dict(
        l=200,  # left margin
        r=200,  # right margin
        t=200,  # top margin
        b=200,  # bottom margin
    ),
    width=1300, 
    height=1400,
    showlegend=False, 
    
    yaxis7_title="FA",
    yaxis7_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis11_title="MD",
    yaxis11_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis15_title="AD",
    yaxis15_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis19_title="RD",
    yaxis19_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis8_title="ODI",
    yaxis8_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis12_title="FISO",
    yaxis12_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis16_title="FICVF",
    yaxis16_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis20_title='MTR',
    yaxis20_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis13_title="ODI",
    yaxis13_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis17_title="FISO",
    yaxis17_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis21_title='FICVF',
    yaxis21_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
    yaxis24_title="T2* ratio",
    yaxis24_title_font=dict(size=20, family="Arial", color="black", weight='bold'),
)

### ADD STATIC BACKGROUND IMAGE ###
fig.update_layout(
    images=[dict(
        source='Figure4_template.png',  
        x=-0.22,
        y=0.97,
        xanchor="left",
        yanchor="top",
        sizex=1.4,
        sizey=1.39, 
        layer="below",  
    )],
)

# Set the y-axis range for each subplot
fig.update_yaxes(range=[0.3, 0.8],  row=2, col=3) # FA (WM)
fig.update_yaxes(range=[0.0002, 0.0016], row=3, col=3) # MD (WM) 
fig.update_yaxes(range=[0.0002, 0.0026], row=4, col=3) # AD (WM) 
fig.update_yaxes(range=[0.00025, 0.001], row=5, col=3) # RD (WM) 
fig.update_yaxes(range=[0, 0.5], row=2, col=4) # ODI (WM) 
fig.update_yaxes(range=[0.1, 0.7], row=3, col=4) # FISO (WM)
fig.update_yaxes(range=[0.4, 1.2], row=4, col=4) # FICVF (WM)
fig.update_yaxes(range=[35, 55], row=5, col=4) # MTR (WM)
fig.update_yaxes(range=[0, 0.5], row=4, col=1) # ODI (GM)
fig.update_yaxes(range=[0.1, 0.7], row=5, col=1) # FISO (GM)
fig.update_yaxes(range=[0.4, 1.2], row=6, col=1) # FICVF (GM)
fig.update_yaxes(range=[0.8, 1.1], row=6, col=4) # T2* (WM/GM ratio) 

# List of subplot positions for x-axis updates
xaxis_subplots = [(2, 3), (3, 3), (4, 3), (5, 3), (2, 4), (3, 4), (4, 4), (5, 4), (4, 1), (5, 1), (6, 1), (6, 4)]

# Update x-axis labels and font
for row, col in xaxis_subplots:
    fig.update_xaxes(
        ticktext=['HC', 'Low', 'Med', 'Adv'],
        tickvals=['CTRL', 'low', 'med', 'adv'],
        tickfont=dict(size=14, weight='bold'),
        row=row,
        col=col
    )

fig.show()

OLS results for WM MTR in C2-C5:                             OLS Regression Results                            
Dep. Variable:                     WA   R-squared:                       0.213
Model:                            OLS   Adj. R-squared:                  0.181
Method:                 Least Squares   F-statistic:                     6.692
Date:                Thu, 10 Apr 2025   Prob (F-statistic):           8.27e-05
Time:                        13:06:05   Log-Likelihood:                 173.77
No. Observations:                 104   AIC:                            -337.5
Df Residuals:                      99   BIC:                            -324.3
Df Model:                           4                                         
Covariance Type:            nonrobust                                         
                                coef    std err          t      P>|t|      [0.025      0.975]
-----------------------------------------------------------------------------------