In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import pandas as pd
import torch
import re
from sklearn.metrics import balanced_accuracy_score, roc_auc_score,accuracy_score
from Constants import *
from Preprocessing import *
from Models import *
pd.set_option('display.max_rows', 200)



  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model_dir = Const.data_dir + '/models/'

tuned_transition_models = [
    'final_transition1_model_state1_input63_dims500,500_dropout0.25,0.5.pt',
    'final_transition2_model_state2_input85_dims100_dropout0.25,0.pt',
    'final_outcome_model_state1_input83_dims1000_dropout0,0.pt'
]
tuned_transition_models = [model_dir + f for f in tuned_transition_models]
Const.tuned_transition_models = tuned_transition_models
Const.tuned_decision_model = model_dir +  'final_decision_model_statedecisions_input132_dims100,100_dropout0.1,0.7.pt'

In [5]:
def get_dt_ids():
    df = load_digital_twin()
    return df.id.values

def get_tt_split(ids=None,use_default_split=True,use_bagging_split=False,resample_training=False):
        if ids is None:
            ids = get_dt_ids()
        #pre-made, stratified by decision and outcome 72:28
        if use_default_split:
            train_ids = Const.stratified_train_ids[:]
            test_ids = Const.stratified_test_ids[:]
        elif use_bagging_split:
            train_ids = np.random.choice(ids,len(ids),replace=True)
            test_ids = [i for i in ids if i not in train_ids]
        else:
            test_ids = ids[0: int(len(ids)*(1-split))]
            train_ids = [i for i in ids if i not in test_ids]

        if resample_training:
            train_ids = np.random.choice(train_ids,len(train_ids),replace=True)
            test_ids = [i for i in ids if i not in train_ids]
        return train_ids,test_ids
    
def df_to_torch(df,ttype  = torch.FloatTensor):
    values = df.values.astype(float)
    values = torch.from_numpy(values)
    return values.type(ttype)

In [6]:
def nllloss(ytrue,ypred):
    #nll loss with argmax added in
    loss = torch.nn.NLLLoss()
    return loss(ypred,ytrue.argmax(axis=1))

def state_loss(ytrue,ypred,weights=[1,1,1,1]):
    pd_loss = nllloss(ytrue[0],ypred[0])*weights[0]
    nd_loss = nllloss(ytrue[1],ypred[1])*weights[1]
    mod_loss = nllloss(ytrue[2],ypred[2])*weights[2]
    loss = pd_loss + nd_loss + mod_loss
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
#     nloss = torch.nn.NLLLoss()
    bce = torch.nn.BCELoss()
    for i in range(ndlt):
        dlt_loss = bce(dlt_pred[:,i].view(-1),dlt_true[:,i].view(-1))
        loss += dlt_loss*weights[3]/ndlt
    return loss

def outcome_loss(ytrue,ypred,weights=[1,1,1]):
    loss = 0
    nloss = torch.nn.BCELoss()
    for i in range(len(weights)):
        iloss = nloss(ypred[:,i],ytrue[i])*weights[i]
        loss += iloss
    return loss

def mc_metrics(yt,yp,numpy=False,is_dlt=False):
    if not numpy:
        yt = yt .cpu().detach().numpy()
        yp = yp.cpu().detach().numpy()
    #dlt prediction (binary)
    if is_dlt:
        acc = accuracy_score(yt,yp>.5)
        if yt.sum() > 1:
            auc = roc_auc_score(yt,yp)
        else:
            auc=-1
        error = np.mean((yt-yp)**2)
        return {'accuracy': acc, 'mse': error, 'auc': auc}
    #this is a catch for when I se the dlt prediction format (encoded integer ordinal, predict as a categorical and take the argmax)
    elif yt.ndim > 1:
        try:
            bacc = balanced_accuracy_score(yt.argmax(axis=1),yp.argmax(axis=1))
        except:
            bacc = -1
        try:
            roc_micro = roc_auc_score(yt,yp,average='micro')
        except:
            roc_micro=-1
        try:
            roc_macro = roc_auc_score(yt,yp,average='macro')
        except:
            roc_macro = -1
        return {'accuracy': bacc, 'roc_micro': roc_micro,'roc_macro': roc_macro}
    #outcomes (binary)
    else:
        if yp.ndim > 1:
            yp = yp.argmax(axis=1)
        try:
            bacc = accuracy_score(yt,yp)
        except:
            bacc = -1
        try:
            roc = roc_auc_score(yt,yp)
        except:
            roc = -1
        error = np.mean((yt-yp)**2)
        return {'accuracy': bacc, 'mse': error, 'auc': roc}

def state_metrics(ytrue,ypred,numpy=False):
    pd_metrics = mc_metrics(ytrue[0],ypred[0],numpy=numpy)
    nd_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    mod_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    
    dlt_metrics = []
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
    nloss = torch.nn.NLLLoss()
    for i in range(ndlt):
        dm = mc_metrics(dlt_true[:,i],dlt_pred[:,i].view(-1),is_dlt=True)
        dlt_metrics.append(dm)
    dlt_acc =[d['accuracy'] for d in dlt_metrics]
    dlt_error = [d['mse'] for d in dlt_metrics]
    dlt_auc = [d['auc'] for d in dlt_metrics]
    return {'pd': pd_metrics,'nd': nd_metrics,'mod': mod_metrics,'dlts': {'accuracy': dlt_acc,'accuracy_mean': np.mean(dlt_acc),'auc': dlt_auc,'auc_mean': np.mean(dlt_auc)}}
    
def outcome_metrics(ytrue,ypred,numpy=False):
    res = {}
    for i, outcome in enumerate(Const.outcomes):
        metrics = mc_metrics(ytrue[i],ypred[:,i])
        res[outcome] = metrics
    return res


In [7]:
from sklearn.ensemble import RandomForestClassifier

def train_state_rf(model_args={}):
    ids = get_dt_ids()
    
    dataset = DTDataset()

    train_ids = ids[0:int(len(ids)*.7)]
    test_ids = ids[int(len(ids)*.7):]
    
    #most things are multiclass, dlts are several ordinal and outcomes are multiple binary
    xtrain1 = dataset.get_state('baseline',ids=train_ids)
    xtest1 = dataset.get_state('baseline',ids=test_ids)
    
    xtrain2 = dataset.get_input_state(step=2,ids=train_ids)
    xtest2 = dataset.get_input_state(step=2,ids=test_ids)
    
    xtrain3 = dataset.get_input_state(step=3,ids=train_ids)
    xtest3 = dataset.get_input_state(step=3,ids=test_ids)
    
    [pd1_train,nd1_train, mod_train,dlts1_train] = dataset.get_intermediate_outcomes(ids=train_ids)
    [pd2_train,nd2_train, cc_train,dlts2_train] = dataset.get_intermediate_outcomes(step=2,ids=train_ids)
    [pd1_test,nd1_test, mod_test,dlts1_test] = dataset.get_intermediate_outcomes(ids=test_ids)
    [pd2_test,nd2_test, cc_test,dlts2_test] = dataset.get_intermediate_outcomes(step=2,ids=test_ids)
    outcomes_train = dataset.get_state('outcomes',ids=train_ids)
    outcomes_test = dataset.get_state('outcomes',ids=test_ids)
    

    def train_multiclass_rf(xtrain,xtest,ytrain,ytest):
        model = RandomForestClassifier(class_weight='balanced',**model_args).fit(xtrain,ytrain)
        ypred = model.predict(xtest)
        metrics = mc_metrics(ytest.values,ypred,numpy=True)
        return model, metrics
    
    all_metrics = {}
    pd1_model, all_metrics['pd1'] = train_multiclass_rf(xtrain1,xtest1,pd1_train,pd1_test)
    nd1_model, all_metrics['nd1']  = train_multiclass_rf(xtrain1,xtest1,nd1_train,nd1_test)
    mod_model, all_metrics['mod']  = train_multiclass_rf(xtrain1,xtest1,mod_train,mod_test)
    
    return all_metrics

train_state_rf({'max_depth': 5,'n_estimators': 100})

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


{'pd1': {'accuracy': 0.34712441314553993,
  'roc_micro': 0.5057627557627558,
  'roc_macro': 0.5069683908045977},
 'nd1': {'accuracy': 0.38253241800152554,
  'roc_micro': 0.5694120694120695,
  'roc_macro': 0.5231884057971015},
 'mod': {'accuracy': 0.16666666666666666,
  'roc_micro': 0.5093167701863354,
  'roc_macro': -1}}

In [8]:
 DTDataset().get_state('baseline').T

  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


id,3,5,6,7,8,9,10,11,13,14,...,10196,10197,10198,10199,10200,10201,10202,10203,10204,10205
1A,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1A1B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1A6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1B2A,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1B3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2A,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,2.0,...,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0
2A2B,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,2.0,...,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0
2A3,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,2.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
2B,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,2.0,...,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0


In [9]:
def train_state(model=None,
                model_args={},
                state=1,
                split=.7,
                lr=.0001,
                epochs=1000,
                patience=10,
                use_attention=True,
                weights=[1,1,1,10],
                save_path='../data/models/',
                use_default_split=True,
                use_bagging_split=False,
                resample_training=False,#use bootstraping on training data after splitting
                n_validation_trainsteps=2,
                verbose=True,
                file_suffix=''):
    
    ids = get_dt_ids()
    
    train_ids, test_ids = get_tt_split(use_default_split=use_default_split,use_bagging_split=use_bagging_split,resample_training=resample_training)
    
    dataset = DTDataset()
    
    xtrain = dataset.get_input_state(step=state,ids=train_ids)
    xtest = dataset.get_input_state(step=state,ids=test_ids)
    ytrain = dataset.get_intermediate_outcomes(step=state,ids=train_ids)
    ytest = dataset.get_intermediate_outcomes(step=state,ids=test_ids)
    

    if state < 3:
        if model is None:
            if use_attention:
                model = OutcomeAttentionSimulator(xtrain.shape[1],state=state,**model_args)
            else:
                model = OutcomeSimulator(xtrain.shape[1],state=state,**model_args)
        lfunc = state_loss
    else:
        if model is None:
            if use_attention:
                model = EndpointAttentionSimulator(xtrain.shape[1],**model_args)
            else:
                model = EndpointSimulator(xtrain.shape[1],**model_args)
        weights = weights[:3]
        lfunc = outcome_loss
        
    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    save_file = save_path + 'model_' + model.identifier + '_split' + str(split) + '_resample' + str(resample_training) +  '_hash' + hashcode + file_suffix + '.tar'
    xtrain = df_to_torch(xtrain)
    xtest = df_to_torch(xtest)
    ytrain = [df_to_torch(t) for t in ytrain]
    ytest= [df_to_torch(t) for t in ytest]
    
    model.fit_normalizer(xtrain)
#     normalize = lambda x: (x - xtrain.mean(axis=0)+.01)/(xtrain.std(axis=0)+.01)
#     unnormalize = lambda x: (x * (xtrain.std(axis=0) +.01)) + xtrain.mean(axis=0) - .01
    
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    best_val_loss = 1000000000000000000000000000
    best_loss_metrics = {}
    last_epoch = False
    for epoch in range(epochs):
        
        model.train(True)
        optimizer.zero_grad()
        
        xtrain_sample = xtrain#[torch.randint(len(xtrain),(len(xtrain),) )]
        ypred = model(xtrain_sample)
        loss = lfunc(ytrain,ypred,weights=weights)

        loss.backward()
        optimizer.step()
        if verbose:
            print('epoch',epoch,'train loss',loss.item())
        
        model.eval()
        yval = model(xtest)
        val_loss = lfunc(ytest,yval,weights=weights)
        if state < 3:
            val_metrics = state_metrics(ytest,yval)
        else:
            val_metrics = outcome_metrics(ytest,yval)
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            best_loss_metrics = val_metrics
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        if verbose:
            print('val loss',val_loss.item())
            print('______________')
        if steps_since_improvement > patience:
            break
    print('best loss',best_val_loss,best_loss_metrics)
    model.load_state_dict(torch.load(save_file))
    
    #train one step on validation data
    for i in range(n_validation_trainsteps):
        model.train()
        yval = model(xtest)
        val_loss = lfunc(ytest,yval,weights=weights)
        val_loss.backward()
        optimizer.step()
        torch.save(model.state_dict(),save_file)
    
    model.eval()
    return model,  best_val_loss, best_loss_metrics

In [10]:



def gridsearch_transition_models(state=1):
#     model_arglist = [
#         {
#             'hidden_layers': [100],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [500],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [1000],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [100],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [500],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [1000],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [100,100],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [500,500],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [1000,1000],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [100,100],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [500,500],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [1000,1000],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [500,500,500],
#             'attention_heads': [5,5,5]
#         }
#     ]
    model_arglist = [
        {
            'hidden_layers': [100],
            'attention_heads': [10],
        },
        {
            'hidden_layers': [500],
            'attention_heads': [10],
        },
        {
            'hidden_layers': [1000],
            'attention_heads': [10],
        },
        {
            'hidden_layers': [100],
            'attention_heads': [5],
        },
        {
            'hidden_layers': [500],
            'attention_heads': [5],
        },
        {
            'hidden_layers': [1000],
            'attention_heads': [5],
        },
        {
            'hidden_layers': [100,100],
            'attention_heads': [10,10],
        },
        {
            'hidden_layers': [500,500],
            'attention_heads': [10,10],
        },
        {
            'hidden_layers': [100,100],
            'attention_heads': [5,5],
        },
        {
            'hidden_layers': [500,500],
            'attention_heads': [5,5],
        },
    ]
    best_loss = 100000000000
    best_metrics = {}
    best_args = {}
    best_model = None
    k = 0
    for margs in model_arglist:
        args = {k:v for k,v in margs.items()}
        for embed_size in [200,400,800]:
            args['embed_size'] = embed_size
            for dropout in [.9,.95]:
                args['dropout'] = dropout
                for input_dropout in [.25,.35,.5]:
                    args['input_dropout'] = input_dropout
                    model,m_loss,m_metrics = train_state(model_args=args,state=state,verbose=False)
                    print('done',k,m_loss)
                    k+=1
                    if m_loss < best_loss:
                        best_loss = m_loss
                        best_metrics  = m_metrics
                        best_model = model
                        best_args = args
                        print('_++++++++++New Best++++____')
                        print(best_loss)
                        print(best_metrics)
                        print(best_args)
                        print('___________')
                        print('++++++++')
                        print()
    print('_________')
    print('+++++++++++')
    print('best stuff',best_loss)
    print(best_metrics)
    print(best_args)
    return best_model
# model = gridsearch_transition_models(1)

In [11]:
# model2 = gridsearch_transition_models(2)

In [12]:
# model3 = gridsearch_transition_models(3)

In [14]:
Const.tuned_transition_models

['../data/models/final_transition1_model_state1_input63_dims500,500_dropout0.25,0.5.pt',
 '../data/models/final_transition2_model_state2_input85_dims100_dropout0.25,0.pt',
 '../data/models/final_outcome_model_state1_input83_dims1000_dropout0,0.pt']

In [16]:
def load_trained_models():
    files = Const.tuned_transition_models
    decision_file = Const.tuned_decision_model
    [model1,model2,model3] = [torch.load(file) for file in files]
    decision_model = torch.load(decision_file)
    return decision_model, model1,model2,model3
_, model1, model2, model3 =load_trained_models()

In [62]:

def shuffle_col(v,col=None):
    if col is None:
        col = np.random.choice([i for i in range(v.shape[1])])
    idx = torch.randperm(v.shape[0])
    vv = torch.clone(v)
    vv[:,col] = vv[idx,col]
    return vv
    
def train_decision_model(
    tmodel1,
    tmodel2,
    tmodel3,
    use_default_split=True,
    use_bagging_split=False,
    lr=.0001,
    epochs=10000,
    patience=100,
    weights=[3,1,1], #realtive weight of survival, feeding tube, and aspiration
    imitation_weight=1,
    shufflecol_chance = 0.1,
    reward_weight=10,
    split=.7,
    resample_training=False,
    save_path='../data/models/',
    file_suffix='',
    use_attention=False,
    verbose=True,
    threshold_decisions=True,
    **model_kwargs,
):
    
    tmodel1.eval()
    tmodel2.eval()
    tmodel3.eval()
    
    train_ids, test_ids = get_tt_split(use_default_split=use_default_split,use_bagging_split=use_bagging_split,resample_training=resample_training)
    
    dataset = DTDataset()
    
    data = dataset.processed_df.copy()
    
    def get_dlt(state):
        if state == 2:
            return data[Const.dlt2].copy()
        d = data[Const.dlt1].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_pd(state):
        if state == 2:
            return data[Const.primary_disease_states2].copy()
        d = data[Const.primary_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_nd(state):
        if state == 2:
            return data[Const.nodal_disease_states2].copy()
        d = data[Const.nodal_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_cc(state):
        res = data[Const.ccs].copy()
        if state == 1:
            res.values[:,:] = np.zeros(res.values.shape)
        return res
    
    def get_mod(state):
        res = data[Const.modifications].copy()
        return res
        
    outcomedf = data[Const.outcomes]
    baseline = dataset.get_state('baseline')
    
    def formatdf(d,dids=train_ids):
        d = df_to_torch(d.loc[dids])
        return d
    
    def makegrad(v):
        if not v.requires_grad:
            v.requires_grad=True
        return v
    
    if use_attention:
        model = DecisionAttentionModel(baseline.shape[1],**model_kwargs)
    else:
        model = DecisionModel(baseline.shape[1],**model_kwargs)

    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    
    save_file = save_path + 'model_' + model.identifier +'_hash' + hashcode + file_suffix + '.tar'
    model.fit_normalizer(df_to_torch(baseline.loc[train_ids]))
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)

    def outcome_loss(ypred):
        #convert survival to death
        loss = torch.mul(torch.mean(-1*(ypred[:,0] - 1)),weights[0])
        for i,weight in enumerate(weights[1:]):
            newloss = torch.mean(ypred[:,i])*weight
            loss = torch.add(loss,torch.mul(newloss,weight))
        return loss
    
    mse = torch.nn.MSELoss()
    nllloss = torch.nn.NLLLoss()
    bce = torch.nn.BCELoss()
    
    def compare_decisions(d1,d2,d3,ids):
#         ypred = np.concatenate([dd.cpu().detach().numpy().reshape(-1,1) for dd in [d1,d2,d3]],axis=1)
        ytrue = df_to_torch(outcomedf.loc[ids])
        dloss = bce(d1.view(-1),ytrue[:,0])
        dloss += bce(d2.view(-1),ytrue[:,1])
        dloss += bce(d3,view(-1),ytrue[:,2])
        return dloss
        
    def remove_decisions(df):
        cols = [c for c in df.columns if c not in Const.decisions ]
        ddf = df[cols]
        return ddf
    
    makeinput = lambda step,dids: df_to_torch(remove_decisions(dataset.get_input_state(step=step,ids=dids)))
    threshold = lambda x: torch.gt(x,.5).type(torch.FloatTensor)
    def step(train=True):
        if train:
            model.train(True)
            optimizer.zero_grad()
            ids = train_ids
        else:
            ids = test_ids
            model.eval()
            
            
        ytrain = df_to_torch(outcomedf.loc[ids])
        #imitation losses and decision 1
        xxtrain = [baseline, get_dlt(0),get_dlt(0),get_pd(0),get_nd(0),get_cc(0),get_mod(0)]
        xxtrain = [formatdf(xx,ids) for xx in xxtrain]
        o1 = model(torch.cat(xxtrain,axis=1),position=0)
        decision1_imitation = o1[:,3]
        decision1 = o1[:,0]
        if threshold_decisions:
            decision1 = threshold(decision1)
#         imitation_loss1 = bce(threshold(decision1_imitation),ytrain[:,0])
        imitation_loss1 = bce(decision1_imitation,ytrain[:,0])
        
        x1_imitation = [baseline, get_dlt(1),get_dlt(0),get_pd(1),get_nd(1),get_cc(1),get_mod(1)]
        x1_imitation = [formatdf(xx1,ids) for xx1 in x1_imitation]
        decision2_imitation = model(torch.cat(x1_imitation,axis=1),position=1)[:,4]
        
#         imitation_loss2 =  bce(threshold(decision2_imitation),ytrain[:,1])
        imitation_loss2 =  bce(decision2_imitation,ytrain[:,1])
        
        x2_imitation = [baseline, get_dlt(1),get_dlt(2),get_pd(2),get_nd(2),get_cc(2),get_mod(2)]
        x2_imitation = [formatdf(xx2,ids) for xx2 in x2_imitation]
        decision3_imitation = model(torch.cat(x2_imitation,axis=1),position=2)[:,5]
        
#         imitation_loss3 = bce(threshold(decision3_imitation),ytrain[:,2])
        imitation_loss3 = bce(decision3_imitation,ytrain[:,2])
        
        #reward decisions
        xx1 = makeinput(1,ids)
        xx2 = makeinput(2,ids)
        xx3 = makeinput(3,ids)

        baseline_train_base = formatdf(baseline,ids)
            
        baseline_train = torch.clone(baseline_train_base)
        if train and shufflecol_chance > 0.0001:
            for col in range(baseline_train_base.shape[1]): 
                if np.random.random() < shufflecol_chance:
                    baseline_train = shuffle_col(baseline_train,col)
                    
        
        xi1 = torch.cat([xx1,decision1.view(-1,1)],axis=1)
        [ypd1, ynd1, ymod, ydlt1] = tmodel1(xi1)
        #this outputs log likelihoods (except for dlts) -> convert to probability
        ypd1 = torch.exp(ypd1)
        ynd1 = torch.exp(ynd1)
        ymod = torch.exp(ymod)
        x1 = [baseline_train,ydlt1,formatdf(get_dlt(0),ids),ypd1,ynd1,formatdf(get_cc(1),ids),ymod]
        
        decision2 = model(torch.cat(x1,axis=1),position=1)[:,1] 
        if threshold_decisions:
            decision2 = threshold(decision2)
            
        xi2 = torch.cat([xx2,decision1.view(-1,1),decision2.view(-1,1)],axis=1)
        [ypd2,ynd2,ycc,ydlt2] = tmodel2(xi2)
        ypd2 = torch.exp(ypd2)
        ynd2 = torch.exp(ynd2)
        ycc = torch.exp(ycc)
        x2 = [baseline_train,ydlt1,ydlt2,ypd2,ynd2,ycc,ymod]
            
        decision3 = model(torch.cat(x2,axis=1),position=2)[:,2]
        if threshold_decisions:
            decision3 = threshold(decision3)
            
        
        xi3 = torch.cat([xx3,decision1.view(-1,1),decision2.view(-1,1),decision3.view(-1,1)],axis=1)
        outcomes = tmodel3(xi3)

        reward_loss = outcome_loss(outcomes)
        loss = torch.add(imitation_loss1,imitation_loss2)
        loss = torch.add(loss,imitation_loss3)
        loss = torch.mul(loss,imitation_weight/3)
        loss = torch.add(loss,torch.mul(reward_loss,reward_weight))
        losses = [imitation_loss1+imitation_loss2+imitation_loss3,reward_loss]
        if train:
            loss.backward()
            optimizer.step()
            return losses
        else:
            scores = []
            for i,decision in enumerate([decision1_imitation,decision2_imitation,decision3_imitation]):
                dec = decision.cpu().detach().numpy()
                dec0 = (dec > .5).astype(int)
                out = ytrain[:,i].cpu().detach().numpy()
                acc = accuracy_score(out,dec > .5)
                auc = roc_auc_score(out,dec)
                scores.append({'decision': i,'accuracy': acc,'auc': auc})
            return losses, scores
        
    best_val_loss = torch.tensor(1000000000.0)
    steps_since_improvement = 0
    best_val_score = {}
    for epoch in range(epochs):
        losses = step(True)
        val_losses,val_metrics = step(False)
        vl = val_losses[0] + val_losses[1]
        if verbose:
            print('______epoch',str(epoch),'_____')
            print('train imitation',losses[0].item(),'reward',losses[1].item())
            print('val imitation',val_losses[0].item(),'reward',val_losses[1].item())
            print('val loss',vl.item(),best_val_loss.item())
            print(val_metrics)
        if vl < best_val_loss:
            best_val_loss = vl
            best_val_score = val_metrics
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        if steps_since_improvement > patience:
            break
    print('++++++++++Final+++++++++++')
    print('best',best_val_loss)
    print(best_val_score)
    model.load_state_dict(torch.load(save_file))
    model.eval()
    return model, best_val_score, best_val_loss

from Models import *
args = {
    'hidden_layers': [600], 
    'attention_heads': [3], 
    'embed_size': 210, 
    'dropout': 0.9, 
    'input_dropout': 0.5, 
    'shufflecol_chance': 0.1,
}
decision_model, _, _ = train_decision_model(model1,model2,model3,lr=.0001,use_attention=True,**args)

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


______epoch 0 _____
train imitation 2.213536262512207 reward 1.535596489906311
val imitation 2.0247085094451904 reward 1.5417895317077637
val loss 3.566498041152954 1000000000.0
[{'decision': 0, 'accuracy': 0.3835616438356164, 'auc': 0.3802884615384616}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.5778508771929824}, {'decision': 2, 'accuracy': 0.6301369863013698, 'auc': 0.36250000000000004}]
______epoch 1 _____
train imitation 2.175110101699829 reward 1.5305566787719727
val imitation 1.9839191436767578 reward 1.5427231788635254
val loss 3.526642322540283 3.566498041152954
[{'decision': 0, 'accuracy': 0.5205479452054794, 'auc': 0.3889423076923077}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.5797697368421053}, {'decision': 2, 'accuracy': 0.726027397260274, 'auc': 0.3698717948717949}]
______epoch 2 _____
train imitation 2.055936574935913 reward 1.5291204452514648
val imitation 1.9450087547302246 reward 1.5427231788635254
val loss 3.48773193359375 3.526642322540283

______epoch 21 _____
train imitation 1.6462006568908691 reward 1.5312477350234985
val imitation 1.472330927848816 reward 1.5400276184082031
val loss 3.0123586654663086 3.027879238128662
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.4745192307692308}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6178728070175439}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.5195512820512821}]
______epoch 22 _____
train imitation 1.5814028978347778 reward 1.5333291292190552
val imitation 1.4588595628738403 reward 1.5400276184082031
val loss 2.998887062072754 3.0123586654663086
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.4798076923076924}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6165021929824561}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.5256410256410257}]
______epoch 23 _____
train imitation 1.5970546007156372 reward 1.5326300859451294
val imitation 1.4463844299316406 reward 1.5400745868682861
val loss 2.9864590167999268 2.

______epoch 42 _____
train imitation 1.477730631828308 reward 1.5387372970581055
val imitation 1.3328640460968018 reward 1.5484507083892822
val loss 2.881314754486084 2.8834593296051025
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5432692307692308}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6737938596491229}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6823717948717949}]
______epoch 43 _____
train imitation 1.4722785949707031 reward 1.536476969718933
val imitation 1.330936312675476 reward 1.5487303733825684
val loss 2.879666805267334 2.881314754486084
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.54375}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6751644736842106}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6865384615384615}]
______epoch 44 _____
train imitation 1.494117259979248 reward 1.536814570426941
val imitation 1.3291969299316406 reward 1.5487303733825684
val loss 2.877927303314209 2.879666805267334
[

______epoch 62 _____
train imitation 1.4177806377410889 reward 1.5383515357971191
val imitation 1.3077411651611328 reward 1.559348702430725
val loss 2.8670897483825684 2.8681585788726807
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5701923076923078}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7132675438596491}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7282051282051283}]
______epoch 63 _____
train imitation 1.4150912761688232 reward 1.533603549003601
val imitation 1.3066819906234741 reward 1.559348702430725
val loss 2.866030693054199 2.8670897483825684
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5725961538461538}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7143640350877193}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7288461538461539}]
______epoch 64 _____
train imitation 1.436271071434021 reward 1.5239148139953613
val imitation 1.3056297302246094 reward 1.5598499774932861
val loss 2.8654797077178955 2.86

______epoch 82 _____
train imitation 1.3664249181747437 reward 1.5423691272735596
val imitation 1.2894186973571777 reward 1.5603561401367188
val loss 2.8497748374938965 2.850212335586548
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5903846153846154}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7283442982456141}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7442307692307693}]
______epoch 83 _____
train imitation 1.3608098030090332 reward 1.5380144119262695
val imitation 1.2886989116668701 reward 1.5603406429290771
val loss 2.8490395545959473 2.8497748374938965
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5908653846153846}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7280701754385964}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7439102564102564}]
______epoch 84 _____
train imitation 1.370718240737915 reward 1.5307117700576782
val imitation 1.287951946258545 reward 1.5606660842895508
val loss 2.8486180305480957 2.

______epoch 102 _____
train imitation 1.3170132637023926 reward 1.5426756143569946
val imitation 1.2751433849334717 reward 1.564955234527588
val loss 2.8400986194610596 2.840745210647583
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6096153846153846}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7291666666666665}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7516025641025641}]
______epoch 103 _____
train imitation 1.341646432876587 reward 1.5312227010726929
val imitation 1.2745071649551392 reward 1.564955234527588
val loss 2.8394622802734375 2.8400986194610596
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6100961538461539}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7294407894736843}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7509615384615385}]
______epoch 104 _____
train imitation 1.3743512630462646 reward 1.537083625793457
val imitation 1.2738226652145386 reward 1.5654144287109375
val loss 2.8392372131347656 2

______epoch 123 _____
train imitation 1.3334777355194092 reward 1.5432628393173218
val imitation 1.2644308805465698 reward 1.5621693134307861
val loss 2.8266000747680664 2.8285622596740723
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6168269230769231}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7242324561403508}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7557692307692307}]
______epoch 124 _____
train imitation 1.3114463090896606 reward 1.5366657972335815
val imitation 1.2638649940490723 reward 1.561967372894287
val loss 2.8258323669433594 2.8266000747680664
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6168269230769231}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7239583333333334}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7554487179487179}]
______epoch 125 _____
train imitation 1.3343546390533447 reward 1.5292513370513916
val imitation 1.2633171081542969 reward 1.5633150339126587
val loss 2.82663202285766

______epoch 144 _____
train imitation 1.3293651342391968 reward 1.5443511009216309
val imitation 1.2526657581329346 reward 1.5619724988937378
val loss 2.814638137817383 2.816528797149658
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6235576923076923}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7228618421052633}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7538461538461538}]
______epoch 145 _____
train imitation 1.2795491218566895 reward 1.544779658317566
val imitation 1.2522480487823486 reward 1.5619724988937378
val loss 2.814220428466797 2.814638137817383
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6225961538461539}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7228618421052632}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7532051282051282}]
______epoch 146 _____
train imitation 1.2911376953125 reward 1.5413686037063599
val imitation 1.2518527507781982 reward 1.5617529153823853
val loss 2.813605785369873 2.814

______epoch 164 _____
train imitation 1.2818254232406616 reward 1.543867588043213
val imitation 1.245447039604187 reward 1.5614368915557861
val loss 2.8068838119506836 2.8071978092193604
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6326923076923077}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.7250548245614036}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7576923076923077}]
______epoch 165 _____
train imitation 1.2351776361465454 reward 1.5429327487945557
val imitation 1.2451493740081787 reward 1.5613198280334473
val loss 2.806469202041626 2.8068838119506836
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.633173076923077}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.7250548245614036}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7573717948717948}]
______epoch 166 _____
train imitation 1.2507483959197998 reward 1.5407251119613647
val imitation 1.2448668479919434 reward 1.5613198280334473
val loss 2.8061866760253906 2.

______epoch 184 _____
train imitation 1.231065034866333 reward 1.5418410301208496
val imitation 1.2379446029663086 reward 1.5679761171340942
val loss 2.8059206008911133 2.8031511306762695
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6365384615384615}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7245065789473685}, {'decision': 2, 'accuracy': 0.815068493150685, 'auc': 0.7615384615384615}]
______epoch 185 _____
train imitation 1.2604572772979736 reward 1.55339777469635
val imitation 1.2374578714370728 reward 1.5679761171340942
val loss 2.805433988571167 2.8031511306762695
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6365384615384615}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7239583333333334}, {'decision': 2, 'accuracy': 0.815068493150685, 'auc': 0.7612179487179487}]
______epoch 186 _____
train imitation 1.239485502243042 reward 1.5352487564086914
val imitation 1.2369191646575928 reward 1.5667892694473267
val loss 2.803708553314209 2.

______epoch 205 _____
train imitation 1.2323276996612549 reward 1.534374713897705
val imitation 1.2320419549942017 reward 1.5650955438613892
val loss 2.797137498855591 2.797147035598755
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.645673076923077}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7288925438596491}, {'decision': 2, 'accuracy': 0.815068493150685, 'auc': 0.7685897435897435}]
______epoch 206 _____
train imitation 1.235102653503418 reward 1.5456551313400269
val imitation 1.2317914962768555 reward 1.5650955438613892
val loss 2.796886920928955 2.797137498855591
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6451923076923076}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7275219298245613}, {'decision': 2, 'accuracy': 0.815068493150685, 'auc': 0.7689102564102563}]
______epoch 207 _____
train imitation 1.272595763206482 reward 1.5461760759353638
val imitation 1.2314578294754028 reward 1.5650955438613892
val loss 2.796553373336792 2.796

______epoch 225 _____
train imitation 1.2429537773132324 reward 1.5445237159729004
val imitation 1.2305107116699219 reward 1.5677517652511597
val loss 2.798262596130371 2.7956864833831787
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6528846153846155}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7294407894736842}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7705128205128206}]
______epoch 226 _____
train imitation 1.2670388221740723 reward 1.5381250381469727
val imitation 1.2308478355407715 reward 1.5680749416351318
val loss 2.7989227771759033 2.7956864833831787
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6528846153846155}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7302631578947368}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.771474358974359}]
______epoch 227 _____
train imitation 1.1980640888214111 reward 1.5386240482330322
val imitation 1.2312490940093994 reward 1.5677517652511597
val loss 2.799000740051269

______epoch 245 _____
train imitation 1.184584617614746 reward 1.5395187139511108
val imitation 1.2258107662200928 reward 1.567713737487793
val loss 2.7935245037078857 2.793931007385254
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6557692307692309}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.731907894736842}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7721153846153846}]
______epoch 246 _____
train imitation 1.1738641262054443 reward 1.5466978549957275
val imitation 1.2253062725067139 reward 1.567713737487793
val loss 2.793020009994507 2.7935245037078857
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.65625}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7319078947368421}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7733974358974359}]
______epoch 247 _____
train imitation 1.18866765499115 reward 1.5470819473266602
val imitation 1.224778413772583 reward 1.567713737487793
val loss 2.792492151260376 2.793020009994507
[

______epoch 267 _____
train imitation 1.2253814935684204 reward 1.5498250722885132
val imitation 1.2208244800567627 reward 1.569972038269043
val loss 2.7907965183258057 2.7903926372528076
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.666826923076923}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.7341008771929824}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7753205128205128}]
______epoch 268 _____
train imitation 1.246215581893921 reward 1.544671893119812
val imitation 1.2209993600845337 reward 1.569972038269043
val loss 2.790971279144287 2.7903926372528076
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6673076923076924}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.7338267543859649}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7746794871794872}]
______epoch 269 _____
train imitation 1.2209887504577637 reward 1.5418567657470703
val imitation 1.2211453914642334 reward 1.5703911781311035
val loss 2.791536569595337 2.790

______epoch 287 _____
train imitation 1.1963703632354736 reward 1.5420432090759277
val imitation 1.2257903814315796 reward 1.5715242624282837
val loss 2.7973146438598633 2.7903926372528076
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6692307692307692}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.7365679824561404}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7740384615384616}]
______epoch 288 _____
train imitation 1.2164748907089233 reward 1.5491544008255005
val imitation 1.225372076034546 reward 1.5715242624282837
val loss 2.796896457672119 2.7903926372528076
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6697115384615384}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.7360197368421053}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7733974358974359}]
______epoch 289 _____
train imitation 1.2043367624282837 reward 1.5488295555114746
val imitation 1.224969744682312 reward 1.5715242624282837
val loss 2.7964940071105957 2

______epoch 307 _____
train imitation 1.2289069890975952 reward 1.5439480543136597
val imitation 1.2239362001419067 reward 1.5653820037841797
val loss 2.789318084716797 2.788292407989502
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6783653846153846}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7335526315789473}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7772435897435896}]
______epoch 308 _____
train imitation 1.15598464012146 reward 1.5434300899505615
val imitation 1.2234879732131958 reward 1.5653340816497803
val loss 2.7888221740722656 2.788292407989502
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6788461538461539}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7324561403508772}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7772435897435896}]
______epoch 309 _____
train imitation 1.1636983156204224 reward 1.5453182458877563
val imitation 1.223294973373413 reward 1.5653340816497803
val loss 2.7886290550231934 2.

______epoch 327 _____
train imitation 1.2178138494491577 reward 1.538156509399414
val imitation 1.2129765748977661 reward 1.5615036487579346
val loss 2.7744803428649902 2.7737741470336914
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6836538461538462}, {'decision': 1, 'accuracy': 0.7534246575342466, 'auc': 0.734375}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7810897435897436}]
______epoch 328 _____
train imitation 1.1122136116027832 reward 1.5447901487350464
val imitation 1.213221549987793 reward 1.5616282224655151
val loss 2.7748498916625977 2.7737741470336914
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6831730769230769}, {'decision': 1, 'accuracy': 0.7534246575342466, 'auc': 0.7351973684210527}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7814102564102564}]
______epoch 329 _____
train imitation 1.1591808795928955 reward 1.5469942092895508
val imitation 1.2140004634857178 reward 1.5612144470214844
val loss 2.775214910507202 2.7737741

______epoch 348 _____
train imitation 1.2183024883270264 reward 1.5460684299468994
val imitation 1.2131760120391846 reward 1.5575156211853027
val loss 2.7706916332244873 2.770792007446289
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6855769230769231}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7398574561403509}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7814102564102564}]
______epoch 349 _____
train imitation 1.150640845298767 reward 1.5346527099609375
val imitation 1.2133232355117798 reward 1.5590181350708008
val loss 2.772341251373291 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6855769230769231}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7401315789473684}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7814102564102564}]
______epoch 350 _____
train imitation 1.1936639547348022 reward 1.5453333854675293
val imitation 1.2133090496063232 reward 1.5592257976531982
val loss 2.77253484725952

______epoch 368 _____
train imitation 1.1379607915878296 reward 1.5415208339691162
val imitation 1.2219502925872803 reward 1.5644221305847168
val loss 2.786372423171997 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6870192307692308}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7404057017543859}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7788461538461537}]
______epoch 369 _____
train imitation 1.148348093032837 reward 1.5427634716033936
val imitation 1.2226805686950684 reward 1.5644797086715698
val loss 2.7871603965759277 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6870192307692308}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7415021929824561}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7794871794871794}]
______epoch 370 _____
train imitation 1.1549850702285767 reward 1.5454974174499512
val imitation 1.223056435585022 reward 1.5644797086715698
val loss 2.78753614425659

______epoch 388 _____
train imitation 1.156724214553833 reward 1.5474193096160889
val imitation 1.2211552858352661 reward 1.5617254972457886
val loss 2.7828807830810547 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6870192307692308}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7423245614035088}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7772435897435896}]
______epoch 389 _____
train imitation 1.1393070220947266 reward 1.5455015897750854
val imitation 1.2230719327926636 reward 1.5617254972457886
val loss 2.784797430038452 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6870192307692308}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7420504385964912}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7782051282051282}]
______epoch 390 _____
train imitation 1.126009464263916 reward 1.551119089126587
val imitation 1.224608302116394 reward 1.5617254972457886
val loss 2.7863337993621826 2

______epoch 408 _____
train imitation 1.1642241477966309 reward 1.5406465530395508
val imitation 1.230198621749878 reward 1.568003535270691
val loss 2.7982020378112793 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6822115384615385}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7393092105263157}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7714743589743589}]
______epoch 409 _____
train imitation 1.1457912921905518 reward 1.5463306903839111
val imitation 1.229892611503601 reward 1.568003535270691
val loss 2.797896146774292 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6822115384615385}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7393092105263157}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7714743589743589}]
______epoch 410 _____
train imitation 1.1276837587356567 reward 1.5446178913116455
val imitation 1.2293195724487305 reward 1.568003535270691
val loss 2.797323226928711 2.7

______epoch 428 _____
train imitation 1.1038633584976196 reward 1.545990228652954
val imitation 1.2287447452545166 reward 1.568638563156128
val loss 2.7973833084106445 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6836538461538462}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.740953947368421}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7708333333333334}]
______epoch 429 _____
train imitation 1.1351114511489868 reward 1.5443835258483887
val imitation 1.2272142171859741 reward 1.568638563156128
val loss 2.7958526611328125 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6841346153846154}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7412280701754386}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.771474358974359}]
______epoch 430 _____
train imitation 1.1273794174194336 reward 1.5475362539291382
val imitation 1.2264297008514404 reward 1.568638563156128
val loss 2.7950682640075684 2.

______epoch 448 _____
train imitation 1.1397809982299805 reward 1.544212818145752
val imitation 1.230861783027649 reward 1.568291425704956
val loss 2.7991533279418945 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6855769230769231}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7384868421052632}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.769551282051282}]
______epoch 449 _____
train imitation 1.1461656093597412 reward 1.5468840599060059
val imitation 1.2302359342575073 reward 1.5676478147506714
val loss 2.7978837490081787 2.7706916332244873
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6855769230769231}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7376644736842106}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7698717948717948}]
++++++++++Final+++++++++++
best tensor(2.7707, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6855769230769231}, {'decision': 1, 'accur

In [None]:
def gridsearch_decision_model(m1,m2,m3):
#     model_arglist = [
#         {
#             'hidden_layers': [100],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [500],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [1000],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [100],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [500],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [1000],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [100,100],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [500,500],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [1000,1000],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [100,100],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [500,500],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [1000,1000],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [500,500,500],
#             'attention_heads': [5,5,5]
#         }
#     ]
    model_arglist = [

        {
            'hidden_layers': [100,100],
            'attention_heads': [1,1],
        },
        {
            'hidden_layers': [50,50],
            'attention_heads': [1,1],
        },
        {
            'hidden_layers': [100,100],
            'attention_heads': [2,2],
        },
        {
            'hidden_layers': [300],
            'attention_heads': [3],
        },
        {
            'hidden_layers': [600],
            'attention_heads': [3],
        },
        {
            'hidden_layers': [300,300],
            'attention_heads': [3,3],
        },
        {
            'hidden_layers': [600,600],
            'attention_heads': [3,3],
        },
        {
            'hidden_layers': [1000],
            'attention_heads': [1],
        },
        {
            'hidden_layers': [1000],
            'attention_heads': [5],
        },
    ]
    best_loss = 100000000000
    best_metrics = {}
    best_args = {}
    best_model = None
    k = 0
    for margs in model_arglist:
        args = {k:v for k,v in margs.items()}
        for embed_size in [0,120,210]:
            #embed_size = 0 skips the firt layer that makes the sizes right
            if embed_size == 0 and args['attention_heads'][0] != 1:
                continue
            args['embed_size'] = embed_size
            for dropout in [.5,.9]:
                args['dropout'] = dropout
                for input_dropout in [.5]:
                    args['input_dropout'] = input_dropout
                    for shufflecol_chance in [.1,.5]:
                        args['shufflecol_chance'] = shufflecol_chance
                        model,m_metrics,m_loss = train_decision_model(m1,m2,m3,use_attention=True,verbose=False,**args)
                        print('done',k,m_loss)
                        print('curr best',best_loss)
                        k+=1
                        if m_loss < best_loss:
                            best_loss = m_loss
                            best_metrics  = m_metrics
                            best_model = model
                            best_args = args
                            print('_++++++++++New Best++++____')
                            print(best_loss)
                            print(best_metrics)
                            print(best_args)
                            print('___________')
                            print('++++++++')
                            print()
    print('_________')
    print('+++++++++++')
    print('best stuff',best_loss)
    print(best_metrics)
    print(best_args)
    return best_model

from Models import *
decision_model = gridsearch_decision_model(model1,model2,model3)
decision_model

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


++++++++++Final+++++++++++
best tensor(2.7922, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6634615384615385}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6461074561403509}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.728525641025641}]
done 0 tensor(2.7922, grad_fn=<AddBackward0>)
curr best 100000000000
_++++++++++New Best++++____
tensor(2.7922, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6634615384615385}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6461074561403509}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.728525641025641}]
{'hidden_layers': [100, 100], 'attention_heads': [1, 1], 'embed_size': 0, 'dropout': 0.5, 'input_dropout': 0.5, 'shufflecol_chance': 0.1}
___________
++++++++



  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


++++++++++Final+++++++++++
best tensor(2.8300, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6725961538461539}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6639254385964912}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7371794871794872}]
done 1 tensor(2.8300, grad_fn=<AddBackward0>)
curr best tensor(2.7922, grad_fn=<AddBackward0>)


  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


++++++++++Final+++++++++++
best tensor(2.7624, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6519230769230769}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6787280701754387}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7544871794871795}]
done 2 tensor(2.7624, grad_fn=<AddBackward0>)
curr best tensor(2.7922, grad_fn=<AddBackward0>)
_++++++++++New Best++++____
tensor(2.7624, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6519230769230769}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6787280701754387}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7544871794871795}]
{'hidden_layers': [100, 100], 'attention_heads': [1, 1], 'embed_size': 0, 'dropout': 0.9, 'input_dropout': 0.5, 'shufflecol_chance': 0.1}
___________
++++++++



  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


++++++++++Final+++++++++++
best tensor(2.7685, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.664423076923077}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.6990131578947368}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7730769230769231}]
done 3 tensor(2.7685, grad_fn=<AddBackward0>)
curr best tensor(2.7624, grad_fn=<AddBackward0>)


  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


++++++++++Final+++++++++++
best tensor(2.8643, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6596153846153846}, {'decision': 1, 'accuracy': 0.7534246575342466, 'auc': 0.7025767543859649}, {'decision': 2, 'accuracy': 0.8013698630136986, 'auc': 0.7615384615384615}]
done 4 tensor(2.8643, grad_fn=<AddBackward0>)
curr best tensor(2.7624, grad_fn=<AddBackward0>)


  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


++++++++++Final+++++++++++
best tensor(2.7816, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5918269230769231}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.6866776315789473}, {'decision': 2, 'accuracy': 0.815068493150685, 'auc': 0.7685897435897436}]
done 5 tensor(2.7816, grad_fn=<AddBackward0>)
curr best tensor(2.7624, grad_fn=<AddBackward0>)


  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


++++++++++Final+++++++++++
best tensor(2.8555, grad_fn=<AddBackward0>)
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6475961538461538}, {'decision': 1, 'accuracy': 0.7876712328767124, 'auc': 0.6636513157894737}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7413461538461539}]
done 6 tensor(2.8555, grad_fn=<AddBackward0>)
curr best tensor(2.7624, grad_fn=<AddBackward0>)


  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


In [63]:
torch.save(decision_model,'../data/models/final_decision_model_' + decision_model.identifier + '.pt')
print('../data/models/final_decision_model_' + decision_model.identifier + '.pt')

../data/models/final_decision_model_statedecisions_input132_dims600_dropout0.5,0.9.pt


In [None]:
torch.save(model,'../data/models/final_transition1_model_' + model.identifier + '.pt')
torch.save(model2,'../data/models/final_transition2_model_' + model2.identifier + '.pt')
torch.save(model3,'../data/models/final_outcome_model_' + model3.identifier + '.pt')
print('../data/models/final_transition1_model_' + model.identifier + '.pt')
print('../data/models/final_transition2_model_' + model2.identifier + '.pt')
print('../data/models/final_outcome_model_' + model3.identifier + '.pt')

In [None]:
xatt = []
for att,xxdf in zip(list(attributions),xdf):
    new = pd.DataFrame(att.cpu().detach().numpy(),columns=xxdf.columns,index=xxdf.index)
    xatt.append(new)
attributions = pd.concat(xatt,axis=1)
attributions

In [None]:
attributions.sum().sort_values()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
fig = plt.subplots(1,1,figsize=(100,100))
sns.heatmap(data=attributions.T,ax=fig[1])

In [None]:
def breakup_state_models(state_model):
    #for state 1 and 2
    models = {}
    models['pd'] = lambda x: state_model(x)[0]
    models['nd'] = lambda x: state_model(x)[1]
    models['chemo'] = lambda x: state_model(x)[2]
    for i,dlt in enumerate(Const.dlt1):
        models[dlt] = lambda x: state_model(x)[3][:,i]
    return models

def breakup_outcome_models(omodel):
    models = {}
    for i,name in enumerate(Const.outcomes):
        models[name] = lambda x: omodel(x)[:,i].reshape(-1,1)
    return models

def get_all_models(m1,m2,m3):
    state1_models = breakup_state_models(m1)
    state2_models = breakup_state_models(m2)
    state3_models = breakup_outcome_models(m3)
    all_models = {}
    for i,sm in enumerate([state1_models,state2_models,state3_models]):
        for ii,m in sm.items():
            all_models[ii +  '_state' + str(i+1)] = m
    return all_models

all_models = get_all_models(model,model2,model3)
all_models

In [None]:

def get_ytrue(name,df):
    outcomes=None
    value = None
    if name == 'pd_state1':
        outcomes = df[Const.primary_disease_states]
    elif name == 'pd_state2':
        outcomes = df[Const.primary_disease_states2]
    elif name == 'nd_state1':
        outcomes = df[Const.nodal_disease_states]
    elif name == 'nd_state2':
        outcomes = df[Const.nodal_disease_states2]
    elif name == 'chemo_state1':
        outcomes = df[Const.modifications]
    elif name == 'chemo_state2':
        outcomes = df[Const.ccs]   
    if outcomes is not None:
        value = outcomes.idxmax(axis=1)
    if 'DLT' in name:
        newname = name.replace('_state', ' ').replace('1','').strip()
        value = df[newname]
    if name.replace('_state3','') in Const.outcomes:
        value = df[name.replace('_state3','')]
    if value is None:
        print(name,df.columns)
    return value

def check_impact_of_decisions(model_dict,data):
    results = []
    #todo: this is wrong fix it
    ids = []
    df = data.get_data()
    outcomedict = {step: pd.concat(data.get_intermediate_outcomes(step=step),axis=1) for step in [1,2,3]}
    for decision in Const.decisions:
        for name, model in model_dict.items():
            step = int(name[-1])
            subset0 = dataset.get_input_state(step=step,fixed={decision: 0})
            subset1 = dataset.get_input_state(step=step,fixed={decision: 1})
            outcomes = outcomedict[step]
            ids = subset0.index.values
            x0 = df_to_torch(subset0)
            x1 = df_to_torch(subset1)
            y0 = model(x0).detach().cpu().numpy()
            y1 = model(x1).detach().cpu().numpy()
            original = data.get_input_state(step=step)
            xx = df_to_torch(original)
            yy = model(xx).detach().cpu().numpy()
            ytrue = get_ytrue(name,outcomes)
            if "DLT" in name:
                y0 = y0.argmax(axis=1).reshape(-1,1)
                y1 = y1.argmax(axis=1).reshape(-1,1)
                yy = yy.argmax(axis=1).reshape(-1,1)
                change = y0 - y1
                decision_change = (y0 != y1).astype(int)
            elif y0.shape[1] == 1:
                change = y1 - y0
                decision_change = np.abs((y0 > .5).astype(int) - (y1 > .5).astype(int))
            else:
                index = np.unravel_index(np.argmax(yy, axis=1), yy.shape)
                change = (y0[index] - y1[index]).reshape(-1,1)
                decision_change =  (y0.argmax(axis=1).reshape(-1,1) != y1.argmax(axis=1).reshape(-1,1)).astype(int)
                yy = yy.argmax(axis=1).reshape(-1,1)
                y1 = y1.argmax(axis=1).reshape(-1,1)
                y0 = y0.argmax(axis=1).reshape(-1,1)
            outcome = name.replace('_state','')
            for ii,pid in enumerate(ids):
                oo = ytrue.loc[pid]
                onew = y0[ii][0]
                original_decision = df.loc[pid,decision]
                if original_decision > 0:
                    onew = y0[ii][0]
                oname = Const.name_dict.get(name)
                if oname is not None:
                    onew = oname[onew]
                entry = {'id': pid, 'decision': decision,'outcome': outcome,'original_choice': original_decision, 'original_result': oo, 'alt_result': onew, 'change': change[ii][0], 'decision_change': decision_change[ii][0]}
                results.append(entry)
    return pd.DataFrame(results)

test = check_impact_of_decisions(all_models,dataset)
test

In [None]:
data.get_data()['SD Primary 2'].sum()

In [None]:
(test[test.outcome == 'pd2'].original_result == 'SD Primary 2').sum()

In [None]:
check_impact_of_decisions(all_models,dataset).to_csv('../data/decision_impacts.csv')