In [1]:
import numpy as np
import pandas as pd
import torch
import re
pd.set_option('display.max_rows', 200)

In [49]:
class Const:
    data_dir = '../data'
    twin_data = data_dir + 'digital_twin_data.csv'
    twin_ln_data = data_dir + 'digital_twin_ln_data.csv'
    
    rename_dict = {
        'Dummy ID': 'id',
        'Age at Diagnosis (Calculated)': 'age',
        'Feeding tube 6m': 'FT',
        'Affected Lymph node UPPER': 'affected_nodes',
        'Aspiration rate(Y/N)': 'AS',
        'Neck boost (Y/N)': 'neck_boost',
        'Gender': 'gender',
        'Tm Laterality (R/L)': 'laterality',
        'AJCC 8th edition': 'ajcc8',
        'AJCC 7th edition':'ajcc7',
        'N_category_full': 'N-category',
        'HPV/P16 status': 'hpv',
        'Tumor subsite (BOT/Tonsil/Soft Palate/Pharyngeal wall/GPS/NOS)': 'subsite',
        'Total dose': 'total_dose',
        'Therapeutic combination': 'treatment',
        'Smoking status at Diagnosis (Never/Former/Current)': 'smoking_status',
        'Smoking status (Packs/Year)': 'packs_per_year',
        'Overall Survival (1=alive,0=dead)': 'os',
        'Dose/fraction (Gy)': 'dose_fraction'
    }
    
    dlt_dict = {
         'Allergic reaction to Cetuximab': 'DLT_Other',
         'Cardiological (A-fib)': 'DLT_Other',
         'Dermatological': 'DLT_Dermatological',
         'Failure to Thrive': 'DLT_Other',
         'Failure to thrive': 'DLT_Other',
         'GIT [elevated liver enzymes]': 'DLT_Gastrointestinal',
         'Gastrointestina': 'DLT_Gastrointestinal',
         'Gastrointestinal': 'DLT_Gastrointestinal',
         'General': 'DLT_Other',
         'Hematological': 'DLT_Hematological',
         'Hematological (Neutropenia)': 'DLT_Hematological',
         'Hyponatremia': 'DLT_Other',
         'Immunological': 'DLT_Other',
         'Infection': 'DLT_Infection (Pneumonia)',
         'NOS': 'DLT_Other',
         'Nephrological': 'DLT_Nephrological',
         'Nephrological (ARF)': 'DLT_Nephrological',
         'Neurological': 'DLT_Neurological',
         'Neutropenia': 'DLT_Hematological',
         'Nutritional': 'DLT_Other',
         'Pancreatitis': 'DLT_Other',
         'Pulmonary': 'DLT_Other',
         'Respiratory (Pneumonia)': 'DLT_Infection (Pneumonia)',
         'Sepsis': 'DLT_Infection (Pneumonia)',
         'Suboptimal response to treatment' : 'DLT_Other',
         'Vascular': 'DLT_Vascular'
    }
    
    decision1 = 'Decision 1 (Induction Chemo) Y/N'
    decision2 = 'Decision 2 (CC / RT alone)'
    decision3 = 'Decision 3 Neck Dissection (Y/N)'
    decisions = [decision1,decision2, decision3]
    outcomes = ['Overall Survival (4 Years)', 'FT', 'Aspiration rate Post-therapy']
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
    primary_disease_states = ['CR Primary','PR Primary','SD Primary']
    nodal_disease_states = [t.replace('Primary','Nodal') for t in primary_disease_states]
    dlt1 = list(set(dlt_dict.values()))
    state2 = list(modification_types.values()) + primary_disease_states+nodal_disease_states +dlt1 #+['No imaging 0=N,1=Y']
    
    primary_disease_states2 = [t + ' 2' for t in primary_disease_states]
    nodal_disease_states2 = [t + ' 2' for t in nodal_disease_states]
    dlt2 = [d + ' 2' for d in dlt1]
    
    state3 = list(cc_types.values()) + primary_disease_states2 + nodal_disease_states2 + dlt2
    
Const.modification_types

{0: 'no_dose_adjustment',
 1: 'dose_modified',
 2: 'dose_delayed',
 3: 'dose_cancelled',
 4: 'dose_delayed_&_modified',
 5: 'regiment_modification',
 9: 'unknown'}

In [57]:
def preprocess(data_cleaned):
    #this was Elisa's preprocessing except I removed all the Ifs because that's dumb
    if len(data_cleaned.shape) < 2:
        data_cleaned = pd.DataFrame([data], columns=data.index)
        
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'N', 'Aspiration rate Pre-therapy'] = 0
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'Y', 'Aspiration rate Pre-therapy'] = 1

    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'I', 'Pathological Grade'] = 1
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'II', 'Pathological Grade'] = 2
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'III', 'Pathological Grade'] = 3
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'IV', 'Pathological Grade'] = 4

    data_cleaned.loc[(data_cleaned['T-category'] == 'Tx') | (data_cleaned['T-category'] == 'Tis'), 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T1', 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T2', 'T-category'] = 2
    data_cleaned.loc[data_cleaned['T-category'] == 'T3', 'T-category'] = 3
    data_cleaned.loc[data_cleaned['T-category'] == 'T4', 'T-category'] = 4

    data_cleaned.loc[data_cleaned['N-category'] == 'N0', 'N-category'] = 0
    data_cleaned.loc[data_cleaned['N-category'] == 'N1', 'N-category'] = 1
    data_cleaned.loc[data_cleaned['N-category'] == 'N2', 'N-category'] = 2
    data_cleaned.loc[data_cleaned['N-category'] == 'N3', 'N-category'] = 3

    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N0', 'N-category_8th_edition'] = 0
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N1', 'N-category_8th_edition'] = 1
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N2', 'N-category_8th_edition'] = 2
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N3', 'N-category_8th_edition'] = 3

    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'I', 'AJCC 7th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'II', 'AJCC 7th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'III', 'AJCC 7th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'IV', 'AJCC 7th edition'] = 4

    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'I', 'AJCC 8th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'II', 'AJCC 8th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'III', 'AJCC 8th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'IV', 'AJCC 8th edition'] = 4

    data_cleaned.loc[data_cleaned['Gender'] == 'Male', 'Gender'] = 1
    data_cleaned.loc[data_cleaned['Gender'] == 'Female', 'Gender'] = 0


    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Positive', 'HPV/P16 status'] = 1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Negative', 'HPV/P16 status'] = -1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Unknown', 'HPV/P16 status'] = 0

    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Formar', 'Smoking status at Diagnosis (Never/Former/Current)'] = .5
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Current', 'Smoking status at Diagnosis (Never/Former/Current)'] = 1
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Never', 'Smoking status at Diagnosis (Never/Former/Current)'] = 0


    data_cleaned.loc[data_cleaned['Chemo Modification (Y/N)'] == 'Y', 'Chemo Modification (Y/N)'] = 1

    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'N', 'DLT (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'Y', 'DLT (Y/N)'] = 1

    data_cleaned['DLT_Other'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT_Type'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT_Type']):
            if i.strip() != '' and data_cleaned.loc[index, Const.dlt_dict[i.strip()]] == 0:
                data_cleaned.loc[index, Const.dlt_dict[i.strip()]] = 1

    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'RT alone', 'Decision 2 (CC / RT alone)'] = 0
    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'CC', 'Decision 2 (CC / RT alone)'] = 1

    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'N', 'CC modification (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'Y', 'CC modification (Y/N)'] = 1

    data_cleaned['DLT_Dermatological 2'] = 0
    data_cleaned['DLT_Neurological 2'] = 0
    data_cleaned['DLT_Gastrointestinal 2'] = 0
    data_cleaned['DLT_Hematological 2'] = 0
    data_cleaned['DLT_Nephrological 2'] = 0
    data_cleaned['DLT_Vascular 2'] = 0
    data_cleaned['DLT_Infection (Pneumonia) 2'] = 0
    data_cleaned['DLT_Other 2'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT 2'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT 2']):
            if i.strip() != '':
                data_cleaned.loc[index, Const.dlt_dict[i.strip()] + ' 2'] = 1

    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'N', 'Decision 3 Neck Dissection (Y/N)'] = 0
    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'Y', 'Decision 3 Neck Dissection (Y/N)'] = 1

    return data_cleaned

def merge_editions(row,basecol='AJCC 8th edition',fallback='AJCC 7th edition'):
    if pd.isnull(row[basecol]):
        return row[fallback]
    return row[basecol]


def preprocess_dt_data(df,extra_to_keep=None):
    
    to_keep = ['id','hpv','age','packs_per_year','smoking_status','gender','Aspiration rate Pre-therapy','total_dose','dose_fraction'] 
    to_onehot = ['T-category','N-category','AJCC','Pathological Grade','subsite','treatment','ln_cluster']
    if extra_to_keep is not None:
        to_keep = to_keep + [c for c in extra_to_keep if c not in to_keep and c not in to_onehot]
    
    decisions =Const.decisions
    outcomes = Const.outcomes
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
    for k,v in Const.cc_types.items():
        df[v] = df['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)'].apply(lambda x: int(Const.cc_types.get(int(x),0) == v))
        to_keep.append(v)
    for k,v in Const.modification_types.items():
        name = 'Modification Type (0= no dose adjustment, 1=dose modified, 2=dose delayed, 3=dose cancelled, 4=dose delayed & modified, 5=regimen modification, 9=unknown)'
        df[v] = df[name].apply(lambda x: int(Const.modification_types.get(int(x),0) == v))
        to_keep.append(v)
    #Features to keep. I think gender is is in 
    
    keywords = []
    for keyword in keywords:
        toadd = [c for c in df.columns if keyword in c and c not in to_keep]
        to_keep = to_keep + toadd
    
    df['packs_per_year'] = df['packs_per_year'].apply(lambda x: str(x).replace('>','').replace('<','')).astype(float).fillna(0)
    #so I'm actually not sure if this is biological sex or gender given this is texas
    df['AJCC'] = df.apply(lambda row: merge_editions(row,'ajcc8','ajcc7'),axis=1)
    df['N-category'] = df.apply(lambda row: merge_editions(row,'N-category_8th_edition','N-category'),axis=1)
    
    dummy_df = pd.get_dummies(df[to_onehot].fillna(0).astype(str),drop_first=False)
    for col in dummy_df.columns:
        df[col] = dummy_df[col]
        to_keep.append(col)
        
    yn_to_binary = ['FT','Aspiration rate Post-therapy']
    for col in yn_to_binary:
        df[col] = df[col].apply(lambda x: int(x == 'Y'))
        
    to_keep = to_keep + [c for c in df.columns if 'DLT' in c]
    
        
    for statelist in [Const.state2,Const.state3,Const.decisions,Const.outcomes]:
        toadd = [c for c in statelist if c not in to_keep]
        to_keep = to_keep + toadd
    return df[to_keep].set_index('id')



In [58]:
class DTDataset():
    
    def __init__(self,data_file = '../data/digital_twin_data.csv',ln_data_file = '../data/digital_twin_ln_data.csv'):
        df = pd.read_csv(data_file)
        
        df = preprocess(df)
        df = df.rename(columns = Const.rename_dict).copy()
        df = df.drop('MRN OPC',axis=1)

        ln_data = pd.read_csv(ln_data_file)
        ln_data = ln_data.rename(columns={'cluster':'ln_cluster'})
        self.ln_cols = [c for c in ln_data.columns if c not in df.columns]
        df = df.merge(ln_data,on='id')
        self.df=df
        self.processed_df = preprocess_dt_data(df,self.ln_cols).fillna(0)
        
        self.means = self.processed_df.mean(axis=0)
        self.stds = self.processed_df.std(axis=0)
        self.maxes = self.processed_df.max(axis=0)
        self.mins = self.processed_df.min(axis=0)
        
        arrays = self.get_states()
        self.state_sizes = {k: (v.shape[1] if v.ndim > 1 else 1) for k,v in arrays.items()}
        
    def get_data(self):
        return self.processed_df
    
    def sample(self,frac=.5):
        return self.processed_df.sample(frac=frac)
    
    def split_sample(self,ratio = .3):
        assert(ratio > 0 and ratio <= 1)
        df1 = self.processed_df.sample(frac=1-ratio)
        df2 = self.processed_df.drop(index=df1.index)
        return df1,df2
    
    def get_states(self,fixed=None,ids = None):
        processed_df = self.processed_df.copy()
        if ids is not None:
            processed_df = processed_df.loc[ids]
        if fixed is not None:
            for col,val in fixed.items():
                if col in processed_df.columns:
                    processed_df[col] = val
                else:
                    print('bad fixed entry',col)
                    
        to_skip = ['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)','DLT_Type','DLT 2'] + [c for c in processed_df.columns if 'treatment' in c]
        other_states = set(Const.decisions + Const.state3 + Const.state2 + Const.outcomes  + to_skip)

        base_state = sorted([c for c in processed_df.columns if c not in other_states])

        dlt1 = Const.dlt1
        state2 = Const.state2 
        state3 =  Const.state3 

        outcomes = Const.outcomes
        decisions= Const.decisions
        #intermediate states are only udated values. Models should use baseline + state2 etc
        results = {
            'baseline': processed_df[base_state],
            'state2': processed_df[state2],
            'state3': processed_df[state3],
            'outcomes': processed_df[outcomes],
            'decision1': processed_df[decisions[0]],
            'decision2': processed_df[decisions[1]],
            'decision3': processed_df[decisions[2]],
        }
    
        return results
    
    def normalize(self,df):
        means = self.means[df.columns]
        std = self.stds[df.columns]
        return ((df - means)/std).fillna(0)
    
    def get_state(self,num=1,normalize=False,**kwargs):
        states = self.get_states(**kwargs)
        state = states['baseline']
        assert(num in [1,2,3])
        if num > 1:
            state = state.merge(states['state'+str(num)],on='id')
        if normalize:
            return self.normalize(state)
        return state
    
    def get_intermediate_outcomes(self,num=2,**kwargs):
        states = self.get_states(**kwargs)
        assert(num in [2,3])
        return states['state'+str(num)]
    
    def get_outcomes(self,**kwargs):
        states = self.get_outcomes(**kwargs)
        return states['outcomes']
    
data = DTDataset()
data.get_intermediate_outcomes()

  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


Unnamed: 0_level_0,no_dose_adjustment,dose_modified,dose_delayed,dose_cancelled,dose_delayed_&_modified,regiment_modification,unknown,CR Primary,PR Primary,SD Primary,...,PR Nodal,SD Nodal,DLT_Gastrointestinal,DLT_Other,DLT_Neurological,DLT_Dermatological,DLT_Infection (Pneumonia),DLT_Vascular,DLT_Nephrological,DLT_Hematological
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
6,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
7,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
8,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10201,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
10202,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
10203,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
10204,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [59]:
[c for c in data.df.columns if 'modification' in c.lower()]

['Chemo Modification (Y/N)',
 'Modification Type (0= no dose adjustment, 1=dose modified, 2=dose delayed, 3=dose cancelled, 4=dose delayed & modified, 5=regimen modification, 9=unknown)',
 'CC modification (Y/N)',
 'regiment_modification']

In [60]:
def df_to_torch(df,ttype  = torch.FloatTensor):
    values = df.values.astype(float)
    values = torch.from_numpy(values)
    return values.type(ttype)

df_to_torch(data.get_state(1,fixed={'1A': 3},normalize=True,ids=data.split_sample()[1].index.values))

tensor([[34.7387, -0.1063,  0.0000,  ..., -0.1063, -0.8433, -1.1674],
        [34.7387, -0.1063,  0.0000,  ..., -0.1063, -0.8433,  0.5397],
        [34.7387, -0.1063,  0.0000,  ..., -0.1063, -0.8433,  0.5570],
        ...,
        [34.7387, -0.1063,  0.0000,  ..., -0.1063, -0.8433,  1.4192],
        [34.7387, -0.1063,  0.0000,  ..., -0.1063, -0.8433,  0.5570],
        [34.7387, -0.1063,  0.0000,  ..., -0.1063, -0.8433,  0.5397]])

In [61]:
class OutcomeSimulator1(torch.nn.Module):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [1000],
                 dropout = .5,
                 input_dropout=0,
                ):
        torch.nn.Module.__init__(self)
        
        self.input_dropout = torch.nn.Dropout(input_dropout)
        
        first_layer =torch.nn.Linear(input_size,hidden_layers[0],bias=True)
        layers = [first_layer,torch.nn.ReLU()]
        curr_size = hidden_layers[0]
        for ndim in hidden_layers[1:]:
            layer = torch.nn.Linear(curr_size,ndim)
            curr_size = ndim
            layers.append(layer)
            layers.append(torch.nn.ReLU())
        self.layers = torch.nn.ModuleList(layers)
        self.batchnorm = torch.nn.BatchNorm1d(hidden_layers[-1])
        self.dropout = torch.nn.Dropout(dropout)
        
        self.disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.primary_disease_states))
        self.nodal_disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.nodal_disease_states))
#         final_layers = []
#         for size in output_sizes:
#             sin= torch.nn.Linear(hidden_layers[-1],output_size)
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self,x):
        x = self.input_dropout(x)
        for layer in self.layers:
            x = layer(x)
        x = self.batchnorm(x)
        x = self.dropout(x)
        x = self.final_layer(x)
        x = self.sigmoid(x)
        return x
    
outcomes1_model = OutcomeSimulator(data.state_sizes['baseline'],data.state_sizes['state2'])
outcomes2_model = OutcomeSimulator(data.state_sizes['baseline']+data.state_sizes['state2'],data.state_sizes['state3'])
x = df_to_torch(data.get_state(1,normalize=True))
y = df_to_torch(data.get_intermediate_outcomes(2))
outcomes1_model(x)

tensor([[0.5358, 0.5109, 0.5215,  ..., 0.5798, 0.3287, 0.4508],
        [0.3211, 0.4001, 0.7209,  ..., 0.5084, 0.4035, 0.5867],
        [0.3345, 0.7183, 0.4821,  ..., 0.1275, 0.3235, 0.4268],
        ...,
        [0.3504, 0.5385, 0.2055,  ..., 0.3352, 0.5954, 0.3697],
        [0.4871, 0.4768, 0.5843,  ..., 0.5876, 0.7455, 0.1747],
        [0.5662, 0.6564, 0.2497,  ..., 0.5460, 0.4809, 0.3786]],
       grad_fn=<SigmoidBackward0>)

In [8]:
y.detach().numpy().max(axis=0)

array([1., 1., 1., 1., 1., 3., 3., 3., 4., 1., 3., 1., 1., 4., 1., 1., 1.,
       1., 1., 1., 1.], dtype=float32)

In [62]:
from sklearn.metrics import roc_auc_score , f1_score
def multi_bce_loss(ypred,target,weights=None):
    #loss for predicting multiple of non-exclusive binary values
    nclasses = ypred.shape[1]
    if weights is None:
        weights = torch.ones(nclasses).to('cuda' if torch.cuda.is_available() else 'cpu')
    bce = torch.nn.BCELoss()
    total_loss = torch.tensor([0],dtype=torch.float32,device='cuda' if torch.cuda.is_available() else 'cpu')
    for i in range(nclasses):
        closs = bce(ypred[:,i],target[:,i])
        total_loss += weights[i]*closs.item()
    return total_loss

def multiclass_metrics(ypred,target):
    nclasses = ypred.shape[1]
    ypred = ypred.cpu().detach().numpy().astype(float)
    target = (target.cpu().detach().numpy() > 0).astype(float)
    aucs = []
    f1s = []
    for i in range(nclasses):
        if target[:,i].std() < .0001 and ypred[:].argmax(axis=1).sum() > 0:
            auc_score = -1
            f1_scores = -1
        else:
            auc_score = roc_auc_score(target[:,i],ypred[:,i])
            f1_scores = f1_score(target[:,i],ypred[:].argmax(axis=1) == i)
        aucs.append(auc_score)
        f1s.append(f1_scores)
    return {'auc': aucs, 'f1': f1s}



In [63]:
y.shape

torch.Size([536, 21])

In [64]:
def train_state1(dataset,model_args={},lr=.00001,epochs=1000,patience=10):
    model = OutcomeSimulator(dataset.state_sizes['baseline'],dataset.state_sizes['state2'],**model_args)
    for param in model.parameters():
        param.required_grad=True
    train_idx,test_idx = [i.index for i in dataset.split_sample()]
    xtrain = df_to_torch(dataset.get_state(1,normalize=False,ids=train_idx))
    xtest = df_to_torch(dataset.get_state(1,normalize=False,ids=test_idx))
    ytrain = df_to_torch(data.get_intermediate_outcomes(2,ids=train_idx),torch.LongTensor)
    ytest = df_to_torch(data.get_intermediate_outcomes(2,ids=test_idx),torch.LongTensor)
    
    maxy = torch.max(torch.tensor([ytrain.max(),ytest.max()]))
    ytrain = (ytrain > 0).type(torch.FloatTensor)
    ytest = (ytest > 0).type(torch.FloatTensor)
    model = model.to('cuda')
    if torch.cuda.is_available():
        model = model.to('cuda')
        xtrain = xtrain.to('cuda')
        xtest = xtest.to('cuda')
        ytrain = ytrain.to('cuda')
        ytest = ytest.to('cuda')
        
   
    
    ytrain.requires_grad=True
    normalize = lambda x: (x - xtrain.mean(axis=0)+.01)/(xtrain.std(axis=0)+.01)
    unnormalize = lambda x: (x * (xtrain.std(axis=0) +.01)) + xtrain.mean(axis=0) - .01
    
    optimizer = torch.optim.SGD(model.parameters(),lr=lr)
    best_val_loss = 1000000000000000000000000000
    best_val_metrics = {}
    for epoch in range(epochs):
        
        model.train(True)
        optimizer.zero_grad()
        
        ypred = model(normalize(xtrain))
        print(ypred.get_device(),ytrain.get_device())
        loss = multi_bce_loss(ypred.to('cuda'),ytrain.to('cuda'))
        loss.requires_grad=True
        print(loss.item())
        loss.backward()
        optimizer.step()
        print('epoch',epoch,'train loss',loss.item())
        
        model.train(False)
        yval = model(normalize(xtest))
        val_loss = multi_bce_loss(yval,ytest)
        val_metrics = multiclass_metrics(yval,ytest)
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            best_val_metrics = val_metrics
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
        print('val loss',val_loss.item(),np.mean(val_metrics['auc']),np.mean(val_metrics['f1']))
        if steps_since_improvement > patience:
            break
    print('best loss',best_val_loss)
    best_val_metrics['loss']= best_val_loss
    return model, best_val_metrics
train_state1(data)

0 0
16.000381469726562
epoch 0 train loss 16.000381469726562
val loss 15.041558265686035 0.43012409908877586 -0.04082214082214083
0 0
16.07006072998047
epoch 1 train loss 16.07006072998047
val loss 15.024404525756836 0.429974817128273 -0.042326121244220326
0 0
16.033220291137695
epoch 2 train loss 16.033220291137695
val loss 15.009340286254883 0.4302254228160023 -0.045679855632437476
0 0
15.986144065856934
epoch 3 train loss 15.986144065856934
val loss 14.99630355834961 0.4295757641202247 -0.044862971014732536
0 0
15.930232048034668
epoch 4 train loss 15.930232048034668


  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,


val loss 14.985233306884766 0.4296094158072125 -0.04244945428520308
0 0
16.050121307373047
epoch 5 train loss 16.050121307373047
val loss 14.976068496704102 0.42996292784929985 -0.03620637169182446
0 0
16.150718688964844
epoch 6 train loss 16.150718688964844
val loss 14.968751907348633 0.4296945457580045 -0.03686327304748358
0 0
16.164289474487305
epoch 7 train loss 16.164289474487305
val loss 14.963224411010742 0.4294917381651211 -0.0319042944042944
0 0
16.07388687133789
epoch 8 train loss 16.07388687133789
val loss 14.95942497253418 0.42881449973886465 -0.03105395426823998
0 0
16.04673194885254
epoch 9 train loss 16.04673194885254
val loss 14.95728874206543 0.4282160876946966 -0.031189428087924333
0 0
16.081722259521484
epoch 10 train loss 16.081722259521484


  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,
  precision = _prf_divide(tp_sum, pred_sum,


val loss 14.956747055053711 0.4286299139100457 -0.03131795453224025
0 0
15.944466590881348
epoch 11 train loss 15.944466590881348
val loss 14.95772647857666 0.4290229554605365 -0.03131795453224025
0 0
15.902392387390137
epoch 12 train loss 15.902392387390137
val loss 14.960153579711914 0.428694135064083 -0.030930807803630094
0 0
15.986706733703613
epoch 13 train loss 15.986706733703613
val loss 14.963940620422363 0.42908663948198733 -0.03290525611954183
0 0
16.095243453979492
epoch 14 train loss 16.095243453979492
val loss 14.969000816345215 0.42945346846836074 -0.032381969881969876
0 0
16.13425636291504
epoch 15 train loss 16.13425636291504
val loss 14.975238800048828 0.42949804386162427 -0.03273210993799229
0 0
15.939824104309082
epoch 16 train loss 15.939824104309082
val loss 14.982558250427246 0.4293163071994378 -0.03480781201369437
0 0
15.958976745605469
epoch 17 train loss 15.958976745605469
val loss 14.990853309631348 0.4291753922555666 -0.036697614056104624
0 0
16.0497055053710

(OutcomeSimulator(
   (input_dropout): Dropout(p=0, inplace=False)
   (layers): ModuleList(
     (0): Linear(in_features=61, out_features=1000, bias=True)
     (1): ReLU()
   )
   (batchnorm): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (dropout): Dropout(p=0.5, inplace=False)
   (final_layer): Linear(in_features=1000, out_features=21, bias=True)
   (sigmoid): Sigmoid()
 ),
 {'auc': [0.6376716808371485,
   0.8075268817204301,
   0.3080168776371308,
   0.43312101910828027,
   0.4789029535864979,
   0.4365591397849462,
   -1,
   0.5909367396593674,
   0.5472636815920398,
   0.2515923566878981,
   0.6217948717948718,
   0.4561302681992337,
   0.8016877637130801,
   0.6419354838709677,
   0.7628205128205129,
   0.7064516129032259,
   0.6666666666666666,
   0.41874999999999996,
   -1,
   0.81875,
   0.6146496815286624],
  'f1': [0.2625,
   0.2222222222222222,
   0.0,
   0.0,
   0.0,
   0.0,
   -1,
   0.3333333333333333,
   0.10256410256410256,
   0.0