In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

METRIC_TO_TITLE = {
    'MEAN(diameter_AP)': 'AP Diameter',
    'MEAN(area)': 'Cross-Sectional Area',
    'MEAN(diameter_RL)': 'RL Diameter',
    'MEAN(eccentricity)': 'Eccentricity',
    'MEAN(solidity)': 'Solidity',
    'MEAN(compression_ratio)': 'Compression Ratio',
}

METRIC_TO_AXIS = {
    'MEAN(diameter_AP)': 'AP Diameter [mm]',
    'MEAN(area)': 'Cross-Sectional Area [mm²]',
    'MEAN(diameter_RL)': 'RL Diameter [mm]',
    'MEAN(eccentricity)': 'Eccentricity [a.u.]',
    'MEAN(solidity)': 'Solidity [%]',
    'MEAN(compression_ratio)': 'Compression Ratio [a.u.]',
}

# ylim max offset (used for showing text)
METRICS_TO_YLIM = {
    'MEAN(diameter_AP)': 0.4,
    'MEAN(area)': 6,
    'MEAN(diameter_RL)': 0.7,
    'MEAN(eccentricity)': 0.03,
    'MEAN(solidity)': 1,
    'MEAN(compression_ratio)': 0.03,
}

LABELS_FONT_SIZE = 14
TICKS_FONT_SIZE = 12

def create_lineplot(df, metric, df_single_subject=None, hue=None, show_cv=False):
    """
    Create lineplot for selected metric per vertebral levels.
    Note: we are ploting slices not levels to avoid averaging across levels.
    Args:
        df (pd.dataFrame): dataframe with metric values across all healthy subjects
        metric (str): metric to plot (e.g., area, compression_ratio, etc.)
        df_single_subject (pd.dataFrame): dataframe with metric values for a single subject
        hue (str): column name of the dataframe to use for grouping; if None, no grouping is applied
        show_cv (bool): if True, show coefficient of variation
    """
    fig, ax = plt.subplots()
    # Note: we are ploting slices not levels to avoid averaging across levels
    sns.lineplot(ax=ax, x="Slice (I->S)", y=metric, data=df, errorbar='sd', hue=hue, legend='auto')
    if df_single_subject is not None:
        sns.lineplot(ax=ax, x="Slice (I->S)", y=metric, data=df_single_subject, errorbar=None, label='single subject')
    # Set legend
    if df_single_subject is None and hue is None:
        plt.legend(['spine-generic mean', 'spine-generic std'])
    elif df_single_subject is not None and hue is None:
        plt.legend(['spine-generic mean', 'spine-generic std', 'single subject'])
    # Move y-axis to the right
    plt.tick_params(axis='y', which='both', labelleft=False, labelright=True)
    plt.grid(color='lightgrey', zorder=0)
    plt.title('Spinal Cord ' + METRIC_TO_TITLE[metric], fontsize=LABELS_FONT_SIZE)
    # Adjust ymlim for solidity (it has low variance)
    if metric == 'MEAN(solidity)':
        ax.set_ylim(90, 100)
    ymin, ymax = ax.get_ylim()
    ax.set_ylabel(METRIC_TO_AXIS[metric], fontsize=LABELS_FONT_SIZE)
    ax.set_xlabel('Vertebral Level (S->I)', fontsize=LABELS_FONT_SIZE)
    # Remove xticks
    ax.set_xticks([])

    # Get indices of slices corresponding to mid-vertebrae
    vert, ind_vert, ind_vert_mid = get_vert_indices(df)
    # Insert a vertical line for each vertebral level
    for idx, x in enumerate(ind_vert[1:]):
        plt.axvline(df.loc[x, 'Slice (I->S)'], color='black', linestyle='--', alpha=0.5)

    # Insert a text label for each vertebral level
    for idx, x in enumerate(ind_vert, 1):
        if show_cv:
            cv = compute_cv(df[(df['VertLevel'] == vert[x])], metric)
        if vert[x] > 7:
            level = 'T' + str(vert[x] - 7)
            ax.text(df.loc[ind_vert_mid[idx], 'Slice (I->S)'], ymin, level, horizontalalignment='center',
                    verticalalignment='bottom', color='black')
            # Show CV
            if show_cv:
                ax.text(df.loc[ind_vert_mid[idx], 'Slice (I->S)'], ymax-METRICS_TO_YLIM[metric],
                        str(round(cv, 1)) + '%', horizontalalignment='center',
                        verticalalignment='bottom', color='black')
        # Deal with C1 label position
        elif vert[x] == 1:
            level = 'C' + str(vert[x])
            ax.text(df.loc[ind_vert_mid[idx], 'Slice (I->S)']+15, ymin, level, horizontalalignment='center',
                    verticalalignment='bottom', color='black')
            # Show CV
            if show_cv:
                ax.text(df.loc[ind_vert_mid[idx], 'Slice (I->S)']+15, ymax-METRICS_TO_YLIM[metric],
                        str(round(cv, 1)) + '%', horizontalalignment='center',
                        verticalalignment='bottom', color='black')
        else:
            level = 'C' + str(vert[x])
            ax.text(df.loc[ind_vert_mid[idx], 'Slice (I->S)'], ymin, level, horizontalalignment='center',
                    verticalalignment='bottom', color='black')
            # Show CV
            if show_cv:
                ax.text(df.loc[ind_vert_mid[idx], 'Slice (I->S)'], ymax-METRICS_TO_YLIM[metric],
                        str(round(cv, 1)) + '%', horizontalalignment='center',
                        verticalalignment='bottom', color='black')
        if show_cv:
            print(f'{metric}, {level}, COV: {cv}')

    # Invert x-axis
    ax.invert_xaxis()


def compute_cv(df, metric):
    """
    Compute coefficient of variation (CV) of a given metric.
    Args:
        df (pd.dataFrame): dataframe with CSA values
        metric (str): column name of the dataframe to compute CV
    Returns:
        cv (float): coefficient of variation
    """
    cv = df[metric].std() / df[metric].mean()
    cv = cv * 100
    return cv


def get_vert_indices(df):
    """
    Get indices of slices corresponding to mid-vertebrae
    Args:
        df (pd.dataFrame): dataframe with CSA values
    Returns:
        vert (pd.Series): vertebrae levels across slices
        ind_vert (np.array): indices of slices corresponding to the beginning of each level
        ind_vert_mid (np.array): indices of slices corresponding to mid-levels
    """
    # Get vert levels for one certain subject
    vert = df[df['participant_id'] == 'sub-amu01']['VertLevel']
    # Get indexes of where array changes value
    ind_vert = vert.diff()[vert.diff() != 0].index.values
    ind_vert_mid = []
    for i in range(len(ind_vert)):
        ind_vert_mid.append(int(ind_vert[i:i + 2].mean()))
    ind_vert_mid.insert(0, ind_vert[0] - 20)
    ind_vert_mid = ind_vert_mid

    return vert, ind_vert, ind_vert_mid