# Normative DTI values in the pediatric spinal cord

This jupyter notebook includes scripts to generate figures related to normative DTI values in the pediatric spinal cord.

In [305]:
import os
import pandas as pd
import json
import yaml
import re
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import webbrowser


### Load config file to get path to dataset 

In [53]:
# Load config file
with open('../../../config/config_preprocessing.yaml' , 'r') as file:
    config = yaml.safe_load(file)

# Get data path from config file
path_data = config['path_data']

### Get the `participants.tsv` file from the dataset

In [54]:
# Get path to participants.tsv file
participants_tsv = pd.read_csv(os.path.join(path_data, 'participants.tsv'), sep='\t')
participants_tsv

Unnamed: 0,participant_id,age,sex,group,scan_series
0,sub-101,17,M,control,complete
1,sub-102,15,F,control,complete
2,sub-103,15,M,control,complete
3,sub-104,15,F,control,complete
4,sub-105,13,M,control,complete
...,...,...,...,...,...
110,sub-214,6,M,control,complete
111,sub-215,16,F,control,complete
112,sub-216,15,F,control,complete
113,sub-217,15,M,control,complete


## Function that gets the number of subjects used in the analysis pipelines

This function takes the subject count from the `participants.tsv` file, and then extracts the subjects contained in the `exclude.yml` file and the subjects with missing dwi data.

In [71]:
def get_list_of_subjects_to_include(contrast, path_data, missing_data_subjects):
    """
    This function takes an image contrast (T2w, dwi, etc.), a path to a dataset, and a list of subjects with missing data,
    and returns a list of subjects to include in the analysis.

    The dataset needs to be in BIDS format, and the function will look for the participants.tsv file to get the list of subjects.
    The dataset should also contain an `exclude.yml` file that lists subjects to exclude from the analysis.
    """

    # Get the `participants.tsv` file and read it into a dataframe
    participants_tsv = pd.read_csv(os.path.join(path_data, 'participants.tsv'), sep='\t')

    # Get all subject IDs from the participants.tsv
    all_subjects = participants_tsv['participant_id'].tolist()

    # Get list of subjects to exclude from the analysis from the `exclude.yml` file (under the 'dwi' key)
    with open(os.path.join(path_data, 'exclude.yml'), 'r') as file:
        exclude_yml = yaml.safe_load(file)

    exclude_dwi_key = exclude_yml.get('dwi', []) # Extract subjects under 'dwi' key
    exclude_subjects = sorted(set(re.match(r"(sub-\d+)", entry).group(1) for entry in exclude_dwi_key if re.match(r"(sub-\d+)", entry))) # Extract the subject ID 

    # Add the list of subjects with missing data to the exclude_subjects list
    exclude_subjects.extend(missing_data_subjects)
    
    # Remove duplicates (if any), sort and print the list of subjects to exclude from the analysis
    exclude_subjects = sorted(set(exclude_subjects))

    # Compute the list of subjects to include in the analysis 
    include_subjects = [sub for sub in all_subjects if sub not in exclude_subjects]

    # Convert the list of included subjects to a dataframe
    include_subjects = participants_tsv[participants_tsv['participant_id'].isin(include_subjects)]

    return include_subjects

In [None]:
# List of subjects with missing dwi data
missing_dwi_subjects = ["sub-125",
                        "sub-152",
                        "sub-174",
                        "sub-200",
                        "sub-205",
                        "sub-213"]

missing_dwi_subjects = []

# Get the list of subjects to include in the analysis
include_dwi_subjects = get_list_of_subjects_to_include('dwi', path_data, missing_dwi_subjects)
include_dwi_subjects.to_csv(os.path.join('include_dwi_subjects.csv'), sep='\t', index=False)

Unnamed: 0,participant_id,age,sex,group,scan_series
0,sub-101,17,M,control,complete
1,sub-102,15,F,control,complete
2,sub-103,15,M,control,complete
3,sub-104,15,F,control,complete
4,sub-105,13,M,control,complete
...,...,...,...,...,...
108,sub-212,12,F,control,complete
109,sub-213,8,M,control,incomplete
111,sub-215,16,F,control,complete
112,sub-216,15,F,control,complete


## Plot demographics

This function plots the age and sex distribution of the subjects included in a pipeline analysis, according to the include list generated above. 

In [76]:
def plot_demographics(df):
    """
    This function plots the demographic information of participants, given a dataframe with the list of subjects to include in the analysis.
    """

    # Sort by sex
    df_M = df[df['sex'] == 'M']
    df_F = df[df['sex'] == 'F']

    # Round down age to nearest month 
    df['age'] = np.floor(df['age']) 

    # Create subplot
    fig = make_subplots(rows=1, cols=1)

    # Add histogram for female subjects
    fig.add_trace(go.Histogram(
        x=df_F['age'], 
        name='F', 
        marker=dict(color= "#D19D88"),
        opacity=1.0,
        legendgroup='F',
        ),
        row=1, col=1
    )

    # Add histogram for male subjects
    fig.add_trace(go.Histogram(
        x=df_M['age'], 
        name='M', 
        
        marker=dict(color="#5C8EA1"),
        opacity=1.0,
        legendgroup='M',
        ), 
        row=1, col=1
    )

    # Generate tick values (every 1 month)
    min_age = int(df['age'].min())
    max_age = int(df['age'].max())

    # Define age tick range
    tick_vals = list(range(6, 18)) 

    # Update layout
    fig.update_layout(
        width=900,
        height=500,
        font=dict(family='Arial', size=18, color='black'), 
        legend=dict(
            orientation="h", 
            yanchor="bottom", 
            y=1.0, 
            xanchor="center",  
            x=0.5,
        ),
        xaxis=dict(
            range=[5, 18],  # Set x-axis range from 6 to 17
        ),
        plot_bgcolor='white',
        barmode='stack',
        bargap=0.3,  
        xaxis_title='Age (years)',
        yaxis_title='Number of Subjects',
        xaxis_title_standoff=50, 
    )

    # Generate tick values (every 1 month)
    min_age = int(df['age'].min())
    max_age = int(df['age'].max())

    fig.update_xaxes(
        tickmode='array',
        tickvals=tick_vals,
        ticktext=['' for _ in tick_vals], # Hide default tick labels (to be replaced with custom annotations)
        showgrid=False,
        gridwidth=1
    )

    # Add annotations for year ticks only
    for val in tick_vals:
        fig.add_annotation(
        x=val,
        y=-0.01,  # position of the text below the x-axis
        text=f"{val}",
        showarrow=False,
        xref='x',
        yref='paper',
        font=dict(size=18),
        xanchor='center',
        yanchor='top'
        )

    fig.update_yaxes(
        showgrid=True,             # Enable horizontal grid lines
        gridcolor='lightgrey',
        gridwidth=1
    )

    # Set bin size to 1 year
    fig.update_traces(xbins=dict(size=1))

    fig.show()

In [77]:
# Plot demographics for included subjects in DWI analysis
plot_demographics(include_dwi_subjects)

## Get average DTI metrics 

The cells below defines a dataframe for each DTI metric, based on the CSV files (one for each subject) stored under "results/tables/DWI/DTI_metrics/"

In [155]:
# DTI metric folders
DTI_folder = "../../tables/DWI/DTI_metrics/"
metrics = ['FA', 'MD', 'AD', 'RD']

DTI_df = {}

for metric in metrics:
    metric_folder = os.path.join(DTI_folder, metric)
    metric_dfs = []
    
    for filename in os.listdir(metric_folder):
        if filename.endswith(".csv"):
            subject_path = os.path.join(metric_folder, filename)
            df = pd.read_csv(subject_path)
            subject_id = filename.split("_")[0]  # Get subject id from the filename (i.e., from 'sub-101_FA.csv')
            df["participant_id"] = subject_id
            metric_dfs.append(df)
    
    # Combine csv files of all subjects into a single dataframe for this metric
    DTI_df[metric] = pd.concat(metric_dfs, ignore_index=True)

# Add age and sex to DTI metric dataframe
for metric in DTI_df:
    DTI_df[metric] = DTI_df[metric].merge(include_dwi_subjects, on="participant_id", how="left")


In [164]:
DTI_df['FA']

Unnamed: 0,Timestamp,SCT Version,Filename,Slice (I->S),VertLevel,DistancePMJ,Label,Size [vox],WA(),STD(),participant_id,age,sex,group,scan_series
0,2025-08-07 14:00:58,7.0,/home/samuelle/Documents/datasets/philadelphia...,5:6,5,,WM right fasciculus gracilis,8.085680,0.658081,0.166286,sub-173,6.0,F,control,complete
1,2025-08-07 14:00:58,7.0,/home/samuelle/Documents/datasets/philadelphia...,7:8,4,,WM right fasciculus gracilis,7.302660,0.727178,0.130589,sub-173,6.0,F,control,complete
2,2025-08-07 14:00:58,7.0,/home/samuelle/Documents/datasets/philadelphia...,9:11,3,,WM right fasciculus gracilis,11.414467,0.712787,0.181275,sub-173,6.0,F,control,complete
3,2025-08-07 14:00:58,7.0,/home/samuelle/Documents/datasets/philadelphia...,12:13,2,,WM right fasciculus gracilis,6.087038,0.818279,0.073614,sub-173,6.0,F,control,complete
4,2025-08-07 14:00:58,7.0,/home/samuelle/Documents/datasets/philadelphia...,5:6,5,,WM left fasciculus cuneatus,9.537961,0.648535,0.178546,sub-173,6.0,F,control,complete
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3127,2025-08-07 13:59:00,7.0,/home/samuelle/Documents/datasets/philadelphia...,12:13,2,,lateral funiculi,56.713801,0.703749,0.159815,sub-131,15.0,F,control,complete
3128,2025-08-07 13:59:00,7.0,/home/samuelle/Documents/datasets/philadelphia...,3:5,5,,ventral funiculi,46.090072,0.643242,0.154078,sub-131,15.0,F,control,complete
3129,2025-08-07 13:59:00,7.0,/home/samuelle/Documents/datasets/philadelphia...,6:8,4,,ventral funiculi,39.876820,0.601502,0.136456,sub-131,15.0,F,control,complete
3130,2025-08-07 13:59:00,7.0,/home/samuelle/Documents/datasets/philadelphia...,9:11,3,,ventral funiculi,34.215026,0.613628,0.185898,sub-131,15.0,F,control,complete


# Mean and DTI values per age, per vertebral level

In [258]:
def get_mean_and_std_by_age(df, vertlevel, label, metric_col):
    """
    Returns a DataFrame showing the mean ± std of the given metric,
    grouped by age, filtered by vertebral level and label.
    """
    # Filter the DataFrame
    filtered_df = df[(df['Label'] == label) & (df['VertLevel'] == vertlevel)]

    # Group by age and calculate mean and std
    age_grouped_df = filtered_df.groupby('age').agg(
        mean=('WA()', 'mean'),
        std=('WA()', 'std')
    ).reset_index()

    # Rename 'age' to 'Age'
    age_grouped_df.rename(columns={'age': 'Age'}, inplace=True)

    # Format the mean ± std string
    age_grouped_df[metric_col] = age_grouped_df.apply(
        lambda row: f"{row['mean']:.2g} ± {row['std']:.2g}", axis=1
    )

    return age_grouped_df[['Age', metric_col]]

In [303]:
def plot_mean_std_table(DTI_df, label):
    """
    Plots a table showing the mean ± std of DTI metrics by age.
    """
    for vertlevel in [2, 3, 4, 5]:
        combined_df = None
        for metric in DTI_df:
            result = get_mean_and_std_by_age(DTI_df[metric], vertlevel, label, metric_col=metric)
            if combined_df is None:
                combined_df = result
            else:
                combined_df = pd.merge(combined_df, result, on='Age')

        fig = go.Figure(data=[go.Table(
            columnwidth=[80, 200, 200, 200, 200],  # width in pixels per column
            header=dict(values=list(combined_df.columns),
                        fill_color='lightgrey',
                        align='center',
                        font=dict(size=18, color='black', family='Arial', weight='bold')),
            cells=dict(values=[combined_df[col] for col in combined_df.columns],
                    height=30,
                    fill_color='white',
                    align='center')) 
        ])

        fig.update_layout(
            width=1200,   # in pixels
            height=550,   # in pixels,
            font=dict(size=18, color='black', family='Arial'),
            title=f'DTI metrics by age in {label} for vertebral level {vertlevel}')
        
        fig.show()

In [304]:
plot_mean_std_table(DTI_df, label='spinal cord')

# Plot DTI metrics per vert level, by age

In [None]:
def plot_morphometrics(df, DTI_metric):

    # Compute mean and std dev per age and vertebral level
    grouped = df.groupby(['age', 'VertLevel'])['WA()'].agg(['mean', 'std']).reset_index()

    # Initialize figure
    fig = go.Figure()

    # Assign unique colors using Plotly's color sequence
    unique_ages = sorted(grouped['age'].unique())

    # Generate 12 evenly spaced colors from a sequential colormap
    colorscale = px.colors.diverging.Portland
    num_colors = 12
    color_gradient = px.colors.sample_colorscale(colorscale, [i / (num_colors - 1) for i in range(num_colors)])

    # Sort the ages and assign a color from the gradient
    unique_ages = sorted(grouped['age'].unique())
    color_map = {age: color_gradient[i] for i, age in enumerate(unique_ages)}

    # Plot each age group
    for age in unique_ages:
        group = grouped[grouped['age'] == age].sort_values('VertLevel')

        x = group['VertLevel']
        y_mean = group['mean']
        y_std = group['std']

        y_upper = y_mean + y_std
        y_lower = y_mean - y_std
        color = color_map[age]

        # Add mean line
        fig.add_trace(go.Scatter(
            x=x,
            y=y_mean,
            mode='lines',
            name=f'{age}',
            line=dict(color=color, width=2),
            marker=dict(size=6),
            legendgroup=f'Age {age}'
        ))

    # Layout
    fig.update_layout(
        width=1000,
        height=600,
        title=f" ",
        xaxis_title='Vertebral Level',
        xaxis_title_font=dict(size=20, color='black', family='Arial', weight='bold'),
        yaxis_title=DTI_metric,
        yaxis_title_font=dict(size=20, color='black', family='Arial', weight='bold'),
        legend_title='Age',
        legend_title_font=dict(size=16, color='black', family='Arial', weight='bold'),
        legend_font=dict(size=14, color='black', family='Arial'),
        xaxis_tickangle=-45,
        template='plotly_white',
    )
    
    # Map VertLevel numbers to anatomical labels
    vertebral_label_map = {1: "C1", 2: "C2", 3: "C3", 4: "C4", 5: "C5"}

    # Extract all vertebral levels used in the data
    all_vert_levels = sorted(df['VertLevel'].dropna().unique())

    # Filter only integer levels (to show only C1, C2, etc.)
    integer_vert_levels = [v for v in all_vert_levels if v.is_integer() and v in vertebral_label_map]

    # Map to anatomical labels
    ticktexts = [vertebral_label_map[v] for v in integer_vert_levels]

    # Update the x-axis to show anatomical vertebral labels
    fig.update_xaxes(
        tickmode='array',
        tickvals=integer_vert_levels,
        ticktext=ticktexts,
        tickcolor='black',
        tickfont=dict(size=16, color='black', family='Arial'),
        tickangle=0,
        gridcolor='rgba(0, 0, 0, 0.1)',  
        
    )
    
    fig.update_yaxes(
        gridcolor='rgba(0, 0, 0, 0.1)',
        tickcolor='black',
        tickfont=dict(size=16, color='black', family='Arial')
    )

    # # Define the save path
    # save_path = os.path.join(f'{DTI_metric}_vs_VertLevels.html')

    # # Save the figure to an HTML file
    # fig.write_html(save_path)

    # # Open the file in the default web browser
    # webbrowser.open(f'file://{os.path.abspath(save_path)}')

    # Show the figure in the notebook
    fig.show()

In [314]:
plot_morphometrics(DTI_df['FA'], 'FA')

In [315]:
plot_morphometrics(DTI_df['MD'], 'MD')

In [316]:
plot_morphometrics(DTI_df['FA'], 'AD')

In [317]:
plot_morphometrics(DTI_df['FA'], 'RD')