In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

METRICS = ['MEAN(area)', 'MEAN(diameter_AP)', 'MEAN(diameter_RL)', 'MEAN(compression_ratio)', 'MEAN(eccentricity)', 'MEAN(solidity)']

# Create the initial metric value
INITIAL_METRIC = METRICS[1]

METRIC_TO_TITLE = {
    'MEAN(area)': 'Cross-Sectional Area',
    'MEAN(diameter_AP)': 'AP Diameter',
    'MEAN(diameter_RL)': 'RL Diameter',
    'MEAN(compression_ratio)': 'Compression Ratio',
    'MEAN(eccentricity)': 'Eccentricity',
    'MEAN(solidity)': 'Solidity',
}

METRIC_TO_AXIS = {
    'MEAN(diameter_AP)': 'AP Diameter [mm]',
    'MEAN(area)': 'Cross-Sectional Area [mm²]',
    'MEAN(diameter_RL)': 'RL Diameter [mm]',
    'MEAN(eccentricity)': 'Eccentricity [a.u.]',
    'MEAN(solidity)': 'Solidity [%]',
    'MEAN(compression_ratio)': 'Compression Ratio [a.u.]',
}

# Set ylim to do not overlap horizontal grid with vertebrae labels
METRICS_TO_YLIM = {
    'MEAN(diameter_AP)': (5.7, 9.3),
    'MEAN(area)': (35, 95),
    'MEAN(diameter_RL)': (8.5, 14.5),
    'MEAN(eccentricity)': (0.51, 0.89),
    'MEAN(solidity)': (90, 99.9),
    'MEAN(compression_ratio)': (0.41, 0.84),
}

# ylim offset (used for showing text)
METRICS_TO_YLIM_OFFSET = {
    'MEAN(diameter_AP)': 0.15,
    'MEAN(area)': 2.5,
    'MEAN(diameter_RL)': 0.25,
    'MEAN(eccentricity)': 0.015,
    'MEAN(solidity)': 0.4,
    'MEAN(compression_ratio)': 0.02,
}

# y-axis tick values
YTICKVALS = [950, 900, 850, 800, 750, 700]

LABELS_FONT_SIZE = 16
TICKS_FONT_SIZE = 16
TICKS_FONT_SIZE_SUBPLOT=14


LEGEND_ITEMS = {
    'sex': {'M': 'Males', 'F': 'Females'},
    'manufacturer': {'Siemens': 'Siemens', 'Philips': 'Philips', 'GE': 'GE'},
    'age': {'10-20': '10-20', '21-30': '21-30', '31-40': '31-40', '41-50': '41-50', '51-60': '51-60'},
    }

PALETTE = {
    'sex': {'M': 'blue', 'F': 'red'},
    'manufacturer': {'Siemens': 'green', 'Philips': 'dodgerblue', 'GE': 'black'},
    'age': {'10-20': 'blue', '21-30': 'green', '31-40': 'black', '41-50': 'red', '51-60': 'purple'},
    }

# paletter with 0.2 opacity -- used for fillcolor
PALETTE_RGBA = {
    'sex': {'M': 'rgba(0, 0, 255, 0.2)', 'F': 'rgba(255, 0, 0, 0.2)'},
    'manufacturer': {'Siemens': 'rgba(0, 128, 0, 0.2)',  'Philips': 'rgba(30, 144, 255, 0.2)', 'GE': 'rgba(0, 0, 0, 0.2)'},
    'age': {'10-20': 'rgba(0, 0, 255, 0.2)', '21-30': 'rgba(0, 128, 0, 0.2)', '31-40': 'rgba(0, 0, 0, 0.2)', '41-50': 'rgba(255, 0, 0, 0.2)', '51-60': 'rgba(128, 0, 128, 0.2)'}
    }


def create_lineplot(df, metric=INITIAL_METRIC, df_single_subject=None, hue=None, show_cv=False):
    """
    Create lineplot for selected metric per vertebral levels.
    Note: we are ploting slices not levels to avoid averaging across levels.
    Args:
        df (pd.dataFrame): dataframe with metric values across all healthy subjects
        metric (str): metric to plot (e.g., area, compression_ratio, etc.)
        df_single_subject (pd.dataFrame): dataframe with metric values for a single subject
        hue (str): column name of the dataframe to use for grouping; if None, no grouping is applied
        show_cv (bool): if True, show coefficient of variation
    """
    
    slices = df["Slice (I->S)"]
    # Calculating mean and standard deviation
    mean=df.groupby("Slice (I->S)")[metric].mean()
    std=df.groupby("Slice (I->S)")[metric].std()

    # Init figure
    fig = go.Figure()
    # Note: we are ploting slices not levels to avoid averaging across levels
    # Add trace for upper standard deviation
    fig.add_trace(go.Scatter(x=slices,
                             y=mean+std,
                             mode='lines',
                             line=dict(color='blue',width=0.1),
                             name='STD'))
    # Add trace for lower standard deviation and fill to the upper standard deviation
    fig.add_trace(go.Scatter(x=slices,
                             y=mean-std,
                             mode='lines',
                             line=dict(color='blue',width=0.1),
                             fill='tonexty',
                             name='STD'))
    # Add trace for mean
    fig.add_trace(go.Scatter(x=slices,
                             y=mean,
                             mode='lines',
                             line=dict(color='blue'),
                             name='Mean'))
    
    # Modify the legend text for each trace
    for trace in fig.data:
        if trace.name == "STD":
            trace.showlegend = False
        elif trace.name == "mean":
            trace.name = "mean"
    
    fig.update_layout(
        width=800,  # Set the width of the figure to 800 pixels
        height=600,  # Set the height of the figure to 600 pixels
        xaxis_title="Vertebral Level (S->I)",
        yaxis_title=METRIC_TO_AXIS[metric],
        # Invert x-axis
        xaxis=dict(autorange="reversed", 
                   title_font=dict(size=LABELS_FONT_SIZE),
                   tickfont=dict(size=TICKS_FONT_SIZE)),
        yaxis=dict(title_font=dict(size=LABELS_FONT_SIZE),
                   tickfont=dict(size=TICKS_FONT_SIZE)),
        title={
            "text": "Spinal Cord " + METRIC_TO_TITLE[metric],
            "x": 0.5,  # Center-align the title
            "y": 0.9   # Adjust the y position if needed
        },
        title_font_size=LABELS_FONT_SIZE,
        # Increase legend font
        legend=dict(font=dict(size=LABELS_FONT_SIZE))
    )
    
    # Insert a vertical line for each vertebral level
    # Get indices of slices corresponding to mid-vertebrae
    vert, ind_vert, ind_vert_mid = get_vert_indices(df)
    for idx, x in enumerate(ind_vert[1:-1]):
        fig.add_trace(
            go.Scatter(
                x=[df.loc[x, 'Slice (I->S)'], df.loc[x, 'Slice (I->S)']],
                y=[METRICS_TO_YLIM[metric][0], METRICS_TO_YLIM[metric][1]],
                mode='lines',
                line=dict(
                    color='black',
                    width=1,
                    dash='dash'
                ),
                showlegend=False
            )
        )
    
    # Adjust ymlim for solidity (it has low variance)
    if metric == 'MEAN(solidity)':
        fig.update_layout(yaxis=dict(range=[METRICS_TO_YLIM[metric][0], METRICS_TO_YLIM[metric][1]]))    
    else:
        fig.update_layout(yaxis=dict(range=[METRICS_TO_YLIM[metric][0], METRICS_TO_YLIM[metric][1]]))
    
    # Insert a text label for each vertebral level
    for idx, x in enumerate(ind_vert_mid, 0):
        # Th
        if vert[x] > 7:
            level = 'T' + str(vert[x] - 7)
            x_shift=0
        # Other levels
        else:
            level = 'C' + str(vert[x])
            x_shift=0
        
        fig.add_annotation(
            x=df.loc[ind_vert_mid[idx], 'Slice (I->S)'] + x_shift,
            y=METRICS_TO_YLIM[metric][0] + METRICS_TO_YLIM_OFFSET[metric],
            text=level,
            showarrow=False,
            font=dict(size=TICKS_FONT_SIZE)
        )
    
    # Remove xticks
    #fig.update_xaxes(tickvals=[])
    fig.show()
    

def create_subplot(df):
    """
    Create 2x3 subplot with lineplots for all MRI metric per vertebral levels.
    Note: we are ploting slices not levels to avoid averaging across levels.
    Args:
        df (pd.dataFrame): dataframe with metric values across all healthy subjects
    """

    slices = df["Slice (I->S)"]

    # Create subplots with 2 rows and 3 columns
    fig = make_subplots(rows=2, cols=3,vertical_spacing=0.1, horizontal_spacing=0.08)

    # Iterate over each metric and plot on the corresponding subplot
    for i, metric in enumerate(METRICS):
        row = (i // 3) + 1
        col = (i % 3) + 1

        # Calculating mean and standard deviation
        mean = df.groupby("Slice (I->S)")[metric].mean()
        std = df.groupby("Slice (I->S)")[metric].std()
        
        # Add trace for upper standard deviation
        fig.add_trace(
            go.Scatter(
                x=slices, 
                y=mean + std, 
                mode='lines', 
                line=dict(color=PALETTE['sex']['M'], 
                          width=0.5), 
                name='',
                hovertemplate =
                'STD: %{y:.2f}'+
                '<br>Slice: %{x}'
            ), 
            row=row, 
            col=col
        )
        # Add trace for lower standard deviation and fill to the upper standard deviation
        fig.add_trace(
            go.Scatter(
                x=slices, 
                y=mean - std, 
                mode='lines', 
                line=dict(color=PALETTE['sex']['M'], 
                          width=0.5), 
                fill='tonexty', 
                fillcolor=PALETTE_RGBA['sex']['M'], 
                name='',
                hovertemplate =
                'STD: %{y:.2f}'+
                '<br>Slice: %{x}'
            ), 
            row=row, 
            col=col
        )
        # Add trace for mean
        fig.add_trace(
            go.Scatter(
                x=slices, 
                y=mean, 
                mode='lines', 
                line=dict(color=PALETTE['sex']['M'], 
                          width=3), 
                name='',
                hovertemplate =
                'Mean: %{y:.2f}' +
                '<br>Slice: %{x}',
            ), 
            row=row,
            col=col
        )

        # Insert a vertical line for each vertebral level
        # Get indices of slices corresponding to mid-vertebrae
        vert, ind_vert, ind_vert_mid = get_vert_indices(df)
        for idx, x in enumerate(ind_vert[1:-1]):
            fig.add_trace(
                go.Scatter(
                    x=[df.loc[x, 'Slice (I->S)'], df.loc[x, 'Slice (I->S)']],
                    y=[METRICS_TO_YLIM[metric][0], METRICS_TO_YLIM[metric][1]],
                    mode='lines',
                    line=dict(
                        color='black',
                        width=1,
                        dash='dash'
                    ),
                    showlegend=False,
                    hoverinfo='none'
                ),
                row=row, 
                col=col
            )
        
        # Hide the legend for each trace
        for trace in fig.data:
            trace.showlegend = False
            
        # Update the x-axis settings
        fig.update_xaxes(
            autorange="reversed",  # Reverse the x-axis for axial slices
            title="Axial Slice #",  # Set the x-axis label
            title_font=dict(size=TICKS_FONT_SIZE_SUBPLOT),  # Adjust title font size
            tickfont=dict(size=TICKS_FONT_SIZE_SUBPLOT),  # Adjust tick font size
            title_standoff=0,  # Set title standoff
            showgrid=False,  # Hide grid lines
            tickvals=YTICKVALS,  # Set tick values
            showline=True,  # Show axis line
            linecolor='gray',  # Set axis line color
            showticklabels=True,  # Show tick labels
            row=row,  # Specify the row of the subplot
            col=col  # Specify the column of the subplot
        )
        # Update y-axis settings
        fig.update_yaxes(
            title=METRIC_TO_AXIS[metric],  # Set the y-axis label based on the metric
            title_font=dict(size=TICKS_FONT_SIZE_SUBPLOT),  # Adjust title font size
            tickfont=dict(size=TICKS_FONT_SIZE_SUBPLOT),  # Adjust tick font size
            title_standoff=0,  # Set title standoff
            range=[METRICS_TO_YLIM[metric][0], METRICS_TO_YLIM[metric][1]],  # Set y-axis range
            row=row,  # Specify the row of the subplot
            col=col,  # Specify the column of the subplot
            showgrid=True,  # Show grid lines
            gridcolor='lightgray'  # Set grid color
        )
        
        # Insert a text label for each vertebral level
        for idx, x in enumerate(ind_vert_mid, 0):
            # Th
            if vert[x] > 7:
                level = 'T' + str(vert[x] - 7)
                x_shift=0
            # Other levels
            else:
                level = 'C' + str(vert[x])
                x_shift=0
            
            fig.add_annotation(
                x=df.loc[ind_vert_mid[idx], 'Slice (I->S)'] + x_shift,
                y=METRICS_TO_YLIM[metric][0] + METRICS_TO_YLIM_OFFSET[metric],
                text=level,
                showarrow=False,
                font=dict(size=TICKS_FONT_SIZE_SUBPLOT),
                row=row, col=col
            )
        
    # Update the subplot sizes
    fig.update_layout(
        height=800, 
        width=1300, 
        plot_bgcolor='white'
    )
    
    fig.show()

    
def create_subplot_hue(df, hue):
    """
    Create 2x3 subplot with lineplots for all MRI metric per vertebral levels for a specific hue ('age', 'sex', 'manufacturer').
    Note: we are ploting slices not levels to avoid averaging across levels.
    Args:
        df (pd.dataFrame): dataframe with metric values across all healthy subjects
        hue (str): column name of the dataframe to use for grouping; if None, no grouping is applied
    """

    
    # Get a list of unique categories for the hue variable
    categories = df[hue].unique()
    categories.sort()
    print(categories)
    slices = df["Slice (I->S)"]

    # Create subplots with 2 rows and 3 columns
    fig = make_subplots(rows=2, cols=3,vertical_spacing=0.1, horizontal_spacing=0.08)

    # Iterate over each metric and plot on the corresponding subplot
    for i, metric in enumerate(METRICS):
        row = (i // 3) + 1
        col = (i % 3) + 1

        
        # Calculating mean and standard deviation for each category
        for j, category in enumerate(categories):
            category_data = df[df[hue] == category]
            mean = category_data.groupby("Slice (I->S)")[metric].mean()
            std = category_data.groupby("Slice (I->S)")[metric].std()
        
            # Add trace for upper standard deviation
            fig.add_trace(
                go.Scatter(
                    x=slices,
                    y=mean + std,
                    mode='lines',
                    line=dict(
                        color=PALETTE[hue][category],
                        width=0.5
                    ),
                    name=LEGEND_ITEMS[hue][category],
                    legendgroup=category,
                    hovertemplate =
                    'STD: %{y:.2f}'+
                    '<br>Slice: %{x}',
                    showlegend=False
                ),
                row=row,
                col=col
            )
            # Add trace for lower standard deviation and fill to the upper standard deviation
            fig.add_trace(
                go.Scatter(
                    x=slices,
                    y=mean - std,
                    mode='lines',
                    line=dict(
                        color=PALETTE[hue][category],
                        width=0.5
                    ),
                    fill='tonexty',
                    fillcolor=PALETTE_RGBA[hue][category],
                    name=LEGEND_ITEMS[hue][category],
                    legendgroup=category,
                    hovertemplate =
                    'STD: %{y:.2f}'+
                    '<br>Slice: %{x}',
                    showlegend=False
                ),
                row=row,
                col=col
            )
            
            # Add trace for mean
            # Note: legend is added only for the first subplot to avoid legend item duplications
            if i == 0:
                fig.add_trace(
                    go.Scatter(
                        x=slices,
                        y=mean,
                        mode='lines',
                        line=dict(
                            color=PALETTE[hue][category],
                            width=3
                        ),
                        name=LEGEND_ITEMS[hue][category],
                        legendgroup=category,
                        hovertemplate =
                        'Mean: %{y:.2f}' +
                        '<br>Slice: %{x}',
                    ),
                    row=row,
                    col=col
                )
            else:
                fig.add_trace(
                    go.Scatter(
                        x=slices,
                        y=mean,
                        mode='lines',
                        line=dict(
                            color=PALETTE[hue][category],
                            width=3
                        ),
                        name=LEGEND_ITEMS[hue][category],
                        legendgroup=category,
                        hovertemplate =
                        'Mean: %{y:.2f}' +
                        '<br>Slice: %{x}',
                        showlegend=False
                    ),
                    row=row,
                    col=col
                )
        
        # Insert a vertical line for each vertebral level
        # Get indices of slices corresponding to mid-vertebrae
        vert, ind_vert, ind_vert_mid = get_vert_indices(df)
        for idx, x in enumerate(ind_vert[1:-1]):
            fig.add_trace(
                go.Scatter(
                    x=[df.loc[x, 'Slice (I->S)'], df.loc[x, 'Slice (I->S)']],
                    y=[METRICS_TO_YLIM[metric][0], METRICS_TO_YLIM[metric][1]],
                    mode='lines',
                    line=dict(
                        color='black',
                        width=1,
                        dash='dash'
                    ),
                    showlegend=False,
                    hoverinfo='none'
                ),
                row=row, col=col
            )
            
        # Update the x-axis settings
        fig.update_xaxes(
            autorange="reversed",  # Reverse the x-axis for axial slices
            title="Axial Slice #",  # Set the x-axis label
            title_font=dict(size=TICKS_FONT_SIZE_SUBPLOT),  # Adjust title font size
            tickfont=dict(size=TICKS_FONT_SIZE_SUBPLOT),  # Adjust tick font size
            title_standoff=0,  # Set title standoff
            showgrid=False,  # Hide grid lines
            tickvals=YTICKVALS,  # Set tick values
            showline=True,  # Show axis line
            linecolor='gray',  # Set axis line color
            showticklabels=True,  # Show tick labels
            row=row,  # Specify the row of the subplot
            col=col  # Specify the column of the subplot
        )
        # Update y-axis settings
        fig.update_yaxes(
            title=METRIC_TO_AXIS[metric],  # Set the y-axis label based on the metric
            title_font=dict(size=TICKS_FONT_SIZE_SUBPLOT),  # Adjust title font size
            tickfont=dict(size=TICKS_FONT_SIZE_SUBPLOT),  # Adjust tick font size
            title_standoff=0,  # Set title standoff
            range=[METRICS_TO_YLIM[metric][0], METRICS_TO_YLIM[metric][1]],  # Set y-axis range
            row=row,  # Specify the row of the subplot
            col=col,  # Specify the column of the subplot
            showgrid=True,  # Show grid lines
            gridcolor='lightgray'  # Set grid color
        )

        # Insert a text label for each vertebral level
        for idx, x in enumerate(ind_vert_mid, 0):
            # Th
            if vert[x] > 7:
                level = 'T' + str(vert[x] - 7)
                x_shift=0
            # Other levels
            else:
                level = 'C' + str(vert[x])
                x_shift=0
            
            fig.add_annotation(
                x=df.loc[ind_vert_mid[idx], 'Slice (I->S)'] + x_shift,
                y=METRICS_TO_YLIM[metric][0] + METRICS_TO_YLIM_OFFSET[metric],
                text=level,
                showarrow=False,
                font=dict(size=TICKS_FONT_SIZE_SUBPLOT),
                row=row, 
                col=col
            )
        
    # Update the subplot sizes
    fig.update_layout(
        height=800, 
        width=1300, 
        plot_bgcolor='white', 
        legend_title_text=hue,
        legend={'traceorder':'normal'}
    )
    
    fig.show()    



def get_vert_indices(df):
    """
    Get indices of slices corresponding to mid-vertebrae
    Args:
        df (pd.dataFrame): dataframe with CSA values
    Returns:
        vert (pd.Series): vertebrae levels across slices
        ind_vert (np.array): indices of slices corresponding to the beginning of each level (=intervertebral disc)
        ind_vert_mid (np.array): indices of slices corresponding to mid-levels
    """
    # Get vert levels for one certain subject
    vert = df[df['participant_id'] == 'sub-amu01']['VertLevel']
    # Get indexes of where array changes value
    ind_vert = vert.diff()[vert.diff() != 0].index.values
    # Get the beginning of C1
    ind_vert = np.append(ind_vert, vert.index.values[-1])
    ind_vert_mid = []
    # Get indexes of mid-vertebrae
    for i in range(len(ind_vert)-1):
        ind_vert_mid.append(int(ind_vert[i:i+2].mean()))

    return vert, ind_vert, ind_vert_mid