In [1]:
import numpy as np
import pandas as pd
import torch
import re
pd.set_option('display.max_rows', 200)

In [2]:
class Const:
    data_dir = '../data'
    twin_data = data_dir + 'digital_twin_data.csv'
    twin_ln_data = data_dir + 'digital_twin_ln_data.csv'
    
    rename_dict = {
        'Dummy ID': 'id',
        'Age at Diagnosis (Calculated)': 'age',
        'Feeding tube 6m': 'FT',
        'Affected Lymph node UPPER': 'affected_nodes',
        'Aspiration rate(Y/N)': 'AS',
        'Neck boost (Y/N)': 'neck_boost',
        'Gender': 'gender',
        'Tm Laterality (R/L)': 'laterality',
        'AJCC 8th edition': 'ajcc8',
        'AJCC 7th edition':'ajcc7',
        'N_category_full': 'N-category',
        'HPV/P16 status': 'hpv',
        'Tumor subsite (BOT/Tonsil/Soft Palate/Pharyngeal wall/GPS/NOS)': 'subsite',
        'Total dose': 'total_dose',
        'Therapeutic combination': 'treatment',
        'Smoking status at Diagnosis (Never/Former/Current)': 'smoking_status',
        'Smoking status (Packs/Year)': 'packs_per_year',
        'Overall Survival (1=alive,0=dead)': 'os',
        'Dose/fraction (Gy)': 'dose_fraction'
    }
    
    dlt_dict = {
         'Allergic reaction to Cetuximab': 'DLT_Other',
         'Cardiological (A-fib)': 'DLT_Other',
         'Dermatological': 'DLT_Dermatological',
         'Failure to Thrive': 'DLT_Other',
         'Failure to thrive': 'DLT_Other',
         'GIT [elevated liver enzymes]': 'DLT_Gastrointestinal',
         'Gastrointestina': 'DLT_Gastrointestinal',
         'Gastrointestinal': 'DLT_Gastrointestinal',
         'General': 'DLT_Other',
         'Hematological': 'DLT_Hematological',
         'Hematological (Neutropenia)': 'DLT_Hematological',
         'Hyponatremia': 'DLT_Other',
         'Immunological': 'DLT_Other',
         'Infection': 'DLT_Infection (Pneumonia)',
         'NOS': 'DLT_Other',
         'Nephrological': 'DLT_Nephrological',
         'Nephrological (ARF)': 'DLT_Nephrological',
         'Neurological': 'DLT_Neurological',
         'Neutropenia': 'DLT_Hematological',
         'Nutritional': 'DLT_Other',
         'Pancreatitis': 'DLT_Other',
         'Pulmonary': 'DLT_Other',
         'Respiratory (Pneumonia)': 'DLT_Infection (Pneumonia)',
         'Sepsis': 'DLT_Infection (Pneumonia)',
         'Suboptimal response to treatment' : 'DLT_Other',
         'Vascular': 'DLT_Vascular'
    }
    
    decision1 = 'Decision 1 (Induction Chemo) Y/N'
    decision2 = 'Decision 2 (CC / RT alone)'
    decision3 = 'Decision 3 Neck Dissection (Y/N)'
    decisions = [decision1,decision2, decision3]
    outcomes = ['Overall Survival (4 Years)', 'FT', 'Aspiration rate Post-therapy']
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
    primary_disease_states = ['CR Primary','PR Primary','SD Primary']
    nodal_disease_states = [t.replace('Primary','Nodal') for t in primary_disease_states]
    dlt1 = list(set(dlt_dict.values()))
    
    modifications =  list(modification_types.values())
    state2 = modifications + primary_disease_states+nodal_disease_states +dlt1 #+['No imaging 0=N,1=Y']
    
    primary_disease_states2 = [t + ' 2' for t in primary_disease_states]
    nodal_disease_states2 = [t + ' 2' for t in nodal_disease_states]
    dlt2 = [d + ' 2' for d in dlt1]
    
    ccs = list(cc_types.values())
    state3 = ccs + primary_disease_states2 + nodal_disease_states2 + dlt2
    
Const.modification_types

{0: 'no_dose_adjustment',
 1: 'dose_modified',
 2: 'dose_delayed',
 3: 'dose_cancelled',
 4: 'dose_delayed_&_modified',
 5: 'regiment_modification',
 9: 'unknown'}

In [3]:
def preprocess(data_cleaned):
    #this was Elisa's preprocessing except I removed all the Ifs because that's dumb
    if len(data_cleaned.shape) < 2:
        data_cleaned = pd.DataFrame([data], columns=data.index)
        
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'N', 'Aspiration rate Pre-therapy'] = 0
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'Y', 'Aspiration rate Pre-therapy'] = 1

    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'I', 'Pathological Grade'] = 1
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'II', 'Pathological Grade'] = 2
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'III', 'Pathological Grade'] = 3
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'IV', 'Pathological Grade'] = 4

    data_cleaned.loc[(data_cleaned['T-category'] == 'Tx') | (data_cleaned['T-category'] == 'Tis'), 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T1', 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T2', 'T-category'] = 2
    data_cleaned.loc[data_cleaned['T-category'] == 'T3', 'T-category'] = 3
    data_cleaned.loc[data_cleaned['T-category'] == 'T4', 'T-category'] = 4

    data_cleaned.loc[data_cleaned['N-category'] == 'N0', 'N-category'] = 0
    data_cleaned.loc[data_cleaned['N-category'] == 'N1', 'N-category'] = 1
    data_cleaned.loc[data_cleaned['N-category'] == 'N2', 'N-category'] = 2
    data_cleaned.loc[data_cleaned['N-category'] == 'N3', 'N-category'] = 3

    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N0', 'N-category_8th_edition'] = 0
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N1', 'N-category_8th_edition'] = 1
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N2', 'N-category_8th_edition'] = 2
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N3', 'N-category_8th_edition'] = 3

    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'I', 'AJCC 7th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'II', 'AJCC 7th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'III', 'AJCC 7th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'IV', 'AJCC 7th edition'] = 4

    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'I', 'AJCC 8th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'II', 'AJCC 8th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'III', 'AJCC 8th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'IV', 'AJCC 8th edition'] = 4

    data_cleaned.loc[data_cleaned['Gender'] == 'Male', 'Gender'] = 1
    data_cleaned.loc[data_cleaned['Gender'] == 'Female', 'Gender'] = 0


    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Positive', 'HPV/P16 status'] = 1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Negative', 'HPV/P16 status'] = -1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Unknown', 'HPV/P16 status'] = 0

    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Formar', 'Smoking status at Diagnosis (Never/Former/Current)'] = .5
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Current', 'Smoking status at Diagnosis (Never/Former/Current)'] = 1
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Never', 'Smoking status at Diagnosis (Never/Former/Current)'] = 0


    data_cleaned.loc[data_cleaned['Chemo Modification (Y/N)'] == 'Y', 'Chemo Modification (Y/N)'] = 1

    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'N', 'DLT (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'Y', 'DLT (Y/N)'] = 1

    data_cleaned['DLT_Other'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT_Type'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT_Type']):
            if i.strip() != '' and data_cleaned.loc[index, Const.dlt_dict[i.strip()]] == 0:
                data_cleaned.loc[index, Const.dlt_dict[i.strip()]] = 1

    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'RT alone', 'Decision 2 (CC / RT alone)'] = 0
    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'CC', 'Decision 2 (CC / RT alone)'] = 1

    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'N', 'CC modification (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'Y', 'CC modification (Y/N)'] = 1

    data_cleaned['DLT_Dermatological 2'] = 0
    data_cleaned['DLT_Neurological 2'] = 0
    data_cleaned['DLT_Gastrointestinal 2'] = 0
    data_cleaned['DLT_Hematological 2'] = 0
    data_cleaned['DLT_Nephrological 2'] = 0
    data_cleaned['DLT_Vascular 2'] = 0
    data_cleaned['DLT_Infection (Pneumonia) 2'] = 0
    data_cleaned['DLT_Other 2'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT 2'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT 2']):
            if i.strip() != '':
                data_cleaned.loc[index, Const.dlt_dict[i.strip()] + ' 2'] = 1

    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'N', 'Decision 3 Neck Dissection (Y/N)'] = 0
    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'Y', 'Decision 3 Neck Dissection (Y/N)'] = 1

    return data_cleaned

def merge_editions(row,basecol='AJCC 8th edition',fallback='AJCC 7th edition'):
    if pd.isnull(row[basecol]):
        return row[fallback]
    return row[basecol]


def preprocess_dt_data(df,extra_to_keep=None):
    
    to_keep = ['id','hpv','age','packs_per_year','smoking_status','gender','Aspiration rate Pre-therapy','total_dose','dose_fraction'] 
    to_onehot = ['T-category','N-category','AJCC','Pathological Grade','subsite','treatment','ln_cluster']
    if extra_to_keep is not None:
        to_keep = to_keep + [c for c in extra_to_keep if c not in to_keep and c not in to_onehot]
    
    decisions =Const.decisions
    outcomes = Const.outcomes
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
    for k,v in Const.cc_types.items():
        df[v] = df['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)'].apply(lambda x: int(Const.cc_types.get(int(x),0) == v))
        to_keep.append(v)
    for k,v in Const.modification_types.items():
        name = 'Modification Type (0= no dose adjustment, 1=dose modified, 2=dose delayed, 3=dose cancelled, 4=dose delayed & modified, 5=regimen modification, 9=unknown)'
        df[v] = df[name].apply(lambda x: int(Const.modification_types.get(int(x),0) == v))
        to_keep.append(v)
    #Features to keep. I think gender is is in 
    
    keywords = []
    for keyword in keywords:
        toadd = [c for c in df.columns if keyword in c and c not in to_keep]
        to_keep = to_keep + toadd
    
    df['packs_per_year'] = df['packs_per_year'].apply(lambda x: str(x).replace('>','').replace('<','')).astype(float).fillna(0)
    #so I'm actually not sure if this is biological sex or gender given this is texas
    df['AJCC'] = df.apply(lambda row: merge_editions(row,'ajcc8','ajcc7'),axis=1)
    df['N-category'] = df.apply(lambda row: merge_editions(row,'N-category_8th_edition','N-category'),axis=1)
    
    dummy_df = pd.get_dummies(df[to_onehot].fillna(0).astype(str),drop_first=False)
    for col in dummy_df.columns:
        df[col] = dummy_df[col]
        to_keep.append(col)
        
    yn_to_binary = ['FT','Aspiration rate Post-therapy','Decision 1 (Induction Chemo) Y/N']
    for col in yn_to_binary:
        df[col] = df[col].apply(lambda x: int(x == 'Y'))
        
    to_keep = to_keep + [c for c in df.columns if 'DLT' in c]
    
        
    for statelist in [Const.state2,Const.state3,Const.decisions,Const.outcomes]:
        toadd = [c for c in statelist if c not in to_keep]
        to_keep = to_keep + toadd
    return df[to_keep].set_index('id')



In [4]:
def load_digital_twin(file='../data/digital_twin_data.csv'):
    df = pd.read_csv(file)
    return df.rename(columns = Const.rename_dict)

def get_dt_ids():
    df = load_digital_twin()
    return df.id.values
get_dt_ids()

  df = pd.read_csv(file)


array([    3,     5,     6,     7,     8,     9,    10,    11,    13,
          14,    15,    16,    17,    18,    21,    23,    24,    25,
          26,    27,    28,    31,    32,    33,    35,    36,    37,
          38,    39,    40,    41,    42,    44,    45,    47,    48,
          49,    50,    51,    53,    55,    56,    57,    60,    64,
          65,    67,    68,    69,    71,    74,    75,    77,    78,
          79,    80,    81,    82,    87,    88,    91,    94,    96,
          99,   103,   109,   116,   117,   119,   120,   121,   125,
         133,   148,   150,   153,   168,   178,   181,   183,   184,
         185,   186,   187,   188,   189,   190,   191,   192,   193,
         194,   195,   196,   197,   198,   199,   200,   201,   202,
         203,   204,   205,   206,   207,   208,   209,   210,   211,
         212,   213,   214,   215,   216,   217,   218,   219,   220,
         221,   222,   223,   224,   225,   226,   227,   228,   229,
         230,   231,

In [5]:
class DTDataset():
    
    def __init__(self,data_file = '../data/digital_twin_data.csv',ln_data_file = '../data/digital_twin_ln_data.csv',ids=None):
        df = pd.read_csv(data_file)
        
        df = preprocess(df)
        df = df.rename(columns = Const.rename_dict).copy()
        df = df.drop('MRN OPC',axis=1)

        ln_data = pd.read_csv(ln_data_file)
        ln_data = ln_data.rename(columns={'cluster':'ln_cluster'})
        self.ln_cols = [c for c in ln_data.columns if c not in df.columns]
        df = df.merge(ln_data,on='id')
        df.index = df.index.astype(int)
        if ids is not None:
            df = df[df.id.apply(lambda x: x in ids)]
        self.processed_df = preprocess_dt_data(df,self.ln_cols).fillna(0)
        
        self.means = self.processed_df.mean(axis=0)
        self.stds = self.processed_df.std(axis=0)
        self.maxes = self.processed_df.max(axis=0)
        self.mins = self.processed_df.min(axis=0)
        
        arrays = self.get_states()
        self.state_sizes = {k: (v.shape[1] if v.ndim > 1 else 1) for k,v in arrays.items()}
        
    def get_data(self):
        return self.processed_df
    
    def sample(self,frac=.5):
        return self.processed_df.sample(frac=frac)
    
    def split_sample(self,ratio = .3):
        assert(ratio > 0 and ratio <= 1)
        df1 = self.processed_df.sample(frac=1-ratio)
        df2 = self.processed_df.drop(index=df1.index)
        return df1,df2
    
    def get_states(self,fixed=None,ids = None):
        processed_df = self.processed_df.copy()
        if ids is not None:
            processed_df = processed_df.loc[ids]
        if fixed is not None:
            for col,val in fixed.items():
                if col in processed_df.columns:
                    processed_df[col] = val
                else:
                    print('bad fixed entry',col)
                    
        to_skip = ['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)','DLT_Type','DLT 2'] + [c for c in processed_df.columns if 'treatment' in c]
        other_states = set(Const.decisions + Const.state3 + Const.state2 + Const.outcomes  + to_skip)

        base_state = sorted([c for c in processed_df.columns if c not in other_states])

        dlt1 = Const.dlt1
        dlt2 = Const.dlt2
        
        modifications = Const.modifications
        ccs = Const.ccs
        pds = Const.primary_disease_states
        nds = Const.nodal_disease_states
        pds2 = Const.primary_disease_states2
        nds2 = Const.nodal_disease_states2
        outcomes = Const.outcomes
        decisions= Const.decisions
        
        #intermediate states are only udated values. Models should use baseline + state2 etc
        results = {
            'baseline': processed_df[base_state],
            'pd_states1': processed_df[pds],
            'nd_states1': processed_df[nds],
            'modifications': processed_df[modifications],
            'ccs': processed_df[ccs],
            'pd_states2': processed_df[pds2],
            'nd_states2': processed_df[nds2],
            'outcomes': processed_df[outcomes],
            'dlt1': processed_df[dlt1],
            'dlt2': processed_df[dlt2],
            'decision1': processed_df[decisions[0]],
            'decision2': processed_df[decisions[1]],
            'decision3': processed_df[decisions[2]],
        }
    
        return results
    
    def get_state(self,name,**kwargs):
        return self.get_states(**kwargs)[name]
    
    def normalize(self,df):
        means = self.means[df.columns]
        std = self.stds[df.columns]
        return ((df - means)/std).fillna(0)
    
    def get_intermediate_outcomes(self,step=1,**kwargs):
        assert(step in [1,2,3])
        states = self.get_states(**kwargs)
        if step == 1:
            keys = ['pd_states1','nd_states1','modifications','dlt1']
        if step == 2:
            keys =  ['pd_states2','nd_states2','ccs','dlt2']
        if step == 3:
            keys = ['decision1','decision2','decision3']
        return [states[key] for key in keys]
    
    def get_input_state(self,step=1,**kwargs):
        assert(step in [1,2,3])
        states = self.get_states(**kwargs)
        if step == 1:
            keys = ['baseline','decision1']
        if step == 2:
            keys =  ['baseline','pd_states1','nd_states1','modifications','dlt1','decision1','decision2']
        if step == 3:
            keys = ['baseline','pd_states2','nd_states2','ccs','dlt2','decision1','decision2','decision3']
        arrays = [states[key] for key in keys]
        return pd.concat(arrays,axis=1)
    
data = DTDataset()
data.get_intermediate_outcomes(step=3)

  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


[id
 3        0
 5        0
 6        0
 7        0
 8        0
         ..
 10201    0
 10202    0
 10203    0
 10204    0
 10205    0
 Name: Decision 1 (Induction Chemo) Y/N, Length: 536, dtype: int64,
 id
 3        1
 5        1
 6        1
 7        0
 8        0
         ..
 10201    1
 10202    1
 10203    0
 10204    1
 10205    1
 Name: Decision 2 (CC / RT alone), Length: 536, dtype: int64,
 id
 3        0
 5        0
 6        0
 7        0
 8        0
         ..
 10201    0
 10202    1
 10203    1
 10204    0
 10205    1
 Name: Decision 3 Neck Dissection (Y/N), Length: 536, dtype: int64]

In [6]:
def df_to_torch(df,ttype  = torch.FloatTensor):
    values = df.values.astype(float)
    values = torch.from_numpy(values)
    return values.type(ttype)

In [8]:
class SimulatorBase(torch.nn.Module):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [1000],
                 dropout = 0.5,
                 input_dropout=0.1,
                 state = 1,
                 eps = 0.01,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        torch.nn.Module.__init__(self)
        self.state = state
        self.input_dropout = torch.nn.Dropout(input_dropout)
        
        first_layer =torch.nn.Linear(input_size,hidden_layers[0],bias=True)
        layers = [first_layer,torch.nn.ReLU()]
        curr_size = hidden_layers[0]
        for ndim in hidden_layers[1:]:
            layer = torch.nn.Linear(curr_size,ndim)
            curr_size = ndim
            layers.append(layer)
            layers.append(torch.nn.ReLU())
        self.layers = torch.nn.ModuleList(layers)
        self.batchnorm = torch.nn.BatchNorm1d(hidden_layers[-1])
        self.dropout = torch.nn.Dropout(dropout)
        
    
        input_mean = torch.tensor([0])
        input_std = torch.tensor([1])
        self.eps = eps
        self.register_buffer('input_mean', input_mean)
        self.register_buffer('input_std',input_std)
        
        self.softmax = torch.nn.LogSoftmax(dim=1)
        self.identifier = 'state'  +str(state) + '_input'+str(input_size) + '_dims' + ','.join([str(h) for h in hidden_layers]) + '_dropout' + str(input_dropout) + ',' + str(dropout)
        
    def normalize(self,x):
        x = (x - self.input_mean + self.eps)/(self.input_std + self.eps)
        return x
    
    def fit_normalizer(self,x):
        input_mean = x.mean(axis=0)
        input_std = x.std(axis=0)
        self.register_buffer('input_mean', input_mean)
        self.register_buffer('input_std',input_std)
        return True
        
class OutcomeSimulator(SimulatorBase):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [1000,1000],
                 dropout = 0.5,
                 input_dropout=0.1,
                 state = 1,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        super(OutcomeSimulator,self).__init__(input_size,hidden_layers=hidden_layers,dropout=dropout,input_dropout=input_dropout,state=state)
    
        self.disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.primary_disease_states))
        self.nodal_disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.nodal_disease_states))
        #dlt ratings are 0-4 even though they don't always appear
        
        assert( state in [1,2])
        if state == 1:
            self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],5) for i in Const.dlt1])
            self.treatment_layer = torch.nn.Linear(hidden_layers[-1],len(Const.modifications))
        else:
            #we only have dlt yes or no for the second state?
            self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],2) for i in Const.dlt2])
            self.treatment_layer = torch.nn.Linear(hidden_layers[-1],len(Const.ccs))

    def forward(self,x):
        x = self.normalize(x)
        x = self.input_dropout(x)
        for layer in self.layers:
            x = layer(x)
#         x = self.batchnorm(x)
        x = self.dropout(x)
        x_pd = self.disease_layer(x)
        x_nd = self.nodal_disease_layer(x)
        x_mod = self.treatment_layer(x)
        x_dlts = [layer(x) for layer in self.dlt_layers]
        
        x_pd = self.softmax(x_pd)
        x_nd = self.softmax(x_nd)
        x_mod = self.softmax(x_mod)
        #dlts are array of nbatch x n_dlts x predictions
        x_dlts = torch.stack([self.softmax(xx) for xx in x_dlts],axis=1)
        return [x_pd, x_nd, x_mod, x_dlts]
    
class EndpointSimulator(SimulatorBase):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [500],
                 dropout = 0.5,
                 input_dropout=0.1,
                 state = 1,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        super(EndpointSimulator,self).__init__(input_size,hidden_layers=hidden_layers,dropout=dropout,input_dropout=input_dropout,state=state)
        
        self.outcome_layer = torch.nn.Linear(hidden_layers[-1],len(Const.outcomes))
        self.sigmoid = torch.nn.Sigmoid()
      
        
    def forward(self,x):
        x = self.normalize(x)
        x = self.input_dropout(x)
        for layer in self.layers:
            x = layer(x)
#         x = self.batchnorm(x)
        x = self.dropout(x)
        x= self.outcome_layer(x)
        x = self.sigmoid(x)
        return x

OutcomeSimulator(3).input_mean

tensor([0])

In [9]:
def nllloss(ytrue,ypred):
    #nll loss with argmax added in
    loss = torch.nn.NLLLoss()
    return loss(ypred,ytrue.argmax(axis=1))

def state_loss(ytrue,ypred,weights=[1,1,1,1]):
    pd_loss = nllloss(ytrue[0],ypred[0])*weights[0]
    nd_loss = nllloss(ytrue[1],ypred[1])*weights[1]
    mod_loss = nllloss(ytrue[2],ypred[2])*weights[2]
    loss = pd_loss + nd_loss + mod_loss
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
    nloss = torch.nn.NLLLoss()
    for i in range(ndlt):
        dlt_loss = nloss(dlt_pred[:,i,:],dlt_true[:,i].type(torch.LongTensor))
        loss += dlt_loss*weights[3]/ndlt
    return loss

def outcome_loss(ytrue,ypred,weights=[1,1,1]):
    loss = 0
    nloss = torch.nn.BCELoss()
    for i in range(len(weights)):
        iloss = nloss(ypred[:,i],ytrue[i])*weights[i]
        loss += iloss
    return loss

from sklearn.metrics import balanced_accuracy_score, roc_auc_score

def mc_metrics(yt,yp,numpy=False,is_dlt=False):
    if not numpy:
        yt = yt .cpu().detach().numpy()
        yp = yp.cpu().detach().numpy()
    #this is a catch for when I se the dlt prediction format (encoded integer ordinal, predict as a categorical and take the argmax)
    if yt.ndim > 1:
        try:
            bacc = balanced_accuracy_score(yt.argmax(axis=1),yp.argmax(axis=1))
        except:
            bacc = -1
        try:
            roc_micro = roc_auc_score(yt,yp,average='micro')
        except:
            roc_micro=-1
        try:
            roc_macro = roc_auc_score(yt,yp,average='macro')
        except:
            roc_macro = -1
        return {'accuracy': bacc, 'roc_micro': roc_micro,'roc_macro': roc_macro}
    else:
        if yp.ndim > 1:
            yp = yp.argmax(axis=1)
        try:
            bacc = balanced_accuracy_score(yt,yp)
        except:
            bacc = -1
        try:
            if is_dlt:
                roc = roc_auc_score(yt > 0, yp > 0)
            else:
                roc = roc_auc_score(yt ,yp )
        except:
            roc = -1
        error = np.mean((yt-yp)**2)
        return {'accuracy': bacc, 'mse': error, 'auc': roc}

def state_metrics(ytrue,ypred,numpy=False):
    pd_metrics = mc_metrics(ytrue[0],ypred[0],numpy=numpy)
    nd_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    mod_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    
    dlt_metrics = []
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
    nloss = torch.nn.NLLLoss()
    for i in range(ndlt):
        dm = mc_metrics(dlt_true[:,i],dlt_pred[:,i,:],is_dlt=True)
        dlt_metrics.append(dm)
    dlt_acc =[d['accuracy'] for d in dlt_metrics]
    dlt_error = [d['mse'] for d in dlt_metrics]
    dlt_auc = [d['auc'] for d in dlt_metrics]
    return {'pd': pd_metrics,'nd': nd_metrics,'mod': mod_metrics,'dlts': {'accuracy': dlt_acc,'accuracy_mean': np.mean(dlt_acc),'auc': dlt_auc,'auc_mean': np.mean(dlt_auc)}}
    
def outcome_metrics(ytrue,ypred,numpy=False):
    res = {}
    for i, outcome in enumerate(Const.outcomes):
        metrics = mc_metrics(ytrue[i],ypred[:,i])
        res[outcome] = metrics
    return res


  from scipy.sparse.base import spmatrix


In [10]:
from sklearn.ensemble import RandomForestClassifier

def train_state_rf(model_args={}):
    ids = get_dt_ids()
    
    dataset = DTDataset()

    train_ids = ids[0:int(len(ids)*.7)]
    test_ids = ids[int(len(ids)*.7):]
    
    #most things are multiclass, dlts are several ordinal and outcomes are multiple binary
    xtrain1 = dataset.get_state('baseline',ids=train_ids)
    xtest1 = dataset.get_state('baseline',ids=test_ids)
    
    xtrain2 = dataset.get_input_state(step=2,ids=train_ids)
    xtest2 = dataset.get_input_state(step=2,ids=test_ids)
    
    xtrain3 = dataset.get_input_state(step=3,ids=train_ids)
    xtest3 = dataset.get_input_state(step=3,ids=test_ids)
    
    [pd1_train,nd1_train, mod_train,dlts1_train] = dataset.get_intermediate_outcomes(ids=train_ids)
    [pd2_train,nd2_train, cc_train,dlts2_train] = dataset.get_intermediate_outcomes(step=2,ids=train_ids)
    [pd1_test,nd1_test, mod_test,dlts1_test] = dataset.get_intermediate_outcomes(ids=test_ids)
    [pd2_test,nd2_test, cc_test,dlts2_test] = dataset.get_intermediate_outcomes(step=2,ids=test_ids)
    outcomes_train = dataset.get_state('outcomes',ids=train_ids)
    outcomes_test = dataset.get_state('outcomes',ids=test_ids)
    

    def train_multiclass_rf(xtrain,xtest,ytrain,ytest):
        model = RandomForestClassifier(class_weight='balanced',**model_args).fit(xtrain,ytrain)
        ypred = model.predict(xtest)
        metrics = mc_metrics(ytest.values,ypred,numpy=True)
        return model, metrics
    
    all_metrics = {}
    pd1_model, all_metrics['pd1'] = train_multiclass_rf(xtrain1,xtest1,pd1_train,pd1_test)
    nd1_model, all_metrics['nd1']  = train_multiclass_rf(xtrain1,xtest1,nd1_train,nd1_test)
    mod_model, all_metrics['mod']  = train_multiclass_rf(xtrain1,xtest1,mod_train,mod_test)
    
    return all_metrics

# train_state_rf({'max_depth': 5,'n_estimators': 100})

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_Gram=True, verbose=0,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes

In [11]:
def train_state(model_args={},
                state=1,
                split=.7,
                lr=.0001,
                epochs=1000,
                patience=10,
                weights=[1,1,1,10],
                save_path='../data/models/',
                resample_training=False,#use bootstraping on training data after splitting
                resample_all = False,# use bootstrapping, then validate with out-of-bag data
                file_suffix=''):
    
    ids = get_dt_ids()
    
    if resample_all:
        train_ids = np.random.choice(ids,len(ids),replace=True)
        test_ids = [i for i in ids if i not in train_ids]
    
    else:
        train_ids = ids[0:int(len(ids)*split)]
        if resample_training:
            train_ids = np.random.choice(train_ids,len(train_ids),replace=True)
        test_ids = ids[int(len(ids)*split):]
    
    dataset = DTDataset()
    
    xtrain = dataset.get_input_state(step=state,ids=train_ids)
    xtest = dataset.get_input_state(step=state,ids=test_ids)
    ytrain = dataset.get_intermediate_outcomes(step=state,ids=train_ids)
    ytest = dataset.get_intermediate_outcomes(step=state,ids=test_ids)
    

    if state < 3:
        model = OutcomeSimulator(xtrain.shape[1],state=state,**model_args)
        lfunc = state_loss
    else:
        model = EndpointSimulator(xtrain.shape[1],**model_args)
        weights = weights[:3]
        lfunc = outcome_loss
        
    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    save_file = save_path + 'model_' + model.identifier + '_split' + str(split) + '_resample' + str(resample_training) +  '_hash' + hashcode + file_suffix + '.tar'
    xtrain = df_to_torch(xtrain)
    xtest = df_to_torch(xtest)
    ytrain = [df_to_torch(t) for t in ytrain]
    ytest= [df_to_torch(t) for t in ytest]
    
    model.fit_normalizer(xtrain)
#     normalize = lambda x: (x - xtrain.mean(axis=0)+.01)/(xtrain.std(axis=0)+.01)
#     unnormalize = lambda x: (x * (xtrain.std(axis=0) +.01)) + xtrain.mean(axis=0) - .01
    
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    best_val_loss = 1000000000000000000000000000
    best_loss_metrics = {}
    for epoch in range(epochs):
        
        model.train(True)
        optimizer.zero_grad()
        
        xtrain_sample = xtrain#[torch.randint(len(xtrain),(len(xtrain),) )]
        ypred = model(xtrain_sample)
        loss = lfunc(ytrain,ypred,weights=weights)

        loss.backward()
        optimizer.step()
        print('epoch',epoch,'train loss',loss.item())
        
        model.eval()
        yval = model(xtest)
        val_loss = lfunc(ytest,yval,weights=weights)
        if state < 3:
            val_metrics = state_metrics(ytest,yval)
        else:
            val_metrics = outcome_metrics(ytest,yval)
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            best_loss_metrics = val_metrics
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        print('val loss',val_loss.item())
        print('______________')
        if steps_since_improvement > patience:
            break
    print('best loss',best_val_loss,best_loss_metrics)
    model.load_state_dict(torch.load(save_file))
    model = model.eval()
    return model
model = train_state(state=1)
model

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


epoch 0 train loss 20.31390380859375
val loss 19.554641723632812
______________
epoch 1 train loss 19.594619750976562
val loss 18.907915115356445
______________




epoch 2 train loss 18.947132110595703
val loss 18.26899528503418
______________
epoch 3 train loss 18.26720428466797
val loss 17.633028030395508
______________
epoch 4 train loss 17.62619400024414




val loss 16.996232986450195
______________
epoch 5 train loss 16.96844482421875
val loss 16.355268478393555
______________
epoch 6 train loss 16.319337844848633
val loss 15.707328796386719
______________
epoch 7 train loss 15.674251556396484
val loss 15.051858901977539
______________
epoch 8 train loss 15.008088111877441
val loss 14.388801574707031
______________
epoch 9 train loss 14.296045303344727
val loss 13.719712257385254
______________
epoch 10 train loss 13.625718116760254
val loss 13.046521186828613
______________
epoch 11 train loss 12.941590309143066
val loss 12.372965812683105
______________
epoch 12 train loss 12.259780883789062
val loss 11.703165054321289
______________
epoch 13 train loss 11.576982498168945
val loss 11.042352676391602
______________
epoch 14 train loss 10.88326644897461
val loss 10.39614200592041
______________
epoch 15 train loss 10.168752670288086
val loss 9.769389152526855
______________
epoch 16 train loss 9.645013809204102
val loss 9.167439460754395

val loss 2.557722806930542
______________
epoch 107 train loss 1.9708058834075928
val loss 2.548135995864868
______________
epoch 108 train loss 1.9384851455688477
val loss 2.5380640029907227
______________
epoch 109 train loss 1.931792974472046
val loss 2.5281875133514404
______________
epoch 110 train loss 1.9543131589889526
val loss 2.5185060501098633
______________
epoch 111 train loss 1.9031983613967896
val loss 2.5094523429870605
______________
epoch 112 train loss 1.9226182699203491
val loss 2.500915050506592
______________
epoch 113 train loss 1.9000120162963867
val loss 2.493039846420288
______________
epoch 114 train loss 1.9530097246170044
val loss 2.4857616424560547
______________
epoch 115 train loss 1.863186240196228
val loss 2.4787819385528564
______________
epoch 116 train loss 1.8908538818359375
val loss 2.47210955619812
______________
epoch 117 train loss 1.8609235286712646
val loss 2.465716600418091
______________
epoch 118 train loss 1.8190964460372925
val loss 2.45

OutcomeSimulator(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=62, out_features=1000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1000, out_features=1000, bias=True)
    (3): ReLU()
  )
  (batchnorm): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (softmax): LogSoftmax(dim=1)
  (disease_layer): Linear(in_features=1000, out_features=3, bias=True)
  (nodal_disease_layer): Linear(in_features=1000, out_features=3, bias=True)
  (dlt_layers): ModuleList(
    (0): Linear(in_features=1000, out_features=5, bias=True)
    (1): Linear(in_features=1000, out_features=5, bias=True)
    (2): Linear(in_features=1000, out_features=5, bias=True)
    (3): Linear(in_features=1000, out_features=5, bias=True)
    (4): Linear(in_features=1000, out_features=5, bias=True)
    (5): Linear(in_features=1000, out_features=5, bias=True)
    (6): Linear(in_features=1000, 

In [12]:
model2 = train_state(state=2)
model2

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 0 train loss 10.293455123901367
val loss 9.65517807006836
______________
epoch 1 train loss 9.782721519470215
val loss 9.158468246459961
______________
epoch 2 train loss 9.307807922363281


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 8.710000038146973
______________
epoch 3 train loss 8.863341331481934
val loss 8.289332389831543
______________
epoch 4 train loss 8.45036506652832
val loss 7.890841484069824
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 5 train loss 8.046409606933594
val loss 7.511332035064697
______________
epoch 6 train loss 7.64516544342041
val loss 7.147595405578613
______________
epoch 7 train loss 7.31526517868042


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 6.799703121185303
______________
epoch 8 train loss 6.922632217407227
val loss 6.468106269836426
______________
epoch 9 train loss 6.586355209350586
val loss 6.153602123260498
______________
epoch 10 train loss 6.281045436859131
val loss 5.857736587524414
______________
epoch 11 train loss 6.003660678863525
val loss 5.580427169799805
______________
epoch 12 train loss 5.702635288238525


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 5.322508335113525
______________
epoch 13 train loss 5.484636306762695
val loss 5.085526943206787
______________
epoch 14 train loss 5.243514537811279
val loss 4.869521141052246
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 15 train loss 5.022521495819092
val loss 4.674798965454102
______________
epoch 16 train loss 4.837058067321777
val loss 4.501087188720703
______________
epoch 17 train loss 4.636934280395508


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 4.347265243530273
______________
epoch 18 train loss 4.503328800201416
val loss 4.212705612182617
______________
epoch 19 train loss 4.333676338195801
val loss 4.095096111297607
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 20 train loss 4.257548809051514
val loss 3.9918432235717773
______________
epoch 21 train loss 4.153221130371094
val loss 3.9024770259857178
______________
epoch 22 train loss 4.092041969299316


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.8250410556793213
______________
epoch 23 train loss 4.03037166595459
val loss 3.7579050064086914
______________
epoch 24 train loss 3.9390017986297607
val loss 3.6987476348876953
______________
epoch 25 train loss 3.9231300354003906
val loss 3.647230863571167
______________
epoch 26 train loss 3.8656165599823
val loss 3.6015827655792236
______________
epoch 27 train loss 3.803894519805908


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.5606184005737305
______________
epoch 28 train loss 3.780531167984009
val loss 3.5240886211395264
______________
epoch 29 train loss 3.740468978881836
val loss 3.492593288421631
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 30 train loss 3.71014142036438
val loss 3.4643454551696777
______________
epoch 31 train loss 3.6975150108337402
val loss 3.437966823577881
______________
epoch 32 train loss 3.6622302532196045


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.414724111557007
______________
epoch 33 train loss 3.6114768981933594
val loss 3.393545627593994
______________
epoch 34 train loss 3.6216704845428467
val loss 3.3754544258117676
______________
epoch 35 train loss 3.5745790004730225


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.3602564334869385
______________
epoch 36 train loss 3.554840326309204
val loss 3.3468925952911377
______________
epoch 37 train loss 3.5837433338165283
val loss 3.334935426712036
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 38 train loss 3.5327858924865723
val loss 3.3237617015838623
______________
epoch 39 train loss 3.4894204139709473
val loss 3.313202381134033
______________
epoch 40 train loss 3.495757579803467


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.3036985397338867
______________
epoch 41 train loss 3.454291582107544
val loss 3.2962098121643066
______________
epoch 42 train loss 3.4351892471313477
val loss 3.2898764610290527
______________
epoch 43 train loss 3.3971259593963623


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.2835781574249268
______________
epoch 44 train loss 3.425179958343506
val loss 3.2778358459472656
______________
epoch 45 train loss 3.358964204788208
val loss 3.2716798782348633
______________
epoch 46 train loss 3.3377130031585693
val loss 3.265054225921631
______________
epoch 47 train loss 3.3983142375946045
val loss 3.2590363025665283
______________
epoch 48 train loss 3.340468645095825


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.252532482147217
______________
epoch 49 train loss 3.298657178878784
val loss 3.2447173595428467
______________
epoch 50 train loss 3.2648942470550537
val loss 3.2370893955230713
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 51 train loss 3.2688615322113037
val loss 3.228454351425171
______________
epoch 52 train loss 3.2721946239471436
val loss 3.220998525619507
______________
epoch 53 train loss 3.2336838245391846


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.2135908603668213
______________
epoch 54 train loss 3.2446439266204834
val loss 3.2072296142578125
______________
epoch 55 train loss 3.2071385383605957
val loss 3.200655698776245
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 56 train loss 3.1916348934173584
val loss 3.1949570178985596
______________
epoch 57 train loss 3.179058790206909
val loss 3.189561605453491
______________
epoch 58 train loss 3.1939010620117188


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.1847550868988037
______________
epoch 59 train loss 3.15960431098938
val loss 3.1799092292785645
______________
epoch 60 train loss 3.1432671546936035
val loss 3.173334836959839
______________
epoch 61 train loss 3.1289327144622803
val loss 3.1667299270629883
______________
epoch 62 train loss 3.1241719722747803
val loss 3.159519672393799
______________
epoch 63 train loss 3.107797861099243


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 3.1519927978515625
______________
epoch 64 train loss 3.060866355895996
val loss 3.145702838897705
______________
epoch 65 train loss 3.0811522006988525
val loss 3.13985013961792
______________
epoch 66 train loss 3.052199363708496
val loss 3.134077787399292
______________
epoch 67 train loss 3.0525476932525635
val loss 3.1281967163085938
______________
epoch 68 train loss 3.0269322395324707


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.1219382286071777
______________
epoch 69 train loss 3.012512445449829
val loss 3.1146440505981445
______________
epoch 70 train loss 2.9745190143585205
val loss 3.1087982654571533
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 71 train loss 3.011258125305176
val loss 3.103254556655884
______________
epoch 72 train loss 2.946201801300049
val loss 3.097930431365967
______________
epoch 73 train loss 2.9810774326324463


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.0935521125793457
______________
epoch 74 train loss 2.926347255706787
val loss 3.088517427444458
______________
epoch 75 train loss 2.9472248554229736
val loss 3.084005117416382
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 76 train loss 2.909857749938965
val loss 3.0795702934265137
______________
epoch 77 train loss 2.9466989040374756
val loss 3.0753226280212402
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 78 train loss 2.9165711402893066
val loss 3.0711638927459717
______________
epoch 79 train loss 2.876617431640625
val loss 3.0666184425354004
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 80 train loss 2.9288554191589355
val loss 3.062391519546509
______________
epoch 81 train loss 2.8886260986328125


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.0587573051452637
______________
epoch 82 train loss 2.8618695735931396
val loss 3.0546700954437256
______________
epoch 83 train loss 2.8534505367279053
val loss 3.0496666431427
______________
epoch 84 train loss 2.8331031799316406
val loss 3.0440433025360107
______________
epoch 85 train loss 2.837718963623047
val loss 3.037994861602783
______________
epoch 86 train loss 2.8417375087738037


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.032968282699585
______________
epoch 87 train loss 2.815450429916382
val loss 3.0271992683410645
______________
epoch 88 train loss 2.777944326400757
val loss 3.0211727619171143
______________
epoch 89 train loss 2.752603054046631


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 3.0161001682281494
______________
epoch 90 train loss 2.7788491249084473
val loss 3.011270046234131
______________
epoch 91 train loss 2.7819931507110596
val loss 3.0066158771514893
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 92 train loss 2.7573978900909424
val loss 3.003504753112793
______________
epoch 93 train loss 2.749675750732422
val loss 3.0013370513916016
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 94 train loss 2.729771375656128
val loss 2.9989120960235596
______________
epoch 95 train loss 2.7445452213287354
val loss 2.9977831840515137
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 96 train loss 2.7038025856018066
val loss 2.9955661296844482
______________
epoch 97 train loss 2.723533868789673
val loss 2.993093252182007
______________
epoch 98 train loss 2.7081336975097656


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 2.9902396202087402
______________
epoch 99 train loss 2.683424234390259
val loss 2.987056255340576
______________
epoch 100 train loss 2.707948923110962
val loss 2.9859914779663086
______________
epoch 101 train loss 2.6641831398010254


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 2.983588933944702
______________
epoch 102 train loss 2.630382537841797
val loss 2.9819891452789307
______________
epoch 103 train loss 2.6635963916778564
val loss 2.979915142059326
______________
epoch 104 train loss 2.6766607761383057
val loss 2.978588342666626
______________
epoch 105 train loss 2.6336441040039062
val loss 2.9774961471557617
______________
epoch 106 train loss 2.681339979171753


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 2.9760007858276367
______________
epoch 107 train loss 2.577894449234009
val loss 2.9749741554260254
______________
epoch 108 train loss 2.580425977706909
val loss 2.97365140914917
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 109 train loss 2.6181492805480957
val loss 2.971081495285034
______________
epoch 110 train loss 2.5845093727111816
val loss 2.9667465686798096
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


epoch 111 train loss 2.5615289211273193
val loss 2.9622650146484375
______________
epoch 112 train loss 2.5710842609405518
val loss 2.9566850662231445
______________
epoch 113 train loss 2.559952974319458


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 2.950462818145752
______________
epoch 114 train loss 2.5440566539764404
val loss 2.9454121589660645
______________
epoch 115 train loss 2.5244078636169434
val loss 2.9429831504821777
______________
epoch 116 train loss 2.530428647994995


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


val loss 2.940115451812744
______________
epoch 117 train loss 2.534360647201538
val loss 2.9402432441711426
______________
epoch 118 train loss 2.4976305961608887
val loss 2.9414355754852295
______________
epoch 119 train loss 2.4785401821136475
val loss 2.943080186843872
______________
epoch 120 train loss 2.477588415145874
val loss 2.9452521800994873
______________
epoch 121 train loss 2.5029568672180176


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

val loss 2.9485318660736084
______________
epoch 122 train loss 2.514151096343994
val loss 2.952755928039551
______________
epoch 123 train loss 2.4670259952545166
val loss 2.9557721614837646
______________
epoch 124 train loss 2.447359800338745
val loss 2.9573276042938232
______________
epoch 125 train loss 2.4992177486419678
val loss 2.958098888397217
______________
epoch 126 train loss 2.4381864070892334
val loss 2.95611310005188
______________


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)
Deprecated in Nu

epoch 127 train loss 2.4426982402801514
val loss 2.9522035121917725
______________
best loss 2.940115451812744 {'pd': {'accuracy': 0.4965034965034965, 'roc_micro': 0.9470455780728317, 'roc_macro': -1}, 'nd': {'accuracy': 0.36074629977069, 'roc_micro': 0.7678779408339157, 'roc_macro': 0.6578843865490452}, 'mod': {'accuracy': 0.36074629977069, 'roc_micro': 0.7678779408339157, 'roc_macro': 0.6578843865490452}, 'dlts': {'accuracy': [0.5, 0.5, 0.5, 0.5, 1.0, 0.5, 0.5, 1.0], 'accuracy_mean': 0.625, 'auc': [0.5, 0.5, 0.5, 0.5, -1, 0.5, 0.5, -1], 'auc_mean': 0.125}}


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  Y = np.zeros((len(y), 1), dtype=np.int)


OutcomeSimulator(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=84, out_features=1000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1000, out_features=1000, bias=True)
    (3): ReLU()
  )
  (batchnorm): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (softmax): LogSoftmax(dim=1)
  (disease_layer): Linear(in_features=1000, out_features=3, bias=True)
  (nodal_disease_layer): Linear(in_features=1000, out_features=3, bias=True)
  (dlt_layers): ModuleList(
    (0): Linear(in_features=1000, out_features=2, bias=True)
    (1): Linear(in_features=1000, out_features=2, bias=True)
    (2): Linear(in_features=1000, out_features=2, bias=True)
    (3): Linear(in_features=1000, out_features=2, bias=True)
    (4): Linear(in_features=1000, out_features=2, bias=True)
    (5): Linear(in_features=1000, out_features=2, bias=True)
    (6): Linear(in_features=1000, 

In [13]:
model3 = train_state(state=3,epochs=10000)
model3

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


epoch 0 train loss 2.1015212535858154
val loss 2.0966429710388184
______________
epoch 1 train loss 2.1059699058532715
val loss 2.0853188037872314
______________
epoch 2 train loss 2.086408853530884
val loss 2.074215888977051
______________
epoch 3 train loss 2.077211618423462
val loss 2.063295364379883
______________
epoch 4 train loss 2.048168182373047
val loss 2.0525667667388916
______________
epoch 5 train loss 2.033522129058838
val loss 2.041977643966675
______________
epoch 6 train loss 2.0281002521514893
val loss 2.0315542221069336
______________
epoch 7 train loss 2.0261025428771973
val loss 2.0213098526000977
______________
epoch 8 train loss 2.00384521484375
val loss 2.0112171173095703
______________
epoch 9 train loss 2.004908561706543
val loss 2.001276731491089
______________
epoch 10 train loss 1.9926023483276367
val loss 1.9914941787719727
______________
epoch 11 train loss 1.9776113033294678
val loss 1.9818599224090576
______________
epoch 12 train loss 1.977262735366821

val loss 1.5131748914718628
______________
epoch 103 train loss 1.364267110824585
val loss 1.5100288391113281
______________
epoch 104 train loss 1.3561034202575684
val loss 1.506890058517456
______________
epoch 105 train loss 1.338455319404602
val loss 1.503773808479309
______________
epoch 106 train loss 1.337105631828308
val loss 1.500678539276123
______________
epoch 107 train loss 1.3370769023895264
val loss 1.497603416442871
______________
epoch 108 train loss 1.343722939491272
val loss 1.4945294857025146
______________
epoch 109 train loss 1.3152530193328857
val loss 1.4914586544036865
______________
epoch 110 train loss 1.3143858909606934
val loss 1.4883888959884644
______________
epoch 111 train loss 1.3281302452087402
val loss 1.4853273630142212
______________
epoch 112 train loss 1.3213024139404297
val loss 1.4822919368743896
______________
epoch 113 train loss 1.310210108757019
val loss 1.4792516231536865
______________
epoch 114 train loss 1.2941986322402954
val loss 1.47

val loss 1.2258294820785522
______________
epoch 204 train loss 1.0298409461975098
val loss 1.2231554985046387
______________
epoch 205 train loss 1.0041135549545288
val loss 1.2204821109771729
______________
epoch 206 train loss 1.012751579284668
val loss 1.2178139686584473
______________
epoch 207 train loss 0.9914513230323792
val loss 1.2151459455490112
______________
epoch 208 train loss 0.9986729621887207
val loss 1.2124624252319336
______________
epoch 209 train loss 0.9916490316390991
val loss 1.2097591161727905
______________
epoch 210 train loss 0.9785895347595215
val loss 1.207074761390686
______________
epoch 211 train loss 0.9941638708114624
val loss 1.2043933868408203
______________
epoch 212 train loss 0.9614653587341309
val loss 1.2017185688018799
______________
epoch 213 train loss 0.9748309254646301
val loss 1.1990429162979126
______________
epoch 214 train loss 0.9763867855072021
val loss 1.1963834762573242
______________
epoch 215 train loss 0.9727275371551514
val lo

val loss 0.9463107585906982
______________
epoch 319 train loss 0.7253338098526001
val loss 0.9441831111907959
______________
epoch 320 train loss 0.7140083312988281
val loss 0.9420498013496399
______________
epoch 321 train loss 0.7160630226135254
val loss 0.9399226903915405
______________
epoch 322 train loss 0.7212742567062378
val loss 0.9377880096435547
______________
epoch 323 train loss 0.7199103236198425
val loss 0.9356637001037598
______________
epoch 324 train loss 0.7109595537185669
val loss 0.9335514307022095
______________
epoch 325 train loss 0.7064943313598633
val loss 0.9314417839050293
______________
epoch 326 train loss 0.7031735181808472
val loss 0.9293434619903564
______________
epoch 327 train loss 0.7207792401313782
val loss 0.9272761344909668
______________
epoch 328 train loss 0.7099193334579468
val loss 0.9252252578735352
______________
epoch 329 train loss 0.7070194482803345
val loss 0.9231759309768677
______________
epoch 330 train loss 0.693734347820282
val l

val loss 0.7597965002059937
______________
epoch 420 train loss 0.5548861026763916
val loss 0.7582204341888428
______________
epoch 421 train loss 0.560340166091919
val loss 0.7566412687301636
______________
epoch 422 train loss 0.5415385961532593
val loss 0.7550570964813232
______________
epoch 423 train loss 0.5503355264663696
val loss 0.7534670233726501
______________
epoch 424 train loss 0.5547986030578613
val loss 0.7518699169158936
______________
epoch 425 train loss 0.5347933769226074
val loss 0.7502496242523193
______________
epoch 426 train loss 0.5352985858917236
val loss 0.7486287355422974
______________
epoch 427 train loss 0.5358232259750366
val loss 0.7470083236694336
______________
epoch 428 train loss 0.5567042231559753
val loss 0.7453973293304443
______________
epoch 429 train loss 0.5340100526809692
val loss 0.7438057661056519
______________
epoch 430 train loss 0.5202455520629883
val loss 0.7422386407852173
______________
epoch 431 train loss 0.5394269824028015
val l

epoch 536 train loss 0.39350858330726624
val loss 0.6005626916885376
______________
epoch 537 train loss 0.41144222021102905
val loss 0.5994390249252319
______________
epoch 538 train loss 0.39379754662513733
val loss 0.5983330011367798
______________
epoch 539 train loss 0.41423487663269043
val loss 0.5972167253494263
______________
epoch 540 train loss 0.4106537699699402
val loss 0.5961236953735352
______________
epoch 541 train loss 0.3880569338798523
val loss 0.5950787663459778
______________
epoch 542 train loss 0.3963806629180908
val loss 0.5940229296684265
______________
epoch 543 train loss 0.398185133934021
val loss 0.5929632186889648
______________
epoch 544 train loss 0.40398305654525757
val loss 0.591899037361145
______________
epoch 545 train loss 0.39435628056526184
val loss 0.5908199548721313
______________
epoch 546 train loss 0.4140913486480713
val loss 0.5897439122200012
______________
epoch 547 train loss 0.41953590512275696
val loss 0.5886635780334473
______________

val loss 0.5027691721916199
______________
epoch 636 train loss 0.31435999274253845
val loss 0.5019323825836182
______________
epoch 637 train loss 0.3071950674057007
val loss 0.5010979175567627
______________
epoch 638 train loss 0.32649531960487366
val loss 0.5002340078353882
______________
epoch 639 train loss 0.31934690475463867
val loss 0.49938374757766724
______________
epoch 640 train loss 0.33262312412261963
val loss 0.49851590394973755
______________
epoch 641 train loss 0.31769776344299316
val loss 0.4976823329925537
______________
epoch 642 train loss 0.321165531873703
val loss 0.49682676792144775
______________
epoch 643 train loss 0.3004528880119324
val loss 0.49598556756973267
______________
epoch 644 train loss 0.3061668872833252
val loss 0.4951556324958801
______________
epoch 645 train loss 0.30596888065338135
val loss 0.4943065941333771
______________
epoch 646 train loss 0.31751203536987305
val loss 0.49345898628234863
______________
epoch 647 train loss 0.3295696377

epoch 736 train loss 0.25470805168151855
val loss 0.42401042580604553
______________
epoch 737 train loss 0.2568811774253845
val loss 0.4232921004295349
______________
epoch 738 train loss 0.2459629774093628
val loss 0.4226069450378418
______________
epoch 739 train loss 0.2656131386756897
val loss 0.42190104722976685
______________
epoch 740 train loss 0.29209452867507935
val loss 0.4211850166320801
______________
epoch 741 train loss 0.2532423138618469
val loss 0.4204750061035156
______________
epoch 742 train loss 0.2714530825614929
val loss 0.4197287857532501
______________
epoch 743 train loss 0.2660319209098816
val loss 0.4189770519733429
______________
epoch 744 train loss 0.266239196062088
val loss 0.4182116687297821
______________
epoch 745 train loss 0.2628461718559265
val loss 0.4174472689628601
______________
epoch 746 train loss 0.25955793261528015
val loss 0.416687548160553
______________
epoch 747 train loss 0.27983230352401733
val loss 0.4158672094345093
______________


val loss 0.35454076528549194
______________
epoch 853 train loss 0.20980745553970337
val loss 0.3539932370185852
______________
epoch 854 train loss 0.2356913536787033
val loss 0.35344475507736206
______________
epoch 855 train loss 0.20582038164138794
val loss 0.35293376445770264
______________
epoch 856 train loss 0.23354794085025787
val loss 0.3524025082588196
______________
epoch 857 train loss 0.19390860199928284
val loss 0.3518618941307068
______________
epoch 858 train loss 0.2225935459136963
val loss 0.3513113856315613
______________
epoch 859 train loss 0.19428715109825134
val loss 0.35076504945755005
______________
epoch 860 train loss 0.20874634385108948
val loss 0.35022711753845215
______________
epoch 861 train loss 0.1975315511226654
val loss 0.3496742844581604
______________
epoch 862 train loss 0.2117052972316742
val loss 0.3490836024284363
______________
epoch 863 train loss 0.2067120373249054
val loss 0.34848925471305847
______________
epoch 864 train loss 0.206390574

val loss 0.3051312565803528
______________
epoch 952 train loss 0.17744839191436768
val loss 0.30468982458114624
______________
epoch 953 train loss 0.17522652447223663
val loss 0.304268479347229
______________
epoch 954 train loss 0.1836375892162323
val loss 0.30385905504226685
______________
epoch 955 train loss 0.16385847330093384
val loss 0.3034534454345703
______________
epoch 956 train loss 0.16314569115638733
val loss 0.3030148148536682
______________
epoch 957 train loss 0.16750836372375488
val loss 0.30260777473449707
______________
epoch 958 train loss 0.1871088743209839
val loss 0.3021909296512604
______________
epoch 959 train loss 0.16418002545833588
val loss 0.3017989993095398
______________
epoch 960 train loss 0.19186033308506012
val loss 0.301398366689682
______________
epoch 961 train loss 0.15912549197673798
val loss 0.3010390102863312
______________
epoch 962 train loss 0.18880274891853333
val loss 0.30069097876548767
______________
epoch 963 train loss 0.1840918660

val loss 0.2652358412742615
______________
epoch 1054 train loss 0.17326992750167847
val loss 0.2648453712463379
______________
epoch 1055 train loss 0.1465451866388321
val loss 0.2644449472427368
______________
epoch 1056 train loss 0.14691162109375
val loss 0.2640461027622223
______________
epoch 1057 train loss 0.1537531614303589
val loss 0.2636430859565735
______________
epoch 1058 train loss 0.14364194869995117
val loss 0.26325279474258423
______________
epoch 1059 train loss 0.1494346261024475
val loss 0.26285520195961
______________
epoch 1060 train loss 0.15195631980895996
val loss 0.26252442598342896
______________
epoch 1061 train loss 0.16855791211128235
val loss 0.26218289136886597
______________
epoch 1062 train loss 0.17471176385879517
val loss 0.2618129253387451
______________
epoch 1063 train loss 0.15344740450382233
val loss 0.26145485043525696
______________
epoch 1064 train loss 0.15068957209587097
val loss 0.2611173689365387
______________
epoch 1065 train loss 0.16

val loss 0.23342807590961456
______________
epoch 1156 train loss 0.12820357084274292
val loss 0.23309406638145447
______________
epoch 1157 train loss 0.13428595662117004
val loss 0.2327612340450287
______________
epoch 1158 train loss 0.12364634871482849
val loss 0.23244836926460266
______________
epoch 1159 train loss 0.13025644421577454
val loss 0.2321535050868988
______________
epoch 1160 train loss 0.13248591125011444
val loss 0.23187649250030518
______________
epoch 1161 train loss 0.14368698000907898
val loss 0.23158834874629974
______________
epoch 1162 train loss 0.14191806316375732
val loss 0.23129770159721375
______________
epoch 1163 train loss 0.11413101851940155
val loss 0.23104722797870636
______________
epoch 1164 train loss 0.16127288341522217
val loss 0.2307988405227661
______________
epoch 1165 train loss 0.11613081395626068
val loss 0.23056258261203766
______________
epoch 1166 train loss 0.14574715495109558
val loss 0.23031604290008545
______________
epoch 1167 tr

val loss 0.20887017250061035
______________
epoch 1256 train loss 0.11570204794406891
val loss 0.2087177038192749
______________
epoch 1257 train loss 0.12707604467868805
val loss 0.208563432097435
______________
epoch 1258 train loss 0.11655738949775696
val loss 0.2083624005317688
______________
epoch 1259 train loss 0.11846226453781128
val loss 0.2081623673439026
______________
epoch 1260 train loss 0.1038283258676529
val loss 0.20801550149917603
______________
epoch 1261 train loss 0.13801036775112152
val loss 0.20787571370601654
______________
epoch 1262 train loss 0.13895854353904724
val loss 0.20771439373493195
______________
epoch 1263 train loss 0.10381878912448883
val loss 0.207579106092453
______________
epoch 1264 train loss 0.12461298704147339
val loss 0.20742595195770264
______________
epoch 1265 train loss 0.12666232883930206
val loss 0.20722663402557373
______________
epoch 1266 train loss 0.12072121351957321
val loss 0.2070099413394928
______________
epoch 1267 train lo

val loss 0.18631722033023834
______________
epoch 1359 train loss 0.1071953997015953
val loss 0.18616081774234772
______________
epoch 1360 train loss 0.09649166464805603
val loss 0.1859986037015915
______________
epoch 1361 train loss 0.11857417225837708
val loss 0.1858362853527069
______________
epoch 1362 train loss 0.09162862598896027
val loss 0.18569491803646088
______________
epoch 1363 train loss 0.10403409600257874
val loss 0.18553954362869263
______________
epoch 1364 train loss 0.11686218529939651
val loss 0.18536928296089172
______________
epoch 1365 train loss 0.10814688354730606
val loss 0.18520468473434448
______________
epoch 1366 train loss 0.11856906116008759
val loss 0.18498235940933228
______________
epoch 1367 train loss 0.09912233054637909
val loss 0.18480348587036133
______________
epoch 1368 train loss 0.09842359274625778
val loss 0.18465149402618408
______________
epoch 1369 train loss 0.10622565448284149
val loss 0.1845298707485199
______________
epoch 1370 tra

val loss 0.17064018547534943
______________
epoch 1458 train loss 0.13460560142993927
val loss 0.17029544711112976
______________
epoch 1459 train loss 0.125975102186203
val loss 0.16993507742881775
______________
epoch 1460 train loss 0.09818544238805771
val loss 0.16961997747421265
______________
epoch 1461 train loss 0.10737866163253784
val loss 0.16930973529815674
______________
epoch 1462 train loss 0.10181179642677307
val loss 0.169008269906044
______________
epoch 1463 train loss 0.10307560861110687
val loss 0.16875198483467102
______________
epoch 1464 train loss 0.12748658657073975
val loss 0.1684880554676056
______________
epoch 1465 train loss 0.12258991599082947
val loss 0.16821670532226562
______________
epoch 1466 train loss 0.10157199203968048
val loss 0.16797907650470734
______________
epoch 1467 train loss 0.08742427825927734
val loss 0.16777467727661133
______________
epoch 1468 train loss 0.1146957278251648
val loss 0.16756537556648254
______________
epoch 1469 train

val loss 0.15230952203273773
______________
epoch 1569 train loss 0.09713344275951385
val loss 0.15224888920783997
______________
epoch 1570 train loss 0.0941154807806015
val loss 0.15220554172992706
______________
epoch 1571 train loss 0.09064488112926483
val loss 0.15217220783233643
______________
epoch 1572 train loss 0.1019456759095192
val loss 0.1520928591489792
______________
epoch 1573 train loss 0.0724455714225769
val loss 0.1520647406578064
______________
epoch 1574 train loss 0.10971759259700775
val loss 0.15202338993549347
______________
epoch 1575 train loss 0.08788830041885376
val loss 0.15198519825935364
______________
epoch 1576 train loss 0.10382348299026489
val loss 0.15198427438735962
______________
epoch 1577 train loss 0.0906401127576828
val loss 0.15198230743408203
______________
epoch 1578 train loss 0.07825427502393723
val loss 0.15199053287506104
______________
epoch 1579 train loss 0.10504946857690811
val loss 0.15200257301330566
______________
epoch 1580 train

val loss 0.14041246473789215
______________
epoch 1667 train loss 0.09299381822347641
val loss 0.14025422930717468
______________
epoch 1668 train loss 0.09801840782165527
val loss 0.1400904506444931
______________
epoch 1669 train loss 0.09073350578546524
val loss 0.13991405069828033
______________
epoch 1670 train loss 0.1049758791923523
val loss 0.13972938060760498
______________
epoch 1671 train loss 0.06852348148822784
val loss 0.13959258794784546
______________
epoch 1672 train loss 0.10389634221792221
val loss 0.13943824172019958
______________
epoch 1673 train loss 0.08356232196092606
val loss 0.1392812430858612
______________
epoch 1674 train loss 0.0748407244682312
val loss 0.13915766775608063
______________
epoch 1675 train loss 0.09282203018665314
val loss 0.13901659846305847
______________
epoch 1676 train loss 0.09867441654205322
val loss 0.1388765275478363
______________
epoch 1677 train loss 0.07723255455493927
val loss 0.13873769342899323
______________
epoch 1678 trai

val loss 0.12928219139575958
______________
epoch 1771 train loss 0.1063118577003479
val loss 0.12907841801643372
______________
epoch 1772 train loss 0.1094861775636673
val loss 0.12887518107891083
______________
epoch 1773 train loss 0.07343532145023346
val loss 0.12872351706027985
______________
epoch 1774 train loss 0.12401885539293289
val loss 0.12856237590312958
______________
epoch 1775 train loss 0.0783233791589737
val loss 0.128409281373024
______________
epoch 1776 train loss 0.09133182466030121
val loss 0.1282447874546051
______________
epoch 1777 train loss 0.07010054588317871
val loss 0.12808549404144287
______________
epoch 1778 train loss 0.08254209160804749
val loss 0.12791350483894348
______________
epoch 1779 train loss 0.09671226143836975
val loss 0.12772700190544128
______________
epoch 1780 train loss 0.07311300933361053
val loss 0.12755607068538666
______________
epoch 1781 train loss 0.10699000954627991
val loss 0.12736383080482483
______________
epoch 1782 train

val loss 0.12008346617221832
______________
epoch 1876 train loss 0.07395332306623459
val loss 0.12003910541534424
______________
epoch 1877 train loss 0.09155160188674927
val loss 0.11995736509561539
______________
epoch 1878 train loss 0.07852783054113388
val loss 0.11987482756376266
______________
epoch 1879 train loss 0.08789939433336258
val loss 0.11977994441986084
______________
epoch 1880 train loss 0.06467851251363754
val loss 0.11968136578798294
______________
epoch 1881 train loss 0.05937909334897995
val loss 0.11960476636886597
______________
epoch 1882 train loss 0.0994437113404274
val loss 0.11951343715190887
______________
epoch 1883 train loss 0.08659110963344574
val loss 0.1194242388010025
______________
epoch 1884 train loss 0.08818425238132477
val loss 0.11930045485496521
______________
epoch 1885 train loss 0.08864790201187134
val loss 0.11918330192565918
______________
epoch 1886 train loss 0.08167676627635956
val loss 0.11905103921890259
______________
epoch 1887 t

epoch 1974 train loss 0.06210070848464966
val loss 0.11049351096153259
______________
epoch 1975 train loss 0.07640516757965088
val loss 0.11039852350950241
______________
epoch 1976 train loss 0.07738958299160004
val loss 0.11029817163944244
______________
epoch 1977 train loss 0.07857005298137665
val loss 0.11020424962043762
______________
epoch 1978 train loss 0.10593549907207489
val loss 0.11006031185388565
______________
epoch 1979 train loss 0.06290172785520554
val loss 0.109925776720047
______________
epoch 1980 train loss 0.0633811503648758
val loss 0.10979132354259491
______________
epoch 1981 train loss 0.049256086349487305
val loss 0.10969626158475876
______________
epoch 1982 train loss 0.06682717800140381
val loss 0.1096091940999031
______________
epoch 1983 train loss 0.08670559525489807
val loss 0.10951701551675797
______________
epoch 1984 train loss 0.0818590596318245
val loss 0.10942773520946503
______________
epoch 1985 train loss 0.07078488171100616
val loss 0.10935

EndpointSimulator(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=82, out_features=500, bias=True)
    (1): ReLU()
  )
  (batchnorm): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (softmax): LogSoftmax(dim=1)
  (outcome_layer): Linear(in_features=500, out_features=3, bias=True)
  (sigmoid): Sigmoid()
)

In [14]:
def breakup_state_models(state_model):
    #for state 1 and 2
    models = {}
    models['pd'] = lambda x: state_model(x)[0]
    models['nd'] = lambda x: state_model(x)[1]
    models['chemo'] = lambda x: state_model(x)[2]
    for i,dlt in enumerate(Const.dlt1):
        models[dlt] = lambda x: state_model(x)[3][:,i]
    return models

def breakup_outcome_models(omodel):
    models = {}
    for i,name in enumerate(Const.outcomes):
        models[name] = lambda x: omodel(x)[:,i].reshape(-1,1)
    return models

def get_all_models(m1,m2,m3):
    state1_models = breakup_state_models(m1)
    state2_models = breakup_state_models(m2)
    state3_models = breakup_outcome_models(m3)
    all_models = {}
    for i,sm in enumerate([state1_models,state2_models,state3_models]):
        for ii,m in sm.items():
            all_models[ii +  '_state' + str(i+1)] = m
    return all_models

all_models = get_all_models(model,model2,model3)
all_models

{'pd_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'nd_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'chemo_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'DLT_Hematological_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'DLT_Gastrointestinal_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'DLT_Nephrological_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'DLT_Other_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'DLT_Infection (Pneumonia)_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'DLT_Dermatological_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'DLT_Neurological_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'DLT_Vascular_state1': <function __main__.breakup_state_models.<locals>.<lambda>(x)>,
 'pd_state2': <function __main__

In [15]:
from captum.attr import IntegratedGradients

key ='FT_state3'
ig = IntegratedGradients(all_models[key])
dataset = DTDataset()
xtestdf = dataset.get_input_state(step=int(key[-1]))
xtest = df_to_torch(xtestdf)
all_models[key](xtest).shape
# attributions = ig.attribute(xtest,torch.zeros(xtest.shape),target=1)
# test = pd.DataFrame(attributions,columns = xtestdf.columns, index=xtestdf.index)
# test.describe().T

  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


torch.Size([536, 1])

In [23]:
def check_impact_of_decisions(model_dict,data):
    results = []
    ids = []
    for decision in Const.decisions:
        subset0 = dataset.get_input_state(step=3,fixed={decision: 0})
        subset1 = dataset.get_input_state(step=3,fixed={decision: 1})
        ids = subset0.index.values
        x0 = df_to_torch(subset0)
        x1 = df_to_torch(subset1)
        entry = {}
        for outcome in Const.outcomes:
            model = model_dict[outcome+"_state3"]
            y0 = model(x0).detach().cpu().numpy()
            y1 = model(x1).detach().cpu().numpy()
            change = y1 - y0
            decision_change = np.abs((y0 > .5).astype(int) - (y1 > .5).astype(int))
            results.append((decision + ': ' + outcome+'_change',change ))
            results.append((decision + ': ' + outcome+'_decision_change',decision_change))
#         impacts[decision] = entry
    return pd.DataFrame([r[1] for r in results],columns = [r[0] for r in results],index=ids)

check_impact_of_decisions(all_models,dataset)

ValueError: Must pass 2-d input. shape=(18, 536, 1)