In [1]:
from matplotlib import pyplot as plt 
import numpy as np 
import pandas as pd 
import plotly.express as px
import plotly.io as pio
import plotly.graph_objs as go
import plotly.offline as pyo
import pickle

from collections import Counter

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

from matplotlib import rc

### Load pickle

In [2]:
# NEW FILE
read_input = open('data/EDADS_subtype_timelines_agecorrected_opt.pickle','rb')
load_inputs = pickle.load(read_input)
read_input.close()

T, S, Sboot = load_inputs

diagnosis = np.load('data/diagnosis.npy', allow_pickle=True)


Trying to unpickle estimator MinCovDet from version 0.24.1 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations



## Function

In [13]:
def staging_scatterplot(S, diagnosis, subtype_labels = None, chosen_subtypes = None, color_list = ['#000000'], width=1100, height=800, fontsize=[34,18,14,22]):
    """
    Creates a scatterplot of staging ~ atypicality
    :param S: subtyping dictionary, subtypes for each patient individually
    :param subtype_labels: list with label name for the subtypes (optional)
    :param chosen_subtypes: a list with diagnosis labels to consider in the plot
    :param color_list: list with color hex values (optional)
    :param width: int (optional)
    :param height: int (optional)
    :param fontsize: a list of 4 ints, corresponding to [font_title, font_axes, font_ticks, font_legend] respectively (optional)
    :return: plotly scatterplot
    """     
    
    # Get subtype labels
    unique_subtypes = np.unique(S['subtypes'][~np.isnan(S['subtypes'])])
    if subtype_labels is None:
        subtype_labels = []
        for i in range(len(unique_subtypes)):
            subtype_labels.append('Subtype '+str(int(unique_subtypes[i])))
            
    subtype_map = {unique_subtypes[i]: subtype_labels[i] for i in range(len(subtype_labels))}
    
    # Get diagnosis lables (exclude controls)
    diagnosis_labels = list(set(diagnosis))
    diagnosis_labels.remove('CN')
    diagnosis_labels.sort()

    if chosen_subtypes is None:
        chosen_subtypes=diagnosis_labels
       
    # Create DataFrame
    staging = list(S['staging'])
    atypical = list(S['atypicality'])
    diagnosis = list(diagnosis)
    subtype = list(S['subtypes'])
    
    df = pd.DataFrame(list(zip(staging, atypical,subtype, diagnosis)),
               columns =['Stage', 'Atypicality','Subtype','Diagnosis'])
    df = df[df['Diagnosis'] != 'CN']
    df['Subtype'] = df['Subtype'].apply(lambda x: x if np.isnan(x) else subtype_map[x])
    df = df.dropna(axis=0, subset=['Subtype'])

    color_map = {subtype_labels[i]: color_list[i] for i in range(len(color_list))}

    font_title, font_axes, font_ticks, font_legend = fontsize
    

    df_plot = df[df['Diagnosis'].isin(chosen_subtypes)]
    
    fig = px.scatter(df_plot, x='Stage', y='Atypicality', color='Subtype', color_discrete_map=color_map)

    fig.update_layout(
        title="Staging ~ Atypicality",
        title_font_size=font_title,
        title_x=0.5,
        xaxis_title="Stage",
        yaxis_title="Atypicality",
        xaxis = dict(
            tickmode = 'linear',
            tick0 = 0.0,
            dtick = 2
        ),
        barmode='group',
        legend_font_size=font_legend,
        autosize = False,
        width=width,
        height=height
    )
    
    fig.update_xaxes(range=[np.min(atypical)-1.5, np.max(atypical)])
    
    fig.update_yaxes(title_font_size = font_axes, 
                    tickfont_size=font_ticks)
    
    fig.update_xaxes(title_font_size = font_axes, 
                    tickfont_size = font_ticks)

    return fig  

In [14]:
# FOR TESTING
color_list = ['#000000']
chosen_subtypes = ['MCI']

In [15]:
p = staging_scatterplot(S=S,
                       diagnosis=diagnosis,
                        chosen_subtypes=chosen_subtypes,
                        color_list=color_list,
                       width=900)

p