# Normative DTI values in the pediatric spinal cord

This jupyter notebook includes scripts to generate figures related to normative DTI values in the pediatric spinal cord.

In [1]:
import os
import pandas as pd
import json
import yaml
import re
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import webbrowser
import statsmodels.formula.api as smf

### Load config file to get path to dataset 

In [2]:
# Load config file
with open('../../config/config_preprocessing.yaml' , 'r') as file:
    config = yaml.safe_load(file)

# Get data path from config file
path_data = config['path_data']

### Get the `participants.tsv` file from the dataset

In [3]:
# Get path to participants.tsv file
participants_tsv = pd.read_csv(os.path.join(path_data, 'participants.tsv'), sep='\t')
participants_tsv

Unnamed: 0,participant_id,age,sex,group,scan_series,height,weight
0,sub-101,17,M,control,complete,1.778004,68.038864
1,sub-102,15,F,control,complete,1.625603,52.163129
2,sub-103,15,M,control,complete,1.651003,54.431091
3,sub-104,15,F,control,complete,1.625603,52.163129
4,sub-105,13,M,control,complete,1.524000,35.381000
...,...,...,...,...,...,...,...
110,sub-214,6,M,control,complete,,
111,sub-215,16,F,control,complete,,
112,sub-216,15,F,control,complete,,
113,sub-217,15,M,control,complete,,


## Function that gets the number of subjects used in the analysis pipelines

This function takes the subject count from the `participants.tsv` file, and then extracts the subjects contained in the `exclude.yml` file and the subjects with missing dwi data.

In [4]:
def get_list_of_subjects_to_include(contrast, path_data, missing_data_subjects):
    """
    This function takes an image contrast (T2w, dwi, etc.), a path to a dataset, and a list of subjects with missing data,
    and returns a list of subjects to include in the analysis.

    The dataset needs to be in BIDS format, and the function will look for the participants.tsv file to get the list of subjects.
    The dataset should also contain an `exclude.yml` file that lists subjects to exclude from the analysis.
    """

    # Get the `participants.tsv` file and read it into a dataframe
    participants_tsv = pd.read_csv(os.path.join(path_data, 'participants.tsv'), sep='\t')

    # Get all subject IDs from the participants.tsv
    all_subjects = participants_tsv['participant_id'].tolist()

    # Get list of subjects to exclude from the analysis from the `exclude.yml` file (under the 'dwi' key)
    with open(os.path.join(path_data, 'exclude.yml'), 'r') as file:
        exclude_yml = yaml.safe_load(file)

    exclude_dwi_key = exclude_yml.get(contrast, []) # Extract subjects under 'dwi' key
    exclude_subjects = sorted(set(re.match(r"(sub-\d+)", entry).group(1) for entry in exclude_dwi_key if re.match(r"(sub-\d+)", entry))) # Extract the subject ID 

    # Add the list of subjects with missing data to the exclude_subjects list
    exclude_subjects.extend(missing_data_subjects)
    
    # Remove duplicates (if any), sort and print the list of subjects to exclude from the analysis
    exclude_subjects = sorted(set(exclude_subjects))

    # Compute the list of subjects to include in the analysis 
    include_subjects = [sub for sub in all_subjects if sub not in exclude_subjects]

    # Convert the list of included subjects to a dataframe
    include_subjects = participants_tsv[participants_tsv['participant_id'].isin(include_subjects)]

    return include_subjects

In [5]:
# List of subjects with missing dwi data
missing_dwi_subjects = ["sub-125",
                        "sub-152",
                        "sub-174",
                        "sub-200",
                        "sub-205",
                        "sub-213"]

# Get the list of subjects to include in the analysis
include_dwi_subjects = get_list_of_subjects_to_include('dwi', path_data, missing_dwi_subjects)
include_dwi_subjects.to_csv(os.path.join('../tables/DTI/include_dwi_subjects.csv'), sep='\t', index=False)

print(include_dwi_subjects.shape[0], "subjects to include in the analysis")

70 subjects to include in the analysis


In [6]:
include_dwi_subjects

Unnamed: 0,participant_id,age,sex,group,scan_series,height,weight
0,sub-101,17,M,control,complete,1.778004,68.038864
1,sub-102,15,F,control,complete,1.625603,52.163129
2,sub-103,15,M,control,complete,1.651003,54.431091
3,sub-104,15,F,control,complete,1.625603,52.163129
4,sub-105,13,M,control,complete,1.524000,35.381000
...,...,...,...,...,...,...,...
106,sub-210,6,F,control,complete,,
108,sub-212,12,F,control,complete,,
111,sub-215,16,F,control,complete,,
112,sub-216,15,F,control,complete,,


## Plot demographics

This function plots the age and sex distribution of the subjects included in a pipeline analysis, according to the include list generated above. 

In [7]:
def plot_demographics(df):
    """
    This function plots the demographic information of participants, given a dataframe with the list of subjects to include in the analysis.
    """

    # Sort by sex
    df_M = df[df['sex'] == 'M']
    df_F = df[df['sex'] == 'F']

    # Round down age to nearest month 
    df['age'] = np.floor(df['age']) 

    # Create subplot
    fig = make_subplots(rows=1, cols=1)

    # Add histogram for female subjects
    fig.add_trace(go.Histogram(
        x=df_F['age'], 
        name='F', 
        marker=dict(color="#1E59BE"),
        opacity=1.0,
        legendgroup='F',
        ),
        row=1, col=1
    )

    # Add histogram for male subjects
    fig.add_trace(go.Histogram(
        x=df_M['age'], 
        name='M', 
        
        marker=dict(color="#06273B"),
        opacity=1.0,
        legendgroup='M',
        ), 
        row=1, col=1
    )

    # Define age tick range
    tick_vals = list(range(6, 18)) 

    # Update layout
    fig.update_layout(
        width=700,
        height=500,
        font=dict(family='Arial', size=20, color='black'), 
        legend=dict(
            orientation="h", 
            yanchor="bottom", 
            y=1.0, 
            xanchor="center",  
            x=0.5,
        ),
        xaxis=dict(
            range=[5, 18],  # Set x-axis range from 6 to 17
        ),
        plot_bgcolor='white',
        barmode='stack',
        bargap=0.3,  
        xaxis_title='Age (years)',
        xaxis_title_font=dict(family='Arial', size=20, weight='bold'),
        yaxis_title='Number of Participants',
        yaxis_title_font=dict(family='Arial', size=20, weight='bold'),
        xaxis_title_standoff=20, 
    )

    fig.update_xaxes(
        tickmode='array',
        tickvals=tick_vals,
        showgrid=False,
        gridwidth=1
    )

    fig.update_yaxes(
        showgrid=True,             # Horizontal grid lines
        gridcolor='lightgrey',
        gridwidth=1
    )

    # Set bin size to 1 year
    fig.update_traces(xbins=dict(size=1))

    fig.show()

In [8]:
# Plot demographics for included subjects in DWI analysis
plot_demographics(include_dwi_subjects)

## Get average DTI metrics 

The cells below defines a dataframe for each DTI metric, based on the CSV files (one for each subject) stored under "results/tables/DWI/DTI_metrics/"

In [9]:
# DTI metric folders
DTI_folder = "../tables/DTI/"
metrics = ['FA', 'MD', 'AD', 'RD']

DTI_df = {}

for metric in metrics:
    metric_folder = os.path.join(DTI_folder, metric)
    metric_dfs = []
    
    for filename in os.listdir(metric_folder):
        if filename.endswith(".csv"):
            subject_path = os.path.join(metric_folder, filename)
            df = pd.read_csv(subject_path)
            subject_id = filename.split("_")[0]  # Get subject id from the filename (i.e., from 'sub-101_FA.csv')
            df["participant_id"] = subject_id
            metric_dfs.append(df)
    
    # Combine csv files of all subjects into a single dataframe for this metric
    DTI_df[metric] = pd.concat(metric_dfs, ignore_index=True)

# Add age and sex to DTI metric dataframe
for metric in DTI_df:
    DTI_df[metric] = DTI_df[metric].merge(include_dwi_subjects, on="participant_id", how="left")


In [10]:
df = DTI_df['FA']
DTI_df['FA'].to_csv('fa_values.csv', index=False)

# Mean and DTI values per age, per vertebral level

In [11]:
def get_mean_and_std_by_age(df, vertlevel, label, metric_col):
    """
    Returns a DataFrame showing the mean ± std of the given metric,
    grouped by age, filtered by vertebral level and label.
    """
    # Filter the DataFrame
    filtered_df = df[(df['Label'] == label) & (df['VertLevel'] == vertlevel)]

    # Group by age and calculate mean and std
    age_grouped_df = filtered_df.groupby('age').agg(
        mean=('WA()', 'mean'),
        std=('WA()', 'std'),
        N=('WA()', 'count')  # number of subjects
    ).reset_index()

    # Rename 'age' to 'Age'
    age_grouped_df.rename(columns={'age': 'Age'}, inplace=True)

    # Format the mean ± std values for the table
    age_grouped_df[metric_col] = (age_grouped_df["mean"].map("{:.2g}".format) + " ± " + age_grouped_df["std"].map("{:.2g}".format))

    return age_grouped_df[['Age', metric_col,  'N']]

In [12]:
def plot_mean_std_table(DTI_df, label):
    """
    Plots a figure with subplots containing tables of mean ± std of DTI metrics by age, for each vertebral level (C2 to C6).
    """

    # Vertebral levels to add inside the table
    vertlevels = ['3:5', '2', '3', '4', '5', '6']

    #  List to save all results in a single CSV file
    all_results = []  
    
    # Create subplots: one row per vertebral level, 1 column
    fig = make_subplots(
        rows=len(vertlevels), cols=1,
        shared_xaxes=False,
        vertical_spacing=0.02,
        specs=[[{"type": "table"}] for _ in vertlevels]
    )

    for i, vertlevel in enumerate(vertlevels):
        combined_df = None
        for j, metric in enumerate(DTI_df):
            result = get_mean_and_std_by_age(DTI_df[metric], vertlevel, label, metric_col=metric)

            # Keep 'N' only from the first metric
            if j == 0:
                subject_counts = result[['N', 'Age']]
                result = result.drop(columns=['N'])
            else:
                result = result.drop(columns=['N'], errors='ignore')

            if combined_df is None:
                combined_df = result
            else:
                combined_df = pd.merge(combined_df, result, on='Age')

        # Add the number of subjects (N), merging on 'Age'
        combined_df = pd.merge(combined_df, subject_counts, on='Age')

        # Save copy of dataframe for CSV file
        combined_df_for_CSV = combined_df.copy()
        combined_df_for_CSV['Age'] = combined_df_for_CSV['Age']
        all_results.append(combined_df_for_CSV)

        # Add vertebral level column for CSV
        combined_df_for_CSV['VertLevel'] = f"C{vertlevel}"

        # Round age values to integers and make them appear in bold
        combined_df['Age'] = combined_df['Age'].round().astype(int)
        combined_df['Age'] = combined_df['Age'].apply(lambda x: f"<b>{x}</b>")

        # Plot
        fig.add_trace(
            go.Table(
                columnwidth=[80] + [200] * 4 + [80],
                header=dict(
                    values=list(combined_df.columns),
                    fill_color='lightgrey',
                    align='center',
                    font=dict(size=16, color='black', family='Arial', weight="bold")
                ),
                cells=dict(
                    values=[combined_df[col] for col in combined_df.columns],
                    height=30,
                    fill_color='white',
                    align='center',
                    font=dict(size=16, family='Arial', color='black')
                )
            ),
            row=i+1, col=1
        )

        # Add vertebral level annotation on the left side of each row
        fig.add_annotation(
            text=f"C{vertlevel}",
            xref="paper",
            yref="paper",
            x=0,
            y=1 - (i + 0.5) / len(vertlevels), 
            showarrow=False,
            font=dict(size=26, color="black", family="Arial", weight="bold"),
            xanchor="right",
            align="right"
        )

    # Table layout
    fig.update_layout(
        height=480 * len(vertlevels),
        width=1200,
        showlegend=False,
        margin=dict(l=80)
    )

    fig.show()

    # Concatenate all results and save to a single CSV file
    final_df = pd.concat(all_results, ignore_index=True)
    cols = ['VertLevel'] + [col for col in final_df.columns if col != 'VertLevel']
    final_df = final_df[cols]
    final_df.to_csv("../tables/DTI/DTI_all_levels_table.csv", index=False)


In [13]:
plot_mean_std_table(DTI_df, label='spinal cord')

In [14]:
plot_mean_std_table(DTI_df, label='white matter')

In [15]:
plot_mean_std_table(DTI_df, label='gray matter')

# Plot DTI metrics per vert level, by age

In [16]:
def plot_dti_per_age_SC_WM_GM(df, metric_name):
    
    df = df.copy()
    df.rename(columns={'WA()': metric_name}, inplace=True)

    # Labels and vertebral levels to plot
    labels = ['spinal cord', 'white matter', 'gray matter']
    vert_levels = ['3', '4', '5']

    # Define colors for male and female subjects
    colors = {
        'M': "#06273B",
        'F': "#1E59BE",
        'All': "#7A7A7A"
    }

    # Colors for confidence interval
    ci_colors = {
        'M': 'rgba(13, 62, 93, 0.2)',  # Light teal with transparency
        'F': 'rgba(30, 89, 190, 0.2)',  # Light blue with transparency
        'All': 'rgba(122, 122, 122, 0.15)'  # Light gray with transparency
    }

    # Define range of y axis for the DTI metrics
    metric_y_ranges = {
        'FA': [0.3, 0.9],
        'MD': [0.0004, 0.0016],
        'RD': [0.0002, 0.0012],
        'AD': [0.0010, 0.0024],
    }

    fig = make_subplots(
        rows=3,
        cols=3,
        shared_xaxes=False,
        shared_yaxes=False,
        vertical_spacing=0.12,
        horizontal_spacing=0.1
    )

    for row_idx, vert in enumerate(vert_levels, start=1):
        for col_idx, label in enumerate(labels, start=1):

            # Filder data according to the VertLevel and Label
            data = df[(df['VertLevel'] == vert) & (df['Label'] == label)]
    
            x_range = np.linspace(data['age'].min(), data['age'].max(), 100) # To compute predictions
            pred_df = pd.DataFrame({'age': x_range})

            for group in ['F', 'M', 'All']:
                if group in ['F', 'M']:
                    group_df = data[data['sex'] == group].copy()
                    pred_df['sex'] = group
                    ols_formula = f'{metric_name} ~ age + sex'

                else:
                    group_df = data.copy()
                    ols_formula = f'{metric_name} ~ age'
                
                ols_model = smf.ols(formula=ols_formula, data=group_df)
                ols_results = ols_model.fit()

                pred = ols_results.get_prediction(pred_df)
                pred_summary = pred.summary_frame(alpha=0.05)
                y_fit = pred_summary['mean']
                ci_lower = pred_summary['mean_ci_lower']
                ci_upper = pred_summary['mean_ci_upper']

                # Scatter plot
                if group != 'All':
                    fig.add_trace(
                        go.Scatter(
                            x=group_df['age'],
                            y=group_df[metric_name],
                            mode='markers',
                            marker=dict(color=colors[group], size=11, symbol='circle', opacity=0.8),
                            name=group,
                            legendgroup=group,
                            showlegend=(row_idx == 1 and col_idx == 1)
                        ),
                        row=row_idx, col=col_idx
                    )

                    # Fit line
                    fig.add_trace(
                        go.Scatter(
                            x=x_range,
                            y=y_fit,
                            mode='lines',
                            line=dict(color=colors[group], width=2, dash='solid'),
                            name=f"{group} - Fit",
                            legendgroup=group,
                            showlegend=False
                        ),
                        row=row_idx, col=col_idx
                    )

                # All subjects line
                else:
                    fig.add_trace(
                        go.Scatter(
                            x=x_range,
                            y=y_fit,
                            mode='lines',
                            line=dict(color=colors[group], width=2, dash='dash'),
                            name='All subjects',
                            legendgroup=group,
                            showlegend=False
                        ),
                        row=row_idx, col=col_idx
                    )

                # Confidence interval
                fig.add_trace(
                    go.Scatter(
                        x=np.concatenate([x_range, x_range[::-1]]),
                        y=np.concatenate([ci_upper, ci_lower[::-1]]),
                        fill='toself',
                        fillcolor=ci_colors[group],
                        line=dict(color='rgba(255,255,255,0)'),
                        hoverinfo='skip',
                        showlegend=False,
                        legendgroup=group,
                    ),
                    row=row_idx, col=col_idx
                )

            # Add subplot title
            fig.add_annotation(
                text=f"C{vert} - {label}",
                xref="x domain",
                yref="y domain",
                x=0.5,
                y=1.1,
                showarrow=False,
                font=dict(size=18, family="Arial", weight='bold'),
                row=row_idx,
                col=col_idx
            )

    # Update layout
    fig.update_layout(
        height=1200,
        width=1350,
        plot_bgcolor="#ffffff",
        paper_bgcolor="#ffffff",
        legend=dict(
            x=1.05,
            y=1,
            bgcolor="#ffffff",
            bordercolor="#ffffff",
            font=dict(family='Arial', color='black', size=30),
        )
    )

    # Axis formatting
    for r in range(1, 4):
        for c in range(1, 4):
            fig.update_xaxes(
                title_text='Age',
                title_font=dict(family='Arial', size=20, color='black', weight='bold'),
                tickfont=dict(family='Arial', size=18),
                tickvals=list(range(6, 18)),
                range=[5, 18],
                row=r, col=c
            )
            fig.update_yaxes(
                title_text=f'{metric_name}',
                title_font=dict(family='Arial', size=20, color='black', weight='bold'),
                tickfont=dict(family='Arial', size=18),
                range=metric_y_ranges.get(metric_name, None),
                row=r, col=c
            )

    fig.show()


In [17]:
plot_dti_per_age_SC_WM_GM(DTI_df['FA'], metric_name='FA')
plot_dti_per_age_SC_WM_GM(DTI_df['MD'], metric_name='MD')
plot_dti_per_age_SC_WM_GM(DTI_df['AD'], metric_name='AD')
plot_dti_per_age_SC_WM_GM(DTI_df['RD'], metric_name='RD')

In [18]:
def plot_dti_per_age_per_vertlevel(df, metric_name, label):
    
    df = df.copy()
    df.rename(columns={'WA()': metric_name}, inplace=True)

    # Filter for the given label (i.e., spinal cord, white matter, gray matter)
    df = df[df['Label'] == label].copy()

    # Ensure correct types
    df['age'] = pd.to_numeric(df['age'], errors='coerce')
    df['VertLevel'] = df['VertLevel'].astype(str)

    # Get sorted unique VertLevels
    vert_levels = sorted(df['VertLevel'].unique())
    num_plots = len(vert_levels)

    # Hardcode 3x2 layout and skip (1,1)
    rows, cols = 3, 2

    # Define colors for each group (Male, Female, All subjects)
    colors = {
        'M': "#06273B",
        'F': "#1E59BE",
        'All': "#7A7A7A"
    }

    # Colors for confidence interval
    ci_colors = {
        'M': 'rgba(13, 62, 93, 0.2)',  # Light teal with transparency
        'F': 'rgba(30, 89, 190, 0.2)',  # Light blue with transparency
        'All': 'rgba(122, 122, 122, 0.15)'  # Light gray with transparency
    }

    # Define y axis range for each DTI metric
    metric_y_ranges = {
    'FA': [0.3, 0.9],
    'MD': [0.0005, 0.0016],
    'RD': [0.0002, 0.0012],
    'AD': [0.0014, 0.0022],
    }

    fig = make_subplots(
        rows=rows,
        cols=cols,
        shared_xaxes=False,
        vertical_spacing=0.13,
        horizontal_spacing=0.2
    )

    # Predefine subplot positions, skipping (1,1)
    subplot_positions = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (3, 2)]

    # Sort the list of vertebral levels to plot in a specific order
    vert_levels = ['3:5', '2', '3', '4', '5', '6']

    for i, vert in enumerate(vert_levels):
        row, col = subplot_positions[i]

        # Use filtered data for current VertLevel
        data = df[df['VertLevel'] == vert].dropna(subset=['age', metric_name])

        # Use age range from the current VertLevel data
        x_range = np.linspace(data['age'].min(), data['age'].max(), 100)
        pred_df = pd.DataFrame({'age': x_range})

        for group in ['F', 'M', 'All']:
            if group in ['F', 'M']:
                # Use data filtered by current VertLevel and sex
                group_df = data[data['sex'] == group].copy()
                pred_df['sex'] = group
                ols_formula = f'{metric_name} ~ age + sex'
            else:
                # Use all subjects within current VertLevel
                group_df = data.copy()
                ols_formula = f'{metric_name} ~ age'

            # Perform OLS regression 
            ols_model = smf.ols(formula=ols_formula, data=group_df)
            ols_results = ols_model.fit()

            # Add confidence intervals
            pred = ols_results.get_prediction(pred_df)
            pred_summary = pred.summary_frame(alpha=0.05)
            y_fit = pred_summary['mean']
            ci_lower = pred_summary['mean_ci_lower']
            ci_upper = pred_summary['mean_ci_upper']

            if group == 'F':
                legend_group = 'Male'
            elif group == 'M':
                legend_group = 'Female'
            else:
                legend_group = 'All'

            # Add scatter plot for male and female subjects
            if group != 'All':
                fig.add_trace(
                    go.Scatter(
                        x=group_df['age'], y=group_df[metric_name],
                        mode='markers',
                        marker=dict(color=colors[group], size=9, symbol='circle', opacity=0.8),
                        name=('M' if group == 'M' else 'F'),
                        legendgroup=legend_group,
                        showlegend=(i == 0),
                    ),
                    row=row, col=col
                )

                # Add regression line for male and female subjects
                fig.add_trace(
                    go.Scatter(
                        x=x_range, y=y_fit,
                        mode='lines',
                        line=dict(color=colors[group], width=2, dash='solid'),
                        name=f"{'All subjects' if group == 'All' else group} - Fit",
                        legendgroup=legend_group,
                        showlegend=False,
                    ),
                    row=row, col=col
                )

            # Add regression line for all subjects
            else:
                fig.add_trace(
                    go.Scatter(
                        x=x_range, y=y_fit,
                        mode='lines',
                        line=dict(color=colors[group], width=2, dash='dash'),
                        name=f'All subjects',
                        legendgroup=legend_group,
                        showlegend=False,
                    ),
                    row=row, col=col
                )

            # Add confidence interval
            fig.add_trace(
                go.Scatter(
                    x=np.concatenate([x_range, x_range[::-1]]),
                    y=np.concatenate([ci_upper, ci_lower[::-1]]),
                    fill='toself',
                    fillcolor=ci_colors[group],
                    line=dict(color='rgba(255,255,255,0)'),
                    hoverinfo='skip',
                    showlegend=False,
                    legendgroup=legend_group,
                ),
                row=row, col=col
            )

    # Layout
    fig.update_layout(
        height=1000,
        width=1100,
        plot_bgcolor="#ffffff",
        paper_bgcolor="#ffffff",
        legend=dict(
            x=1.02,
            y=1,
            bgcolor="#ffffff",
            bordercolor="#ffffff",
            font=dict(family='Arial', size=20)
        )
    )

    # Add axis titles and tick labels
    for r in range(1, rows+1):
        for c in range(1, cols+1):
            fig.update_xaxes(
                title_text='Age',
                title_font=dict(family='Arial', size=20, weight='bold'),
                tickfont=dict(family='Arial', size=20),
                tickvals=list(range(6, 18)),  # All ticks from 6 to 17
                range=[5, 18], # Start xaxis range at 5 to add space before the first age tick      
                row=r, col=c
            )
            fig.update_yaxes(
                title_text=f'{metric_name} in {label}',
                title_font=dict(family='Arial', size=20, weight='bold'),
                tickfont=dict(family='Arial', size=20),
                range=metric_y_ranges.get(metric_name, None),
                row=r, col=c
            )

    # Add vertebral level on top of each subplot
    for i, vert in enumerate(vert_levels):
        row, col = subplot_positions[i]
        fig.add_annotation(
            text=f"C{vert}",
            xref="x domain",
            yref="y domain",
            x=0.5,
            y=1.1,
            showarrow=False,
            font=dict(size=26, family="Arial", weight='bold'),
            row=row,
            col=col
        )


    fig.show()


# DTI metrics in the spinal cord 

In [19]:
plot_dti_per_age_per_vertlevel(DTI_df['FA'], metric_name='FA', label='spinal cord')
plot_dti_per_age_per_vertlevel(DTI_df['MD'], metric_name='MD', label='spinal cord')
plot_dti_per_age_per_vertlevel(DTI_df['AD'], metric_name='AD', label='spinal cord')
plot_dti_per_age_per_vertlevel(DTI_df['RD'], metric_name='RD', label='spinal cord')

# DTI metrics in the gray matter

In [20]:
plot_dti_per_age_per_vertlevel(DTI_df['FA'], metric_name='FA', label='gray matter')
plot_dti_per_age_per_vertlevel(DTI_df['MD'], metric_name='MD', label='gray matter')
plot_dti_per_age_per_vertlevel(DTI_df['AD'], metric_name='AD', label='gray matter')
plot_dti_per_age_per_vertlevel(DTI_df['RD'], metric_name='RD', label='gray matter')

# DTI metrics in the white matter

In [21]:
plot_dti_per_age_per_vertlevel(DTI_df['FA'], metric_name='FA', label='white matter')
plot_dti_per_age_per_vertlevel(DTI_df['MD'], metric_name='MD', label='white matter')
plot_dti_per_age_per_vertlevel(DTI_df['AD'], metric_name='AD', label='white matter')
plot_dti_per_age_per_vertlevel(DTI_df['RD'], metric_name='RD', label='white matter')