In [111]:
from matplotlib import pyplot as plt 
import numpy as np 
import pandas as pd 
import plotly.express as px
import plotly.io as pio
import plotly.graph_objs as go
import plotly.offline as pyo
import pickle

from collections import Counter

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

from matplotlib import rc

### Load pickle

In [112]:
# NEW FILE
read_input = open('data/Data.pickle','rb')
load_inputs = pickle.load(read_input)
read_input.close()

T, S, Sboot = load_inputs

diagnosis = np.load('data/diagnosis.npy', allow_pickle=True)

In [113]:
S['subtypes_weights']

array([[ 90.21351566, 111.08181437,  70.60416769,   0.        ],
       [ 71.65399619,  42.5054571 ,  51.82561177,  96.54300176],
       [ 55.09250334,  65.73655685,  97.09525292,  36.30677329],
       ...,
       [ 57.15528021,  60.57542443,  54.95834977,  76.12637423],
       [  0.        , 177.97538759,  33.47145564,  19.8527195 ],
       [ 52.83029462,  58.05090857,  39.02333413,  94.88530476]])

- T --> timeline (timeline object)

- S --> subjects (dict)

- Sboot --> subjects bootstrappng

## Load patients' data
## !! This won't work here, as private UMC's data was used for this function !!

In [114]:
data = pd.read_csv("data/EDADS_data.csv")
data.columns

Index(['PTID', 'Diagnosis', 'Temporal_lobe', 'Superior_frontal_gyrus',
       'Middle_frontal_gyrus', 'Inferior_frontal_gyrus', 'Gyrus_rectus',
       'Orbitofrontal_gyri', 'Precentral_gyrus', 'Postcentral_gyrus',
       'Superior_parietal_gyrus', 'Inferolateral_remainder_of_parietal_lobe',
       'Lateral_remainder_of_occipital_lobe', 'Lingual_gyrus', 'Insula',
       'Gyrus_cinguli_anterior_part', 'Gyrus_cinguli_posterior_part',
       'Parahippocampal_and_ambient_gyri', 'Thalamus', 'Caudate', 'Putamen',
       'Hippocampus', 'Amygdala', 'Accumbens-area'],
      dtype='object')

In [116]:
subtype_labels = ['Subcortical subtype', 'Frontal subtype', 'Parietal subtype','Typical subtype']

## Get prediction

In [119]:
def get_prediction(data, S, patient_id = 0, subtype_labels=None):

    if patient_id not in list(data['PTID']) or patient_id is None:
            return 'Wrong patient ID', 'No prediction'
    else:
        unique_subtypes = np.unique(S['subtypes'][~np.isnan(S['subtypes'])])
        if subtype_labels is None:
            subtype_labels = []
            for i in range(len(unique_subtypes)):
                subtype_labels.append('Subtype '+str(int(unique_subtypes[i])))

    subtype_map = {unique_subtypes[i]: subtype_labels[i] for i in range(len(subtype_labels))}
    subtypes = S['subtypes']
    subtypes = ['Outlier' if np.isnan(s) else subtype_map[s] for s in subtypes]    

    # subtypes
    patients = data['PTID']
    df = pd.DataFrame({'ID':patient_id, 'Prediction':subtypes})
    
    prediction = np.array(df['Prediction'][df['ID']==patients])[0]

    return prediction

In [120]:
patient_id = 102

pred = get_prediction(data=data,
                     S=S,
                      patient_id=patient_id,
                     subtype_labels=subtype_labels)

pred

'Parietal subtype'

## Subtype Probabilities

In [121]:
def subtype_probabilities(info, S, patient_id=0, subtype_labels = None, color=['#000000'],fontlist = [24, 18, 14, 22], width=900, height=600):
    """
    Creates a barplot for subtype probabilities
    :param info: csv with patients' data
    :param S: subtyping dictionary, subtypes for each patient individually
    :param patient_id: ID of a patient to visualize
    :param subtype_labels: list with label name for the subtypes (optional)
    :param colort: hex color value (optional)
    :param width: int (optional)
    :param height: int (optional)
    :return: plotly express bar figure
    """  
    
    if patient_id not in list(info['PTID']) or patient_id is None:
        return 'Wrong patient ID', 'No prediction'
    else:
    
        # Get subtype labels
        unique_subtypes = np.unique(S['subtypes'][~np.isnan(S['subtypes'])])
        if subtype_labels is None:
            subtype_labels = []
            for i in range(len(unique_subtypes)):
                subtype_labels.append('Subtype '+str(int(unique_subtypes[i])))                 

        subtype_map = {unique_subtypes[i]: subtype_labels[i] for i in range(len(subtype_labels))}

        subtypes = S['subtypes']
        subtypes = ['Outlier' if np.isnan(s) else subtype_map[s] for s in subtypes]    

        weights = S['subtypes_weights']

        # Create weight DataFrame
        df_weights = pd.DataFrame(weights, columns=subtype_labels)
        df_weights['Sum']=df_weights[subtype_labels].sum(axis = 1, skipna = True)
        df_weights['Prediction'] = subtypes
        
        # Count probabilities
        df_prob = pd.DataFrame()

        # TO CHANHE WHEN I GET DATA
        df_prob['Patient ID'] = info['PTID']
        for s in subtype_labels:
            df_prob[s]=round(df_weights[s]/df_weights['Sum']*100,2)
        
        data = df_prob[subtype_labels][df_prob['Patient ID']==patient_id]
                
        prediction = np.array(df_weights['Prediction'][df_prob['Patient ID']==patient_id])[0]

        df = pd.DataFrame(data.values[0], data.columns)
        df = df.rename(columns={0: "Probability"})

        fig = px.bar(df, x=df.index, y="Probability",
                text=data.values[0],
                text_auto=True,
                width=width,
                height=height)

        # Styling 
        font_title, font_axes, font_ticks, font_bars = fontlist

        fig.update_layout(
            title_text='Subtype probabilities', # title of plot
            title_x=0.5,
            title_font_size=font_title,
            xaxis_title_text='Subtype', # xaxis label
            yaxis_title_text='Probability (%)', # yaxis label
            bargap=0.2, # gap between bars of adjacent location coordinates
        )

        fig.update_traces(marker_color=color,
                        textfont_size=font_bars,
                        texttemplate='%{text} %')

        fig.update_yaxes(title_font_size = font_axes, 
                        tickfont_size=font_ticks)

        fig.update_xaxes(title_font_size = font_axes, 
                        tickfont_size = font_ticks)


        return fig, prediction

In [122]:
patient_id = 102
color = '#EB89B5'
fontlist = [24, 18, 14, 22]
width = 900
height = 600


In [123]:
p, pred = subtype_probabilities(info=data,
                                S=S,
                                  patient_id=patient_id,
                                  color=color,
                                  fontlist=fontlist,
                                  width = width,
                                  height=height,
                               subtype_labels=subtype_labels)

p

In [124]:
pred

'Parietal subtype'