In [1]:
import numpy as np
import pandas as pd
import torch
import re
pd.set_option('display.max_rows', 200)

In [2]:
class Const:
    data_dir = '../data'
    twin_data = data_dir + 'digital_twin_data.csv'
    twin_ln_data = data_dir + 'digital_twin_ln_data.csv'
    
    rename_dict = {
        'Dummy ID': 'id',
        'Age at Diagnosis (Calculated)': 'age',
        'Feeding tube 6m': 'FT',
        'Affected Lymph node UPPER': 'affected_nodes',
        'Aspiration rate(Y/N)': 'AS',
        'Neck boost (Y/N)': 'neck_boost',
        'Gender': 'gender',
        'Tm Laterality (R/L)': 'laterality',
        'AJCC 8th edition': 'ajcc8',
        'AJCC 7th edition':'ajcc7',
        'N_category_full': 'N-category',
        'HPV/P16 status': 'hpv',
        'Tumor subsite (BOT/Tonsil/Soft Palate/Pharyngeal wall/GPS/NOS)': 'subsite',
        'Total dose': 'total_dose',
        'Therapeutic combination': 'treatment',
        'Smoking status at Diagnosis (Never/Former/Current)': 'smoking_status',
        'Smoking status (Packs/Year)': 'packs_per_year',
        'Overall Survival (1=alive,0=dead)': 'os',
        'Dose/fraction (Gy)': 'dose_fraction'
    }
    
    dlt_dict = {
         'Allergic reaction to Cetuximab': 'DLT_Other',
         'Cardiological (A-fib)': 'DLT_Other',
         'Dermatological': 'DLT_Dermatological',
         'Failure to Thrive': 'DLT_Other',
         'Failure to thrive': 'DLT_Other',
         'GIT [elevated liver enzymes]': 'DLT_Gastrointestinal',
         'Gastrointestina': 'DLT_Gastrointestinal',
         'Gastrointestinal': 'DLT_Gastrointestinal',
         'General': 'DLT_Other',
         'Hematological': 'DLT_Hematological',
         'Hematological (Neutropenia)': 'DLT_Hematological',
         'Hyponatremia': 'DLT_Other',
         'Immunological': 'DLT_Other',
         'Infection': 'DLT_Infection (Pneumonia)',
         'NOS': 'DLT_Other',
         'Nephrological': 'DLT_Nephrological',
         'Nephrological (ARF)': 'DLT_Nephrological',
         'Neurological': 'DLT_Neurological',
         'Neutropenia': 'DLT_Hematological',
         'Nutritional': 'DLT_Other',
         'Pancreatitis': 'DLT_Other',
         'Pulmonary': 'DLT_Other',
         'Respiratory (Pneumonia)': 'DLT_Infection (Pneumonia)',
         'Sepsis': 'DLT_Infection (Pneumonia)',
         'Suboptimal response to treatment' : 'DLT_Other',
         'Vascular': 'DLT_Vascular'
    }
    
    decision1 = 'Decision 1 (Induction Chemo) Y/N'
    decision2 = 'Decision 2 (CC / RT alone)'
    decision3 = 'Decision 3 Neck Dissection (Y/N)'
    decisions = [decision1,decision2, decision3]
    outcomes = ['Overall Survival (4 Years)', 'FT', 'Aspiration rate Post-therapy']
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
    primary_disease_states = ['CR Primary','PR Primary','SD Primary']
    nodal_disease_states = [t.replace('Primary','Nodal') for t in primary_disease_states]
    dlt1 = list(set(dlt_dict.values()))
    
    modifications =  list(modification_types.values())
    state2 = modifications + primary_disease_states+nodal_disease_states +dlt1 #+['No imaging 0=N,1=Y']
    
    primary_disease_states2 = [t + ' 2' for t in primary_disease_states]
    nodal_disease_states2 = [t + ' 2' for t in nodal_disease_states]
    dlt2 = [d + ' 2' for d in dlt1]
    
    ccs = list(cc_types.values())
    state3 = ccs + primary_disease_states2 + nodal_disease_states2 + dlt2
    name_dict = {
        'pd_state1': primary_disease_states,
        'nd_state1': nodal_disease_states,
        'chemo_state1': modifications,
        'chemo_state2': ccs,
        'pd_state2': primary_disease_states2,
        'pd_state2': nodal_disease_states2,
    }
    
Const.primary_disease_states

['CR Primary', 'PR Primary', 'SD Primary']

In [75]:
def preprocess(data_cleaned):
    #this was Elisa's preprocessing except I removed all the Ifs because that's dumb
    if len(data_cleaned.shape) < 2:
        data_cleaned = pd.DataFrame([data], columns=data.index)
        
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'N', 'Aspiration rate Pre-therapy'] = 0
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'Y', 'Aspiration rate Pre-therapy'] = 1

    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'I', 'Pathological Grade'] = 1
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'II', 'Pathological Grade'] = 2
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'III', 'Pathological Grade'] = 3
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'IV', 'Pathological Grade'] = 4

    data_cleaned.loc[(data_cleaned['T-category'] == 'Tx') | (data_cleaned['T-category'] == 'Tis'), 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T1', 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T2', 'T-category'] = 2
    data_cleaned.loc[data_cleaned['T-category'] == 'T3', 'T-category'] = 3
    data_cleaned.loc[data_cleaned['T-category'] == 'T4', 'T-category'] = 4

    data_cleaned.loc[data_cleaned['N-category'] == 'N0', 'N-category'] = 0
    data_cleaned.loc[data_cleaned['N-category'] == 'N1', 'N-category'] = 1
    data_cleaned.loc[data_cleaned['N-category'] == 'N2', 'N-category'] = 2
    data_cleaned.loc[data_cleaned['N-category'] == 'N3', 'N-category'] = 3

    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N0', 'N-category_8th_edition'] = 0
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N1', 'N-category_8th_edition'] = 1
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N2', 'N-category_8th_edition'] = 2
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N3', 'N-category_8th_edition'] = 3

    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'I', 'AJCC 7th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'II', 'AJCC 7th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'III', 'AJCC 7th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'IV', 'AJCC 7th edition'] = 4

    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'I', 'AJCC 8th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'II', 'AJCC 8th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'III', 'AJCC 8th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'IV', 'AJCC 8th edition'] = 4

    data_cleaned.loc[data_cleaned['Gender'] == 'Male', 'Gender'] = 1
    data_cleaned.loc[data_cleaned['Gender'] == 'Female', 'Gender'] = 0


    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Positive', 'HPV/P16 status'] = 1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Negative', 'HPV/P16 status'] = -1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Unknown', 'HPV/P16 status'] = 0

    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Formar', 'Smoking status at Diagnosis (Never/Former/Current)'] = .5
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Current', 'Smoking status at Diagnosis (Never/Former/Current)'] = 1
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Never', 'Smoking status at Diagnosis (Never/Former/Current)'] = 0


    data_cleaned.loc[data_cleaned['Chemo Modification (Y/N)'] == 'Y', 'Chemo Modification (Y/N)'] = 1

    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'N', 'DLT (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'Y', 'DLT (Y/N)'] = 1

    data_cleaned['DLT_Other'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT_Type'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT_Type']):
            if i.strip() != '' and data_cleaned.loc[index, Const.dlt_dict[i.strip()]] == 0:
                data_cleaned.loc[index, Const.dlt_dict[i.strip()]] = 1

    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'RT alone', 'Decision 2 (CC / RT alone)'] = 0
    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'CC', 'Decision 2 (CC / RT alone)'] = 1

    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'N', 'CC modification (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'Y', 'CC modification (Y/N)'] = 1

    data_cleaned['DLT_Dermatological 2'] = 0
    data_cleaned['DLT_Neurological 2'] = 0
    data_cleaned['DLT_Gastrointestinal 2'] = 0
    data_cleaned['DLT_Hematological 2'] = 0
    data_cleaned['DLT_Nephrological 2'] = 0
    data_cleaned['DLT_Vascular 2'] = 0
    data_cleaned['DLT_Infection (Pneumonia) 2'] = 0
    data_cleaned['DLT_Other 2'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT 2'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT 2']):
            if i.strip() != '':
                data_cleaned.loc[index, Const.dlt_dict[i.strip()] + ' 2'] = 1

    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'N', 'Decision 3 Neck Dissection (Y/N)'] = 0
    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'Y', 'Decision 3 Neck Dissection (Y/N)'] = 1

    return data_cleaned

def merge_editions(row,basecol='AJCC 8th edition',fallback='AJCC 7th edition'):
    if pd.isnull(row[basecol]):
        return row[fallback]
    return row[basecol]


def preprocess_dt_data(df,extra_to_keep=None):
    
    to_keep = ['id','hpv','age','packs_per_year','smoking_status','gender','Aspiration rate Pre-therapy','total_dose','dose_fraction'] 
    to_onehot = ['T-category','N-category','AJCC','Pathological Grade','subsite','treatment','ln_cluster']
    if extra_to_keep is not None:
        to_keep = to_keep + [c for c in extra_to_keep if c not in to_keep and c not in to_onehot]
    
    df['bilateral'] = df['laterality'].apply(lambda x: x.lower().strip() == 'bilateral')
    to_keep.append('bilateral')
    
    decisions =Const.decisions
    outcomes = Const.outcomes
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
#     races_shortened = ['White/Caucasian','Hispanic/Latino','African American/Black']
#     for race in races_shortened:
#         df[race] = df['Race'].apply(lambda x: x.strip() == race)
#         to_keep.append(race)
    
    for k,v in Const.cc_types.items():
        df[v] = df['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)'].apply(lambda x: int(Const.cc_types.get(int(x),0) == v))
        to_keep.append(v)
    for k,v in Const.modification_types.items():
        name = 'Modification Type (0= no dose adjustment, 1=dose modified, 2=dose delayed, 3=dose cancelled, 4=dose delayed & modified, 5=regimen modification, 9=unknown)'
        df[v] = df[name].apply(lambda x: int(Const.modification_types.get(int(x),0) == v))
        to_keep.append(v)
    #Features to keep. I think gender is is in 
    
    for col in Const.dlt1 + Const.dlt2:
        df[col] = (df[col] > 0).astype(float)
    keywords = []
    for keyword in keywords:
        toadd = [c for c in df.columns if keyword in c and c not in to_keep]
        to_keep = to_keep + toadd
    
    df['packs_per_year'] = df['packs_per_year'].apply(lambda x: str(x).replace('>','').replace('<','')).astype(float).fillna(0)
    #so I'm actually not sure if this is biological sex or gender given this is texas
    df['AJCC'] = df.apply(lambda row: merge_editions(row,'ajcc8','ajcc7'),axis=1)
    df['N-category'] = df.apply(lambda row: merge_editions(row,'N-category_8th_edition','N-category'),axis=1)
    
    dummy_df = pd.get_dummies(df[to_onehot].fillna(0).astype(str),drop_first=False)
    for col in dummy_df.columns:
        df[col] = dummy_df[col]
        to_keep.append(col)
        
    yn_to_binary = ['FT','Aspiration rate Post-therapy','Decision 1 (Induction Chemo) Y/N']
    for col in yn_to_binary:
        df[col] = df[col].apply(lambda x: int(x == 'Y'))
        
    to_keep = to_keep + [c for c in df.columns if 'DLT' in c]
    
    for statelist in [Const.state2,Const.state3,Const.decisions,Const.outcomes]:
        toadd = [c for c in statelist if c not in to_keep]
        to_keep = to_keep + toadd
    return df[to_keep].set_index('id')



In [4]:
def load_digital_twin(file='../data/digital_twin_data.csv'):
    df = pd.read_csv(file)
    return df.rename(columns = Const.rename_dict)

def get_dt_ids():
    df = load_digital_twin()
    return df.id.values
get_dt_ids()

  df = pd.read_csv(file)


array([    3,     5,     6,     7,     8,     9,    10,    11,    13,
          14,    15,    16,    17,    18,    21,    23,    24,    25,
          26,    27,    28,    31,    32,    33,    35,    36,    37,
          38,    39,    40,    41,    42,    44,    45,    47,    48,
          49,    50,    51,    53,    55,    56,    57,    60,    64,
          65,    67,    68,    69,    71,    74,    75,    77,    78,
          79,    80,    81,    82,    87,    88,    91,    94,    96,
          99,   103,   109,   116,   117,   119,   120,   121,   125,
         133,   148,   150,   153,   168,   178,   181,   183,   184,
         185,   186,   187,   188,   189,   190,   191,   192,   193,
         194,   195,   196,   197,   198,   199,   200,   201,   202,
         203,   204,   205,   206,   207,   208,   209,   210,   211,
         212,   213,   214,   215,   216,   217,   218,   219,   220,
         221,   222,   223,   224,   225,   226,   227,   228,   229,
         230,   231,

In [None]:
for c in pd.read_csv('../data/digital_twin_data.csv').columns:
    if c == 'Tm Laterality (R/L)':
        print(c)


In [None]:
pd.get_dummies(pd.read_csv('../data/digital_twin_data.csv')[Const.outcomes[0]]).sum()

In [79]:
class DTDataset():
    
    def __init__(self,data_file = '../data/digital_twin_data.csv',ln_data_file = '../data/digital_twin_ln_data.csv',ids=None):
        df = pd.read_csv(data_file)
        df = preprocess(df)
        df = df.rename(columns = Const.rename_dict).copy()
        df = df.drop('MRN OPC',axis=1)

        ln_data = pd.read_csv(ln_data_file)
        ln_data = ln_data.rename(columns={'cluster':'ln_cluster'})
        self.ln_cols = [c for c in ln_data.columns if c not in df.columns]
        df = df.merge(ln_data,on='id')
        df.index = df.index.astype(int)
        if ids is not None:
            df = df[df.id.apply(lambda x: x in ids)]
        self.processed_df = preprocess_dt_data(df,self.ln_cols).fillna(0)
        
        self.means = self.processed_df.mean(axis=0)
        self.stds = self.processed_df.std(axis=0)
        self.maxes = self.processed_df.max(axis=0)
        self.mins = self.processed_df.min(axis=0)
        
        arrays = self.get_states()
        self.state_sizes = {k: (v.shape[1] if v.ndim > 1 else 1) for k,v in arrays.items()}
        
    def get_data(self):
        return self.processed_df
    
    def sample(self,frac=.5):
        return self.processed_df.sample(frac=frac)
    
    def split_sample(self,ratio = .3):
        assert(ratio > 0 and ratio <= 1)
        df1 = self.processed_df.sample(frac=1-ratio)
        df2 = self.processed_df.drop(index=df1.index)
        return df1,df2
    
    def get_states(self,fixed=None,ids = None):
        processed_df = self.processed_df.copy()
        if ids is not None:
            processed_df = processed_df.loc[ids]
        if fixed is not None:
            for col,val in fixed.items():
                if col in processed_df.columns:
                    processed_df[col] = val
                else:
                    print('bad fixed entry',col)
                    
        to_skip = ['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)','DLT_Type','DLT 2'] + [c for c in processed_df.columns if 'treatment' in c]
        other_states = set(Const.decisions + Const.state3 + Const.state2 + Const.outcomes  + to_skip)

        base_state = sorted([c for c in processed_df.columns if c not in other_states])

        dlt1 = Const.dlt1
        dlt2 = Const.dlt2
        
        modifications = Const.modifications
        ccs = Const.ccs
        pds = Const.primary_disease_states
        nds = Const.nodal_disease_states
        pds2 = Const.primary_disease_states2
        nds2 = Const.nodal_disease_states2
        outcomes = Const.outcomes
        decisions= Const.decisions
        
        #intermediate states are only udated values. Models should use baseline + state2 etc
        results = {
            'baseline': processed_df[base_state],
            'pd_states1': processed_df[pds],
            'nd_states1': processed_df[nds],
            'modifications': processed_df[modifications],
            'ccs': processed_df[ccs],
            'pd_states2': processed_df[pds2],
            'nd_states2': processed_df[nds2],
            'outcomes': processed_df[outcomes],
            'dlt1': processed_df[dlt1],
            'dlt2': processed_df[dlt2],
            'decision1': processed_df[decisions[0]],
            'decision2': processed_df[decisions[1]],
            'decision3': processed_df[decisions[2]],
            'survival': processed_df[outcomes[0]],
            'ft': processed_df[outcomes[1]],
            'aspiration': processed_df[outcomes[2]],
        }
    
        return results
    
    def get_state(self,name,**kwargs):
        return self.get_states(**kwargs)[name]
    
    def normalize(self,df):
        means = self.means[df.columns]
        std = self.stds[df.columns]
        return ((df - means)/std).fillna(0)
    
    def get_intermediate_outcomes(self,step=1,**kwargs):
        assert(step in [0,1,2,3])
        states = self.get_states(**kwargs)
        if step == 1:
            keys = ['pd_states1','nd_states1','modifications','dlt1']
        if step == 2:
            keys =  ['pd_states2','nd_states2','ccs','dlt2']
        if step == 3:
            keys = ['survival','ft','aspiration']# ['decision1','decision2','decision3']
        if step == 0:
            keys = ['decision1','decision2','decision3']
        if len(keys) < 2:
            return states[keys[0]]
        return [states[key] for key in keys]
    
    def get_input_state(self,step=1,**kwargs):
        assert(step in [0,1,2,3])
        states = self.get_states(**kwargs)
        if step == 0:
            keys = ['baseline']
        if step == 1:
            keys = ['baseline','decision1']
        if step == 2:
            keys =  ['baseline','pd_states1','nd_states1','modifications','dlt1','decision1','decision2']
        if step == 3:
            keys = ['baseline','pd_states2','nd_states2','ccs','dlt2','decision1','decision2','decision3']
        arrays = [states[key] for key in keys]
        if len(arrays) < 2:
            return arrays[0]
        return pd.concat(arrays,axis=1)
    
data = DTDataset()
data.get_input_state(2)

  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


Unnamed: 0_level_0,1A,1A1B,1A6,1B,1B2A,1B3,2A,2A2B,2A3,2B,...,DLT_Other,DLT_Hematological,DLT_Gastrointestinal,DLT_Neurological,DLT_Vascular,DLT_Nephrological,DLT_Dermatological,DLT_Infection (Pneumonia),Decision 1 (Induction Chemo) Y/N,Decision 2 (CC / RT alone)
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,1
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,1
6,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,1
7,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0
8,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10201,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,1
10202,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,1
10203,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0
10204,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,1


In [6]:
def df_to_torch(df,ttype  = torch.FloatTensor):
    values = df.values.astype(float)
    values = torch.from_numpy(values)
    return values.type(ttype)

In [188]:
class SimulatorBase(torch.nn.Module):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [1000],
                 dropout = 0.5,
                 input_dropout=0.1,
                 state = 1,
                 eps = 0.01,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        torch.nn.Module.__init__(self)
        self.state = state
        self.input_dropout = torch.nn.Dropout(input_dropout)
        
        first_layer =torch.nn.Linear(input_size,hidden_layers[0],bias=True)
        layers = [first_layer,torch.nn.ReLU()]
        curr_size = hidden_layers[0]
        for ndim in hidden_layers[1:]:
            layer = torch.nn.Linear(curr_size,ndim)
            curr_size = ndim
            layers.append(layer)
            layers.append(torch.nn.ReLU())
        self.layers = torch.nn.ModuleList(layers)
        self.batchnorm = torch.nn.BatchNorm1d(hidden_layers[-1])
        self.dropout = torch.nn.Dropout(dropout)
        
    
        input_mean = torch.tensor([0])
        input_std = torch.tensor([1])
        self.eps = eps
        self.register_buffer('input_mean', input_mean)
        self.register_buffer('input_std',input_std)
        
        self.sigmoid = torch.nn.Sigmoid()
        self.softmax = torch.nn.LogSoftmax(dim=1)
        self.identifier = 'state'  +str(state) + '_input'+str(input_size) + '_dims' + ','.join([str(h) for h in hidden_layers]) + '_dropout' + str(input_dropout) + ',' + str(dropout)
        
    def normalize(self,x):
        x = (x - self.input_mean + self.eps)/(self.input_std + self.eps)
        return x
    
    def fit_normalizer(self,x):
        input_mean = x.mean(axis=0)
        input_std = x.std(axis=0)
        self.register_buffer('input_mean', input_mean)
        self.register_buffer('input_std',input_std)
        return True
        
class OutcomeSimulator(SimulatorBase):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [500,500],
                 dropout = 0.7,
                 input_dropout=0.1,
                 state = 1,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        super(OutcomeSimulator,self).__init__(input_size,hidden_layers=hidden_layers,dropout=dropout,input_dropout=input_dropout,state=state)
    
        self.disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.primary_disease_states))
        self.nodal_disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.nodal_disease_states))
        #dlt ratings are 0-4 even though they don't always appear
        
        self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],1) for i in Const.dlt1])
        assert( state in [1,2])
        if state == 1:
#             self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],5) for i in Const.dlt1])
            self.treatment_layer = torch.nn.Linear(hidden_layers[-1],len(Const.modifications))
        else:
            #we only have dlt yes or no for the second state?
#             self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],2) for i in Const.dlt2])
            self.treatment_layer = torch.nn.Linear(hidden_layers[-1],len(Const.ccs))

    def forward(self,x):
        x = self.normalize(x)
        x = self.input_dropout(x)
        for layer in self.layers:
            x = layer(x)
#         x = self.batchnorm(x)
        x = self.dropout(x)
        x_pd = self.disease_layer(x)
        x_nd = self.nodal_disease_layer(x)
        x_mod = self.treatment_layer(x)
        x_dlts = [layer(x) for layer in self.dlt_layers]
        
        x_pd = self.softmax(x_pd)
        x_nd = self.softmax(x_nd)
        x_mod = self.softmax(x_mod)
        #dlts are array of nbatch x n_dlts x predictions
#         x_dlts = torch.stack([self.sigmoid(xx) for xx in x_dlts],axis=1)
        x_dlts = torch.cat([self.sigmoid(xx) for xx in x_dlts],axis=1)
        return [x_pd, x_nd, x_mod, x_dlts]
    
class EndpointSimulator(SimulatorBase):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [500],
                 dropout = 0.7,
                 input_dropout=0.1,
                 state = 1,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        super(EndpointSimulator,self).__init__(input_size,hidden_layers=hidden_layers,dropout=dropout,input_dropout=input_dropout,state=state)
        
        self.outcome_layer = torch.nn.Linear(hidden_layers[-1],len(Const.outcomes))
      
        
    def forward(self,x):
        x = self.normalize(x)
        x = self.input_dropout(x)
        for layer in self.layers:
            x = layer(x)
#         x = self.batchnorm(x)
        x = self.dropout(x)
        x= self.outcome_layer(x)
        x = self.sigmoid(x)
        return x

OutcomeSimulator(3).input_mean

tensor([0])

In [131]:
def nllloss(ytrue,ypred):
    #nll loss with argmax added in
    loss = torch.nn.NLLLoss()
    return loss(ypred,ytrue.argmax(axis=1))

def state_loss(ytrue,ypred,weights=[1,1,1,1]):
    pd_loss = nllloss(ytrue[0],ypred[0])*weights[0]
    nd_loss = nllloss(ytrue[1],ypred[1])*weights[1]
    mod_loss = nllloss(ytrue[2],ypred[2])*weights[2]
    loss = pd_loss + nd_loss + mod_loss
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
#     nloss = torch.nn.NLLLoss()
    bce = torch.nn.BCELoss()
    for i in range(ndlt):
        dlt_loss = bce(dlt_pred[:,i].view(-1),dlt_true[:,i].view(-1))
        loss += dlt_loss*weights[3]/ndlt
    return loss

def outcome_loss(ytrue,ypred,weights=[1,1,1]):
    loss = 0
    nloss = torch.nn.BCELoss()
    for i in range(len(weights)):
        iloss = nloss(ypred[:,i],ytrue[i])*weights[i]
        loss += iloss
    return loss

from sklearn.metrics import balanced_accuracy_score, roc_auc_score,accuracy_score

def mc_metrics(yt,yp,numpy=False,is_dlt=False):
    if not numpy:
        yt = yt .cpu().detach().numpy()
        yp = yp.cpu().detach().numpy()
    #dlt prediction (binary)
    if is_dlt:
        acc = accuracy_score(yt,yp>.5)
        if yt.sum() > 1:
            auc = roc_auc_score(yt,yp)
        else:
            auc=-1
        error = np.mean((yt-yp)**2)
        return {'accuracy': acc, 'mse': error, 'auc': auc}
    #this is a catch for when I se the dlt prediction format (encoded integer ordinal, predict as a categorical and take the argmax)
    elif yt.ndim > 1:
        try:
            bacc = balanced_accuracy_score(yt.argmax(axis=1),yp.argmax(axis=1))
        except:
            bacc = -1
        try:
            roc_micro = roc_auc_score(yt,yp,average='micro')
        except:
            roc_micro=-1
        try:
            roc_macro = roc_auc_score(yt,yp,average='macro')
        except:
            roc_macro = -1
        return {'accuracy': bacc, 'roc_micro': roc_micro,'roc_macro': roc_macro}
    #outcomes (binary)
    else:
        if yp.ndim > 1:
            yp = yp.argmax(axis=1)
        try:
            bacc = accuracy_score(yt,yp)
        except:
            bacc = -1
        try:
            roc = roc_auc_score(yt,yp)
        except:
            roc = -1
        error = np.mean((yt-yp)**2)
        return {'accuracy': bacc, 'mse': error, 'auc': roc}

def state_metrics(ytrue,ypred,numpy=False):
    pd_metrics = mc_metrics(ytrue[0],ypred[0],numpy=numpy)
    nd_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    mod_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    
    dlt_metrics = []
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
    nloss = torch.nn.NLLLoss()
    for i in range(ndlt):
        dm = mc_metrics(dlt_true[:,i],dlt_pred[:,i].view(-1),is_dlt=True)
        dlt_metrics.append(dm)
    dlt_acc =[d['accuracy'] for d in dlt_metrics]
    dlt_error = [d['mse'] for d in dlt_metrics]
    dlt_auc = [d['auc'] for d in dlt_metrics]
    return {'pd': pd_metrics,'nd': nd_metrics,'mod': mod_metrics,'dlts': {'accuracy': dlt_acc,'accuracy_mean': np.mean(dlt_acc),'auc': dlt_auc,'auc_mean': np.mean(dlt_auc)}}
    
def outcome_metrics(ytrue,ypred,numpy=False):
    res = {}
    for i, outcome in enumerate(Const.outcomes):
        metrics = mc_metrics(ytrue[i],ypred[:,i])
        res[outcome] = metrics
    return res


In [135]:
from sklearn.ensemble import RandomForestClassifier

def train_state_rf(model_args={}):
    ids = get_dt_ids()
    
    dataset = DTDataset()

    train_ids = ids[0:int(len(ids)*.7)]
    test_ids = ids[int(len(ids)*.7):]
    
    #most things are multiclass, dlts are several ordinal and outcomes are multiple binary
    xtrain1 = dataset.get_state('baseline',ids=train_ids)
    xtest1 = dataset.get_state('baseline',ids=test_ids)
    
    xtrain2 = dataset.get_input_state(step=2,ids=train_ids)
    xtest2 = dataset.get_input_state(step=2,ids=test_ids)
    
    xtrain3 = dataset.get_input_state(step=3,ids=train_ids)
    xtest3 = dataset.get_input_state(step=3,ids=test_ids)
    
    [pd1_train,nd1_train, mod_train,dlts1_train] = dataset.get_intermediate_outcomes(ids=train_ids)
    [pd2_train,nd2_train, cc_train,dlts2_train] = dataset.get_intermediate_outcomes(step=2,ids=train_ids)
    [pd1_test,nd1_test, mod_test,dlts1_test] = dataset.get_intermediate_outcomes(ids=test_ids)
    [pd2_test,nd2_test, cc_test,dlts2_test] = dataset.get_intermediate_outcomes(step=2,ids=test_ids)
    outcomes_train = dataset.get_state('outcomes',ids=train_ids)
    outcomes_test = dataset.get_state('outcomes',ids=test_ids)
    

    def train_multiclass_rf(xtrain,xtest,ytrain,ytest):
        model = RandomForestClassifier(class_weight='balanced',**model_args).fit(xtrain,ytrain)
        ypred = model.predict(xtest)
        metrics = mc_metrics(ytest.values,ypred,numpy=True)
        return model, metrics
    
    all_metrics = {}
    pd1_model, all_metrics['pd1'] = train_multiclass_rf(xtrain1,xtest1,pd1_train,pd1_test)
    nd1_model, all_metrics['nd1']  = train_multiclass_rf(xtrain1,xtest1,nd1_train,nd1_test)
    mod_model, all_metrics['mod']  = train_multiclass_rf(xtrain1,xtest1,mod_train,mod_test)
    
    return all_metrics

# train_state_rf({'max_depth': 5,'n_estimators': 100})

In [136]:
def train_state(model_args={},
                state=1,
                split=.7,
                lr=.0001,
                epochs=1000,
                patience=10,
                weights=[1,1,1,10],
                save_path='../data/models/',
                resample_training=False,#use bootstraping on training data after splitting
                resample_all = False,# use bootstrapping, then validate with out-of-bag data
                n_validation_trainsteps=2,
                file_suffix=''):
    
    ids = get_dt_ids()
    
    if resample_all:
        train_ids = np.random.choice(ids,len(ids),replace=True)
        test_ids = [i for i in ids if i not in train_ids]
    
    else:
        train_ids = ids[0:int(len(ids)*split)]
        if resample_training:
            train_ids = np.random.choice(train_ids,len(train_ids),replace=True)
        test_ids = ids[int(len(ids)*split):]
    
    dataset = DTDataset()
    
    xtrain = dataset.get_input_state(step=state,ids=train_ids)
    xtest = dataset.get_input_state(step=state,ids=test_ids)
    ytrain = dataset.get_intermediate_outcomes(step=state,ids=train_ids)
    ytest = dataset.get_intermediate_outcomes(step=state,ids=test_ids)
    

    if state < 3:
        model = OutcomeSimulator(xtrain.shape[1],state=state,**model_args)
        lfunc = state_loss
    else:
        model = EndpointSimulator(xtrain.shape[1],**model_args)
        weights = weights[:3]
        lfunc = outcome_loss
        
    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    save_file = save_path + 'model_' + model.identifier + '_split' + str(split) + '_resample' + str(resample_training) +  '_hash' + hashcode + file_suffix + '.tar'
    xtrain = df_to_torch(xtrain)
    xtest = df_to_torch(xtest)
    ytrain = [df_to_torch(t) for t in ytrain]
    ytest= [df_to_torch(t) for t in ytest]
    
    model.fit_normalizer(xtrain)
#     normalize = lambda x: (x - xtrain.mean(axis=0)+.01)/(xtrain.std(axis=0)+.01)
#     unnormalize = lambda x: (x * (xtrain.std(axis=0) +.01)) + xtrain.mean(axis=0) - .01
    
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    best_val_loss = 1000000000000000000000000000
    best_loss_metrics = {}
    last_epoch = False
    for epoch in range(epochs):
        
        model.train(True)
        optimizer.zero_grad()
        
        xtrain_sample = xtrain#[torch.randint(len(xtrain),(len(xtrain),) )]
        ypred = model(xtrain_sample)
        loss = lfunc(ytrain,ypred,weights=weights)

        loss.backward()
        optimizer.step()
        print('epoch',epoch,'train loss',loss.item())
        
        model.eval()
        yval = model(xtest)
        val_loss = lfunc(ytest,yval,weights=weights)
        if state < 3:
            val_metrics = state_metrics(ytest,yval)
        else:
            val_metrics = outcome_metrics(ytest,yval)
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            best_loss_metrics = val_metrics
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        print('val loss',val_loss.item())
        print('______________')
        if steps_since_improvement > patience:
            break
    print('best loss',best_val_loss,best_loss_metrics)
    model.load_state_dict(torch.load(save_file))
    
    #train one step on validation data
    for i in range(n_validation_trainsteps):
        model.train()
        yval = model(xtest)
        val_loss = lfunc(ytest,yval,weights=weights)
        val_loss.backward()
        optimizer.step()
        torch.save(model.state_dict(),save_file)
    
    model.eval()
    return model

model = train_state(state=1)
model

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


torch.Size([375, 8])
epoch 0 train loss 10.93934154510498
torch.Size([161, 8])
val loss 10.71257209777832
______________
torch.Size([375, 8])
epoch 1 train loss 10.764174461364746
torch.Size([161, 8])
val loss 10.569899559020996
______________
torch.Size([375, 8])
epoch 2 train loss 10.621047973632812
torch.Size([161, 8])
val loss 10.429350852966309
______________
torch.Size([375, 8])
epoch 3 train loss 10.535173416137695
torch.Size([161, 8])
val loss 10.290160179138184
______________
torch.Size([375, 8])
epoch 4 train loss 10.371516227722168
torch.Size([161, 8])
val loss 10.151975631713867
______________
torch.Size([375, 8])
epoch 5 train loss 10.235965728759766
torch.Size([161, 8])
val loss 10.014631271362305
______________
torch.Size([375, 8])
epoch 6 train loss 10.101807594299316
torch.Size([161, 8])
val loss 9.877754211425781
______________
torch.Size([375, 8])
epoch 7 train loss 9.937782287597656
torch.Size([161, 8])
val loss 9.74122142791748
______________
torch.Size([375, 8])
e

val loss 3.6322174072265625
______________
torch.Size([375, 8])
epoch 68 train loss 3.520026206970215
torch.Size([161, 8])
val loss 3.6033501625061035
______________
torch.Size([375, 8])
epoch 69 train loss 3.4553682804107666
torch.Size([161, 8])
val loss 3.575464963912964
______________
torch.Size([375, 8])
epoch 70 train loss 3.383659839630127
torch.Size([161, 8])
val loss 3.5485994815826416
______________
torch.Size([375, 8])
epoch 71 train loss 3.4115796089172363
torch.Size([161, 8])
val loss 3.5225296020507812
______________
torch.Size([375, 8])
epoch 72 train loss 3.429332971572876
torch.Size([161, 8])
val loss 3.497224807739258
______________
torch.Size([375, 8])
epoch 73 train loss 3.3457467555999756
torch.Size([161, 8])
val loss 3.4727094173431396
______________
torch.Size([375, 8])
epoch 74 train loss 3.3257861137390137
torch.Size([161, 8])
val loss 3.4489595890045166
______________
torch.Size([375, 8])
epoch 75 train loss 3.3371849060058594
torch.Size([161, 8])
val loss 3.42

epoch 136 train loss 2.3230459690093994
torch.Size([161, 8])
val loss 2.5637004375457764
______________
torch.Size([375, 8])
epoch 137 train loss 2.3839914798736572
torch.Size([161, 8])
val loss 2.554749011993408
______________
torch.Size([375, 8])
epoch 138 train loss 2.331793785095215
torch.Size([161, 8])
val loss 2.5457053184509277
______________
torch.Size([375, 8])
epoch 139 train loss 2.291590929031372
torch.Size([161, 8])
val loss 2.536470890045166
______________
torch.Size([375, 8])
epoch 140 train loss 2.3071773052215576
torch.Size([161, 8])
val loss 2.527391195297241
______________
torch.Size([375, 8])
epoch 141 train loss 2.2815301418304443
torch.Size([161, 8])
val loss 2.5184326171875
______________
torch.Size([375, 8])
epoch 142 train loss 2.251798152923584
torch.Size([161, 8])
val loss 2.509364366531372
______________
torch.Size([375, 8])
epoch 143 train loss 2.2468669414520264
torch.Size([161, 8])
val loss 2.500276565551758
______________
torch.Size([375, 8])
epoch 144 t

val loss 2.129121780395508
______________
torch.Size([375, 8])
epoch 206 train loss 1.8608014583587646
torch.Size([161, 8])
val loss 2.1258246898651123
______________
torch.Size([375, 8])
epoch 207 train loss 1.8387558460235596
torch.Size([161, 8])
val loss 2.122523784637451
______________
torch.Size([375, 8])
epoch 208 train loss 1.879799723625183
torch.Size([161, 8])
val loss 2.1191964149475098
______________
torch.Size([375, 8])
epoch 209 train loss 1.8462135791778564
torch.Size([161, 8])
val loss 2.1158227920532227
______________
torch.Size([375, 8])
epoch 210 train loss 1.8634463548660278
torch.Size([161, 8])
val loss 2.1125569343566895
______________
torch.Size([375, 8])
epoch 211 train loss 1.830691933631897
torch.Size([161, 8])
val loss 2.1095147132873535
______________
torch.Size([375, 8])
epoch 212 train loss 1.8402045965194702
torch.Size([161, 8])
val loss 2.106292724609375
______________
torch.Size([375, 8])
epoch 213 train loss 1.7900208234786987
torch.Size([161, 8])
val l

torch.Size([375, 8])
epoch 274 train loss 1.6322078704833984
torch.Size([161, 8])
val loss 1.9784153699874878
______________
torch.Size([375, 8])
epoch 275 train loss 1.619339108467102
torch.Size([161, 8])
val loss 1.9770835638046265
______________
torch.Size([375, 8])
epoch 276 train loss 1.5922596454620361
torch.Size([161, 8])
val loss 1.9759118556976318
______________
torch.Size([375, 8])
epoch 277 train loss 1.557849645614624
torch.Size([161, 8])
val loss 1.9747458696365356
______________
torch.Size([375, 8])
epoch 278 train loss 1.611525058746338
torch.Size([161, 8])
val loss 1.9734736680984497
______________
torch.Size([375, 8])
epoch 279 train loss 1.6725330352783203
torch.Size([161, 8])
val loss 1.9723656177520752
______________
torch.Size([375, 8])
epoch 280 train loss 1.5866079330444336
torch.Size([161, 8])
val loss 1.971360445022583
______________
torch.Size([375, 8])
epoch 281 train loss 1.5687716007232666
torch.Size([161, 8])
val loss 1.9703576564788818
______________
torc

epoch 344 train loss 1.4496715068817139
torch.Size([161, 8])
val loss 1.9293510913848877
______________
torch.Size([375, 8])
epoch 345 train loss 1.4966094493865967
torch.Size([161, 8])
val loss 1.9290759563446045
______________
torch.Size([375, 8])
epoch 346 train loss 1.4613925218582153
torch.Size([161, 8])
val loss 1.9286974668502808
______________
torch.Size([375, 8])
epoch 347 train loss 1.443617343902588
torch.Size([161, 8])
val loss 1.9282366037368774
______________
best loss 1.9277799129486084 {'pd': {'accuracy': 0.4970657276995305, 'roc_micro': 0.6841060174393507, 'roc_macro': 0.6229905883979276}, 'nd': {'accuracy': 0.6244596999745742, 'roc_micro': 0.7341362341362341, 'roc_macro': 0.6942363773142871}, 'mod': {'accuracy': 0.6244596999745742, 'roc_micro': 0.7341362341362341, 'roc_macro': 0.6942363773142871}, 'dlts': {'accuracy': [0.9875776397515528, 0.9503105590062112, 0.9192546583850931, 0.9503105590062112, 0.9875776397515528, 0.9937888198757764, 0.9440993788819876, 0.993788819

OutcomeSimulator(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=63, out_features=500, bias=True)
    (1): ReLU()
    (2): Linear(in_features=500, out_features=500, bias=True)
    (3): ReLU()
  )
  (batchnorm): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.7, inplace=False)
  (sigmoid): Sigmoid()
  (softmax): LogSoftmax(dim=1)
  (disease_layer): Linear(in_features=500, out_features=3, bias=True)
  (nodal_disease_layer): Linear(in_features=500, out_features=3, bias=True)
  (dlt_layers): ModuleList(
    (0): Linear(in_features=500, out_features=1, bias=True)
    (1): Linear(in_features=500, out_features=1, bias=True)
    (2): Linear(in_features=500, out_features=1, bias=True)
    (3): Linear(in_features=500, out_features=1, bias=True)
    (4): Linear(in_features=500, out_features=1, bias=True)
    (5): Linear(in_features=500, out_features=1, bias=True)
    (6): Linear(in_feat

In [138]:
model2 = train_state(state=2)
model2

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


epoch 0 train loss 10.824926376342773
val loss 10.559334754943848
______________
epoch 1 train loss 10.633543014526367
val loss 10.402231216430664
______________
epoch 2 train loss 10.499116897583008
val loss 10.251803398132324
______________
epoch 3 train loss 10.349471092224121
val loss 10.10681438446045
______________




epoch 4 train loss 10.246648788452148
val loss 9.966307640075684
______________
epoch 5 train loss 10.081233978271484
val loss 9.829595565795898
______________
epoch 6 train loss 9.986676216125488
val loss 9.696270942687988
______________
epoch 7 train loss 9.84835433959961
val loss 9.565403938293457
______________
epoch 8 train loss 9.70019817352295
val loss 9.436539649963379
______________
epoch 9 train loss 9.594338417053223
val loss 9.309237480163574
______________
epoch 10 train loss 9.44973373413086
val loss 9.183113098144531
______________
epoch 11 train loss 9.327949523925781
val loss 9.057470321655273
______________
epoch 12 train loss 9.184783935546875
val loss 8.932377815246582
______________
epoch 13 train loss 9.080824851989746
val loss 8.807476043701172
______________
epoch 14 train loss 8.970341682434082
val loss 8.682559967041016
______________
epoch 15 train loss 8.832427978515625
val loss 8.557458877563477
______________
epoch 16 train loss 8.705297470092773
val loss 

epoch 109 train loss 3.4936859607696533
val loss 3.344481945037842
______________
epoch 110 train loss 3.476807117462158
val loss 3.3395726680755615
______________
epoch 111 train loss 3.456163167953491
val loss 3.3347506523132324
______________
epoch 112 train loss 3.4975638389587402
val loss 3.3298404216766357
______________
epoch 113 train loss 3.462568759918213
val loss 3.3248934745788574
______________
epoch 114 train loss 3.5250649452209473
val loss 3.3200106620788574
______________
epoch 115 train loss 3.4841322898864746
val loss 3.3150322437286377
______________
epoch 116 train loss 3.5095221996307373
val loss 3.3101730346679688
______________
epoch 117 train loss 3.4353599548339844
val loss 3.3051886558532715
______________
epoch 118 train loss 3.46453595161438
val loss 3.300067663192749
______________
epoch 119 train loss 3.4357690811157227
val loss 3.294727325439453
______________
epoch 120 train loss 3.404020071029663
val loss 3.289228677749634
______________
epoch 121 trai

val loss 2.996961832046509
______________
epoch 210 train loss 2.9641823768615723
val loss 2.9946846961975098
______________
epoch 211 train loss 2.912851095199585
val loss 2.9923439025878906
______________
epoch 212 train loss 2.905364990234375
val loss 2.9899978637695312
______________
epoch 213 train loss 2.90580677986145
val loss 2.987903594970703
______________
epoch 214 train loss 2.8555729389190674
val loss 2.9861764907836914
______________
epoch 215 train loss 2.923715591430664
val loss 2.9847052097320557
______________
epoch 216 train loss 2.924726724624634
val loss 2.983126163482666
______________
epoch 217 train loss 2.8948700428009033
val loss 2.981353759765625
______________
epoch 218 train loss 2.869149923324585
val loss 2.9795684814453125
______________
epoch 219 train loss 2.8636679649353027
val loss 2.9778060913085938
______________
epoch 220 train loss 2.8237929344177246
val loss 2.9756650924682617
______________
epoch 221 train loss 2.86818528175354
val loss 2.973704

val loss 2.8787097930908203
______________
epoch 312 train loss 2.5532350540161133
val loss 2.8780312538146973
______________
epoch 313 train loss 2.530701160430908
val loss 2.8773062229156494
______________
epoch 314 train loss 2.539682149887085
val loss 2.8766777515411377
______________
epoch 315 train loss 2.5304579734802246
val loss 2.8760647773742676
______________
epoch 316 train loss 2.538092613220215
val loss 2.8756935596466064
______________
epoch 317 train loss 2.491151809692383
val loss 2.875434160232544
______________
epoch 318 train loss 2.5252819061279297
val loss 2.8753788471221924
______________
epoch 319 train loss 2.4834141731262207
val loss 2.8752541542053223
______________
epoch 320 train loss 2.5202746391296387
val loss 2.875352382659912
______________
epoch 321 train loss 2.543156385421753
val loss 2.8752477169036865
______________
epoch 322 train loss 2.540236711502075
val loss 2.875518321990967
______________
epoch 323 train loss 2.5715434551239014
val loss 2.87

OutcomeSimulator(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=85, out_features=500, bias=True)
    (1): ReLU()
    (2): Linear(in_features=500, out_features=500, bias=True)
    (3): ReLU()
  )
  (batchnorm): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.7, inplace=False)
  (sigmoid): Sigmoid()
  (softmax): LogSoftmax(dim=1)
  (disease_layer): Linear(in_features=500, out_features=3, bias=True)
  (nodal_disease_layer): Linear(in_features=500, out_features=3, bias=True)
  (dlt_layers): ModuleList(
    (0): Linear(in_features=500, out_features=1, bias=True)
    (1): Linear(in_features=500, out_features=1, bias=True)
    (2): Linear(in_features=500, out_features=1, bias=True)
    (3): Linear(in_features=500, out_features=1, bias=True)
    (4): Linear(in_features=500, out_features=1, bias=True)
    (5): Linear(in_features=500, out_features=1, bias=True)
    (6): Linear(in_feat

In [123]:
model3 = train_state(state=3,epochs=10000,lr=.0001)
model3

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more 

epoch 0 train loss 2.218344211578369
val loss 2.210022449493408
______________
epoch 1 train loss 2.183475971221924
val loss 2.1929216384887695
______________
epoch 2 train loss 2.174403667449951
val loss 2.1761274337768555
______________
epoch 3 train loss 2.212148666381836
val loss 2.1594808101654053
______________
epoch 4 train loss 2.1500658988952637
val loss 2.142989158630371
______________
epoch 5 train loss 2.164356231689453
val loss 2.1267170906066895
______________
epoch 6 train loss 2.135077953338623
val loss 2.1106340885162354
______________
epoch 7 train loss 2.10724139213562
val loss 2.0947723388671875
______________
epoch 8 train loss 2.12799072265625
val loss 2.0791072845458984
______________
epoch 9 train loss 2.126326560974121
val loss 2.063631534576416
______________
epoch 10 train loss 2.1051721572875977
val loss 2.048356294631958
______________
epoch 11 train loss 2.090031623840332
val loss 2.0332491397857666
______________
epoch 12 train loss 2.0584282875061035
val

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

epoch 20 train loss 1.982211947441101
val loss 1.9055818319320679
______________
epoch 21 train loss 1.9939837455749512
val loss 1.8923559188842773
______________
epoch 22 train loss 1.9311997890472412
val loss 1.879347801208496
______________
epoch 23 train loss 1.95609712600708
val loss 1.8665647506713867
______________
epoch 24 train loss 1.9366852045059204
val loss 1.853963017463684
______________
epoch 25 train loss 1.9314186573028564
val loss 1.8415237665176392
______________
epoch 26 train loss 1.8774622678756714
val loss 1.829293966293335
______________
epoch 27 train loss 1.932045817375183
val loss 1.817254900932312
______________
epoch 28 train loss 1.9049208164215088
val loss 1.805382490158081
______________
epoch 29 train loss 1.9003539085388184
val loss 1.7936649322509766
______________
epoch 30 train loss 1.8695964813232422
val loss 1.7821364402770996
______________
epoch 31 train loss 1.8689465522766113
val loss 1.7707622051239014
______________
epoch 32 train loss 1.853

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.686305284500122
______________
epoch 40 train loss 1.8151648044586182
val loss 1.6765446662902832
______________
epoch 41 train loss 1.7473175525665283
val loss 1.6669597625732422
______________
epoch 42 train loss 1.793830156326294
val loss 1.6574945449829102
______________
epoch 43 train loss 1.767322063446045
val loss 1.648215413093567
______________
epoch 44 train loss 1.786980152130127
val loss 1.6390645503997803
______________
epoch 45 train loss 1.7632697820663452
val loss 1.6300427913665771
______________
epoch 46 train loss 1.7553386688232422
val loss 1.621190071105957
______________
epoch 47 train loss 1.7306346893310547
val loss 1.6124708652496338
______________
epoch 48 train loss 1.6919059753417969
val loss 1.6039042472839355
______________
epoch 49 train loss 1.7396246194839478
val loss 1.595459222793579
______________
epoch 50 train loss 1.7337429523468018
val loss 1.587141513824463
______________
epoch 51 train loss 1.7292180061340332
val loss 1.5789737701416

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.5188868045806885
______________
epoch 60 train loss 1.683190107345581
val loss 1.511968731880188
______________
epoch 61 train loss 1.6636810302734375
val loss 1.50516939163208
______________
epoch 62 train loss 1.6554410457611084
val loss 1.4984455108642578
______________
epoch 63 train loss 1.6604485511779785
val loss 1.491806149482727
______________
epoch 64 train loss 1.6544256210327148
val loss 1.4852690696716309
______________
epoch 65 train loss 1.6428678035736084
val loss 1.4788458347320557
______________
epoch 66 train loss 1.618099331855774
val loss 1.4724886417388916
______________
epoch 67 train loss 1.6621730327606201
val loss 1.4662233591079712
______________
epoch 68 train loss 1.6350445747375488
val loss 1.4600450992584229
______________
epoch 69 train loss 1.627980351448059
val loss 1.4539787769317627
______________
epoch 70 train loss 1.6239428520202637
val loss 1.4480061531066895
______________
epoch 71 train loss 1.6380044221878052
val loss 1.442116260528

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.3982510566711426
______________
epoch 80 train loss 1.565911054611206
val loss 1.3931598663330078
______________
epoch 81 train loss 1.5879796743392944
val loss 1.3881311416625977
______________
epoch 82 train loss 1.547845482826233
val loss 1.3832142353057861
______________
epoch 83 train loss 1.5594079494476318
val loss 1.37835693359375
______________
epoch 84 train loss 1.5832908153533936
val loss 1.3736053705215454
______________
epoch 85 train loss 1.5628514289855957
val loss 1.3689500093460083
______________
epoch 86 train loss 1.559173583984375
val loss 1.3643817901611328
______________
epoch 87 train loss 1.5498645305633545
val loss 1.3598824739456177
______________
epoch 88 train loss 1.5539543628692627
val loss 1.3554925918579102
______________
epoch 89 train loss 1.5537748336791992
val loss 1.35115647315979
______________
epoch 90 train loss 1.5438522100448608
val loss 1.3468739986419678
______________
epoch 91 train loss 1.5483629703521729
val loss 1.342680931091

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.3081632852554321
______________
epoch 101 train loss 1.5068765878677368
val loss 1.3046435117721558
______________
epoch 102 train loss 1.469712257385254
val loss 1.301161289215088
______________
epoch 103 train loss 1.5116785764694214
val loss 1.2977745532989502
______________
epoch 104 train loss 1.4965943098068237
val loss 1.2944434881210327
______________
epoch 105 train loss 1.5095473527908325
val loss 1.2912112474441528
______________
epoch 106 train loss 1.4951059818267822
val loss 1.2880228757858276
______________
epoch 107 train loss 1.4716500043869019
val loss 1.2849090099334717
______________
epoch 108 train loss 1.4835295677185059
val loss 1.2818323373794556
______________
epoch 109 train loss 1.457260012626648
val loss 1.2787821292877197
______________
epoch 110 train loss 1.4685068130493164
val loss 1.2757682800292969
______________
epoch 111 train loss 1.4879016876220703
val loss 1.2728086709976196
______________
epoch 112 train loss 1.4752800464630127
val los

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.247314214706421
______________
epoch 121 train loss 1.4572542905807495
val loss 1.244698166847229
______________
epoch 122 train loss 1.4695650339126587
val loss 1.2421256303787231
______________
epoch 123 train loss 1.4493823051452637
val loss 1.2395895719528198
______________
epoch 124 train loss 1.434012770652771
val loss 1.237097144126892
______________
epoch 125 train loss 1.4258068799972534
val loss 1.2346456050872803
______________
epoch 126 train loss 1.451338529586792
val loss 1.23225736618042
______________
epoch 127 train loss 1.4521903991699219
val loss 1.2299144268035889
______________
epoch 128 train loss 1.438658356666565
val loss 1.2276263236999512
______________
epoch 129 train loss 1.4573379755020142
val loss 1.2253661155700684
______________
epoch 130 train loss 1.4260454177856445
val loss 1.2231390476226807
______________
epoch 131 train loss 1.4298896789550781
val loss 1.2209243774414062
______________
epoch 132 train loss 1.4270350933074951
val loss 1.2

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.2019667625427246
______________
epoch 141 train loss 1.4423015117645264
val loss 1.1999714374542236
______________
epoch 142 train loss 1.3891801834106445
val loss 1.198035478591919
______________
epoch 143 train loss 1.3882923126220703
val loss 1.196122407913208
______________
epoch 144 train loss 1.4034450054168701
val loss 1.194230079650879
______________
epoch 145 train loss 1.3888261318206787
val loss 1.1923631429672241
______________
epoch 146 train loss 1.4051902294158936
val loss 1.190535068511963
______________
epoch 147 train loss 1.38688063621521
val loss 1.1887164115905762
______________
epoch 148 train loss 1.3679534196853638
val loss 1.186880350112915
______________
epoch 149 train loss 1.408483862876892
val loss 1.1850736141204834
______________
epoch 150 train loss 1.3961130380630493
val loss 1.1833066940307617
______________
epoch 151 train loss 1.4076420068740845
val loss 1.1815506219863892
______________
epoch 152 train loss 1.3802143335342407
val loss 1.1

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.1669617891311646
______________
epoch 161 train loss 1.3371922969818115
val loss 1.165500521659851
______________
epoch 162 train loss 1.3908711671829224
val loss 1.1640815734863281
______________
epoch 163 train loss 1.3625434637069702
val loss 1.1627147197723389
______________
epoch 164 train loss 1.3765066862106323
val loss 1.161353349685669
______________
epoch 165 train loss 1.350749135017395
val loss 1.1599504947662354
______________
epoch 166 train loss 1.3836688995361328
val loss 1.1585791110992432
______________
epoch 167 train loss 1.3480528593063354
val loss 1.1572372913360596
______________
epoch 168 train loss 1.353005290031433
val loss 1.1559009552001953
______________
epoch 169 train loss 1.3375009298324585
val loss 1.1545751094818115
______________
epoch 170 train loss 1.3495255708694458
val loss 1.1532577276229858
______________
epoch 171 train loss 1.3317406177520752
val loss 1.1519722938537598
______________
epoch 172 train loss 1.342077374458313
val loss 

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.1404882669448853
______________
epoch 182 train loss 1.31235933303833
val loss 1.1394129991531372
______________
epoch 183 train loss 1.3391340970993042
val loss 1.1383341550827026
______________
epoch 184 train loss 1.3254098892211914
val loss 1.1372774839401245
______________
epoch 185 train loss 1.29573392868042
val loss 1.136276364326477
______________
epoch 186 train loss 1.3316372632980347
val loss 1.1353192329406738
______________
epoch 187 train loss 1.3529144525527954
val loss 1.134350299835205
______________
epoch 188 train loss 1.3449221849441528
val loss 1.1333909034729004
______________
epoch 189 train loss 1.316981315612793
val loss 1.1324089765548706
______________
epoch 190 train loss 1.303309679031372
val loss 1.131460428237915
______________
epoch 191 train loss 1.316555380821228
val loss 1.1304982900619507
______________
epoch 192 train loss 1.3382809162139893
val loss 1.1295318603515625
______________
epoch 193 train loss 1.316618800163269
val loss 1.1285

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.1206450462341309
______________
epoch 203 train loss 1.282752513885498
val loss 1.1198129653930664
______________
epoch 204 train loss 1.291266918182373
val loss 1.1190046072006226
______________
epoch 205 train loss 1.2817109823226929
val loss 1.1182512044906616
______________
epoch 206 train loss 1.2926836013793945
val loss 1.1175220012664795
______________
epoch 207 train loss 1.2778239250183105
val loss 1.1168122291564941
______________
epoch 208 train loss 1.291414737701416
val loss 1.1161203384399414
______________
epoch 209 train loss 1.3043330907821655
val loss 1.1154415607452393
______________
epoch 210 train loss 1.2769412994384766
val loss 1.1147615909576416
______________
epoch 211 train loss 1.2899832725524902
val loss 1.1141400337219238
______________
epoch 212 train loss 1.2859195470809937
val loss 1.11353600025177
______________
epoch 213 train loss 1.2754340171813965
val loss 1.1129651069641113
______________
epoch 214 train loss 1.2710018157958984
val loss 

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.1082181930541992
______________
epoch 223 train loss 1.2811672687530518
val loss 1.1076949834823608
______________
epoch 224 train loss 1.272822380065918
val loss 1.107169270515442
______________
epoch 225 train loss 1.2738507986068726
val loss 1.1066490411758423
______________
epoch 226 train loss 1.2930551767349243
val loss 1.1061451435089111
______________
epoch 227 train loss 1.2869210243225098
val loss 1.1056509017944336
______________
epoch 228 train loss 1.2624739408493042
val loss 1.1051534414291382
______________
epoch 229 train loss 1.26705801486969
val loss 1.1045973300933838
______________
epoch 230 train loss 1.257765293121338
val loss 1.1040544509887695
______________
epoch 231 train loss 1.255210041999817
val loss 1.1035206317901611
______________
epoch 232 train loss 1.265093207359314
val loss 1.102988362312317
______________
epoch 233 train loss 1.2702738046646118
val loss 1.1024585962295532
______________
epoch 234 train loss 1.2535185813903809
val loss 1.1

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

epoch 242 train loss 1.2606756687164307
val loss 1.0977232456207275
______________
epoch 243 train loss 1.2781271934509277
val loss 1.0971765518188477
______________
epoch 244 train loss 1.261707067489624
val loss 1.0966607332229614
______________
epoch 245 train loss 1.2515524625778198
val loss 1.0961501598358154
______________
epoch 246 train loss 1.2474435567855835
val loss 1.0956432819366455
______________
epoch 247 train loss 1.2588074207305908
val loss 1.0951437950134277
______________
epoch 248 train loss 1.2768776416778564
val loss 1.0946242809295654
______________
epoch 249 train loss 1.252355933189392
val loss 1.0941283702850342
______________
epoch 250 train loss 1.2388842105865479
val loss 1.0935862064361572
______________
epoch 251 train loss 1.2153706550598145
val loss 1.0930335521697998
______________
epoch 252 train loss 1.2450180053710938
val loss 1.0925260782241821
______________
epoch 253 train loss 1.2508444786071777
val loss 1.092094898223877
______________
epoch 2

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

epoch 261 train loss 1.2352464199066162
val loss 1.088696837425232
______________
epoch 262 train loss 1.2203960418701172
val loss 1.0883067846298218
______________
epoch 263 train loss 1.2256852388381958
val loss 1.087909460067749
______________
epoch 264 train loss 1.2275879383087158
val loss 1.0875473022460938
______________
epoch 265 train loss 1.226521611213684
val loss 1.0871723890304565
______________
epoch 266 train loss 1.2012603282928467
val loss 1.0867741107940674
______________
epoch 267 train loss 1.228930950164795
val loss 1.0863996744155884
______________
epoch 268 train loss 1.2483384609222412
val loss 1.0860430002212524
______________
epoch 269 train loss 1.2169607877731323
val loss 1.0857023000717163
______________
epoch 270 train loss 1.200575351715088
val loss 1.0853512287139893
______________
epoch 271 train loss 1.2149012088775635
val loss 1.084983468055725
______________
epoch 272 train loss 1.210311770439148
val loss 1.0845921039581299
______________
epoch 273 t

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

epoch 280 train loss 1.2028381824493408
val loss 1.0815625190734863
______________
epoch 281 train loss 1.2401208877563477
val loss 1.081196904182434
______________
epoch 282 train loss 1.2127610445022583
val loss 1.0808019638061523
______________
epoch 283 train loss 1.219635248184204
val loss 1.080428123474121
______________
epoch 284 train loss 1.2420127391815186
val loss 1.0800549983978271
______________
epoch 285 train loss 1.2384637594223022
val loss 1.0797175168991089
______________
epoch 286 train loss 1.1976566314697266
val loss 1.079380750656128
______________
epoch 287 train loss 1.219142198562622
val loss 1.0790400505065918
______________
epoch 288 train loss 1.203047513961792
val loss 1.0787090063095093
______________
epoch 289 train loss 1.2064025402069092
val loss 1.0783840417861938
______________
epoch 290 train loss 1.2456648349761963
val loss 1.0780367851257324
______________
epoch 291 train loss 1.189598798751831
val loss 1.0777308940887451
______________
epoch 292 t

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

epoch 299 train loss 1.2157334089279175
val loss 1.0756425857543945
______________
epoch 300 train loss 1.1812562942504883
val loss 1.0754690170288086
______________
epoch 301 train loss 1.1996971368789673
val loss 1.0752792358398438
______________
epoch 302 train loss 1.178942084312439
val loss 1.0750913619995117
______________
epoch 303 train loss 1.17598295211792
val loss 1.0749661922454834
______________
epoch 304 train loss 1.1904927492141724
val loss 1.0748088359832764
______________
epoch 305 train loss 1.1838793754577637
val loss 1.0746402740478516
______________
epoch 306 train loss 1.1704142093658447
val loss 1.0744588375091553
______________
epoch 307 train loss 1.1980042457580566
val loss 1.0742721557617188
______________
epoch 308 train loss 1.1998405456542969
val loss 1.0741138458251953
______________
epoch 309 train loss 1.1699802875518799
val loss 1.073923110961914
______________
epoch 310 train loss 1.18284010887146
val loss 1.07371985912323
______________
epoch 311 tr

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

epoch 318 train loss 1.1780306100845337
val loss 1.0721580982208252
______________
epoch 319 train loss 1.1766102313995361
val loss 1.0720081329345703
______________
epoch 320 train loss 1.1786820888519287
val loss 1.0718505382537842
______________
epoch 321 train loss 1.1678012609481812
val loss 1.0717049837112427
______________
epoch 322 train loss 1.17955482006073
val loss 1.071592092514038
______________
epoch 323 train loss 1.1680147647857666
val loss 1.0714830160140991
______________
epoch 324 train loss 1.1843730211257935
val loss 1.0713471174240112
______________
epoch 325 train loss 1.1951937675476074
val loss 1.0712192058563232
______________
epoch 326 train loss 1.1682565212249756
val loss 1.0711400508880615
______________
epoch 327 train loss 1.172165870666504
val loss 1.0710172653198242
______________
epoch 328 train loss 1.1589001417160034
val loss 1.0708848237991333
______________
epoch 329 train loss 1.1840319633483887
val loss 1.0707120895385742
______________
epoch 33

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.0688207149505615
______________
epoch 339 train loss 1.141921043395996
val loss 1.0685540437698364
______________
epoch 340 train loss 1.1381779909133911
val loss 1.0682426691055298
______________
epoch 341 train loss 1.1883059740066528
val loss 1.0679349899291992
______________
epoch 342 train loss 1.1454377174377441
val loss 1.067609190940857
______________
epoch 343 train loss 1.1651747226715088
val loss 1.0673274993896484
______________
epoch 344 train loss 1.1548490524291992
val loss 1.067038655281067
______________
epoch 345 train loss 1.1496789455413818
val loss 1.0667798519134521
______________
epoch 346 train loss 1.164698600769043
val loss 1.0665853023529053
______________
epoch 347 train loss 1.1641991138458252
val loss 1.0663594007492065
______________
epoch 348 train loss 1.1340746879577637
val loss 1.0661532878875732
______________
epoch 349 train loss 1.1485326290130615
val loss 1.065993309020996
______________
epoch 350 train loss 1.1509534120559692
val loss 

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.0646626949310303
______________
epoch 358 train loss 1.1570991277694702
val loss 1.0644718408584595
______________
epoch 359 train loss 1.1702977418899536
val loss 1.0642892122268677
______________
epoch 360 train loss 1.1568173170089722
val loss 1.0640987157821655
______________
epoch 361 train loss 1.1435490846633911
val loss 1.0639172792434692
______________
epoch 362 train loss 1.1513476371765137
val loss 1.0637500286102295
______________
epoch 363 train loss 1.141204595565796
val loss 1.0636340379714966
______________
epoch 364 train loss 1.1997660398483276
val loss 1.0635219812393188
______________
epoch 365 train loss 1.1538372039794922
val loss 1.0634214878082275
______________
epoch 366 train loss 1.1287474632263184
val loss 1.0632375478744507
______________
epoch 367 train loss 1.1814581155776978
val loss 1.0630733966827393
______________
epoch 368 train loss 1.1477134227752686
val loss 1.0629088878631592
______________
epoch 369 train loss 1.1384459733963013
val l

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

epoch 376 train loss 1.1353230476379395
val loss 1.0620818138122559
______________
epoch 377 train loss 1.1259982585906982
val loss 1.0620172023773193
______________
epoch 378 train loss 1.1593427658081055
val loss 1.061974287033081
______________
epoch 379 train loss 1.1329973936080933
val loss 1.0619146823883057
______________
epoch 380 train loss 1.1603161096572876
val loss 1.0618524551391602
______________
epoch 381 train loss 1.136909008026123
val loss 1.0617663860321045
______________
epoch 382 train loss 1.1118218898773193
val loss 1.0616968870162964
______________
epoch 383 train loss 1.151719331741333
val loss 1.0616353750228882
______________
epoch 384 train loss 1.1403226852416992
val loss 1.061580777168274
______________
epoch 385 train loss 1.1379244327545166
val loss 1.061497688293457
______________
epoch 386 train loss 1.165289282798767
val loss 1.0613808631896973
______________
epoch 387 train loss 1.1098217964172363
val loss 1.0612387657165527
______________
epoch 388 

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 1.060354232788086
______________
epoch 396 train loss 1.114006519317627
val loss 1.0603084564208984
______________
epoch 397 train loss 1.1458704471588135
val loss 1.0602431297302246
______________
epoch 398 train loss 1.1223887205123901
val loss 1.0601961612701416
______________
epoch 399 train loss 1.106978416442871
val loss 1.0601364374160767
______________
epoch 400 train loss 1.116714358329773
val loss 1.0601236820220947
______________
epoch 401 train loss 1.1279525756835938
val loss 1.0600742101669312
______________
epoch 402 train loss 1.1079354286193848
val loss 1.060011386871338
______________
epoch 403 train loss 1.1436927318572998
val loss 1.059918761253357
______________
epoch 404 train loss 1.1192282438278198
val loss 1.0599033832550049
______________
epoch 405 train loss 1.1220163106918335
val loss 1.0598728656768799
______________
epoch 406 train loss 1.1162127256393433
val loss 1.0598042011260986
______________
epoch 407 train loss 1.1303926706314087
val loss 1

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


EndpointSimulator(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=83, out_features=500, bias=True)
    (1): ReLU()
  )
  (batchnorm): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.7, inplace=False)
  (sigmoid): Sigmoid()
  (softmax): LogSoftmax(dim=1)
  (outcome_layer): Linear(in_features=500, out_features=3, bias=True)
)

In [210]:
class DecisionModel(SimulatorBase):
    
    def __init__(self,
                 baseline_input_size,#number of baseline features used
                 hidden_layers = [1000],
                 dropout = 0.5,
                 input_dropout=0.1,
                 state = 1,
                 eps = 0.01,
                 ):
        #input will be all states up until treatment 3
        input_size = baseline_input_size  + len(Const.dlt1) + len(Const.primary_disease_states)  + len(Const.nodal_disease_states)  + len(Const.ccs)  + len(Const.modifications) + 2
            
        super(DecisionModel,self).__init__(input_size,hidden_layers=hidden_layers,dropout=dropout,input_dropout=input_dropout,eps=eps,state='decisions')
        self.final_layer = torch.nn.Linear(hidden_layers[-1],len(Const.decisions))

#         self.final_layer = torch.nn.Linear(hidden_layers[-1],1)
        self.sigmoid = torch.nn.Sigmoid()
        
    def add_position_token(self,x,position):
        #add 2 binary variables for if the state has already passed
        if position == 0:
            token = torch.zeros((x.shape[0],2))
            x = torch.cat([x,token],dim=1)
        if position == 1:
            token1 = torch.ones((x.shape[0],1))
            token2 = torch.zeros((x.shape[0],1))
            x = torch.cat([x,token1,token2],dim=1)
        if position == 2:
            token1 = torch.zeros((x.shape[0],1))
            token2 = torch.ones((x.shape[0],1))
            x = torch.cat([x,token1,token2],dim=1)
        if position == 3:
            token1 = torch.ones((x.shape[0],1))
            token2 = torch.ones((x.shape[0],1))
            x = torch.cat([x,token1,token2],dim=1)
        return x
        
    def forward(self,x,position=0):
        #position is 0-2
        [xbase, xdlt, xpd, xnd, xcc,xmod] = x
        xbase = self.normalize(xbase)
        x = torch.cat([xbase,xdlt,xpd,xnd,xcc,xmod],dim=1)
        x = self.input_dropout(x)
        x = self.add_position_token(x,position)
#         print(x)
        for layer in self.layers:
            x = layer(x)
        x = self.dropout(x)
        x = self.final_layer(x)
        x = self.sigmoid(x)
        return x
test = DecisionModel(3)
test.identifier

'statedecisions_input30_dims1000_dropout0.1,0.5'

In [217]:
def train_decision_model(
    tmodel1,
    tmodel2,
    tmodel3,
    lr=.0001,
    epochs=100,
    patience=10,
    weights=[1,1,1], #realtive weight of survival, feeding tube, and aspiration
    imitation_weight=1,
    reward_weight=10,
    split=.8,
    resample_all=False,
    resample_training=False,
    save_path='../data/models/',
    file_suffix='',
):
    ids = get_dt_ids()
    tmodel1.eval()
    tmodel2.eval()
    tmodel3.eval()
    if resample_all:
        train_ids = np.random.choice(ids,len(ids),replace=True)
        test_ids = [i for i in ids if i not in train_ids]
    
    else:
        train_ids = ids[0:int(len(ids)*split)]
        if resample_training:
            train_ids = np.random.choice(train_ids,len(train_ids),replace=True)
        test_ids = ids[int(len(ids)*split):]
    
    dataset = DTDataset()
    
    data = dataset.processed_df.copy()
    
    def get_dlt(state):
        if state == 2:
            return data[Const.dlt2].copy()
        d = data[Const.dlt1].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_pd(state):
        if state == 2:
            return data[Const.primary_disease_states2].copy()
        d = data[Const.primary_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_nd(state):
        if state == 2:
            return data[Const.nodal_disease_states2].copy()
        d = data[Const.nodal_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_cc(state):
        res = data[Const.ccs].copy()
        if state == 1:
            res.values[:,:] = np.zeros(res.values.shape)
        return res
    
    def get_mod(state):
        res = data[Const.modifications].copy()
        return res
    
#     impact_of_decisions = check_impact_of_decisions(all_models,dataset)
    
    outcomedf = data[Const.outcomes]
    baseline = dataset.get_state('baseline')
    
    def formatdf(d):
        d = df_to_torch(d.loc[train_ids])
        return d
    
    def makegrad(v):
        if not v.requires_grad:
            v.requires_grad=True
        return v
    
    model = DecisionModel(baseline.shape[1])

    hashcode = str(hash(str(lr) + str(epochs) + str(patience) + str(weights[1]) + str(weights[-1])))
    
    save_file = save_path + 'model_' + model.identifier +'_hash' + hashcode + file_suffix + '.tar'
    model.fit_normalizer(df_to_torch(baseline.loc[train_ids]))
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    
    best_val_loss = 1000000000000000000000000000
    best_loss_metrics = {}
    last_epoch = False
    
    def outcome_loss(ypred):
        #convert survival to death
        loss = torch.mul(torch.mean(-1*(ypred[:,0] - 1)),weights[0])
        for i,weight in enumerate(weights[1:]):
            newloss = torch.mean(ypred[:,i])*weight
            loss = torch.add(loss,torch.mul(newloss,weight))
        return loss
    
    mse = torch.nn.MSELoss()
    nllloss = torch.nn.NLLLoss()
    bce = torch.nn.BCELoss()
    
    def compare_decisions(d1,d2,d3,ids):
#         ypred = np.concatenate([dd.cpu().detach().numpy().reshape(-1,1) for dd in [d1,d2,d3]],axis=1)
        ytrue = df_to_torch(outcomedf.loc[ids])
        dloss = bce(d1.view(-1),ytrue[:,0])
        dloss += bce(d2.view(-1),ytrue[:,1])
        dloss += bce(d3,view(-1),ytrue[:,2])
        return dloss
        
    def remove_decisions(df):
        cols = [c for c in df.columns if c not in Const.decisions ]
        ddf = df[cols]
        return ddf
    
    ytrain = df_to_torch(outcomedf.loc[train_ids])
    ytest = df_to_torch(outcomedf.loc[test_ids])
    
    makeinput = lambda step: df_to_torch(remove_decisions(dataset.get_input_state(step=step,ids=train_ids)))
    xx1 = makeinput(1)
    xx2 = makeinput(2)
    xx3 = makeinput(3)
    baseline_train = formatdf(baseline)
    xxtrain = [baseline, get_dlt(0),get_pd(0),get_nd(0),get_cc(0),get_mod(0)]
    xxtrain = [formatdf(xx) for xx in xxtrain]
    for epoch in range(epochs):
        
        model.train(True)
        optimizer.zero_grad()
        
        state = 0

        decision1 = model(xxtrain,0)[:,0]
        imitation_loss1 = bce(decision1,ytrain[:,0])

        xi1 = torch.cat([xx1,decision1.view(-1,1)],axis=1)
        [ypd1, ynd1, ymod, ydlt1] = tmodel1(xi1)
        x1 = [baseline_train,ydlt1,ypd1,ynd1,formatdf(get_c,ymod]
            
        decision2 = model(x1,1)[:,1]
        imitation_loss2 =  bce(decision2,ytrain[:,1])

        xi2 = torch.cat([xx2,decision1.view(-1,1),decision2.view(-1,1)],axis=1)
        [ypd2,ynd2,ycc,ydlt2] = tmodel2(xi2)
        x2 = [baseline_train,ydlt2,ypd2,ynd2,ycc,ymod]
            
        decision3 = model(x2,2)[:,2]
        imitation_loss3 = bce(decision3,ytrain[:,2])
        
        xi3 = torch.cat([xx3,decision1.view(-1,1),decision2.view(-1,1),decision3.view(-1,1)],axis=1)
        outcomes = tmodel3(xi3)

        reward_loss = outcome_loss(outcomes)
        loss = torch.add(imitation_loss1,imitation_loss2)
        loss = torch.add(loss,imitation_loss3)
        loss = torch.mul(loss,imitation_weight/3)
        loss = torch.add(loss,torch.mul(reward_loss,reward_weight))
        loss.backward()
        print(loss.item(),imitation_loss1.item(),imitation_loss2.item(),imitation_loss3.item(),reward_loss.item())
        optimizer.step()
    model.eval()
    return model

train_decision_model(model,model2,model3,imitation_weight=0.1,reward_weight=10)

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)


torch.Size([428, 8])


  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


NameError: name 'x1' is not defined

In [None]:
from captum.attr import IntegratedGradients

key ='FT_state3'
ig = IntegratedGradients(all_models[key])
dataset = DTDataset()
xtestdf = dataset.get_input_state(step=int(key[-1]))
xtest = df_to_torch(xtestdf)
all_models[key](xtest).shape
# attributions = ig.attribute(xtest,torch.zeros(xtest.shape),target=1)
# test = pd.DataFrame(attributions,columns = xtestdf.columns, index=xtestdf.index)
# test.describe().T

In [None]:
def breakup_state_models(state_model):
    #for state 1 and 2
    models = {}
    models['pd'] = lambda x: state_model(x)[0]
    models['nd'] = lambda x: state_model(x)[1]
    models['chemo'] = lambda x: state_model(x)[2]
    for i,dlt in enumerate(Const.dlt1):
        models[dlt] = lambda x: state_model(x)[3][:,i]
    return models

def breakup_outcome_models(omodel):
    models = {}
    for i,name in enumerate(Const.outcomes):
        models[name] = lambda x: omodel(x)[:,i].reshape(-1,1)
    return models

def get_all_models(m1,m2,m3):
    state1_models = breakup_state_models(m1)
    state2_models = breakup_state_models(m2)
    state3_models = breakup_outcome_models(m3)
    all_models = {}
    for i,sm in enumerate([state1_models,state2_models,state3_models]):
        for ii,m in sm.items():
            all_models[ii +  '_state' + str(i+1)] = m
    return all_models

all_models = get_all_models(model,model2,model3)
all_models

In [None]:

def get_ytrue(name,df):
    outcomes=None
    value = None
    if name == 'pd_state1':
        outcomes = df[Const.primary_disease_states]
    elif name == 'pd_state2':
        outcomes = df[Const.primary_disease_states2]
    elif name == 'nd_state1':
        outcomes = df[Const.nodal_disease_states]
    elif name == 'nd_state2':
        outcomes = df[Const.nodal_disease_states2]
    elif name == 'chemo_state1':
        outcomes = df[Const.modifications]
    elif name == 'chemo_state2':
        outcomes = df[Const.ccs]   
    if outcomes is not None:
        value = outcomes.idxmax(axis=1)
    if 'DLT' in name:
        newname = name.replace('_state', ' ').replace('1','').strip()
        value = df[newname]
    if name.replace('_state3','') in Const.outcomes:
        value = df[name.replace('_state3','')]
    if value is None:
        print(name,df.columns)
    return value

def check_impact_of_decisions(model_dict,data):
    results = []
    #todo: this is wrong fix it
    ids = []
    df = data.get_data()
    outcomedict = {step: pd.concat(data.get_intermediate_outcomes(step=step),axis=1) for step in [1,2,3]}
    for decision in Const.decisions:
        for name, model in model_dict.items():
            step = int(name[-1])
            subset0 = dataset.get_input_state(step=step,fixed={decision: 0})
            subset1 = dataset.get_input_state(step=step,fixed={decision: 1})
            outcomes = outcomedict[step]
            ids = subset0.index.values
            x0 = df_to_torch(subset0)
            x1 = df_to_torch(subset1)
            y0 = model(x0).detach().cpu().numpy()
            y1 = model(x1).detach().cpu().numpy()
            original = data.get_input_state(step=step)
            xx = df_to_torch(original)
            yy = model(xx).detach().cpu().numpy()
            ytrue = get_ytrue(name,outcomes)
            if "DLT" in name:
                y0 = y0.argmax(axis=1).reshape(-1,1)
                y1 = y1.argmax(axis=1).reshape(-1,1)
                yy = yy.argmax(axis=1).reshape(-1,1)
                change = y0 - y1
                decision_change = (y0 != y1).astype(int)
            elif y0.shape[1] == 1:
                change = y1 - y0
                decision_change = np.abs((y0 > .5).astype(int) - (y1 > .5).astype(int))
            else:
                index = np.unravel_index(np.argmax(yy, axis=1), yy.shape)
                change = (y0[index] - y1[index]).reshape(-1,1)
                decision_change =  (y0.argmax(axis=1).reshape(-1,1) != y1.argmax(axis=1).reshape(-1,1)).astype(int)
                yy = yy.argmax(axis=1).reshape(-1,1)
                y1 = y1.argmax(axis=1).reshape(-1,1)
                y0 = y0.argmax(axis=1).reshape(-1,1)
            outcome = name.replace('_state','')
            for ii,pid in enumerate(ids):
                oo = ytrue.loc[pid]
                onew = y0[ii][0]
                original_decision = df.loc[pid,decision]
                if original_decision > 0:
                    onew = y0[ii][0]
                oname = Const.name_dict.get(name)
                if oname is not None:
                    onew = oname[onew]
                entry = {'id': pid, 'decision': decision,'outcome': outcome,'original_choice': original_decision, 'original_result': oo, 'alt_result': onew, 'change': change[ii][0], 'decision_change': decision_change[ii][0]}
                results.append(entry)
    return pd.DataFrame(results)

test = check_impact_of_decisions(all_models,dataset)
test

In [None]:
data.get_data()['SD Primary 2'].sum()

In [None]:
(test[test.outcome == 'pd2'].original_result == 'SD Primary 2').sum()

In [None]:
check_impact_of_decisions(all_models,dataset).to_csv('../data/decision_impacts.csv')