In [1]:
import numpy as np
import pandas as pd
import torch
import re
pd.set_option('display.max_rows', 200)

In [65]:
class Const:
    data_dir = '../data'
    twin_data = data_dir + 'digital_twin_data.csv'
    twin_ln_data = data_dir + 'digital_twin_ln_data.csv'
    
    rename_dict = {
        'Dummy ID': 'id',
        'Age at Diagnosis (Calculated)': 'age',
        'Feeding tube 6m': 'FT',
        'Affected Lymph node UPPER': 'affected_nodes',
        'Aspiration rate(Y/N)': 'AS',
        'Neck boost (Y/N)': 'neck_boost',
        'Gender': 'gender',
        'Tm Laterality (R/L)': 'laterality',
        'AJCC 8th edition': 'ajcc8',
        'AJCC 7th edition':'ajcc7',
        'N_category_full': 'N-category',
        'HPV/P16 status': 'hpv',
        'Tumor subsite (BOT/Tonsil/Soft Palate/Pharyngeal wall/GPS/NOS)': 'subsite',
        'Total dose': 'total_dose',
        'Therapeutic combination': 'treatment',
        'Smoking status at Diagnosis (Never/Former/Current)': 'smoking_status',
        'Smoking status (Packs/Year)': 'packs_per_year',
        'Overall Survival (1=alive,0=dead)': 'os',
        'Dose/fraction (Gy)': 'dose_fraction'
    }
    
    dlt_dict = {
         'Allergic reaction to Cetuximab': 'DLT_Other',
         'Cardiological (A-fib)': 'DLT_Other',
         'Dermatological': 'DLT_Dermatological',
         'Failure to Thrive': 'DLT_Other',
         'Failure to thrive': 'DLT_Other',
         'GIT [elevated liver enzymes]': 'DLT_Gastrointestinal',
         'Gastrointestina': 'DLT_Gastrointestinal',
         'Gastrointestinal': 'DLT_Gastrointestinal',
         'General': 'DLT_Other',
         'Hematological': 'DLT_Hematological',
         'Hematological (Neutropenia)': 'DLT_Hematological',
         'Hyponatremia': 'DLT_Other',
         'Immunological': 'DLT_Other',
         'Infection': 'DLT_Infection (Pneumonia)',
         'NOS': 'DLT_Other',
         'Nephrological': 'DLT_Nephrological',
         'Nephrological (ARF)': 'DLT_Nephrological',
         'Neurological': 'DLT_Neurological',
         'Neutropenia': 'DLT_Hematological',
         'Nutritional': 'DLT_Other',
         'Pancreatitis': 'DLT_Other',
         'Pulmonary': 'DLT_Other',
         'Respiratory (Pneumonia)': 'DLT_Infection (Pneumonia)',
         'Sepsis': 'DLT_Infection (Pneumonia)',
         'Suboptimal response to treatment' : 'DLT_Other',
         'Vascular': 'DLT_Vascular'
    }
    
    decision1 = 'Decision 1 (Induction Chemo) Y/N'
    decision2 = 'Decision 2 (CC / RT alone)'
    decision3 = 'Decision 3 Neck Dissection (Y/N)'
    decisions = [decision1,decision2, decision3]
    outcomes = ['Overall Survival (4 Years)', 'FT', 'Aspiration rate Post-therapy']
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
    primary_disease_states = ['CR Primary','PR Primary','SD Primary']
    nodal_disease_states = [t.replace('Primary','Nodal') for t in primary_disease_states]
    dlt1 = list(set(dlt_dict.values()))
    
    modifications =  list(modification_types.values())
    state2 = modifications + primary_disease_states+nodal_disease_states +dlt1 #+['No imaging 0=N,1=Y']
    
    primary_disease_states2 = [t + ' 2' for t in primary_disease_states]
    nodal_disease_states2 = [t + ' 2' for t in nodal_disease_states]
    dlt2 = [d + ' 2' for d in dlt1]
    
    ccs = list(cc_types.values())
    state3 = ccs + primary_disease_states2 + nodal_disease_states2 + dlt2
    
Const.modification_types

{0: 'no_dose_adjustment',
 1: 'dose_modified',
 2: 'dose_delayed',
 3: 'dose_cancelled',
 4: 'dose_delayed_&_modified',
 5: 'regiment_modification',
 9: 'unknown'}

In [77]:
def preprocess(data_cleaned):
    #this was Elisa's preprocessing except I removed all the Ifs because that's dumb
    if len(data_cleaned.shape) < 2:
        data_cleaned = pd.DataFrame([data], columns=data.index)
        
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'N', 'Aspiration rate Pre-therapy'] = 0
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'Y', 'Aspiration rate Pre-therapy'] = 1

    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'I', 'Pathological Grade'] = 1
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'II', 'Pathological Grade'] = 2
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'III', 'Pathological Grade'] = 3
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'IV', 'Pathological Grade'] = 4

    data_cleaned.loc[(data_cleaned['T-category'] == 'Tx') | (data_cleaned['T-category'] == 'Tis'), 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T1', 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T2', 'T-category'] = 2
    data_cleaned.loc[data_cleaned['T-category'] == 'T3', 'T-category'] = 3
    data_cleaned.loc[data_cleaned['T-category'] == 'T4', 'T-category'] = 4

    data_cleaned.loc[data_cleaned['N-category'] == 'N0', 'N-category'] = 0
    data_cleaned.loc[data_cleaned['N-category'] == 'N1', 'N-category'] = 1
    data_cleaned.loc[data_cleaned['N-category'] == 'N2', 'N-category'] = 2
    data_cleaned.loc[data_cleaned['N-category'] == 'N3', 'N-category'] = 3

    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N0', 'N-category_8th_edition'] = 0
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N1', 'N-category_8th_edition'] = 1
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N2', 'N-category_8th_edition'] = 2
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N3', 'N-category_8th_edition'] = 3

    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'I', 'AJCC 7th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'II', 'AJCC 7th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'III', 'AJCC 7th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'IV', 'AJCC 7th edition'] = 4

    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'I', 'AJCC 8th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'II', 'AJCC 8th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'III', 'AJCC 8th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'IV', 'AJCC 8th edition'] = 4

    data_cleaned.loc[data_cleaned['Gender'] == 'Male', 'Gender'] = 1
    data_cleaned.loc[data_cleaned['Gender'] == 'Female', 'Gender'] = 0


    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Positive', 'HPV/P16 status'] = 1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Negative', 'HPV/P16 status'] = -1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Unknown', 'HPV/P16 status'] = 0

    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Formar', 'Smoking status at Diagnosis (Never/Former/Current)'] = .5
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Current', 'Smoking status at Diagnosis (Never/Former/Current)'] = 1
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Never', 'Smoking status at Diagnosis (Never/Former/Current)'] = 0


    data_cleaned.loc[data_cleaned['Chemo Modification (Y/N)'] == 'Y', 'Chemo Modification (Y/N)'] = 1

    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'N', 'DLT (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'Y', 'DLT (Y/N)'] = 1

    data_cleaned['DLT_Other'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT_Type'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT_Type']):
            if i.strip() != '' and data_cleaned.loc[index, Const.dlt_dict[i.strip()]] == 0:
                data_cleaned.loc[index, Const.dlt_dict[i.strip()]] = 1

    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'RT alone', 'Decision 2 (CC / RT alone)'] = 0
    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'CC', 'Decision 2 (CC / RT alone)'] = 1

    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'N', 'CC modification (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'Y', 'CC modification (Y/N)'] = 1

    data_cleaned['DLT_Dermatological 2'] = 0
    data_cleaned['DLT_Neurological 2'] = 0
    data_cleaned['DLT_Gastrointestinal 2'] = 0
    data_cleaned['DLT_Hematological 2'] = 0
    data_cleaned['DLT_Nephrological 2'] = 0
    data_cleaned['DLT_Vascular 2'] = 0
    data_cleaned['DLT_Infection (Pneumonia) 2'] = 0
    data_cleaned['DLT_Other 2'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT 2'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT 2']):
            if i.strip() != '':
                data_cleaned.loc[index, Const.dlt_dict[i.strip()] + ' 2'] = 1

    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'N', 'Decision 3 Neck Dissection (Y/N)'] = 0
    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'Y', 'Decision 3 Neck Dissection (Y/N)'] = 1

    return data_cleaned

def merge_editions(row,basecol='AJCC 8th edition',fallback='AJCC 7th edition'):
    if pd.isnull(row[basecol]):
        return row[fallback]
    return row[basecol]


def preprocess_dt_data(df,extra_to_keep=None):
    
    to_keep = ['id','hpv','age','packs_per_year','smoking_status','gender','Aspiration rate Pre-therapy','total_dose','dose_fraction'] 
    to_onehot = ['T-category','N-category','AJCC','Pathological Grade','subsite','treatment','ln_cluster']
    if extra_to_keep is not None:
        to_keep = to_keep + [c for c in extra_to_keep if c not in to_keep and c not in to_onehot]
    
    decisions =Const.decisions
    outcomes = Const.outcomes
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
    for k,v in Const.cc_types.items():
        df[v] = df['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)'].apply(lambda x: int(Const.cc_types.get(int(x),0) == v))
        to_keep.append(v)
    for k,v in Const.modification_types.items():
        name = 'Modification Type (0= no dose adjustment, 1=dose modified, 2=dose delayed, 3=dose cancelled, 4=dose delayed & modified, 5=regimen modification, 9=unknown)'
        df[v] = df[name].apply(lambda x: int(Const.modification_types.get(int(x),0) == v))
        to_keep.append(v)
    #Features to keep. I think gender is is in 
    
    keywords = []
    for keyword in keywords:
        toadd = [c for c in df.columns if keyword in c and c not in to_keep]
        to_keep = to_keep + toadd
    
    df['packs_per_year'] = df['packs_per_year'].apply(lambda x: str(x).replace('>','').replace('<','')).astype(float).fillna(0)
    #so I'm actually not sure if this is biological sex or gender given this is texas
    df['AJCC'] = df.apply(lambda row: merge_editions(row,'ajcc8','ajcc7'),axis=1)
    df['N-category'] = df.apply(lambda row: merge_editions(row,'N-category_8th_edition','N-category'),axis=1)
    
    dummy_df = pd.get_dummies(df[to_onehot].fillna(0).astype(str),drop_first=False)
    for col in dummy_df.columns:
        df[col] = dummy_df[col]
        to_keep.append(col)
        
    yn_to_binary = ['FT','Aspiration rate Post-therapy','Decision 1 (Induction Chemo) Y/N']
    for col in yn_to_binary:
        df[col] = df[col].apply(lambda x: int(x == 'Y'))
        
    to_keep = to_keep + [c for c in df.columns if 'DLT' in c]
    
        
    for statelist in [Const.state2,Const.state3,Const.decisions,Const.outcomes]:
        toadd = [c for c in statelist if c not in to_keep]
        to_keep = to_keep + toadd
    return df[to_keep].set_index('id')



In [198]:
def load_digital_twin(file='../data/digital_twin_data.csv'):
    df = pd.read_csv(file)
    return df.rename(columns = Const.rename_dict)

def get_dt_ids():
    df = load_digital_twin()
    return df.id.values
get_dt_ids()

  df = pd.read_csv(file)


array([    3,     5,     6,     7,     8,     9,    10,    11,    13,
          14,    15,    16,    17,    18,    21,    23,    24,    25,
          26,    27,    28,    31,    32,    33,    35,    36,    37,
          38,    39,    40,    41,    42,    44,    45,    47,    48,
          49,    50,    51,    53,    55,    56,    57,    60,    64,
          65,    67,    68,    69,    71,    74,    75,    77,    78,
          79,    80,    81,    82,    87,    88,    91,    94,    96,
          99,   103,   109,   116,   117,   119,   120,   121,   125,
         133,   148,   150,   153,   168,   178,   181,   183,   184,
         185,   186,   187,   188,   189,   190,   191,   192,   193,
         194,   195,   196,   197,   198,   199,   200,   201,   202,
         203,   204,   205,   206,   207,   208,   209,   210,   211,
         212,   213,   214,   215,   216,   217,   218,   219,   220,
         221,   222,   223,   224,   225,   226,   227,   228,   229,
         230,   231,

In [160]:
class DTDataset():
    
    def __init__(self,data_file = '../data/digital_twin_data.csv',ln_data_file = '../data/digital_twin_ln_data.csv',ids=None):
        df = pd.read_csv(data_file)
        
        df = preprocess(df)
        df = df.rename(columns = Const.rename_dict).copy()
        df = df.drop('MRN OPC',axis=1)

        ln_data = pd.read_csv(ln_data_file)
        ln_data = ln_data.rename(columns={'cluster':'ln_cluster'})
        self.ln_cols = [c for c in ln_data.columns if c not in df.columns]
        df = df.merge(ln_data,on='id')
        df.index = df.index.astype(int)
        if ids is not None:
            df = df[df.id.apply(lambda x: x in ids)]
        self.processed_df = preprocess_dt_data(df,self.ln_cols).fillna(0)
        
        self.means = self.processed_df.mean(axis=0)
        self.stds = self.processed_df.std(axis=0)
        self.maxes = self.processed_df.max(axis=0)
        self.mins = self.processed_df.min(axis=0)
        
        arrays = self.get_states()
        self.state_sizes = {k: (v.shape[1] if v.ndim > 1 else 1) for k,v in arrays.items()}
        
    def get_data(self):
        return self.processed_df
    
    def sample(self,frac=.5):
        return self.processed_df.sample(frac=frac)
    
    def split_sample(self,ratio = .3):
        assert(ratio > 0 and ratio <= 1)
        df1 = self.processed_df.sample(frac=1-ratio)
        df2 = self.processed_df.drop(index=df1.index)
        return df1,df2
    
    def get_states(self,fixed=None,ids = None):
        processed_df = self.processed_df.copy()
        if ids is not None:
            processed_df = processed_df.loc[ids]
        if fixed is not None:
            for col,val in fixed.items():
                if col in processed_df.columns:
                    processed_df[col] = val
                else:
                    print('bad fixed entry',col)
                    
        to_skip = ['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)','DLT_Type','DLT 2'] + [c for c in processed_df.columns if 'treatment' in c]
        other_states = set(Const.decisions + Const.state3 + Const.state2 + Const.outcomes  + to_skip)

        base_state = sorted([c for c in processed_df.columns if c not in other_states])

        dlt1 = Const.dlt1
        dlt2 = Const.dlt2
        
        modifications = Const.modifications
        ccs = Const.ccs
        pds = Const.primary_disease_states
        nds = Const.nodal_disease_states
        pds2 = Const.primary_disease_states2
        nds2 = Const.nodal_disease_states2
        outcomes = Const.outcomes
        decisions= Const.decisions
        
        #intermediate states are only udated values. Models should use baseline + state2 etc
        results = {
            'baseline': processed_df[base_state],
            'pd_states1': processed_df[pds],
            'nd_states1': processed_df[nds],
            'modifications': processed_df[modifications],
            'ccs': processed_df[ccs],
            'pd_states2': processed_df[pds2],
            'nd_states2': processed_df[nds2],
            'outcomes': processed_df[outcomes],
            'dlt1': processed_df[dlt1],
            'dlt2': processed_df[dlt2],
            'decision1': processed_df[decisions[0]],
            'decision2': processed_df[decisions[1]],
            'decision3': processed_df[decisions[2]],
        }
    
        return results
    
    def get_state(self,name,**kwargs):
        return self.get_states(**kwargs)[name]
    
    def normalize(self,df):
        means = self.means[df.columns]
        std = self.stds[df.columns]
        return ((df - means)/std).fillna(0)
    
    def get_intermediate_outcomes(self,step=1,**kwargs):
        assert(step in [1,2])
        states = self.get_states(**kwargs)
        if step == 1:
            keys = ['pd_states1','nd_states1','modifications','dlt1']
        else:
            keys =  ['pd_states2','nd_states2','ccs','dlt2']
        return [states[key] for key in keys]
    
    def get_input_state(self,step=1,**kwargs):
        assert(step in [1,2,3])
        states = self.get_states(**kwargs)
        if step == 1:
            keys = ['baseline']
        if step == 2:
            keys =  ['baseline','pd_states1','nd_states1','modifications','dlt1']
        if step == 3:
            keys = ['baseline','pd_states2','nd_states2','ccs','dlt2']
        arrays = [states[key].values for key in keys]
        return np.concatenate(arrays,axis=1)
    
data = DTDataset()
data.get_states()['modifications']

  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


Unnamed: 0_level_0,no_dose_adjustment,dose_modified,dose_delayed,dose_cancelled,dose_delayed_&_modified,regiment_modification,unknown
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
3,1,0,0,0,0,0,0
5,1,0,0,0,0,0,0
6,1,0,0,0,0,0,0
7,1,0,0,0,0,0,0
8,1,0,0,0,0,0,0
...,...,...,...,...,...,...,...
10201,1,0,0,0,0,0,0
10202,1,0,0,0,0,0,0
10203,1,0,0,0,0,0,0
10204,1,0,0,0,0,0,0


In [146]:
data.get_states()['dlt1']

Unnamed: 0_level_0,DLT_Gastrointestinal,DLT_Other,DLT_Neurological,DLT_Dermatological,DLT_Infection (Pneumonia),DLT_Vascular,DLT_Nephrological,DLT_Hematological
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
3,0,0,0,0,0,0,0,0
5,0,0,0,0,0,0,0,0
6,0,0,0,0,0,0,0,0
7,0,0,0,0,0,0,0,0
8,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...
10201,0,0,0,0,0,0,0,0
10202,0,0,0,0,0,0,0,0
10203,0,0,0,0,0,0,0,0
10204,0,0,0,0,0,0,0,0


In [105]:
def df_to_torch(df,ttype  = torch.FloatTensor):
    values = df.values.astype(float)
    values = torch.from_numpy(values)
    return values.type(ttype)

In [390]:
class OutcomeSimulator(torch.nn.Module):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [1000],
                 dropout = 0.5,
                 input_dropout=0.1,
                 state = 1,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        torch.nn.Module.__init__(self)
        self.state = state
        self.input_dropout = torch.nn.Dropout(input_dropout)
        
        first_layer =torch.nn.Linear(input_size,hidden_layers[0],bias=True)
        layers = [first_layer,torch.nn.ReLU()]
        curr_size = hidden_layers[0]
        for ndim in hidden_layers[1:]:
            layer = torch.nn.Linear(curr_size,ndim)
            curr_size = ndim
            layers.append(layer)
            layers.append(torch.nn.ReLU())
        self.layers = torch.nn.ModuleList(layers)
        self.batchnorm = torch.nn.BatchNorm1d(hidden_layers[-1])
        self.dropout = torch.nn.Dropout(dropout)
        
        self.disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.primary_disease_states))
        self.nodal_disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.nodal_disease_states))
        #dlt ratings are 0-4 even though they don't always appear
        
        assert( state in [1,2])
        if state == 1:
            self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],5) for i in Const.dlt1])
            self.treatment_layer = torch.nn.Linear(hidden_layers[-1],len(Const.modifications))
        else:
            #we only have dlt yes or no for the second state?
            self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],2) for i in Const.dlt2])
            self.treatment_layer = torch.nn.Linear(hidden_layers[-1],len(Const.ccs))
        self.softmax = torch.nn.LogSoftmax(dim=1)
        
    def forward(self,x):
        x = self.input_dropout(x)
        for layer in self.layers:
            x = layer(x)
#         x = self.batchnorm(x)
        x = self.dropout(x)
        x_pd = self.disease_layer(x)
        x_nd = self.nodal_disease_layer(x)
        x_mod = self.treatment_layer(x)
        x_dlts = [layer(x) for layer in self.dlt_layers]
        
        x_pd = self.softmax(x_pd)
        x_nd = self.softmax(x_nd)
        x_mod = self.softmax(x_mod)
        #dlts are array of nbatch x n_dlts x predictions
        x_dlts = torch.stack([self.softmax(xx) for xx in x_dlts],axis=1)
        return [x_pd, x_nd, x_mod, x_dlts]
    


In [391]:
def nllloss(ytrue,ypred):
    #nll loss with argmax added in
    loss = torch.nn.NLLLoss()
    return loss(ypred,ytrue.argmax(axis=1))

def state_loss(ytrue,ypred,weights=[1,1,1,1]):
    pd_loss = nllloss(ytrue[0],ypred[0])*weights[0]
    nd_loss = nllloss(ytrue[1],ypred[1])*weights[1]
    mod_loss = nllloss(ytrue[2],ypred[2])*weights[2]
    loss = pd_loss + nd_loss + mod_loss
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
    nloss = torch.nn.NLLLoss()
    for i in range(ndlt):
        dlt_loss = nloss(dlt_pred[:,i,:],dlt_true[:,i].type(torch.LongTensor))
        loss += dlt_loss*weights[3]/ndlt
    return loss

from sklearn.metrics import balanced_accuracy_score, roc_auc_score

def mc_metrics(yt,yp):
    yt = yt .cpu().detach().numpy()
    yp = yp.cpu().detach().numpy()
    #this is a catch for when I se the dlt prediction format (encoded integer ordinal, predict as a categorical and take the argmax)
    if yt.ndim > 1:
        bacc = balanced_accuracy_score(yt.argmax(axis=1),yp.argmax(axis=1))
        roc_micro = roc_auc_score(yt,yp,average='micro')
        roc_macro = roc_auc_score(yt,yp,average='macro')
        return {'accuracy': bacc, 'roc_micro': roc_micro,'roc_macro': roc_macro}
    else:
        yp = yp.argmax(axis=1)
        bacc = balanced_accuracy_score(yt,yp)
        roc = roc_auc_score(yt > 0,yp > 0)
        error = np.mean((yt-yp)**2)
        return {'accuracy': bacc, 'mse': error, 'auc': roc}

def state_metrics(ytrue,ypred):
    pd_metrics = mc_metrics(ytrue[0],ypred[0])
    nd_metrics = mc_metrics(ytrue[1],ypred[1])
    mod_metrics = mc_metrics(ytrue[1],ypred[1])
    
    dlt_metrics = []
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
    nloss = torch.nn.NLLLoss()
    for i in range(ndlt):
        dm = mc_metrics(dlt_true[:,i],dlt_pred[:,i,:])
        dlt_metrics.append(dm)
    dlt_acc =[d['accuracy'] for d in dlt_metrics]
    dlt_error = [d['mse'] for d in dlt_metrics]
    dlt_auc = [d['auc'] for d in dlt_metrics]
    return {'pd': pd_metrics,'nd': nd_metrics,'mod': mod_metrics,'dlts': {'accuracy': dlt_acc,'accuracy_mean': np.mean(dlt_acc),'auc': dlt_auc,'auc_mean': np.mean(dlt_auc)}}
    
ytrue = [df_to_torch(temp) for temp in data.get_intermediate_outcomes()]
ypred = outcomes1_model(x)
state_metrics(ytrue,ypred)



{'pd': {'accuracy': 0.4825127057996726,
  'roc_micro': 0.5565003663898388,
  'roc_macro': 0.5933534338186263},
 'nd': {'accuracy': 0.33612843612843607,
  'roc_micro': 0.4755714131590709,
  'roc_macro': 0.5566864156955712},
 'mod': {'accuracy': 0.33612843612843607,
  'roc_micro': 0.4755714131590709,
  'roc_macro': 0.5566864156955712},
 'dlts': {'accuracy': [0.34924973965729433,
   0.19270503900484925,
   0.16143156291642868,
   0.14562747035573123,
   0.09380863039399624,
   0.35205992509363293,
   0.04690431519699812,
   0.07082766611253842],
  'accuracy_mean': 0.17657679359143363,
  'auc': [0.5134345442496536,
   0.5815939278937381,
   0.5441516412390199,
   0.5370882740447958,
   0.42714196372732954,
   0.5280898876404494,
   0.3802376485303314,
   0.5307429236702321],
  'auc_mean': 0.5053101013744437}}

In [None]:
from sklearn.ensemble import RandomForestClassifier

def train_state_rf(model_args={}):
    ids = get_dt_ids()
    
    dataset = DTDataset()

    train_ids = ids[0:int(len(ids)*.7)]
    test_ids = ids[int(len(ids)*.7):]
    
    xtrain = dataset.get_state('baseline',ids=train_ids)
    xtest = dataset.get_state('baseline',ids=test_ids)
    ytrain = dataset.get_intermediate_outcomes(ids=train_ids)
    ytest = dataset.get_intermediate_outcomes(ids=test_ids)
    
    model = RandomForestClassifier(**model_args)
    
    model = model.fit(xtrain)
    return model.predict()

In [396]:
def train_state1(model_args={},lr=.01,epochs=10,patience=1000,weights=[1,1,1,10]):
    
    ids = get_dt_ids()
    
    dataset = DTDataset()
    
    
    model = OutcomeSimulator(dataset.state_sizes['baseline'],state=1,**model_args)
    
    train_ids = ids[0:int(len(ids)*.7)]
    test_ids = ids[int(len(ids)*.7):]
    
    
    xtrain = dataset.get_state('baseline',ids=train_ids)
    xtest = dataset.get_state('baseline',ids=test_ids)
    ytrain = dataset.get_intermediate_outcomes(ids=train_ids)
    ytest = dataset.get_intermediate_outcomes(ids=test_ids)

    xtrain = df_to_torch(xtrain)
    xtest = df_to_torch(xtest)
    ytrain = [df_to_torch(t) for t in ytrain]
    ytest= [df_to_torch(t) for t in ytest]
    
    normalize = lambda x: (x - xtrain.mean(axis=0)+.01)/(xtrain.std(axis=0)+.01)
    unnormalize = lambda x: (x * (xtrain.std(axis=0) +.01)) + xtrain.mean(axis=0) - .01
    
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    best_val_loss = 1000000000000000000000000000
    best_loss_metrics = {}
    for epoch in range(epochs):
        
        model.train(True)
        optimizer.zero_grad()
        
        xtrain_sample = xtrain#[torch.randint(len(xtrain),(len(xtrain),) )]
        ypred = model(normalize(xtrain_sample))
        loss = state_loss(ytrain,ypred,weights=weights)
        loss.backward()
        optimizer.step()
        print('epoch',epoch,'train loss',loss.item())
        
        model.eval()
        yval = model(normalize(xtest))
        val_loss = state_loss(ytest,yval,weights=weights)
        val_metrics = state_metrics(ytest,yval)
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            best_loss_metrics = val_metrics
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
        print('val loss',val_loss.item())
        print('______________')
        if steps_since_improvement > patience:
            break
    print('best loss',best_val_loss,best_loss_metrics)
    return model.eval()
model = train_state1()

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


epoch 0 train loss 20.273086547851562
val loss 4.8459930419921875
______________
epoch 1 train loss 4.562682628631592
val loss 5.105750560760498
______________
epoch 2 train loss 4.529652118682861
val loss 4.746618270874023
______________
epoch 3 train loss 4.034062385559082
val loss 4.432910442352295
______________
epoch 4 train loss 3.5463716983795166
val loss 4.185063362121582
______________
epoch 5 train loss 3.1654531955718994
val loss 3.6867668628692627
______________
epoch 6 train loss 2.677767515182495
val loss 3.1905763149261475
______________
epoch 7 train loss 2.332118034362793
val loss 2.8737423419952393
______________
epoch 8 train loss 2.1158735752105713
val loss 2.7195701599121094
______________
epoch 9 train loss 2.135974407196045
val loss 2.6624999046325684
______________
best loss 2.6624999046325684 {'pd': {'accuracy': 0.4856220657276995, 'roc_micro': 0.712682379349046, 'roc_macro': 0.6423000688223298}, 'nd': {'accuracy': 0.4897025171624714, 'roc_micro': 0.68268151601

In [401]:
from captum.attr import IntegratedGradients

ig = IntegratedGradients(model)
test = df_to_torch(DTDataset().get_state('baseline'))
ig.attribute(test,torch.zeros(test.shape),target=0)

  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


AttributeError: 'list' object has no attribute 'shape'