In [16]:
from matplotlib import pyplot as plt 
import numpy as np 
import pandas as pd 
import plotly.express as px
import plotly.io as pio
import plotly.graph_objs as go
import plotly.offline as pyo
import pickle

from collections import Counter

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

from matplotlib import rc

### Load pickle

In [60]:
# NEW FILE
read_input = open('data/EDADS_subtype_timelines_agecorrected_opt.pickle','rb')
load_inputs = pickle.load(read_input)
read_input.close()

T, S, Sboot = load_inputs

diagnosis = np.load('data/diagnosis.npy', allow_pickle=True)

- T --> timeline (timeline object)

- S --> subjects (dict)

In [61]:
len(S['staging'])

2236

In [62]:
diagnosis

array(['AD', 'AD', 'MCI', ..., 'AD', 'SCD', 'MCI'], dtype=object)

In [63]:
set(diagnosis)

{'AD', 'CN', 'MCI', 'SCD'}

In [64]:
counter = dict(Counter(diagnosis))
counter

{'AD': 1054, 'MCI': 322, 'CN': 676, 'SCD': 184}

### Convert Nans to 0.0

In [65]:
staging = [np.float64(0.0) if np.isnan(stage) else stage for stage in staging]

### Get labels

In [66]:
labels = list(set(diagnosis))
labels

['CN', 'AD', 'SCD', 'MCI']

In [67]:
color_list = ['#4daf4a','#377eb8','#e41a1c', '#ffff00']

'#4daf4a'

## Patient staging function

In [80]:
def patient_staging(S, diagnosis, color_list=['#000000'], num_bins=10, bin_width=0.02):
    """
    Creates a barplot
    :param S: dictionary, Snowphlake output
    :param diagnosis: np.array or list; with diagnosis labels corresponding to records in S
    :param color_list: list with color hex values
    :param num_bins: int, how many bins should be displayed
    :param bin_width: int
    :return: plotly go Bar figure
    """  
    
    # Convert NaNs to 0.0
    staging = np.array([np.float64(0.0) if np.isnan(stage) else stage for stage in S['staging']])
#     staging = S['staging']
   
    # Count number of each subtype occurences
    counter = dict(Counter(diagnosis))
        
    # Get labels
    labels = list(set(diagnosis))
    
    # Get indexes
    diagnosis = np.array(diagnosis)
    staging = np.array(staging)
    
    # Get indexes for each diagnostic label
    idx_list = []
    for l in labels:
        idx = np.where(diagnosis==l)
        idx = idx[0]
        idx_list.append(idx)

    # Bar settings
    num_bins = num_bins
    bin_width = np.repeat(bin_width, num_bins)
          
    color_list = color_list
        
    count=-1    
    num_bins = num_bins
    bar_width = np.repeat(0.02, num_bins)
    counter = dict(Counter(diagnosis))

    fig = go.Figure()
    
    for count,idx in enumerate(idx_list):
        freq,binc=np.histogram(staging[idx],bins=num_bins)
        freq = (1.*freq)/len(staging)

        label = labels[count]

        fig.add_trace(go.Bar(
                    x=binc[:-1],
                    y=freq,
                    name=f'{label} (n = {counter[label]})',
                    width=bin_width,
                    marker_color=color_list[count]
            )) 

    fig.update_layout(
        title="Patient Staging",
        title_font_size=34,
        title_x=0.5,
        xaxis_title="Disease Stage",
        yaxis_title="Frequency of occurences",
        xaxis = dict(
            tickmode = 'linear',
            tick0 = 0.0,
            dtick = 0.1
        ),
        barmode='group',
        legend_font_size=16,
        legend=dict(
            yanchor="top",
            y=0.95,
            xanchor="right",
            x=0.95),
        autosize = False,
        width=1000,
        height=800
    )
    
    fig.update_xaxes(range=[-0.05, 1.0])
    
    fig.update_yaxes(title_font_size = 18, 
                    tickfont_size=14)
    
    fig.update_xaxes(title_font_size = 18, 
                    tickfont_size = 14)

    return fig

In [81]:
# FOR TESTING
# diagnosis = np.array(['Control' if np.isnan(subtype) else "FTD" for subtype in S['subtypes']])

num_bins = 10
bin_width = 0.02
color_list = ['#4daf4a','#377eb8','#e41a1c', '#ffff00']


In [82]:
fig = patient_staging(S=S,
                      diagnosis=diagnosis, 
                      color_list = color_list,
                      num_bins=10, 
                      bin_width=0.04)
fig

## Staging Boxplots function

In [89]:
def staging_boxes(S, diagnosis, color_list='#000000'):
    """
    Creates a boxplot
    :param S: dictionary, Snowphlake output
    :param diagnosis: np.array or list; with diagnosis labels corresponding to records in S
    :param color_list: list with color hex values
    :return: plotly go Box figure
    """
    
    # Convert NaNs to 0.0
    staging = np.array([np.float64(0.0) if np.isnan(stage) else stage for stage in S['staging']])
   
    # Count number of each subtype occurences
    counter = dict(Counter(diagnosis))
        
    # Get labels
    labels = list(set(diagnosis))
    
    # Get indexes
    diagnosis = np.array(diagnosis)
    staging = np.array(staging)
    
    # Get indexes for each diagnostic label
    idx_list = []
    for l in labels:
        idx = np.where(diagnosis==l)
        idx = idx[0]
        idx_list.append(idx)
        
    
    fig = go.Figure()

    for count, idx in enumerate(idx_list):
        fig.add_trace(go.Box(x=staging[idx], name=labels[count],
                             fillcolor=color_list[count],
                            line_color='#000000'))


    fig.update_xaxes(range=[-0.05, 1.0])

    # ADD BOXPOINTS
#     fig.update_traces(boxpoints='all', jitter=0)

    fig.update_layout(
            title="Staging - Boxplots",
            title_font_size=34,
            title_x=0.5,
            xaxis_title="Disease Stage",
            yaxis_title="Diagnosis",
            xaxis = dict(
                tickmode = 'linear',
                tick0 = 0.0,
                dtick = 0.1
            ),
            legend_font_size=16,
            legend=dict(
                yanchor="top",
                y=0.97,
                xanchor="right",
                x=0.97),
            autosize = False,
            width=1000,
            height=600
        )
    
    fig.update_yaxes(title_font_size = 18, 
                    tickfont_size=14)
    
    fig.update_xaxes(title_font_size = 18, 
                    tickfont_size = 14)

    return fig

In [90]:
# FOR TESTING
# diagnosis = np.array(['Control' if np.isnan(subtype) else "FTD" for subtype in S['subtypes']])
color_list = ['#4daf4a','#377eb8','#e41a1c', '#ffff00']

In [91]:
fig = staging_boxes(S=S,
                    diagnosis=diagnosis,
                    color_list=color_list)
fig

In [None]:
s

# DOUBLE FUNCTION

In [420]:
def staging(S, diagnosis, color_list=['#000000'], num_bins=10, bin_width=0.02):
    """
    Creates a barplot
    :param S: dictionary, Snowphlake output
    :param diagnosis: np.array or list with diagnosis labels corresponding to records in S
    :param color_list: list with color hex values (optional)
    :param num_bins: int, how many bins should be displayed (optional, defaults to 10)
    :param bin_width: int, desired width of the bars on the plot (optional, defaults to 0.02)
    :return: plotly go Bar figure
    """   
    
    # Convert NaNs to 0.0
    staging = np.array([np.float64(0.0) if np.isnan(stage) else stage for stage in S['staging']])
   
    # Count number of each subtype occurences
    counter = dict(Counter(diagnosis))
        
    # Get labels
    labels = list(set(diagnosis))
#     labels = labels[::-1]
    
    # Convers lists to np.arrays
    diagnosis = np.array(diagnosis)
    staging = np.array(staging)
    
    # Get indexes for each diagnostic label
    idx_list = []
    for l in labels:
        idx = np.where(diagnosis==l)
        idx = idx[0]
        idx_list.append(idx)

    # Assign bar settings
    bin_width = np.repeat(bin_width, num_bins)        
    if color_list == ['#000000']:
        color_list = ['#4daf4a','#377eb8','#e41a1c', '#ffff00']
            
    # Create subplots
    fig = make_subplots(rows=2, cols=1)
    
    for count, idx in enumerate(idx_list):
                freq,binc=np.histogram(staging[idx],bins=num_bins)
                freq = (1.*freq)/len(staging)
                label = labels[count]

                fig.add_trace(go.Bar(
                            x=binc,
                            y=freq,
                            name=f'{label} (n = {counter[label]})',
                            width=bin_width,
                            marker_color=color_list[count]
                ),row=1, col=1) 

    fig.update_layout(
        title="Patient Staging",
        title_font_size=34,
        title_x=0.5,
        xaxis_title="Disease Stage",
        yaxis_title="Frequency of occurences",
        xaxis = dict(
            tickmode = 'linear',
            tick0 = 0.0,
            dtick = 0.1
        ),
        barmode='group',
        legend_font_size=16,
        autosize = False,
        width=1000,
        height=1200
    )
    
    # ADD BOXES BELOW
        
    for count, idx in enumerate(idx_list):
        fig.add_trace(go.Box(x=staging[idx], 
                             name=labels[count],
                             fillcolor=color_list[count],
                            line_color='#000000',
                            showlegend=False),
                      row=2,col=1)    
        
    #STYLE THE PLOT
    fig.update_layout(hovermode="closest")
    
    fig.update_xaxes(range=[-0.05, 1.0],row=1, col=1)
    
    fig.update_yaxes(title_font_size = 18, 
                    tickfont_size=14)
    
    fig.update_xaxes(title_font_size = 18, 
                    tickfont_size = 14)

    return fig

In [421]:
fig = staging(S=S,
                      diagnosis=diagnosis, 
                      color_list = color_list,
#                       num_bins=10, 
                      bin_width=0.04)
fig