# Normative DTI values in the pediatric spinal cord

This jupyter notebook includes scripts to generate figures related to normative DTI values in the pediatric spinal cord.

In [456]:
import os
import pandas as pd
import json
import yaml
import re
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import webbrowser


### Load config file to get path to dataset 

In [457]:
# Load config file
with open('../../../config/config_preprocessing.yaml' , 'r') as file:
    config = yaml.safe_load(file)

# Get data path from config file
path_data = config['path_data']

### Get the `participants.tsv` file from the dataset

In [458]:
# Get path to participants.tsv file
participants_tsv = pd.read_csv(os.path.join(path_data, 'participants.tsv'), sep='\t')
participants_tsv

Unnamed: 0,participant_id,age,sex,group,scan_series
0,sub-101,17,M,control,complete
1,sub-102,15,F,control,complete
2,sub-103,15,M,control,complete
3,sub-104,15,F,control,complete
4,sub-105,13,M,control,complete
...,...,...,...,...,...
110,sub-214,6,M,control,complete
111,sub-215,16,F,control,complete
112,sub-216,15,F,control,complete
113,sub-217,15,M,control,complete


## Function that gets the number of subjects used in the analysis pipelines

This function takes the subject count from the `participants.tsv` file, and then extracts the subjects contained in the `exclude.yml` file and the subjects with missing dwi data.

In [459]:
def get_list_of_subjects_to_include(contrast, path_data, missing_data_subjects):
    """
    This function takes an image contrast (T2w, dwi, etc.), a path to a dataset, and a list of subjects with missing data,
    and returns a list of subjects to include in the analysis.

    The dataset needs to be in BIDS format, and the function will look for the participants.tsv file to get the list of subjects.
    The dataset should also contain an `exclude.yml` file that lists subjects to exclude from the analysis.
    """

    # Get the `participants.tsv` file and read it into a dataframe
    participants_tsv = pd.read_csv(os.path.join(path_data, 'participants.tsv'), sep='\t')

    # Get all subject IDs from the participants.tsv
    all_subjects = participants_tsv['participant_id'].tolist()

    # Get list of subjects to exclude from the analysis from the `exclude.yml` file (under the 'dwi' key)
    with open(os.path.join(path_data, 'exclude.yml'), 'r') as file:
        exclude_yml = yaml.safe_load(file)

    exclude_dwi_key = exclude_yml.get('dwi', []) # Extract subjects under 'dwi' key
    exclude_subjects = sorted(set(re.match(r"(sub-\d+)", entry).group(1) for entry in exclude_dwi_key if re.match(r"(sub-\d+)", entry))) # Extract the subject ID 

    # Add the list of subjects with missing data to the exclude_subjects list
    exclude_subjects.extend(missing_data_subjects)
    
    # Remove duplicates (if any), sort and print the list of subjects to exclude from the analysis
    exclude_subjects = sorted(set(exclude_subjects))

    # Compute the list of subjects to include in the analysis 
    include_subjects = [sub for sub in all_subjects if sub not in exclude_subjects]

    # Convert the list of included subjects to a dataframe
    include_subjects = participants_tsv[participants_tsv['participant_id'].isin(include_subjects)]

    return include_subjects

In [460]:
# List of subjects with missing dwi data
missing_dwi_subjects = ["sub-125",
                        "sub-152",
                        "sub-174",
                        "sub-200",
                        "sub-205",
                        "sub-213"]

missing_dwi_subjects = []

# Get the list of subjects to include in the analysis
include_dwi_subjects = get_list_of_subjects_to_include('dwi', path_data, missing_dwi_subjects)
include_dwi_subjects.to_csv(os.path.join('include_dwi_subjects.csv'), sep='\t', index=False)

## Plot demographics

This function plots the age and sex distribution of the subjects included in a pipeline analysis, according to the include list generated above. 

In [461]:
def plot_demographics(df):
    """
    This function plots the demographic information of participants, given a dataframe with the list of subjects to include in the analysis.
    """

    # Sort by sex
    df_M = df[df['sex'] == 'M']
    df_F = df[df['sex'] == 'F']

    # Round down age to nearest month 
    df['age'] = np.floor(df['age']) 

    # Create subplot
    fig = make_subplots(rows=1, cols=1)

    # Add histogram for female subjects
    fig.add_trace(go.Histogram(
        x=df_F['age'], 
        name='F', 
        marker=dict(color= "#D19D88"),
        opacity=1.0,
        legendgroup='F',
        ),
        row=1, col=1
    )

    # Add histogram for male subjects
    fig.add_trace(go.Histogram(
        x=df_M['age'], 
        name='M', 
        
        marker=dict(color="#5C8EA1"),
        opacity=1.0,
        legendgroup='M',
        ), 
        row=1, col=1
    )

    # Generate tick values (every 1 month)
    min_age = int(df['age'].min())
    max_age = int(df['age'].max())

    # Define age tick range
    tick_vals = list(range(6, 18)) 

    # Update layout
    fig.update_layout(
        width=900,
        height=500,
        font=dict(family='Arial', size=18, color='black'), 
        legend=dict(
            orientation="h", 
            yanchor="bottom", 
            y=1.0, 
            xanchor="center",  
            x=0.5,
        ),
        xaxis=dict(
            range=[5, 18],  # Set x-axis range from 6 to 17
        ),
        plot_bgcolor='white',
        barmode='stack',
        bargap=0.3,  
        xaxis_title='Age (years)',
        xaxis_title_font=dict(family='Arial', size=20, weight='bold'),
        yaxis_title='Number of Subjects',
        yaxis_title_font=dict(family='Arial', size=20, weight='bold'),
        xaxis_title_standoff=50, 
    )

    # Generate tick values (every 1 month)
    min_age = int(df['age'].min())
    max_age = int(df['age'].max())

    fig.update_xaxes(
        tickmode='array',
        tickvals=tick_vals,
        ticktext=['' for _ in tick_vals], # Hide default tick labels (to be replaced with custom annotations)
        showgrid=False,
        gridwidth=1
    )

    # Add annotations for year ticks only
    for val in tick_vals:
        fig.add_annotation(
        x=val,
        y=-0.01,  # position of the text below the x-axis
        text=f"{val}",
        showarrow=False,
        xref='x',
        yref='paper',
        font=dict(size=18),
        xanchor='center',
        yanchor='top'
        )

    fig.update_yaxes(
        showgrid=True,             # Enable horizontal grid lines
        gridcolor='lightgrey',
        gridwidth=1
    )

    # Set bin size to 1 year
    fig.update_traces(xbins=dict(size=1))

    fig.show()

In [462]:
# Plot demographics for included subjects in DWI analysis
plot_demographics(include_dwi_subjects)

## Get average DTI metrics 

The cells below defines a dataframe for each DTI metric, based on the CSV files (one for each subject) stored under "results/tables/DWI/DTI_metrics/"

In [463]:
# DTI metric folders
DTI_folder = "../../tables/DWI/DTI_metrics/"
metrics = ['FA', 'MD', 'AD', 'RD']

DTI_df = {}

for metric in metrics:
    metric_folder = os.path.join(DTI_folder, metric)
    metric_dfs = []
    
    for filename in os.listdir(metric_folder):
        if filename.endswith(".csv"):
            subject_path = os.path.join(metric_folder, filename)
            df = pd.read_csv(subject_path)
            subject_id = filename.split("_")[0]  # Get subject id from the filename (i.e., from 'sub-101_FA.csv')
            df["participant_id"] = subject_id
            metric_dfs.append(df)
    
    # Combine csv files of all subjects into a single dataframe for this metric
    DTI_df[metric] = pd.concat(metric_dfs, ignore_index=True)

# Add age and sex to DTI metric dataframe
for metric in DTI_df:
    DTI_df[metric] = DTI_df[metric].merge(include_dwi_subjects, on="participant_id", how="left")


In [464]:
DTI_df['FA']

Unnamed: 0,Timestamp,SCT Version,Filename,Slice (I->S),VertLevel,DistancePMJ,Label,Size [vox],WA(),STD(),participant_id,age,sex,group,scan_series
0,2025-08-11 18:05:35,7.0,/Users/samuellestonge/Documents/datasets/phila...,4:5,6,,WM right fasciculus gracilis,7.766826,0.651349,0.237501,sub-110,,,,
1,2025-08-11 18:05:35,7.0,/Users/samuellestonge/Documents/datasets/phila...,6:8,5,,WM right fasciculus gracilis,14.669004,0.584605,0.208661,sub-110,,,,
2,2025-08-11 18:05:35,7.0,/Users/samuellestonge/Documents/datasets/phila...,9:11,4,,WM right fasciculus gracilis,14.909128,0.606167,0.236954,sub-110,,,,
3,2025-08-11 18:05:35,7.0,/Users/samuellestonge/Documents/datasets/phila...,12:13,3,,WM right fasciculus gracilis,8.453892,0.386441,0.111381,sub-110,,,,
4,2025-08-11 18:05:35,7.0,/Users/samuellestonge/Documents/datasets/phila...,14,2,,WM right fasciculus gracilis,3.771380,0.505080,0.094191,sub-110,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5875,2025-08-11 18:07:08,7.0,/Users/samuellestonge/Documents/datasets/phila...,1:2,6,,ventral funiculi,31.540796,0.631858,0.164875,sub-131,15.0,F,control,complete
5876,2025-08-11 18:07:08,7.0,/Users/samuellestonge/Documents/datasets/phila...,3:5,5,,ventral funiculi,46.090072,0.643242,0.154078,sub-131,15.0,F,control,complete
5877,2025-08-11 18:07:08,7.0,/Users/samuellestonge/Documents/datasets/phila...,6:8,4,,ventral funiculi,39.876820,0.601502,0.136456,sub-131,15.0,F,control,complete
5878,2025-08-11 18:07:08,7.0,/Users/samuellestonge/Documents/datasets/phila...,9:11,3,,ventral funiculi,34.215026,0.613628,0.185898,sub-131,15.0,F,control,complete


# Mean and DTI values per age, per vertebral level

In [465]:
def get_mean_and_std_by_age(df, vertlevel, label, metric_col):
    """
    Returns a DataFrame showing the mean ± std of the given metric,
    grouped by age, filtered by vertebral level and label.
    """
    # Filter the DataFrame
    filtered_df = df[(df['Label'] == label) & (df['VertLevel'] == vertlevel)]

    # Group by age and calculate mean and std
    age_grouped_df = filtered_df.groupby('age').agg(
        mean=('WA()', 'mean'),
        std=('WA()', 'std')
    ).reset_index()

    # Rename 'age' to 'Age'
    age_grouped_df.rename(columns={'age': 'Age'}, inplace=True)

    # Format the mean ± std string
    age_grouped_df[metric_col] = age_grouped_df.apply(
        lambda row: f"{row['mean']:.2g} ± {row['std']:.2g}", axis=1
    )

    return age_grouped_df[['Age', metric_col]]

In [466]:
def plot_mean_std_table(DTI_df, label):
    """
    Plots a table showing the mean ± std of DTI metrics by age.
    """
    for vertlevel in [2, 3, 4, 5, 6]:
        combined_df = None
        for metric in DTI_df:
            result = get_mean_and_std_by_age(DTI_df[metric], vertlevel, label, metric_col=metric)
            if combined_df is None:
                combined_df = result
            else:
                combined_df = pd.merge(combined_df, result, on='Age')

        fig = go.Figure(data=[go.Table(
            columnwidth=[80, 200, 200, 200, 200],  # width in pixels per column
            header=dict(values=list(combined_df.columns),
                        fill_color='lightgrey',
                        align='center',
                        font=dict(size=18, color='black', family='Arial', weight='bold')),
            cells=dict(values=[combined_df[col] for col in combined_df.columns],
                    height=30,
                    fill_color='white',
                    align='center')) 
        ])

        fig.update_layout(
            width=1200,   # in pixels
            height=550,   # in pixels,
            font=dict(size=18, color='black', family='Arial'),
            title=f'DTI metrics by age in {label} for vertebral level {vertlevel}')
        
        fig.show()

In [467]:
plot_mean_std_table(DTI_df, label='spinal cord')

# Plot DTI metrics per vert level, by age

In [468]:
def plot_morphometrics(df, DTI_metric):

    # Compute mean and std dev per age and vertebral level
    grouped = df.groupby(['age', 'VertLevel'])['WA()'].agg(['mean', 'std']).reset_index()

    # Initialize figure
    fig = go.Figure()

    # Assign unique colors using Plotly's color sequence
    unique_ages = sorted(grouped['age'].unique())

    # Generate 12 evenly spaced colors from a sequential colormap
    colorscale = px.colors.diverging.Portland
    num_colors = 12
    color_gradient = px.colors.sample_colorscale(colorscale, [i / (num_colors - 1) for i in range(num_colors)])

    # Sort the ages and assign a color from the gradient
    unique_ages = sorted(grouped['age'].unique())
    color_map = {age: color_gradient[i] for i, age in enumerate(unique_ages)}

    # Plot each age group
    for age in unique_ages:
        group = grouped[grouped['age'] == age].sort_values('VertLevel')

        x = group['VertLevel']
        y_mean = group['mean']
        y_std = group['std']

        y_upper = y_mean + y_std
        y_lower = y_mean - y_std
        color = color_map[age]

        # Add mean line
        fig.add_trace(go.Scatter(
            x=x,
            y=y_mean,
            mode='lines',
            name=f'{age}',
            line=dict(color=color, width=2),
            marker=dict(size=6),
            legendgroup=f'Age {age}'
        ))

    # Layout
    fig.update_layout(
        width=1000,
        height=600,
        title=f" ",
        xaxis_title='Vertebral Level',
        xaxis_title_font=dict(size=20, color='black', family='Arial', weight='bold'),
        yaxis_title=DTI_metric,
        yaxis_title_font=dict(size=20, color='black', family='Arial', weight='bold'),
        legend_title='Age',
        legend_title_font=dict(size=16, color='black', family='Arial', weight='bold'),
        legend_font=dict(size=14, color='black', family='Arial'),
        xaxis_tickangle=-45,
        template='plotly_white',
    )
    
    # Map VertLevel numbers to anatomical labels
    vertebral_label_map = {1: "C1", 2: "C2", 3: "C3", 4: "C4", 5: "C5"}

    # Extract all vertebral levels used in the data
    all_vert_levels = sorted(df['VertLevel'].dropna().unique())

    # Filter only integer levels (to show only C1, C2, etc.)
    integer_vert_levels = [v for v in all_vert_levels if v.is_integer() and v in vertebral_label_map]

    # Map to anatomical labels
    ticktexts = [vertebral_label_map[v] for v in integer_vert_levels]

    # Update the x-axis to show anatomical vertebral labels
    fig.update_xaxes(
        tickmode='array',
        tickvals=integer_vert_levels,
        ticktext=ticktexts,
        tickcolor='black',
        tickfont=dict(size=16, color='black', family='Arial'),
        tickangle=0,
        gridcolor='rgba(0, 0, 0, 0.1)',  
        
    )
    
    fig.update_yaxes(
        gridcolor='rgba(0, 0, 0, 0.1)',
        tickcolor='black',
        tickfont=dict(size=16, color='black', family='Arial')
    )

    # # Define the save path
    # save_path = os.path.join(f'{DTI_metric}_vs_VertLevels.html')

    # # Save the figure to an HTML file
    # fig.write_html(save_path)

    # # Open the file in the default web browser
    # webbrowser.open(f'file://{os.path.abspath(save_path)}')

    # Show the figure in the notebook
    fig.show()

In [469]:
plot_morphometrics(DTI_df['FA'], 'FA')

In [470]:
plot_morphometrics(DTI_df['MD'], 'MD')

In [471]:
plot_morphometrics(DTI_df['FA'], 'AD')

In [472]:
plot_morphometrics(DTI_df['FA'], 'RD')

Plot DTI metrics per age

In [473]:
DTI_df['FA'].columns

Index(['Timestamp', 'SCT Version', 'Filename', 'Slice (I->S)', 'VertLevel',
       'DistancePMJ', 'Label', 'Size [vox]', 'WA()', 'STD()', 'participant_id',
       'age', 'sex', 'group', 'scan_series'],
      dtype='object')

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf

def plot_dti_per_age(df, metric_name):
    
    df = df.copy()
    df.rename(columns={'WA()': metric_name}, inplace=True)

    # Filter the spinal cord label
    df = df[df['Label'] == 'spinal cord'].copy()

    # Ensure correct types
    df['age'] = pd.to_numeric(df['age'], errors='coerce')
    df['VertLevel'] = df['VertLevel'].astype(str)

    # Get sorted unique VertLevels
    vert_levels = sorted(df['VertLevel'].unique())
    num_plots = len(vert_levels)

    # Hardcode 3x2 layout and skip (1,1)
    rows, cols = 3, 2
    total_cells = rows * cols
    assert num_plots <= total_cells - 1, "Too many VertLevels for this layout!"

    # Define colors for each group (Male, Female, All subjects)
    colors = {
        'M': "#5C8EA1",   
        'F': "#D19D88",   
        'All': "#7A7A7A" 
    }

    # Matching confidence interval (CI) fill colors with ~20% opacity
    ci_colors = {
    'M': 'rgba(92, 142, 161, 0.2)',   # Hex #5C8EA1 to rgba
    'F': 'rgba(209, 157, 136, 0.4)',  # Hex #D19D88 to rgba
    'All': 'rgba(122, 122, 122, 0.15)' # Hex #7A7A7A to rgba
    }   

    metric_y_ranges = {
    'FA': [0.3, 0.9],
    'MD': [0.0005, 0.0015],
    'RD': [0.0002, 0.0012],
    'AD': [0.0014, 0.0022],
    }

    fig = make_subplots(
        rows=rows,
        cols=cols,
        shared_xaxes=False,
        subplot_titles = [""] + [f"VertLevel {v}" for v in vert_levels],
        vertical_spacing=0.12,
        horizontal_spacing=0.3
    )

    # Predefine subplot positions, skipping (1,1)
    subplot_positions = [(2, 1), (3, 1), (1, 2), (2, 2), (3, 2)]

    for i, vert in enumerate(vert_levels):
        row, col = subplot_positions[i]

        # Use filtered data for current VertLevel
        data = df[df['VertLevel'] == vert].dropna(subset=['age', metric_name])

        # Use age range from the current VertLevel data
        x_range = np.linspace(data['age'].min(), data['age'].max(), 100)
        pred_df = pd.DataFrame({'age': x_range})

        for group in ['F', 'M', 'All']:
            if group in ['F', 'M']:
                # Use data filtered by current VertLevel and sex group
                group_df = data[data['sex'] == group].copy()
                pred_df['sex'] = group
            else:
                # Use all subjects within current VertLevel
                group_df = data.copy()
                pred_df['sex'] = group_df['sex'].mode()[0]

            # Perform OLS regression 
            ols_formula = f'{metric_name} ~ age + sex'
            ols_model = smf.ols(formula=ols_formula, data=group_df)
            ols_results = ols_model.fit()

            # Add confidence intervals
            pred = ols_results.get_prediction(pred_df)
            pred_summary = pred.summary_frame(alpha=0.05)
            y_fit = pred_summary['mean']
            ci_lower = pred_summary['mean_ci_lower']
            ci_upper = pred_summary['mean_ci_upper']

            if group == 'F':
                legend_group = 'Male'
            elif group == 'M':
                legend_group = 'Female'
            else:
                legend_group = 'All'

            # Add scatter plot for male and female subjects
            if group != 'All':
                fig.add_trace(
                    go.Scatter(
                        x=group_df['age'], y=group_df[metric_name],
                        mode='markers',
                        marker=dict(color=colors[group], size=8, symbol='circle', opacity=1.0),
                        name=('Male' if group == 'M' else 'Female'),
                        legendgroup=legend_group,
                        showlegend=(i == 0),
                    ),
                    row=row, col=col
                )

                # Add regression line
                fig.add_trace(
                    go.Scatter(
                        x=x_range, y=y_fit,
                        mode='lines',
                        line=dict(color=colors[group], width=2, dash='solid'),
                        name=f"{'All subjects' if group == 'All' else group} - Fit",
                        legendgroup=legend_group,
                        showlegend=False,
                    ),
                    row=row, col=col
                )

            # Add regression line for all subjects only
            else:
                fig.add_trace(
                    go.Scatter(
                        x=x_range, y=y_fit,
                        mode='lines',
                        line=dict(color=colors[group], width=2, dash='dash'),
                        name=f'All subjects',
                        legendgroup=legend_group,
                        showlegend=False,
                    ),
                    row=row, col=col
                )

            # Add confidence interval
            fig.add_trace(
                go.Scatter(
                    x=np.concatenate([x_range, x_range[::-1]]),
                    y=np.concatenate([ci_upper, ci_lower[::-1]]),
                    fill='toself',
                    fillcolor=ci_colors[group],
                    line=dict(color='rgba(255,255,255,0)'),
                    hoverinfo='skip',
                    showlegend=False,
                    legendgroup=legend_group,
                ),
                row=row, col=col
            )

    # Layout
    fig.update_layout(
        height=1000,
        width=1000,
        plot_bgcolor="#ffffff",
        paper_bgcolor="#ffffff",
        legend=dict(
            x=1.02,
            y=1,
            bgcolor="#ffffff",
            bordercolor="#ffffff",
            font=dict(family='Arial', size=14)
        )
    )

    # Add axis titles and tick labels with Arial font
    for r in range(1, rows + 1):
        for c in range(1, cols + 1):
            fig.update_xaxes(
                title_text='Age',
                title_font=dict(family='Arial', size=14, weight='bold'),
                tickfont=dict(family='Arial', size=14),
                tickvals=list(range(6, 18)),  # All ticks from 6 to 17
                range=[5, 18],                # Force visible range
                row=r, col=c
            )
            fig.update_yaxes(
                title_text=metric_name,
                title_font=dict(family='Arial', size=14, weight='bold'),
                tickfont=dict(family='Arial', size=14),
                range=metric_y_ranges.get(metric_name, None),
                row=r, col=c
            )

    fig.show()


In [475]:
plot_dti_per_age(DTI_df['FA'], metric_name='FA')

In [476]:
plot_dti_per_age(DTI_df['MD'], metric_name='MD')
print(DTI_df['RD'])

                Timestamp  SCT Version  \
0     2025-08-11 18:10:10          7.0   
1     2025-08-11 18:10:10          7.0   
2     2025-08-11 18:10:10          7.0   
3     2025-08-11 18:10:10          7.0   
4     2025-08-11 18:10:10          7.0   
...                   ...          ...   
5875  2025-08-11 18:07:09          7.0   
5876  2025-08-11 18:07:09          7.0   
5877  2025-08-11 18:07:09          7.0   
5878  2025-08-11 18:07:09          7.0   
5879  2025-08-11 18:07:09          7.0   

                                               Filename Slice (I->S)  \
0     /Users/samuellestonge/Documents/datasets/phila...          3:4   
1     /Users/samuellestonge/Documents/datasets/phila...          5:6   
2     /Users/samuellestonge/Documents/datasets/phila...          7:8   
3     /Users/samuellestonge/Documents/datasets/phila...         9:11   
4     /Users/samuellestonge/Documents/datasets/phila...        12:13   
...                                                 ...        

In [477]:
plot_dti_per_age(DTI_df['RD'], metric_name='RD')

In [478]:
plot_dti_per_age(DTI_df['AD'], metric_name='AD')