In [1]:
import numpy as np
import pandas as pd
import torch
import re
pd.set_option('display.max_rows', 200)

In [71]:
class Const:
    data_dir = '../data'
    twin_data = data_dir + 'digital_twin_data.csv'
    twin_ln_data = data_dir + 'digital_twin_ln_data.csv'
    
    rename_dict = {
        'Dummy ID': 'id',
        'Age at Diagnosis (Calculated)': 'age',
        'Feeding tube 6m': 'FT',
        'Affected Lymph node UPPER': 'affected_nodes',
        'Aspiration rate(Y/N)': 'AS',
        'Neck boost (Y/N)': 'neck_boost',
        'Gender': 'gender',
        'Tm Laterality (R/L)': 'laterality',
        'AJCC 8th edition': 'ajcc8',
        'AJCC 7th edition':'ajcc7',
        'N_category_full': 'N-category',
        'HPV/P16 status': 'hpv',
        'Tumor subsite (BOT/Tonsil/Soft Palate/Pharyngeal wall/GPS/NOS)': 'subsite',
        'Total dose': 'total_dose',
        'Therapeutic combination': 'treatment',
        'Smoking status at Diagnosis (Never/Former/Current)': 'smoking_status',
        'Smoking status (Packs/Year)': 'packs_per_year',
        'Overall Survival (1=alive,0=dead)': 'os',
        'Dose/fraction (Gy)': 'dose_fraction'
    }
    
    dlt_dict = {
         'Allergic reaction to Cetuximab': 'DLT_Other',
         'Cardiological (A-fib)': 'DLT_Other',
         'Dermatological': 'DLT_Dermatological',
         'Failure to Thrive': 'DLT_Other',
         'Failure to thrive': 'DLT_Other',
         'GIT [elevated liver enzymes]': 'DLT_Gastrointestinal',
         'Gastrointestina': 'DLT_Gastrointestinal',
         'Gastrointestinal': 'DLT_Gastrointestinal',
         'General': 'DLT_Other',
         'Hematological': 'DLT_Hematological',
         'Hematological (Neutropenia)': 'DLT_Hematological',
         'Hyponatremia': 'DLT_Other',
         'Immunological': 'DLT_Other',
         'Infection': 'DLT_Infection (Pneumonia)',
         'NOS': 'DLT_Other',
         'Nephrological': 'DLT_Nephrological',
         'Nephrological (ARF)': 'DLT_Nephrological',
         'Neurological': 'DLT_Neurological',
         'Neutropenia': 'DLT_Hematological',
         'Nutritional': 'DLT_Other',
         'Pancreatitis': 'DLT_Other',
         'Pulmonary': 'DLT_Other',
         'Respiratory (Pneumonia)': 'DLT_Infection (Pneumonia)',
         'Sepsis': 'DLT_Infection (Pneumonia)',
         'Suboptimal response to treatment' : 'DLT_Other',
         'Vascular': 'DLT_Vascular'
    }
    
    decision1 = 'Decision 1 (Induction Chemo) Y/N'
    decision2 = 'Decision 2 (CC / RT alone)'
    decision3 = 'Decision 3 Neck Dissection (Y/N)'
    decisions = [decision1,decision2, decision3]
    outcomes = ['Overall Survival (4 Years)', 'FT', 'Aspiration rate Post-therapy']
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
    primary_disease_states = ['CR Primary','PR Primary','SD Primary']
    nodal_disease_states = [t.replace('Primary','Nodal') for t in primary_disease_states]
    dlt1 = list(set(dlt_dict.values()))
    
    modifications =  list(modification_types.values())
    state2 = modifications + primary_disease_states+nodal_disease_states +dlt1 #+['No imaging 0=N,1=Y']
    
    primary_disease_states2 = [t + ' 2' for t in primary_disease_states]
    nodal_disease_states2 = [t + ' 2' for t in nodal_disease_states]
    dlt2 = [d + ' 2' for d in dlt1]
    
    ccs = list(cc_types.values())
    state3 = ccs + primary_disease_states2 + nodal_disease_states2 + dlt2
    name_dict = {
        'pd_state1': primary_disease_states,
        'nd_state1': nodal_disease_states,
        'chemo_state1': modifications,
        'chemo_state2': ccs,
        'pd_state2': primary_disease_states2,
        'pd_state2': nodal_disease_states2,
    }
    
    stratified_train_ids = [
        5,6,8,11,13,14,15,16,17,18,21,23,24,26,27,28,32,33,37,38,39,40,
        41,42,48,49,50,51,53,55,56,57,60,64,65,67,69,71,74,75,78,79,80,
        81,82,87,88,91,94,96,99,103,109,116,119,120,121,125,148,150,
        153,178,181,183,185,186,188,191,192,193,196,197,198,200,201,
        203,204,205,206,207,210,212,213,214,216,218,219,220,221,222,
        223,225,226,229,230,231,232,233,234,235,237,238,239,240,241,
        243,244,246,247,248,249,251,252,253,255,256,257,258,259,260,
        261,262,263,265,266,269,270,273,275,276,277,278,280,281,282,
        283,285,289,2000,2002,2003,2004,2007,2008,2009,2010,2011,
        2012,2013,2014,2016,2018,2021,2022,2023,2025,2027,2028,2030,
        2033,5000,5002,5004,5005,5006,5008,5009,5010,5011,5012,5013,
        5014,5015,5016,5017,5018,5019,5021,5022,5023,5024,5025,5026,
        5027,5028,5029,5030,5031,5034,5037,5039,5041,5042,5043,5044,
        5045,5047,5050,5051,5055,5057,5058,5059,5060,5061,5062,5063,
        5064,5066,5067,5068,5069,5070,5071,5072,5073,5074,5075,5076,
        5079,5081,5083,5085,5087,5088,5089,5090,5091,5092,5094,5095,
        5096,5097,5100,5102,5104,5106,5108,5110,5111,5112,5113,5114,
        5119,10001,10002,10003,10004,10006,10008,10009,10011,10015,
        10018,10019,10020,10021,10022,10024,10025,10027,10028,10029,
        10031,10033,10034,10035,10036,10037,10038,10039,10041,10042,
        10043,10044,10045,10047,10048,10051,10052,10053,10054,10055,
        10056,10057,10059,10060,10061,10062,10064,10065,10067,10069,
        10070,10071,10072,10073,10074,10075,10077,10078,10079,10080,
        10081,10082,10083,10085,10087,10089,10090,10093,10095,10096,
        10098,10099,10103,10107,10108,10109,10110,10111,10113,10114,
        10115,10116,10117,10118,10119,10120,10121,10124,10127,10128,
        10129,10132,10134,10136,10138,10139,10140,10141,10142,10143,
        10144,10146,10147,10148,10149,10150,10151,10152,10154,10155,
        10156,10157,10158,10159,10162,10163,10164,10167,10168,10171,
        10173,10174,10175,10181,10182,10183,10184,10185,10186,10187,
        10188,10189,10191,10192,10193,10194,10195,10196,10197,10198,
        10199,10200,10201,10202,10203,10204,10205]
    
    stratified_test_ids = [
        133,47,35,10,279,5056,5035,224,209,10063,2006,5020,271,10014,
        5080,10097,10125,10106,2032,10169,2024,286,2015,2019,10026,
        5040,236,187,10161,211,5103,10178,2026,10137,184,199,10040,
        272,68,5105,10177,228,44,242,9,5101,10104,10165,10007,10133,
        10145,10016,264,5098,10023,10050,5120,227,5118,2005,5053,10135,
        5007,10092,36,2001,5115,10005,10102,189,5036,10088,254,10130,
        10086,25,5001,5065,10084,195,5099,3,5093,10094,7,5038,10068,5032,
        202,274,45,2017,10176,217,10160,5082,10012,10017,10100,2031,77,
        10066,5078,117,10010,10170,10190,10058,5049,5086,5052,268,2029,
        5084,10105,10013,245,5048,2020,215,10046,5117,5033,267,5003,168,
        31,10049,10180,190,287,284,5054,10101,208,5077,10091,10172,288,5109,
        10126,10153,10123,5107,194,10131]

Const.stratified_test_ids

[133,
 47,
 35,
 10,
 279,
 5056,
 5035,
 224,
 209,
 10063,
 2006,
 5020,
 271,
 10014,
 5080,
 10097,
 10125,
 10106,
 2032,
 10169,
 2024,
 286,
 2015,
 2019,
 10026,
 5040,
 236,
 187,
 10161,
 211,
 5103,
 10178,
 2026,
 10137,
 184,
 199,
 10040,
 272,
 68,
 5105,
 10177,
 228,
 44,
 242,
 9,
 5101,
 10104,
 10165,
 10007,
 10133,
 10145,
 10016,
 264,
 5098,
 10023,
 10050,
 5120,
 227,
 5118,
 2005,
 5053,
 10135,
 5007,
 10092,
 36,
 2001,
 5115,
 10005,
 10102,
 189,
 5036,
 10088,
 254,
 10130,
 10086,
 25,
 5001,
 5065,
 10084,
 195,
 5099,
 3,
 5093,
 10094,
 7,
 5038,
 10068,
 5032,
 202,
 274,
 45,
 2017,
 10176,
 217,
 10160,
 5082,
 10012,
 10017,
 10100,
 2031,
 77,
 10066,
 5078,
 117,
 10010,
 10170,
 10190,
 10058,
 5049,
 5086,
 5052,
 268,
 2029,
 5084,
 10105,
 10013,
 245,
 5048,
 2020,
 215,
 10046,
 5117,
 5033,
 267,
 5003,
 168,
 31,
 10049,
 10180,
 190,
 287,
 284,
 5054,
 10101,
 208,
 5077,
 10091,
 10172,
 288,
 5109,
 10126,
 10153,
 10123,
 5107,
 19

In [3]:
def preprocess(data_cleaned):
    #this was Elisa's preprocessing except I removed all the Ifs because that's dumb
    if len(data_cleaned.shape) < 2:
        data_cleaned = pd.DataFrame([data], columns=data.index)
        
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'N', 'Aspiration rate Pre-therapy'] = 0
    data_cleaned.loc[data_cleaned['Aspiration rate Pre-therapy'] == 'Y', 'Aspiration rate Pre-therapy'] = 1

    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'I', 'Pathological Grade'] = 1
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'II', 'Pathological Grade'] = 2
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'III', 'Pathological Grade'] = 3
    data_cleaned.loc[data_cleaned['Pathological Grade'] == 'IV', 'Pathological Grade'] = 4

    data_cleaned.loc[(data_cleaned['T-category'] == 'Tx') | (data_cleaned['T-category'] == 'Tis'), 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T1', 'T-category'] = 1
    data_cleaned.loc[data_cleaned['T-category'] == 'T2', 'T-category'] = 2
    data_cleaned.loc[data_cleaned['T-category'] == 'T3', 'T-category'] = 3
    data_cleaned.loc[data_cleaned['T-category'] == 'T4', 'T-category'] = 4

    data_cleaned.loc[data_cleaned['N-category'] == 'N0', 'N-category'] = 0
    data_cleaned.loc[data_cleaned['N-category'] == 'N1', 'N-category'] = 1
    data_cleaned.loc[data_cleaned['N-category'] == 'N2', 'N-category'] = 2
    data_cleaned.loc[data_cleaned['N-category'] == 'N3', 'N-category'] = 3

    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N0', 'N-category_8th_edition'] = 0
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N1', 'N-category_8th_edition'] = 1
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N2', 'N-category_8th_edition'] = 2
    data_cleaned.loc[data_cleaned['N-category_8th_edition'] == 'N3', 'N-category_8th_edition'] = 3

    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'I', 'AJCC 7th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'II', 'AJCC 7th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'III', 'AJCC 7th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 7th edition'] == 'IV', 'AJCC 7th edition'] = 4

    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'I', 'AJCC 8th edition'] = 1
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'II', 'AJCC 8th edition'] = 2
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'III', 'AJCC 8th edition'] = 3
    data_cleaned.loc[data_cleaned['AJCC 8th edition'] == 'IV', 'AJCC 8th edition'] = 4

    data_cleaned.loc[data_cleaned['Gender'] == 'Male', 'Gender'] = 1
    data_cleaned.loc[data_cleaned['Gender'] == 'Female', 'Gender'] = 0


    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Positive', 'HPV/P16 status'] = 1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Negative', 'HPV/P16 status'] = -1
    data_cleaned.loc[data_cleaned['HPV/P16 status'] == 'Unknown', 'HPV/P16 status'] = 0

    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Formar', 'Smoking status at Diagnosis (Never/Former/Current)'] = .5
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Current', 'Smoking status at Diagnosis (Never/Former/Current)'] = 1
    data_cleaned.loc[data_cleaned[
                         'Smoking status at Diagnosis (Never/Former/Current)'] == 'Never', 'Smoking status at Diagnosis (Never/Former/Current)'] = 0


    data_cleaned.loc[data_cleaned['Chemo Modification (Y/N)'] == 'Y', 'Chemo Modification (Y/N)'] = 1

    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'N', 'DLT (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['DLT (Y/N)'] == 'Y', 'DLT (Y/N)'] = 1

    data_cleaned['DLT_Other'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT_Type'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT_Type']):
            if i.strip() != '' and data_cleaned.loc[index, Const.dlt_dict[i.strip()]] == 0:
                data_cleaned.loc[index, Const.dlt_dict[i.strip()]] = 1

    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'RT alone', 'Decision 2 (CC / RT alone)'] = 0
    data_cleaned.loc[data_cleaned['Decision 2 (CC / RT alone)'] == 'CC', 'Decision 2 (CC / RT alone)'] = 1

    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'N', 'CC modification (Y/N)'] = 0
    data_cleaned.loc[data_cleaned['CC modification (Y/N)'] == 'Y', 'CC modification (Y/N)'] = 1

    data_cleaned['DLT_Dermatological 2'] = 0
    data_cleaned['DLT_Neurological 2'] = 0
    data_cleaned['DLT_Gastrointestinal 2'] = 0
    data_cleaned['DLT_Hematological 2'] = 0
    data_cleaned['DLT_Nephrological 2'] = 0
    data_cleaned['DLT_Vascular 2'] = 0
    data_cleaned['DLT_Infection (Pneumonia) 2'] = 0
    data_cleaned['DLT_Other 2'] = 0
    for index, row in data_cleaned.iterrows():
        if row['DLT 2'] == 'None':
            continue
        for i in re.split('&|and|,', row['DLT 2']):
            if i.strip() != '':
                data_cleaned.loc[index, Const.dlt_dict[i.strip()] + ' 2'] = 1

    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'N', 'Decision 3 Neck Dissection (Y/N)'] = 0
    data_cleaned.loc[
        data_cleaned['Decision 3 Neck Dissection (Y/N)'] == 'Y', 'Decision 3 Neck Dissection (Y/N)'] = 1

    return data_cleaned

def merge_editions(row,basecol='AJCC 8th edition',fallback='AJCC 7th edition'):
    if pd.isnull(row[basecol]):
        return row[fallback]
    return row[basecol]


def preprocess_dt_data(df,extra_to_keep=None):
    
    to_keep = ['id','hpv','age','packs_per_year','smoking_status','gender','Aspiration rate Pre-therapy','total_dose','dose_fraction'] 
    to_onehot = ['T-category','N-category','AJCC','Pathological Grade','subsite','treatment','ln_cluster']
    if extra_to_keep is not None:
        to_keep = to_keep + [c for c in extra_to_keep if c not in to_keep and c not in to_onehot]
    
    df['bilateral'] = df['laterality'].apply(lambda x: x.lower().strip() == 'bilateral')
    to_keep.append('bilateral')
    
    decisions =Const.decisions
    outcomes = Const.outcomes
    
    modification_types = {
        0: 'no_dose_adjustment',
        1: 'dose_modified',
        2: 'dose_delayed',
        3: 'dose_cancelled',
        4: 'dose_delayed_&_modified',
        5: 'regiment_modification',
        9: 'unknown'
    }
    
    cc_types = {
        0: 'cc_none',
        1: 'cc_platinum',
        2: 'cc_cetuximab',
        3: 'cc_others',
    }
    
#     races_shortened = ['White/Caucasian','Hispanic/Latino','African American/Black']
#     for race in races_shortened:
#         df[race] = df['Race'].apply(lambda x: x.strip() == race)
#         to_keep.append(race)
    
    for k,v in Const.cc_types.items():
        df[v] = df['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)'].apply(lambda x: int(Const.cc_types.get(int(x),0) == v))
        to_keep.append(v)
    for k,v in Const.modification_types.items():
        name = 'Modification Type (0= no dose adjustment, 1=dose modified, 2=dose delayed, 3=dose cancelled, 4=dose delayed & modified, 5=regimen modification, 9=unknown)'
        df[v] = df[name].apply(lambda x: int(Const.modification_types.get(int(x),0) == v))
        to_keep.append(v)
    #Features to keep. I think gender is is in 
    
    for col in Const.dlt1 + Const.dlt2:
        df[col] = (df[col] > 0).astype(float)
    keywords = []
    for keyword in keywords:
        toadd = [c for c in df.columns if keyword in c and c not in to_keep]
        to_keep = to_keep + toadd
    
    df['packs_per_year'] = df['packs_per_year'].apply(lambda x: str(x).replace('>','').replace('<','')).astype(float).fillna(0)
    #so I'm actually not sure if this is biological sex or gender given this is texas
    df['AJCC'] = df.apply(lambda row: merge_editions(row,'ajcc8','ajcc7'),axis=1)
    df['N-category'] = df.apply(lambda row: merge_editions(row,'N-category_8th_edition','N-category'),axis=1)
    
    dummy_df = pd.get_dummies(df[to_onehot].fillna(0).astype(str),drop_first=False)
    for col in dummy_df.columns:
        df[col] = dummy_df[col]
        to_keep.append(col)
        
    yn_to_binary = ['FT','Aspiration rate Post-therapy','Decision 1 (Induction Chemo) Y/N']
    for col in yn_to_binary:
        df[col] = df[col].apply(lambda x: int(x == 'Y'))
        
    to_keep = to_keep + [c for c in df.columns if 'DLT' in c]
    
    for statelist in [Const.state2,Const.state3,Const.decisions,Const.outcomes]:
        toadd = [c for c in statelist if c not in to_keep]
        to_keep = to_keep + toadd
    return df[to_keep].set_index('id')



In [78]:
def load_digital_twin(file='../data/digital_twin_data.csv'):
    df = pd.read_csv(file)
    return df.rename(columns = Const.rename_dict)

def get_dt_ids():
    df = load_digital_twin()
    return df.id.values

def get_tt_split(ids=None,use_default_split=True,use_bagging_split=False,resample_training=False):
        if ids is None:
            ids = get_dt_ids()
        #pre-made, stratified by decision and outcome 72:28
        if use_default_split:
            train_ids = Const.stratified_train_ids[:]
            test_ids = Const.stratified_test_ids[:]
        elif use_bagging_split:
            train_ids = np.random.choice(ids,len(ids),replace=True)
            test_ids = [i for i in ids if i not in train_ids]
        else:
            test_ids = ids[0: int(len(ids)*(1-split))]
            train_ids = [i for i in ids if i not in test_ids]

        if resample_training:
            train_ids = np.random.choice(train_ids,len(train_ids),replace=True)
            test_ids = [i for i in ids if i not in train_ids]
        return train_ids,test_ids
    
get_tt_split()

  


([5,
  6,
  8,
  11,
  13,
  14,
  15,
  16,
  17,
  18,
  21,
  23,
  24,
  26,
  27,
  28,
  32,
  33,
  37,
  38,
  39,
  40,
  41,
  42,
  48,
  49,
  50,
  51,
  53,
  55,
  56,
  57,
  60,
  64,
  65,
  67,
  69,
  71,
  74,
  75,
  78,
  79,
  80,
  81,
  82,
  87,
  88,
  91,
  94,
  96,
  99,
  103,
  109,
  116,
  119,
  120,
  121,
  125,
  148,
  150,
  153,
  178,
  181,
  183,
  185,
  186,
  188,
  191,
  192,
  193,
  196,
  197,
  198,
  200,
  201,
  203,
  204,
  205,
  206,
  207,
  210,
  212,
  213,
  214,
  216,
  218,
  219,
  220,
  221,
  222,
  223,
  225,
  226,
  229,
  230,
  231,
  232,
  233,
  234,
  235,
  237,
  238,
  239,
  240,
  241,
  243,
  244,
  246,
  247,
  248,
  249,
  251,
  252,
  253,
  255,
  256,
  257,
  258,
  259,
  260,
  261,
  262,
  263,
  265,
  266,
  269,
  270,
  273,
  275,
  276,
  277,
  278,
  280,
  281,
  282,
  283,
  285,
  289,
  2000,
  2002,
  2003,
  2004,
  2007,
  2008,
  2009,
  2010,
  2011,
  2012,
  2013,


In [63]:
class DTDataset():
    
    def __init__(self,data_file = '../data/digital_twin_data.csv',ln_data_file = '../data/digital_twin_ln_data.csv',ids=None):
        df = pd.read_csv(data_file)
        df = preprocess(df)
        df = df.rename(columns = Const.rename_dict).copy()
        df = df.drop('MRN OPC',axis=1)

        ln_data = pd.read_csv(ln_data_file)
        ln_data = ln_data.rename(columns={'cluster':'ln_cluster'})
        self.ln_cols = [c for c in ln_data.columns if c not in df.columns]
        df = df.merge(ln_data,on='id')
        df.index = df.index.astype(int)
        if ids is not None:
            df = df[df.id.apply(lambda x: x in ids)]
        self.processed_df = preprocess_dt_data(df,self.ln_cols).fillna(0)
        
        self.means = self.processed_df.mean(axis=0)
        self.stds = self.processed_df.std(axis=0)
        self.maxes = self.processed_df.max(axis=0)
        self.mins = self.processed_df.min(axis=0)
        
        arrays = self.get_states()
        self.state_sizes = {k: (v.shape[1] if v.ndim > 1 else 1) for k,v in arrays.items()}
        
    def get_data(self):
        return self.processed_df
    
    def sample(self,frac=.5):
        return self.processed_df.sample(frac=frac)
    
    def split_sample(self,ratio = .3):
        assert(ratio > 0 and ratio <= 1)
        df1 = self.processed_df.sample(frac=1-ratio)
        df2 = self.processed_df.drop(index=df1.index)
        return df1,df2
    
    def get_states(self,fixed=None,ids = None):
        processed_df = self.processed_df.copy()
        if ids is not None:
            processed_df = processed_df.loc[ids]
        if fixed is not None:
            for col,val in fixed.items():
                if col in processed_df.columns:
                    processed_df[col] = val
                else:
                    print('bad fixed entry',col)
                    
        to_skip = ['CC Regimen(0= none, 1= platinum based, 2= cetuximab based, 3= others, 9=unknown)','DLT_Type','DLT 2'] + [c for c in processed_df.columns if 'treatment' in c]
        other_states = set(Const.decisions + Const.state3 + Const.state2 + Const.outcomes  + to_skip)

        base_state = sorted([c for c in processed_df.columns if c not in other_states])

        dlt1 = Const.dlt1
        dlt2 = Const.dlt2
        
        modifications = Const.modifications
        ccs = Const.ccs
        pds = Const.primary_disease_states
        nds = Const.nodal_disease_states
        pds2 = Const.primary_disease_states2
        nds2 = Const.nodal_disease_states2
        outcomes = Const.outcomes
        decisions= Const.decisions
        
        #intermediate states are only udated values. Models should use baseline + state2 etc
        results = {
            'baseline': processed_df[base_state],
            'pd_states1': processed_df[pds],
            'nd_states1': processed_df[nds],
            'modifications': processed_df[modifications],
            'ccs': processed_df[ccs],
            'pd_states2': processed_df[pds2],
            'nd_states2': processed_df[nds2],
            'outcomes': processed_df[outcomes],
            'dlt1': processed_df[dlt1],
            'dlt2': processed_df[dlt2],
            'decision1': processed_df[decisions[0]],
            'decision2': processed_df[decisions[1]],
            'decision3': processed_df[decisions[2]],
            'survival': processed_df[outcomes[0]],
            'ft': processed_df[outcomes[1]],
            'aspiration': processed_df[outcomes[2]],
        }
    
        return results
    
    def get_state(self,name,**kwargs):
        return self.get_states(**kwargs)[name]
    
    def normalize(self,df):
        means = self.means[df.columns]
        std = self.stds[df.columns]
        return ((df - means)/std).fillna(0)
    
    def get_intermediate_outcomes(self,step=1,**kwargs):
        assert(step in [0,1,2,3])
        states = self.get_states(**kwargs)
        if step == 1:
            keys = ['pd_states1','nd_states1','modifications','dlt1']
        if step == 2:
            keys =  ['pd_states2','nd_states2','ccs','dlt2']
        if step == 3:
            keys = ['survival','ft','aspiration']# ['decision1','decision2','decision3']
        if step == 0:
            keys = ['decision1','decision2','decision3']
        if len(keys) < 2:
            return states[keys[0]]
        return [states[key] for key in keys]
    
    def get_input_state(self,step=1,**kwargs):
        assert(step in [0,1,2,3])
        states = self.get_states(**kwargs)
        if step == 0:
            keys = ['baseline']
        if step == 1:
            keys = ['baseline','decision1']
        if step == 2:
            keys =  ['baseline','pd_states1','nd_states1','modifications','dlt1','decision1','decision2']
        if step == 3:
            keys = ['baseline','pd_states2','nd_states2','ccs','dlt2','decision1','decision2','decision3']
        arrays = [states[key] for key in keys]
        if len(arrays) < 2:
            return arrays[0]
        return pd.concat(arrays,axis=1)
    
# data = DTDataset()
test_ids = []
outcomes = data.processed_df[Const.outcomes+Const.decisions]
ids = outcomes.index.values
for outcome in Const.outcomes+Const.decisions:
    for res in [0,1]:
        sample = outcomes[outcomes[outcome].astype(int) == res].sample(frac = .05).reset_index()
        toadd = [i for i in sample.id.values if i not in test_ids]
        test_ids.extend(toadd)
train_ids = [i for i in ids if i not in test_ids]
print(len(test_ids),len(train_ids)/len(ids))

146 0.7276119402985075


In [64]:
outcomes.loc[test_ids].mean()

Overall Survival (4 Years)          0.890411
FT                                  0.219178
Aspiration rate Post-therapy        0.178082
Decision 1 (Induction Chemo) Y/N    0.376712
Decision 2 (CC / RT alone)          0.808219
Decision 3 Neck Dissection (Y/N)    0.253425
dtype: float64

In [66]:
outcomes.loc[train_ids].mean()

Overall Survival (4 Years)          0.838462
FT                                  0.169231
Aspiration rate Post-therapy        0.184615
Decision 1 (Induction Chemo) Y/N    0.356410
Decision 2 (CC / RT alone)          0.748718
Decision 3 Neck Dissection (Y/N)    0.189744
dtype: float64

In [70]:
','.join([str(i) for i in test_ids])

'133,47,35,10,279,5056,5035,224,209,10063,2006,5020,271,10014,5080,10097,10125,10106,2032,10169,2024,286,2015,2019,10026,5040,236,187,10161,211,5103,10178,2026,10137,184,199,10040,272,68,5105,10177,228,44,242,9,5101,10104,10165,10007,10133,10145,10016,264,5098,10023,10050,5120,227,5118,2005,5053,10135,5007,10092,36,2001,5115,10005,10102,189,5036,10088,254,10130,10086,25,5001,5065,10084,195,5099,3,5093,10094,7,5038,10068,5032,202,274,45,2017,10176,217,10160,5082,10012,10017,10100,2031,77,10066,5078,117,10010,10170,10190,10058,5049,5086,5052,268,2029,5084,10105,10013,245,5048,2020,215,10046,5117,5033,267,5003,168,31,10049,10180,190,287,284,5054,10101,208,5077,10091,10172,288,5109,10126,10153,10123,5107,194,10131'

In [8]:
def df_to_torch(df,ttype  = torch.FloatTensor):
    values = df.values.astype(float)
    values = torch.from_numpy(values)
    return values.type(ttype)

In [9]:
class SimulatorBase(torch.nn.Module):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [1000],
                 dropout = 0.5,
                 input_dropout=0.1,
                 state = 1,
                 eps = 0.01,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        torch.nn.Module.__init__(self)
        self.state = state
        self.input_dropout = torch.nn.Dropout(input_dropout)
        
        first_layer =torch.nn.Linear(input_size,hidden_layers[0],bias=True)
        layers = [first_layer,torch.nn.ReLU()]
        curr_size = hidden_layers[0]
        for ndim in hidden_layers[1:]:
            layer = torch.nn.Linear(curr_size,ndim)
            curr_size = ndim
            layers.append(layer)
            layers.append(torch.nn.ReLU())
        self.layers = torch.nn.ModuleList(layers)
        self.batchnorm = torch.nn.BatchNorm1d(hidden_layers[-1])
        self.dropout = torch.nn.Dropout(dropout)
        
    
        input_mean = torch.tensor([0])
        input_std = torch.tensor([1])
        self.eps = eps
        self.register_buffer('input_mean', input_mean)
        self.register_buffer('input_std',input_std)
        
        self.sigmoid = torch.nn.Sigmoid()
        self.softmax = torch.nn.LogSoftmax(dim=1)
        self.identifier = 'state'  +str(state) + '_input'+str(input_size) + '_dims' + ','.join([str(h) for h in hidden_layers]) + '_dropout' + str(input_dropout) + ',' + str(dropout)
        
    def normalize(self,x):
        x = (x - self.input_mean + self.eps)/(self.input_std + self.eps)
        return x
    
    def fit_normalizer(self,x):
        input_mean = x.mean(axis=0)
        input_std = x.std(axis=0)
        self.register_buffer('input_mean', input_mean)
        self.register_buffer('input_std',input_std)
        return True
        
class OutcomeSimulator(SimulatorBase):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [500,500],
                 dropout = 0.7,
                 input_dropout=0.1,
                 state = 1,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        super(OutcomeSimulator,self).__init__(input_size,hidden_layers=hidden_layers,dropout=dropout,input_dropout=input_dropout,state=state)
    
        self.disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.primary_disease_states))
        self.nodal_disease_layer = torch.nn.Linear(hidden_layers[-1],len(Const.nodal_disease_states))
        #dlt ratings are 0-4 even though they don't always appear
        
        self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],1) for i in Const.dlt1])
        assert( state in [1,2])
        if state == 1:
#             self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],5) for i in Const.dlt1])
            self.treatment_layer = torch.nn.Linear(hidden_layers[-1],len(Const.modifications))
        else:
            #we only have dlt yes or no for the second state?
#             self.dlt_layers = torch.nn.ModuleList([torch.nn.Linear(hidden_layers[-1],2) for i in Const.dlt2])
            self.treatment_layer = torch.nn.Linear(hidden_layers[-1],len(Const.ccs))

    def forward(self,x):
        x = self.normalize(x)
        x = self.input_dropout(x)
        for layer in self.layers:
            x = layer(x)
#         x = self.batchnorm(x)
        x = self.dropout(x)
        x_pd = self.disease_layer(x)
        x_nd = self.nodal_disease_layer(x)
        x_mod = self.treatment_layer(x)
        x_dlts = [layer(x) for layer in self.dlt_layers]
        
        x_pd = self.softmax(x_pd)
        x_nd = self.softmax(x_nd)
        x_mod = self.softmax(x_mod)
        #dlts are array of nbatch x n_dlts x predictions
#         x_dlts = torch.stack([self.sigmoid(xx) for xx in x_dlts],axis=1)
        x_dlts = torch.cat([self.sigmoid(xx) for xx in x_dlts],axis=1)
        return [x_pd, x_nd, x_mod, x_dlts]
    
class EndpointSimulator(SimulatorBase):
    
    def __init__(self,
                 input_size,
                 hidden_layers = [500],
                 dropout = 0.7,
                 input_dropout=0.1,
                 state = 1,
                ):
        #predicts disease state (sd, pr, cr) for primar and nodal, then dose modications or cc type (depending on state), and [dlt ratings]
        super(EndpointSimulator,self).__init__(input_size,hidden_layers=hidden_layers,dropout=dropout,input_dropout=input_dropout,state=state)
        
        self.outcome_layer = torch.nn.Linear(hidden_layers[-1],len(Const.outcomes))
      
        
    def forward(self,x):
        x = self.normalize(x)
        x = self.input_dropout(x)
        for layer in self.layers:
            x = layer(x)
#         x = self.batchnorm(x)
        x = self.dropout(x)
        x= self.outcome_layer(x)
        x = self.sigmoid(x)
        return x

OutcomeSimulator(3).input_mean

tensor([0])

In [10]:
def nllloss(ytrue,ypred):
    #nll loss with argmax added in
    loss = torch.nn.NLLLoss()
    return loss(ypred,ytrue.argmax(axis=1))

def state_loss(ytrue,ypred,weights=[1,1,1,1]):
    pd_loss = nllloss(ytrue[0],ypred[0])*weights[0]
    nd_loss = nllloss(ytrue[1],ypred[1])*weights[1]
    mod_loss = nllloss(ytrue[2],ypred[2])*weights[2]
    loss = pd_loss + nd_loss + mod_loss
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
#     nloss = torch.nn.NLLLoss()
    bce = torch.nn.BCELoss()
    for i in range(ndlt):
        dlt_loss = bce(dlt_pred[:,i].view(-1),dlt_true[:,i].view(-1))
        loss += dlt_loss*weights[3]/ndlt
    return loss

def outcome_loss(ytrue,ypred,weights=[1,1,1]):
    loss = 0
    nloss = torch.nn.BCELoss()
    for i in range(len(weights)):
        iloss = nloss(ypred[:,i],ytrue[i])*weights[i]
        loss += iloss
    return loss

from sklearn.metrics import balanced_accuracy_score, roc_auc_score,accuracy_score

def mc_metrics(yt,yp,numpy=False,is_dlt=False):
    if not numpy:
        yt = yt .cpu().detach().numpy()
        yp = yp.cpu().detach().numpy()
    #dlt prediction (binary)
    if is_dlt:
        acc = accuracy_score(yt,yp>.5)
        if yt.sum() > 1:
            auc = roc_auc_score(yt,yp)
        else:
            auc=-1
        error = np.mean((yt-yp)**2)
        return {'accuracy': acc, 'mse': error, 'auc': auc}
    #this is a catch for when I se the dlt prediction format (encoded integer ordinal, predict as a categorical and take the argmax)
    elif yt.ndim > 1:
        try:
            bacc = balanced_accuracy_score(yt.argmax(axis=1),yp.argmax(axis=1))
        except:
            bacc = -1
        try:
            roc_micro = roc_auc_score(yt,yp,average='micro')
        except:
            roc_micro=-1
        try:
            roc_macro = roc_auc_score(yt,yp,average='macro')
        except:
            roc_macro = -1
        return {'accuracy': bacc, 'roc_micro': roc_micro,'roc_macro': roc_macro}
    #outcomes (binary)
    else:
        if yp.ndim > 1:
            yp = yp.argmax(axis=1)
        try:
            bacc = accuracy_score(yt,yp)
        except:
            bacc = -1
        try:
            roc = roc_auc_score(yt,yp)
        except:
            roc = -1
        error = np.mean((yt-yp)**2)
        return {'accuracy': bacc, 'mse': error, 'auc': roc}

def state_metrics(ytrue,ypred,numpy=False):
    pd_metrics = mc_metrics(ytrue[0],ypred[0],numpy=numpy)
    nd_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    mod_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    
    dlt_metrics = []
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
    nloss = torch.nn.NLLLoss()
    for i in range(ndlt):
        dm = mc_metrics(dlt_true[:,i],dlt_pred[:,i].view(-1),is_dlt=True)
        dlt_metrics.append(dm)
    dlt_acc =[d['accuracy'] for d in dlt_metrics]
    dlt_error = [d['mse'] for d in dlt_metrics]
    dlt_auc = [d['auc'] for d in dlt_metrics]
    return {'pd': pd_metrics,'nd': nd_metrics,'mod': mod_metrics,'dlts': {'accuracy': dlt_acc,'accuracy_mean': np.mean(dlt_acc),'auc': dlt_auc,'auc_mean': np.mean(dlt_auc)}}
    
def outcome_metrics(ytrue,ypred,numpy=False):
    res = {}
    for i, outcome in enumerate(Const.outcomes):
        metrics = mc_metrics(ytrue[i],ypred[:,i])
        res[outcome] = metrics
    return res


In [11]:
from sklearn.ensemble import RandomForestClassifier

def train_state_rf(model_args={}):
    ids = get_dt_ids()
    
    dataset = DTDataset()

    train_ids = ids[0:int(len(ids)*.7)]
    test_ids = ids[int(len(ids)*.7):]
    
    #most things are multiclass, dlts are several ordinal and outcomes are multiple binary
    xtrain1 = dataset.get_state('baseline',ids=train_ids)
    xtest1 = dataset.get_state('baseline',ids=test_ids)
    
    xtrain2 = dataset.get_input_state(step=2,ids=train_ids)
    xtest2 = dataset.get_input_state(step=2,ids=test_ids)
    
    xtrain3 = dataset.get_input_state(step=3,ids=train_ids)
    xtest3 = dataset.get_input_state(step=3,ids=test_ids)
    
    [pd1_train,nd1_train, mod_train,dlts1_train] = dataset.get_intermediate_outcomes(ids=train_ids)
    [pd2_train,nd2_train, cc_train,dlts2_train] = dataset.get_intermediate_outcomes(step=2,ids=train_ids)
    [pd1_test,nd1_test, mod_test,dlts1_test] = dataset.get_intermediate_outcomes(ids=test_ids)
    [pd2_test,nd2_test, cc_test,dlts2_test] = dataset.get_intermediate_outcomes(step=2,ids=test_ids)
    outcomes_train = dataset.get_state('outcomes',ids=train_ids)
    outcomes_test = dataset.get_state('outcomes',ids=test_ids)
    

    def train_multiclass_rf(xtrain,xtest,ytrain,ytest):
        model = RandomForestClassifier(class_weight='balanced',**model_args).fit(xtrain,ytrain)
        ypred = model.predict(xtest)
        metrics = mc_metrics(ytest.values,ypred,numpy=True)
        return model, metrics
    
    all_metrics = {}
    pd1_model, all_metrics['pd1'] = train_multiclass_rf(xtrain1,xtest1,pd1_train,pd1_test)
    nd1_model, all_metrics['nd1']  = train_multiclass_rf(xtrain1,xtest1,nd1_train,nd1_test)
    mod_model, all_metrics['mod']  = train_multiclass_rf(xtrain1,xtest1,mod_train,mod_test)
    
    return all_metrics

# train_state_rf({'max_depth': 5,'n_estimators': 100})

In [79]:
def train_state(model_args={},
                state=1,
                split=.7,
                lr=.0001,
                epochs=1000,
                patience=10,
                weights=[1,1,1,10],
                save_path='../data/models/',
                use_default_split=True,
                use_bagging_split=False,
                resample_training=False,#use bootstraping on training data after splitting
                n_validation_trainsteps=2,
                file_suffix=''):
    
    ids = get_dt_ids()
    
    train_ids, test_ids = get_tt_split(use_default_split=use_default_split,use_bagging_split=use_bagging_split,resample_training=resample_training)
    
    dataset = DTDataset()
    
    xtrain = dataset.get_input_state(step=state,ids=train_ids)
    xtest = dataset.get_input_state(step=state,ids=test_ids)
    ytrain = dataset.get_intermediate_outcomes(step=state,ids=train_ids)
    ytest = dataset.get_intermediate_outcomes(step=state,ids=test_ids)
    

    if state < 3:
        model = OutcomeSimulator(xtrain.shape[1],state=state,**model_args)
        lfunc = state_loss
    else:
        model = EndpointSimulator(xtrain.shape[1],**model_args)
        weights = weights[:3]
        lfunc = outcome_loss
        
    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    save_file = save_path + 'model_' + model.identifier + '_split' + str(split) + '_resample' + str(resample_training) +  '_hash' + hashcode + file_suffix + '.tar'
    xtrain = df_to_torch(xtrain)
    xtest = df_to_torch(xtest)
    ytrain = [df_to_torch(t) for t in ytrain]
    ytest= [df_to_torch(t) for t in ytest]
    
    model.fit_normalizer(xtrain)
#     normalize = lambda x: (x - xtrain.mean(axis=0)+.01)/(xtrain.std(axis=0)+.01)
#     unnormalize = lambda x: (x * (xtrain.std(axis=0) +.01)) + xtrain.mean(axis=0) - .01
    
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    best_val_loss = 1000000000000000000000000000
    best_loss_metrics = {}
    last_epoch = False
    for epoch in range(epochs):
        
        model.train(True)
        optimizer.zero_grad()
        
        xtrain_sample = xtrain#[torch.randint(len(xtrain),(len(xtrain),) )]
        ypred = model(xtrain_sample)
        loss = lfunc(ytrain,ypred,weights=weights)

        loss.backward()
        optimizer.step()
        print('epoch',epoch,'train loss',loss.item())
        
        model.eval()
        yval = model(xtest)
        val_loss = lfunc(ytest,yval,weights=weights)
        if state < 3:
            val_metrics = state_metrics(ytest,yval)
        else:
            val_metrics = outcome_metrics(ytest,yval)
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            best_loss_metrics = val_metrics
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        print('val loss',val_loss.item())
        print('______________')
        if steps_since_improvement > patience:
            break
    print('best loss',best_val_loss,best_loss_metrics)
    model.load_state_dict(torch.load(save_file))
    
    #train one step on validation data
    for i in range(n_validation_trainsteps):
        model.train()
        yval = model(xtest)
        val_loss = lfunc(ytest,yval,weights=weights)
        val_loss.backward()
        optimizer.step()
        torch.save(model.state_dict(),save_file)
    
    model.eval()
    return model

model = train_state(state=1)
model

  


epoch 0 train loss 11.321202278137207
val loss 11.089699745178223
______________
epoch 1 train loss 11.161775588989258
val loss 10.948609352111816
______________
epoch 2 train loss 10.995708465576172
val loss 10.809993743896484
______________
epoch 3 train loss 10.84663200378418
val loss 10.673552513122559
______________
epoch 4 train loss 10.72703742980957
val loss 10.53877067565918
______________
epoch 5 train loss 10.583897590637207
val loss 10.405231475830078
______________
epoch 6 train loss 10.429883003234863
val loss 10.272747993469238
______________
epoch 7 train loss 10.273176193237305
val loss 10.140966415405273
______________
epoch 8 train loss 10.197378158569336
val loss 10.009674072265625
______________
epoch 9 train loss 10.046801567077637
val loss 9.878469467163086
______________
epoch 10 train loss 9.916838645935059
val loss 9.747241020202637
______________
epoch 11 train loss 9.772366523742676
val loss 9.615583419799805
______________
epoch 12 train loss 9.635926246643

epoch 103 train loss 2.832357168197632
val loss 2.9768619537353516
______________
epoch 104 train loss 2.7890467643737793
val loss 2.964097738265991
______________
epoch 105 train loss 2.791548490524292
val loss 2.9515721797943115
______________
epoch 106 train loss 2.8215065002441406
val loss 2.9391398429870605
______________
epoch 107 train loss 2.779176712036133
val loss 2.926924705505371
______________
epoch 108 train loss 2.730639934539795
val loss 2.914900302886963
______________
epoch 109 train loss 2.76462459564209
val loss 2.902996301651001
______________
epoch 110 train loss 2.742753744125366
val loss 2.8912649154663086
______________
epoch 111 train loss 2.7175374031066895
val loss 2.879735231399536
______________
epoch 112 train loss 2.693880558013916
val loss 2.8682944774627686
______________
epoch 113 train loss 2.7065536975860596
val loss 2.8569464683532715
______________
epoch 114 train loss 2.577432155609131
val loss 2.845661163330078
______________
epoch 115 train los

epoch 208 train loss 1.949485182762146
val loss 2.200037956237793
______________
epoch 209 train loss 1.838420033454895
val loss 2.196558952331543
______________
epoch 210 train loss 1.88373601436615
val loss 2.193074941635132
______________
epoch 211 train loss 1.8465343713760376
val loss 2.1896445751190186
______________
epoch 212 train loss 1.8421432971954346
val loss 2.1863186359405518
______________
epoch 213 train loss 1.8484703302383423
val loss 2.183077812194824
______________
epoch 214 train loss 1.878557801246643
val loss 2.179844379425049
______________
epoch 215 train loss 1.846774697303772
val loss 2.176680088043213
______________
epoch 216 train loss 1.8362205028533936
val loss 2.173482894897461
______________
epoch 217 train loss 1.7877600193023682
val loss 2.1703855991363525
______________
epoch 218 train loss 1.7917028665542603
val loss 2.1672887802124023
______________
epoch 219 train loss 1.8314651250839233
val loss 2.1642162799835205
______________
epoch 220 train l

epoch 313 train loss 1.5138837099075317
val loss 1.9934993982315063
______________
epoch 314 train loss 1.5356178283691406
val loss 1.9923373460769653
______________
epoch 315 train loss 1.4871935844421387
val loss 1.991365909576416
______________
epoch 316 train loss 1.532147765159607
val loss 1.9905414581298828
______________
epoch 317 train loss 1.5246645212173462
val loss 1.9898080825805664
______________
epoch 318 train loss 1.5088404417037964
val loss 1.9891016483306885
______________
epoch 319 train loss 1.5125858783721924
val loss 1.9883733987808228
______________
epoch 320 train loss 1.5278825759887695
val loss 1.9875526428222656
______________
epoch 321 train loss 1.545936942100525
val loss 1.9867939949035645
______________
epoch 322 train loss 1.5151132345199585
val loss 1.9861658811569214
______________
epoch 323 train loss 1.5465316772460938
val loss 1.9856656789779663
______________
epoch 324 train loss 1.4800807237625122
val loss 1.985215663909912
______________
epoch 32

epoch 418 train loss 1.3642195463180542
val loss 1.9542382955551147
______________
epoch 419 train loss 1.3180352449417114
val loss 1.9541406631469727
______________
epoch 420 train loss 1.3588937520980835
val loss 1.9539542198181152
______________
epoch 421 train loss 1.3461133241653442
val loss 1.9537417888641357
______________
epoch 422 train loss 1.3705823421478271
val loss 1.9533976316452026
______________
epoch 423 train loss 1.2959614992141724
val loss 1.9532790184020996
______________
epoch 424 train loss 1.3617805242538452
val loss 1.9531097412109375
______________
epoch 425 train loss 1.3407927751541138
val loss 1.952874779701233
______________
epoch 426 train loss 1.2977375984191895
val loss 1.9526695013046265
______________
epoch 427 train loss 1.349993348121643
val loss 1.952338695526123
______________
epoch 428 train loss 1.3461475372314453
val loss 1.9522383213043213
______________
epoch 429 train loss 1.3086047172546387
val loss 1.9522358179092407
______________
epoch 4

OutcomeSimulator(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=63, out_features=500, bias=True)
    (1): ReLU()
    (2): Linear(in_features=500, out_features=500, bias=True)
    (3): ReLU()
  )
  (batchnorm): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.7, inplace=False)
  (sigmoid): Sigmoid()
  (softmax): LogSoftmax(dim=1)
  (disease_layer): Linear(in_features=500, out_features=3, bias=True)
  (nodal_disease_layer): Linear(in_features=500, out_features=3, bias=True)
  (dlt_layers): ModuleList(
    (0): Linear(in_features=500, out_features=1, bias=True)
    (1): Linear(in_features=500, out_features=1, bias=True)
    (2): Linear(in_features=500, out_features=1, bias=True)
    (3): Linear(in_features=500, out_features=1, bias=True)
    (4): Linear(in_features=500, out_features=1, bias=True)
    (5): Linear(in_features=500, out_features=1, bias=True)
    (6): Linear(in_feat

In [82]:
model2 = train_state(state=2)
model2

  


epoch 0 train loss 10.57856559753418
val loss 10.362581253051758
______________
epoch 1 train loss 10.434846878051758
val loss 10.217052459716797
______________
epoch 2 train loss 10.315289497375488
val loss 10.075185775756836
______________
epoch 3 train loss 10.122429847717285
val loss 9.936813354492188
______________
epoch 4 train loss 10.021270751953125
val loss 9.800786972045898
______________
epoch 5 train loss 9.910812377929688
val loss 9.666879653930664
______________
epoch 6 train loss 9.745349884033203
val loss 9.534619331359863
______________




epoch 7 train loss 9.647045135498047
val loss 9.403688430786133
______________
epoch 8 train loss 9.494948387145996
val loss 9.273765563964844
______________
epoch 9 train loss 9.372906684875488
val loss 9.144530296325684
______________
epoch 10 train loss 9.221898078918457
val loss 9.015680313110352
______________
epoch 11 train loss 9.106910705566406
val loss 8.887109756469727
______________
epoch 12 train loss 8.969865798950195
val loss 8.758516311645508
______________
epoch 13 train loss 8.84380054473877
val loss 8.629800796508789
______________
epoch 14 train loss 8.7363862991333
val loss 8.50080680847168
______________
epoch 15 train loss 8.607129096984863
val loss 8.371540069580078
______________
epoch 16 train loss 8.47675895690918
val loss 8.242104530334473
______________
epoch 17 train loss 8.370414733886719
val loss 8.112591743469238
______________
epoch 18 train loss 8.172536849975586
val loss 7.982658863067627
______________
epoch 19 train loss 8.088911056518555
val loss 7

epoch 112 train loss 3.3845198154449463
val loss 3.429992914199829
______________
epoch 113 train loss 3.3996238708496094
val loss 3.425759792327881
______________
epoch 114 train loss 3.309802293777466
val loss 3.42155122756958
______________
epoch 115 train loss 3.2966971397399902
val loss 3.417501449584961
______________
epoch 116 train loss 3.307574987411499
val loss 3.413276433944702
______________
epoch 117 train loss 3.327221393585205
val loss 3.4087932109832764
______________
epoch 118 train loss 3.2286341190338135
val loss 3.404167413711548
______________
epoch 119 train loss 3.332287311553955
val loss 3.399627208709717
______________
epoch 120 train loss 3.339592456817627
val loss 3.395141839981079
______________
epoch 121 train loss 3.310680627822876
val loss 3.3910908699035645
______________
epoch 122 train loss 3.3207902908325195
val loss 3.386984348297119
______________
epoch 123 train loss 3.2160208225250244
val loss 3.383061170578003
______________
epoch 124 train loss 

epoch 217 train loss 2.7304670810699463
val loss 3.1781890392303467
______________
epoch 218 train loss 2.748652696609497
val loss 3.178069591522217
______________
epoch 219 train loss 2.7663278579711914
val loss 3.1777775287628174
______________
epoch 220 train loss 2.74898099899292
val loss 3.1775193214416504
______________
epoch 221 train loss 2.777428388595581
val loss 3.176987648010254
______________
epoch 222 train loss 2.7731614112854004
val loss 3.1758391857147217
______________
epoch 223 train loss 2.7273526191711426
val loss 3.174546957015991
______________
epoch 224 train loss 2.704627513885498
val loss 3.1734626293182373
______________
epoch 225 train loss 2.826051950454712
val loss 3.172553300857544
______________
epoch 226 train loss 2.726706027984619
val loss 3.171705961227417
______________
epoch 227 train loss 2.7222886085510254
val loss 3.1705539226531982
______________
epoch 228 train loss 2.7674498558044434
val loss 3.169396162033081
______________
epoch 229 train l

OutcomeSimulator(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=85, out_features=500, bias=True)
    (1): ReLU()
    (2): Linear(in_features=500, out_features=500, bias=True)
    (3): ReLU()
  )
  (batchnorm): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.7, inplace=False)
  (sigmoid): Sigmoid()
  (softmax): LogSoftmax(dim=1)
  (disease_layer): Linear(in_features=500, out_features=3, bias=True)
  (nodal_disease_layer): Linear(in_features=500, out_features=3, bias=True)
  (dlt_layers): ModuleList(
    (0): Linear(in_features=500, out_features=1, bias=True)
    (1): Linear(in_features=500, out_features=1, bias=True)
    (2): Linear(in_features=500, out_features=1, bias=True)
    (3): Linear(in_features=500, out_features=1, bias=True)
    (4): Linear(in_features=500, out_features=1, bias=True)
    (5): Linear(in_features=500, out_features=1, bias=True)
    (6): Linear(in_feat

In [83]:
model3 = train_state(state=3,epochs=10000,lr=.0001)
model3

  


epoch 0 train loss 2.2150607109069824
val loss 2.1706087589263916
______________
epoch 1 train loss 2.187577247619629
val loss 2.1569275856018066
______________
epoch 2 train loss 2.212770938873291
val loss 2.1434311866760254
______________
epoch 3 train loss 2.170140266418457
val loss 2.1301097869873047
______________
epoch 4 train loss 2.152308940887451
val loss 2.1169543266296387
______________
epoch 5 train loss 2.126622438430786
val loss 2.103971004486084
______________
epoch 6 train loss 2.14467191696167
val loss 2.0911693572998047
______________
epoch 7 train loss 2.1148276329040527
val loss 2.0785794258117676
______________
epoch 8 train loss 2.1283211708068848
val loss 2.0662038326263428
______________
epoch 9 train loss 2.0900354385375977
val loss 2.0540413856506348
______________
epoch 10 train loss 2.062458038330078
val loss 2.042065143585205
______________
epoch 11 train loss 2.0469765663146973
val loss 2.030250072479248
______________
epoch 12 train loss 2.073414802551269

val loss 1.4905436038970947
______________
epoch 103 train loss 1.4160832166671753
val loss 1.4880157709121704
______________
epoch 104 train loss 1.408632516860962
val loss 1.485527753829956
______________
epoch 105 train loss 1.4423518180847168
val loss 1.4830715656280518
______________
epoch 106 train loss 1.3963558673858643
val loss 1.4806634187698364
______________
epoch 107 train loss 1.4090056419372559
val loss 1.47830331325531
______________
epoch 108 train loss 1.368709683418274
val loss 1.4759641885757446
______________
epoch 109 train loss 1.397509217262268
val loss 1.4736557006835938
______________
epoch 110 train loss 1.3965606689453125
val loss 1.4713646173477173
______________
epoch 111 train loss 1.3788970708847046
val loss 1.4691144227981567
______________
epoch 112 train loss 1.3915494680404663
val loss 1.4668760299682617
______________
epoch 113 train loss 1.3736129999160767
val loss 1.4646542072296143
______________
epoch 114 train loss 1.4062837362289429
val loss 1

epoch 206 train loss 1.1793824434280396
val loss 1.3416240215301514
______________
epoch 207 train loss 1.1762053966522217
val loss 1.3408727645874023
______________
epoch 208 train loss 1.202772855758667
val loss 1.3401318788528442
______________
epoch 209 train loss 1.2178564071655273
val loss 1.3393971920013428
______________
epoch 210 train loss 1.204088807106018
val loss 1.3386671543121338
______________
epoch 211 train loss 1.2195024490356445
val loss 1.337947130203247
______________
epoch 212 train loss 1.2011961936950684
val loss 1.3372386693954468
______________
epoch 213 train loss 1.2096660137176514
val loss 1.3365379571914673
______________
epoch 214 train loss 1.2016918659210205
val loss 1.3358380794525146
______________
epoch 215 train loss 1.177890419960022
val loss 1.3351479768753052
______________
epoch 216 train loss 1.1805782318115234
val loss 1.334464430809021
______________
epoch 217 train loss 1.1675283908843994
val loss 1.3337929248809814
______________
epoch 218

epoch 312 train loss 1.073913335800171
val loss 1.2947970628738403
______________
epoch 313 train loss 1.091803789138794
val loss 1.294551968574524
______________
epoch 314 train loss 1.0972635746002197
val loss 1.2943083047866821
______________
epoch 315 train loss 1.106309175491333
val loss 1.2940764427185059
______________
epoch 316 train loss 1.1151673793792725
val loss 1.293856143951416
______________
epoch 317 train loss 1.0560593605041504
val loss 1.293648362159729
______________
epoch 318 train loss 1.0739567279815674
val loss 1.2934504747390747
______________
epoch 319 train loss 1.0756667852401733
val loss 1.2932730913162231
______________
epoch 320 train loss 1.089507818222046
val loss 1.2930877208709717
______________
epoch 321 train loss 1.0701841115951538
val loss 1.2929078340530396
______________
epoch 322 train loss 1.0558733940124512
val loss 1.2927379608154297
______________
epoch 323 train loss 1.0734230279922485
val loss 1.2925739288330078
______________
epoch 324 t

val loss 1.2828675508499146
______________
epoch 414 train loss 1.0388920307159424
val loss 1.2828478813171387
______________
epoch 415 train loss 1.0145819187164307
val loss 1.2828361988067627
______________
epoch 416 train loss 1.0266978740692139
val loss 1.2827904224395752
______________
epoch 417 train loss 1.00140380859375
val loss 1.2827345132827759
______________
epoch 418 train loss 1.0093512535095215
val loss 1.282717227935791
______________
epoch 419 train loss 1.012784719467163
val loss 1.2827187776565552
______________
epoch 420 train loss 1.0083370208740234
val loss 1.2827190160751343
______________
epoch 421 train loss 0.9897663593292236
val loss 1.2827210426330566
______________
epoch 422 train loss 1.0184508562088013
val loss 1.2827229499816895
______________
epoch 423 train loss 0.9804401397705078
val loss 1.2827154397964478
______________
epoch 424 train loss 0.9961447715759277
val loss 1.282726764678955
______________
epoch 425 train loss 1.0275061130523682
val loss 

EndpointSimulator(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=83, out_features=500, bias=True)
    (1): ReLU()
  )
  (batchnorm): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.7, inplace=False)
  (sigmoid): Sigmoid()
  (softmax): LogSoftmax(dim=1)
  (outcome_layer): Linear(in_features=500, out_features=3, bias=True)
)

In [130]:
class DecisionModel(SimulatorBase):
    
    def __init__(self,
                 baseline_input_size,#number of baseline features used
                 hidden_layers = [500],
                 dropout = 0.05,
                 input_dropout=0,
                 state = 1,
                 eps = 0.01,
                 ):
        #input will be all states up until treatment 3
        input_size = baseline_input_size  + len(Const.dlt1) + len(Const.primary_disease_states)  + len(Const.nodal_disease_states)  + len(Const.ccs)  + len(Const.modifications) + 2
            
        super(DecisionModel,self).__init__(input_size,hidden_layers=hidden_layers,dropout=dropout,input_dropout=input_dropout,eps=eps,state='decisions')
        self.final_layer = torch.nn.Linear(hidden_layers[-1],len(Const.decisions))

#         self.final_layer = torch.nn.Linear(hidden_layers[-1],1)
        self.sigmoid = torch.nn.Sigmoid()
        
    def add_position_token(self,x,position):
        #add 2 binary variables for if the state has already passed
        if position == 0:
            token = torch.zeros((x.shape[0],2))
            x = torch.cat([x,token],dim=1)
        if position == 1:
            token1 = torch.ones((x.shape[0],1))
            token2 = torch.zeros((x.shape[0],1))
            x = torch.cat([x,token1,token2],dim=1)
        if position == 2:
            token1 = torch.zeros((x.shape[0],1))
            token2 = torch.ones((x.shape[0],1))
            x = torch.cat([x,token1,token2],dim=1)
        if position == 3:
            token1 = torch.ones((x.shape[0],1))
            token2 = torch.ones((x.shape[0],1))
            x = torch.cat([x,token1,token2],dim=1)
        return x
    
    def get_embedding(self,xbase,xdlt,xpd,xnd,xcc,xmod,position=0):
        xbase = self.normalize(xbase)
        x = torch.cat([xbase,xdlt,xpd,xnd,xcc,xmod],dim=1)
        x = self.add_position_token(x,position)
        for layer in self.layers:
            x = layer(x)
        return x
    
    def forward(self,xbase,xdlt,xpd,xnd,xcc,xmod,position=0):
        #position is 0-2
#         [xbase, xdlt, xpd, xnd, xcc,xmod] = x
        xbase = self.normalize(xbase)
        x = torch.cat([xbase,xdlt,xpd,xnd,xcc,xmod],dim=1)
        x = self.input_dropout(x)
        x = self.add_position_token(x,position)
        for layer in self.layers:
            x = layer(x)
        x = self.dropout(x)
        x = self.final_layer(x)
        x = self.sigmoid(x)
        return x
test = DecisionModel(3)
test.identifier

'statedecisions_input30_dims500_dropout0,0.05'

In [101]:
def shuffle_col(v,col=None):
    if col is None:
        col = np.random.choice([i for i in range(v.shape[1])])
    idx = torch.randperm(v.shape[0])
    vv = torch.clone(v)
    vv[:,col] = vv[idx,col]
    return vv


test = torch.rand((3,5))
print(test)
print(shuffle_col(test))

tensor([[0.8709, 0.1051, 0.9089, 0.3796, 0.1637],
        [0.9283, 0.7037, 0.7089, 0.0402, 0.7564],
        [0.7775, 0.7739, 0.6718, 0.3050, 0.9293]])
tensor([[0.8709, 0.1051, 0.9089, 0.0402, 0.1637],
        [0.9283, 0.7037, 0.7089, 0.3796, 0.7564],
        [0.7775, 0.7739, 0.6718, 0.3050, 0.9293]])


In [None]:

    
def train_decision_model(
    tmodel1,
    tmodel2,
    tmodel3,
    use_default_split=True,
    use_bagging_split=False,
    lr=.0001,
    epochs=1000,
    patience=100,
    weights=[1,1,1], #realtive weight of survival, feeding tube, and aspiration
    imitation_weight=1,
    shufflecol_chance = 0.1,
    reward_weight=10,
    split=.7,
    resample_training=False,
    save_path='../data/models/',
    file_suffix='',
):
    
    tmodel1.eval()
    tmodel2.eval()
    tmodel3.eval()

    
    train_ids, test_ids = get_tt_split(use_default_split=use_default_split,use_bagging_split=use_bagging_split,resample_training=resample_training)
    
    dataset = DTDataset()
    
    data = dataset.processed_df.copy()
    
    def get_dlt(state):
        if state == 2:
            return data[Const.dlt2].copy()
        d = data[Const.dlt1].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_pd(state):
        if state == 2:
            return data[Const.primary_disease_states2].copy()
        d = data[Const.primary_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_nd(state):
        if state == 2:
            return data[Const.nodal_disease_states2].copy()
        d = data[Const.nodal_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_cc(state):
        res = data[Const.ccs].copy()
        if state == 1:
            res.values[:,:] = np.zeros(res.values.shape)
        return res
    
    def get_mod(state):
        res = data[Const.modifications].copy()
        return res
        
    outcomedf = data[Const.outcomes]
    baseline = dataset.get_state('baseline')
    
    def formatdf(d,dids=train_ids):
        d = df_to_torch(d.loc[dids])
        return d
    
    def makegrad(v):
        if not v.requires_grad:
            v.requires_grad=True
        return v
    
    model = DecisionModel(baseline.shape[1])

    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    
    save_file = save_path + 'model_' + model.identifier +'_hash' + hashcode + file_suffix + '.tar'
    model.fit_normalizer(df_to_torch(baseline.loc[train_ids]))
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)

    def outcome_loss(ypred):
        #convert survival to death
        loss = torch.mul(torch.mean(-1*(ypred[:,0] - 1)),weights[0])
        for i,weight in enumerate(weights[1:]):
            newloss = torch.mean(ypred[:,i])*weight
            loss = torch.add(loss,torch.mul(newloss,weight))
        return loss
    
    mse = torch.nn.MSELoss()
    nllloss = torch.nn.NLLLoss()
    bce = torch.nn.BCELoss()
    
    def compare_decisions(d1,d2,d3,ids):
#         ypred = np.concatenate([dd.cpu().detach().numpy().reshape(-1,1) for dd in [d1,d2,d3]],axis=1)
        ytrue = df_to_torch(outcomedf.loc[ids])
        dloss = bce(d1.view(-1),ytrue[:,0])
        dloss += bce(d2.view(-1),ytrue[:,1])
        dloss += bce(d3,view(-1),ytrue[:,2])
        return dloss
        
    def remove_decisions(df):
        cols = [c for c in df.columns if c not in Const.decisions ]
        ddf = df[cols]
        return ddf
    
    makeinput = lambda step,dids: df_to_torch(remove_decisions(dataset.get_input_state(step=step,ids=dids)))
    
    def step(train=True):
        if train:
            model.train(True)
            optimizer.zero_grad()
            ids = train_ids
        else:
            ids = test_ids
            model.eval()
            
        xx1 = makeinput(1,ids)
        xx2 = makeinput(2,ids)
        xx3 = makeinput(3,ids)
        ytrain = df_to_torch(outcomedf.loc[ids])

        baseline_train_base = formatdf(baseline,ids)
        xxtrain = [baseline, get_dlt(0),get_pd(0),get_nd(0),get_cc(0),get_mod(0)]
        xxtrain = [formatdf(xx,ids) for xx in xxtrain]
            
        baseline_train = torch.clone(baseline_train_base)
        if train and shufflecol_chance > 0.0001:
            for col in range(baseline_train_base.shape[1]): 
                if np.random.random() < shufflecol_chance:
                    baseline_train = shuffle_col(baseline_train,col)
            
            
        decision1 = model(*xxtrain,position=0)[:,0]
        imitation_loss1 = bce(decision1,ytrain[:,0])

        xi1 = torch.cat([xx1,decision1.view(-1,1)],axis=1)
        [ypd1, ynd1, ymod, ydlt1] = tmodel1(xi1)
        x1 = [baseline_train,ydlt1,ypd1,ynd1,formatdf(get_cc(1),ids),ymod]
            
        decision2 = model(*x1,position=1)[:,1]
        imitation_loss2 =  bce(decision2,ytrain[:,1])

        xi2 = torch.cat([xx2,decision1.view(-1,1),decision2.view(-1,1)],axis=1)
        [ypd2,ynd2,ycc,ydlt2] = tmodel2(xi2)
        x2 = [baseline_train,ydlt2,ypd2,ynd2,ycc,ymod]
            
        decision3 = model(*x2,position=2)[:,2]
        imitation_loss3 = bce(decision3,ytrain[:,2])
        
        xi3 = torch.cat([xx3,decision1.view(-1,1),decision2.view(-1,1),decision3.view(-1,1)],axis=1)
        outcomes = tmodel3(xi3)

        reward_loss = outcome_loss(outcomes)
        loss = torch.add(imitation_loss1,imitation_loss2)
        loss = torch.add(loss,imitation_loss3)
        loss = torch.mul(loss,imitation_weight/3)
        loss = torch.add(loss,torch.mul(reward_loss,reward_weight))
        losses = [imitation_loss1+imitation_loss2+imitation_loss3,reward_loss]
        if train:
            loss.backward()
            optimizer.step()
            return losses
        else:
            scores = []
            for i,decision in enumerate([decision1,decision2,decision3]):
                dec = decision.cpu().detach().numpy()
                dec0 = (dec > .5).astype(int)
                out = ytrain[:,i].cpu().detach().numpy()
                acc = accuracy_score(out,dec > .5)
                auc = roc_auc_score(out,dec)
                scores.append({'decision': i,'accuracy': acc,'auc': auc})
            return losses, scores
        
    best_val_loss = torch.tensor(1000000000.0)
    steps_since_improvement = 0
    best_val_score = {}
    for epoch in range(epochs):
        print('______epoch',str(epoch),'_____')
        losses = step(True)
        print('imitation',losses[0].item(),'reward',losses[1].item())
        val_losses,val_metrics = step(False)
        vl = val_losses[0] + val_losses[1]
        print('imitation',val_losses[0].item(),'reward',val_losses[1].item())
        print(vl.item(),best_val_loss.item())
        print(val_metrics)
        if vl < best_val_loss:
            best_val_loss = vl
            best_val_score = val_metrics
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        if steps_since_improvement > patience:
            break
    print('++++++++++Final+++++++++++')
    print('best',best_val_loss)
    print(best_val_score)
    model.load_state_dict(torch.load(save_file))
    model.eval()
    return model, best_val_score

decision_model, _ = train_decision_model(model,model2,model3,imitation_weight=.1,reward_weight=1,lr=.0001)

  


______epoch 0 _____
imitation 1.9935460090637207 reward 1.1700947284698486
imitation 1.9789232015609741 reward 1.1610314846038818
3.1399545669555664 1000000000.0
[{'decision': 0, 'accuracy': 0.5684931506849316, 'auc': 0.5355769230769231}, {'decision': 1, 'accuracy': 0.6095890410958904, 'auc': 0.4673793859649123}, {'decision': 2, 'accuracy': 0.8082191780821918, 'auc': 0.5173076923076924}]
______epoch 1 _____
imitation 1.9583206176757812 reward 1.169677495956421
imitation 1.9481981992721558 reward 1.160551905632019
3.108750104904175 3.1399545669555664
[{'decision': 0, 'accuracy': 0.5958904109589042, 'auc': 0.539423076923077}, {'decision': 1, 'accuracy': 0.678082191780822, 'auc': 0.4723135964912281}, {'decision': 2, 'accuracy': 0.8013698630136986, 'auc': 0.5230769230769231}]
______epoch 2 _____
imitation 1.9349699020385742 reward 1.1691067218780518
imitation 1.91902494430542 reward 1.1600759029388428
3.0791008472442627 3.108750104904175
[{'decision': 0, 'accuracy': 0.6027397260273972, 'au

imitation 1.5459859371185303 reward 1.1611050367355347
imitation 1.6114351749420166 reward 1.1524486541748047
2.7638838291168213 2.7701916694641113
[{'decision': 0, 'accuracy': 0.8698630136986302, 'auc': 0.5682692307692307}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.5764802631578947}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.5961538461538461}]
______epoch 23 _____
imitation 1.5437557697296143 reward 1.1607632637023926
imitation 1.6060019731521606 reward 1.152190923690796
2.758193016052246 2.7638838291168213
[{'decision': 0, 'accuracy': 0.8698630136986302, 'auc': 0.5682692307692309}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.5767543859649122}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6000000000000001}]
______epoch 24 _____
imitation 1.5284197330474854 reward 1.160513162612915
imitation 1.6011345386505127 reward 1.1519453525543213
2.753079891204834 2.758193016052246
[{'decision': 0, 'accuracy': 0.8698630136986302, 'auc': 0.56923

imitation 1.4527946710586548 reward 1.1573429107666016
imitation 1.5623259544372559 reward 1.1490483283996582
2.711374282836914 2.7123494148254395
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5764423076923078}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.5844298245614035}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6339743589743589}]
______epoch 45 _____
imitation 1.4497404098510742 reward 1.15727961063385
imitation 1.5614336729049683 reward 1.1489769220352173
2.7104105949401855 2.711374282836914
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5783653846153847}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.5838815789473685}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6333333333333333}]
______epoch 46 _____
imitation 1.438704013824463 reward 1.1572072505950928
imitation 1.5605117082595825 reward 1.148910403251648
2.7094221115112305 2.7104105949401855
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.579326

imitation 1.4009910821914673 reward 1.1564899682998657
imitation 1.534541130065918 reward 1.1482508182525635
2.6827919483184814 2.684283494949341
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6004807692307692}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.5858004385964912}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6560897435897436}]
______epoch 68 _____
imitation 1.387574553489685 reward 1.1564518213272095
imitation 1.5330464839935303 reward 1.1482411623001099
2.6812877655029297 2.6827919483184814
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6009615384615385}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.5866228070175439}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6564102564102565}]
______epoch 69 _____
imitation 1.3930456638336182 reward 1.1564115285873413
imitation 1.5315347909927368 reward 1.1482328176498413
2.679767608642578 2.6812877655029297
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6014

imitation 1.3520839214324951 reward 1.1564733982086182
imitation 1.5014947652816772 reward 1.1482086181640625
2.6497035026550293 2.651230573654175
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6197115384615385}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.5964912280701755}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6650641025641026}]
______epoch 89 _____
imitation 1.345836877822876 reward 1.156463384628296
imitation 1.4999616146087646 reward 1.1482112407684326
2.6481728553771973 2.6497035026550293
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6201923076923076}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.5967653508771931}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6647435897435897}]
______epoch 90 _____
imitation 1.3436801433563232 reward 1.156445860862732
imitation 1.4984498023986816 reward 1.1482139825820923
2.6466636657714844 2.6481728553771973
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6211

imitation 1.2933346033096313 reward 1.156572699546814
imitation 1.4730502367019653 reward 1.1482806205749512
2.621330738067627 2.622519016265869
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6451923076923077}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6052631578947368}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.671474358974359}]
______epoch 110 _____
imitation 1.2957953214645386 reward 1.1565717458724976
imitation 1.4718806743621826 reward 1.1482834815979004
2.620164155960083 2.621330738067627
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.645673076923077}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6055372807017544}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6708333333333333}]
______epoch 111 _____
imitation 1.2965202331542969 reward 1.1565231084823608
imitation 1.4707262516021729 reward 1.1482864618301392
2.6190128326416016 2.620164155960083
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.647596

In [126]:
from captum.attr import IntegratedGradients

ig = IntegratedGradients(decision_model)
ds = DTDataset()
states = DTDataset().get_states()
x = [states['baseline'],states['dlt1'],states['pd_states1'],states['nd_states1'],states['ccs'],states['modifications']]
x = tuple([df_to_torch(xx) for xx in x])
base = tuple([torch.zeros(xx.shape,requires_grad=True) for xx in x])
attributions = ig.attribute(x,base,target=1,method='gausslegendre')
attributions
# test = pd.DataFrame(attributions,columns = xtestdf.columns, index=xtestdf.index)
# test.describe().T

  after removing the cwd from sys.path.
  """


(tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.3672],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.4814],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.5516],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0022,  0.4200],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0062,  0.4519],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.5110]],
        dtype=torch.float64, grad_fn=<MulBackward0>),
 tensor([[-0., -0., -0.,  ..., -0., -0., -0.],
         [-0., -0., -0.,  ..., -0., -0., -0.],
         [-0., -0., -0.,  ..., -0., -0., -0.],
         ...,
         [-0., -0., -0.,  ..., -0., -0., -0.],
         [-0., -0., -0.,  ..., -0., -0., -0.],
         [-0., -0., -0.,  ..., -0., -0., -0.]], dtype=torch.float64,
        grad_fn=<MulBackward0>),
 tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         ...,
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]], d

In [127]:
states['baseline']

Unnamed: 0_level_0,1A,1A1B,1A6,1B,1B2A,1B3,2A,2A2B,2A3,2B,...,ln_cluster_3,ln_cluster_4,packs_per_year,smoking_status,subsite_BOT,subsite_GPS,subsite_NOS,subsite_Soft palate,subsite_Tonsil,total_dose
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,...,0,0,0.0,0.0,1,0,0,0,0,66.00
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,38.0,1.0,1,0,0,0,0,72.00
6,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,...,0,0,35.0,1.0,1,0,0,0,0,70.00
7,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,...,0,0,0.0,1.0,0,0,1,0,0,70.00
8,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,...,0,0,0.0,0.0,0,0,0,0,1,66.00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10201,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,30.0,1.0,1,0,0,0,0,70.00
10202,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,...,0,0,30.0,1.0,0,0,1,0,0,72.00
10203,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,...,0,0,0.0,0.0,0,0,0,0,1,70.00
10204,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,...,0,0,5.0,0.5,0,0,0,0,1,69.96


In [None]:
def breakup_state_models(state_model):
    #for state 1 and 2
    models = {}
    models['pd'] = lambda x: state_model(x)[0]
    models['nd'] = lambda x: state_model(x)[1]
    models['chemo'] = lambda x: state_model(x)[2]
    for i,dlt in enumerate(Const.dlt1):
        models[dlt] = lambda x: state_model(x)[3][:,i]
    return models

def breakup_outcome_models(omodel):
    models = {}
    for i,name in enumerate(Const.outcomes):
        models[name] = lambda x: omodel(x)[:,i].reshape(-1,1)
    return models

def get_all_models(m1,m2,m3):
    state1_models = breakup_state_models(m1)
    state2_models = breakup_state_models(m2)
    state3_models = breakup_outcome_models(m3)
    all_models = {}
    for i,sm in enumerate([state1_models,state2_models,state3_models]):
        for ii,m in sm.items():
            all_models[ii +  '_state' + str(i+1)] = m
    return all_models

all_models = get_all_models(model,model2,model3)
all_models

In [None]:

def get_ytrue(name,df):
    outcomes=None
    value = None
    if name == 'pd_state1':
        outcomes = df[Const.primary_disease_states]
    elif name == 'pd_state2':
        outcomes = df[Const.primary_disease_states2]
    elif name == 'nd_state1':
        outcomes = df[Const.nodal_disease_states]
    elif name == 'nd_state2':
        outcomes = df[Const.nodal_disease_states2]
    elif name == 'chemo_state1':
        outcomes = df[Const.modifications]
    elif name == 'chemo_state2':
        outcomes = df[Const.ccs]   
    if outcomes is not None:
        value = outcomes.idxmax(axis=1)
    if 'DLT' in name:
        newname = name.replace('_state', ' ').replace('1','').strip()
        value = df[newname]
    if name.replace('_state3','') in Const.outcomes:
        value = df[name.replace('_state3','')]
    if value is None:
        print(name,df.columns)
    return value

def check_impact_of_decisions(model_dict,data):
    results = []
    #todo: this is wrong fix it
    ids = []
    df = data.get_data()
    outcomedict = {step: pd.concat(data.get_intermediate_outcomes(step=step),axis=1) for step in [1,2,3]}
    for decision in Const.decisions:
        for name, model in model_dict.items():
            step = int(name[-1])
            subset0 = dataset.get_input_state(step=step,fixed={decision: 0})
            subset1 = dataset.get_input_state(step=step,fixed={decision: 1})
            outcomes = outcomedict[step]
            ids = subset0.index.values
            x0 = df_to_torch(subset0)
            x1 = df_to_torch(subset1)
            y0 = model(x0).detach().cpu().numpy()
            y1 = model(x1).detach().cpu().numpy()
            original = data.get_input_state(step=step)
            xx = df_to_torch(original)
            yy = model(xx).detach().cpu().numpy()
            ytrue = get_ytrue(name,outcomes)
            if "DLT" in name:
                y0 = y0.argmax(axis=1).reshape(-1,1)
                y1 = y1.argmax(axis=1).reshape(-1,1)
                yy = yy.argmax(axis=1).reshape(-1,1)
                change = y0 - y1
                decision_change = (y0 != y1).astype(int)
            elif y0.shape[1] == 1:
                change = y1 - y0
                decision_change = np.abs((y0 > .5).astype(int) - (y1 > .5).astype(int))
            else:
                index = np.unravel_index(np.argmax(yy, axis=1), yy.shape)
                change = (y0[index] - y1[index]).reshape(-1,1)
                decision_change =  (y0.argmax(axis=1).reshape(-1,1) != y1.argmax(axis=1).reshape(-1,1)).astype(int)
                yy = yy.argmax(axis=1).reshape(-1,1)
                y1 = y1.argmax(axis=1).reshape(-1,1)
                y0 = y0.argmax(axis=1).reshape(-1,1)
            outcome = name.replace('_state','')
            for ii,pid in enumerate(ids):
                oo = ytrue.loc[pid]
                onew = y0[ii][0]
                original_decision = df.loc[pid,decision]
                if original_decision > 0:
                    onew = y0[ii][0]
                oname = Const.name_dict.get(name)
                if oname is not None:
                    onew = oname[onew]
                entry = {'id': pid, 'decision': decision,'outcome': outcome,'original_choice': original_decision, 'original_result': oo, 'alt_result': onew, 'change': change[ii][0], 'decision_change': decision_change[ii][0]}
                results.append(entry)
    return pd.DataFrame(results)

test = check_impact_of_decisions(all_models,dataset)
test

In [None]:
data.get_data()['SD Primary 2'].sum()

In [None]:
(test[test.outcome == 'pd2'].original_result == 'SD Primary 2').sum()

In [None]:
check_impact_of_decisions(all_models,dataset).to_csv('../data/decision_impacts.csv')