In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import pandas as pd
import torch
import re
from sklearn.metrics import balanced_accuracy_score, roc_auc_score,accuracy_score
from Constants import *
from Preprocessing import *
from Models import *
pd.set_option('display.max_rows', 200)



  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model_dir = Const.data_dir + '/models/'

tuned_transition_models = [
    'final_transition1_model_state1_input63_dims500,500_dropout0.25,0.5.pt',
    'final_transition2_model_state2_input85_dims100_dropout0.25,0.pt',
    'final_outcome_model_state1_input83_dims1000_dropout0,0.pt'
]
tuned_transition_models = [model_dir + f for f in tuned_transition_models]
Const.tuned_transition_models = tuned_transition_models
Const.tuned_decision_model = model_dir +  'final_decision_model_statedecisions_input132_dims100,100_dropout0.1,0.7.pt'

In [5]:
def get_dt_ids():
    df = load_digital_twin()
    return df.id.values

def get_tt_split(ids=None,use_default_split=True,use_bagging_split=False,resample_training=False):
        if ids is None:
            ids = get_dt_ids()
        #pre-made, stratified by decision and outcome 72:28
        if use_default_split:
            train_ids = Const.stratified_train_ids[:]
            test_ids = Const.stratified_test_ids[:]
        elif use_bagging_split:
            train_ids = np.random.choice(ids,len(ids),replace=True)
            test_ids = [i for i in ids if i not in train_ids]
        else:
            test_ids = ids[0: int(len(ids)*(1-split))]
            train_ids = [i for i in ids if i not in test_ids]

        if resample_training:
            train_ids = np.random.choice(train_ids,len(train_ids),replace=True)
            test_ids = [i for i in ids if i not in train_ids]
        return train_ids,test_ids
    
def df_to_torch(df,ttype  = torch.FloatTensor):
    values = df.values.astype(float)
    values = torch.from_numpy(values)
    return values.type(ttype)

In [6]:
def nllloss(ytrue,ypred):
    #nll loss with argmax added in
    loss = torch.nn.NLLLoss()
    return loss(ypred,ytrue.argmax(axis=1))

def state_loss(ytrue,ypred,weights=[1,1,1,1]):
    pd_loss = nllloss(ytrue[0],ypred[0])*weights[0]
    nd_loss = nllloss(ytrue[1],ypred[1])*weights[1]
    mod_loss = nllloss(ytrue[2],ypred[2])*weights[2]
    loss = pd_loss + nd_loss + mod_loss
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
#     nloss = torch.nn.NLLLoss()
    bce = torch.nn.BCELoss()
    for i in range(ndlt):
        dlt_loss = bce(dlt_pred[:,i].view(-1),dlt_true[:,i].view(-1))
        loss += dlt_loss*weights[3]/ndlt
    return loss

def outcome_loss(ytrue,ypred,weights=[1,1,1]):
    loss = 0
    nloss = torch.nn.BCELoss()
    for i in range(len(weights)):
        iloss = nloss(ypred[:,i],ytrue[i])*weights[i]
        loss += iloss
    return loss

def mc_metrics(yt,yp,numpy=False,is_dlt=False):
    if not numpy:
        yt = yt .cpu().detach().numpy()
        yp = yp.cpu().detach().numpy()
    #dlt prediction (binary)
    if is_dlt:
        acc = accuracy_score(yt,yp>.5)
        if yt.sum() > 1:
            auc = roc_auc_score(yt,yp)
        else:
            auc=-1
        error = np.mean((yt-yp)**2)
        return {'accuracy': acc, 'mse': error, 'auc': auc}
    #this is a catch for when I se the dlt prediction format (encoded integer ordinal, predict as a categorical and take the argmax)
    elif yt.ndim > 1:
        try:
            bacc = balanced_accuracy_score(yt.argmax(axis=1),yp.argmax(axis=1))
        except:
            bacc = -1
        try:
            roc_micro = roc_auc_score(yt,yp,average='micro')
        except:
            roc_micro=-1
        try:
            roc_macro = roc_auc_score(yt,yp,average='macro')
        except:
            roc_macro = -1
        return {'accuracy': bacc, 'roc_micro': roc_micro,'roc_macro': roc_macro}
    #outcomes (binary)
    else:
        if yp.ndim > 1:
            yp = yp.argmax(axis=1)
        try:
            bacc = accuracy_score(yt,yp)
        except:
            bacc = -1
        try:
            roc = roc_auc_score(yt,yp)
        except:
            roc = -1
        error = np.mean((yt-yp)**2)
        return {'accuracy': bacc, 'mse': error, 'auc': roc}

def state_metrics(ytrue,ypred,numpy=False):
    pd_metrics = mc_metrics(ytrue[0],ypred[0],numpy=numpy)
    nd_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    mod_metrics = mc_metrics(ytrue[1],ypred[1],numpy=numpy)
    
    dlt_metrics = []
    dlt_true = ytrue[3]
    dlt_pred = ypred[3]
    ndlt = dlt_true.shape[1]
    nloss = torch.nn.NLLLoss()
    for i in range(ndlt):
        dm = mc_metrics(dlt_true[:,i],dlt_pred[:,i].view(-1),is_dlt=True)
        dlt_metrics.append(dm)
    dlt_acc =[d['accuracy'] for d in dlt_metrics]
    dlt_error = [d['mse'] for d in dlt_metrics]
    dlt_auc = [d['auc'] for d in dlt_metrics]
    return {'pd': pd_metrics,'nd': nd_metrics,'mod': mod_metrics,'dlts': {'accuracy': dlt_acc,'accuracy_mean': np.mean(dlt_acc),'auc': dlt_auc,'auc_mean': np.mean(dlt_auc)}}
    
def outcome_metrics(ytrue,ypred,numpy=False):
    res = {}
    for i, outcome in enumerate(Const.outcomes):
        metrics = mc_metrics(ytrue[i],ypred[:,i])
        res[outcome] = metrics
    return res


In [7]:
from sklearn.ensemble import RandomForestClassifier

def train_state_rf(model_args={}):
    ids = get_dt_ids()
    
    dataset = DTDataset()

    train_ids = ids[0:int(len(ids)*.7)]
    test_ids = ids[int(len(ids)*.7):]
    
    #most things are multiclass, dlts are several ordinal and outcomes are multiple binary
    xtrain1 = dataset.get_state('baseline',ids=train_ids)
    xtest1 = dataset.get_state('baseline',ids=test_ids)
    
    xtrain2 = dataset.get_input_state(step=2,ids=train_ids)
    xtest2 = dataset.get_input_state(step=2,ids=test_ids)
    
    xtrain3 = dataset.get_input_state(step=3,ids=train_ids)
    xtest3 = dataset.get_input_state(step=3,ids=test_ids)
    
    [pd1_train,nd1_train, mod_train,dlts1_train] = dataset.get_intermediate_outcomes(ids=train_ids)
    [pd2_train,nd2_train, cc_train,dlts2_train] = dataset.get_intermediate_outcomes(step=2,ids=train_ids)
    [pd1_test,nd1_test, mod_test,dlts1_test] = dataset.get_intermediate_outcomes(ids=test_ids)
    [pd2_test,nd2_test, cc_test,dlts2_test] = dataset.get_intermediate_outcomes(step=2,ids=test_ids)
    outcomes_train = dataset.get_state('outcomes',ids=train_ids)
    outcomes_test = dataset.get_state('outcomes',ids=test_ids)
    

    def train_multiclass_rf(xtrain,xtest,ytrain,ytest):
        model = RandomForestClassifier(class_weight='balanced',**model_args).fit(xtrain,ytrain)
        ypred = model.predict(xtest)
        metrics = mc_metrics(ytest.values,ypred,numpy=True)
        return model, metrics
    
    all_metrics = {}
    pd1_model, all_metrics['pd1'] = train_multiclass_rf(xtrain1,xtest1,pd1_train,pd1_test)
    nd1_model, all_metrics['nd1']  = train_multiclass_rf(xtrain1,xtest1,nd1_train,nd1_test)
    mod_model, all_metrics['mod']  = train_multiclass_rf(xtrain1,xtest1,mod_train,mod_test)
    
    return all_metrics

train_state_rf({'max_depth': 5,'n_estimators': 100})

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


{'pd1': {'accuracy': 0.34712441314553993,
  'roc_micro': 0.5057627557627558,
  'roc_macro': 0.5069683908045977},
 'nd1': {'accuracy': 0.38253241800152554,
  'roc_micro': 0.5694120694120695,
  'roc_macro': 0.5231884057971015},
 'mod': {'accuracy': 0.16666666666666666,
  'roc_micro': 0.5093167701863354,
  'roc_macro': -1}}

In [8]:
 DTDataset().get_state('baseline').T

  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


id,3,5,6,7,8,9,10,11,13,14,...,10196,10197,10198,10199,10200,10201,10202,10203,10204,10205
1A,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1A1B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1A6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1B2A,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1B3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2A,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,2.0,...,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0
2A2B,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,2.0,...,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0
2A3,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,2.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
2B,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,2.0,...,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0


In [9]:
def train_state(model=None,
                model_args={},
                state=1,
                split=.7,
                lr=.0001,
                epochs=1000,
                patience=10,
                use_attention=True,
                weights=[1,1,1,10],
                save_path='../data/models/',
                use_default_split=True,
                use_bagging_split=False,
                resample_training=False,#use bootstraping on training data after splitting
                n_validation_trainsteps=2,
                verbose=True,
                file_suffix=''):
    
    ids = get_dt_ids()
    
    train_ids, test_ids = get_tt_split(use_default_split=use_default_split,use_bagging_split=use_bagging_split,resample_training=resample_training)
    
    dataset = DTDataset()
    
    xtrain = dataset.get_input_state(step=state,ids=train_ids)
    xtest = dataset.get_input_state(step=state,ids=test_ids)
    ytrain = dataset.get_intermediate_outcomes(step=state,ids=train_ids)
    ytest = dataset.get_intermediate_outcomes(step=state,ids=test_ids)
    

    if state < 3:
        if model is None:
            if use_attention:
                model = OutcomeAttentionSimulator(xtrain.shape[1],state=state,**model_args)
            else:
                model = OutcomeSimulator(xtrain.shape[1],state=state,**model_args)
        lfunc = state_loss
    else:
        if model is None:
            if use_attention:
                model = EndpointAttentionSimulator(xtrain.shape[1],**model_args)
            else:
                model = EndpointSimulator(xtrain.shape[1],**model_args)
        weights = weights[:3]
        lfunc = outcome_loss
        
    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    save_file = save_path + 'model_' + model.identifier + '_split' + str(split) + '_resample' + str(resample_training) +  '_hash' + hashcode + file_suffix + '.tar'
    xtrain = df_to_torch(xtrain)
    xtest = df_to_torch(xtest)
    ytrain = [df_to_torch(t) for t in ytrain]
    ytest= [df_to_torch(t) for t in ytest]
    
    model.fit_normalizer(xtrain)
#     normalize = lambda x: (x - xtrain.mean(axis=0)+.01)/(xtrain.std(axis=0)+.01)
#     unnormalize = lambda x: (x * (xtrain.std(axis=0) +.01)) + xtrain.mean(axis=0) - .01
    
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    best_val_loss = 1000000000000000000000000000
    best_loss_metrics = {}
    last_epoch = False
    for epoch in range(epochs):
        
        model.train(True)
        optimizer.zero_grad()
        
        xtrain_sample = xtrain#[torch.randint(len(xtrain),(len(xtrain),) )]
        ypred = model(xtrain_sample)
        loss = lfunc(ytrain,ypred,weights=weights)

        loss.backward()
        optimizer.step()
        if verbose:
            print('epoch',epoch,'train loss',loss.item())
        
        model.eval()
        yval = model(xtest)
        val_loss = lfunc(ytest,yval,weights=weights)
        if state < 3:
            val_metrics = state_metrics(ytest,yval)
        else:
            val_metrics = outcome_metrics(ytest,yval)
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            best_loss_metrics = val_metrics
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        if verbose:
            print('val loss',val_loss.item())
            print('______________')
        if steps_since_improvement > patience:
            break
    print('best loss',best_val_loss,best_loss_metrics)
    model.load_state_dict(torch.load(save_file))
    
    #train one step on validation data
    for i in range(n_validation_trainsteps):
        model.train()
        yval = model(xtest)
        val_loss = lfunc(ytest,yval,weights=weights)
        val_loss.backward()
        optimizer.step()
        torch.save(model.state_dict(),save_file)
    
    model.eval()
    return model,  best_val_loss, best_loss_metrics

In [10]:



def gridsearch_transition_models(state=1):
#     model_arglist = [
#         {
#             'hidden_layers': [100],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [500],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [1000],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [100],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [500],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [1000],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [100,100],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [500,500],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [1000,1000],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [100,100],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [500,500],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [1000,1000],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [500,500,500],
#             'attention_heads': [5,5,5]
#         }
#     ]
    model_arglist = [
        {
            'hidden_layers': [100],
            'attention_heads': [10],
        },
        {
            'hidden_layers': [500],
            'attention_heads': [10],
        },
        {
            'hidden_layers': [1000],
            'attention_heads': [10],
        },
        {
            'hidden_layers': [100],
            'attention_heads': [5],
        },
        {
            'hidden_layers': [500],
            'attention_heads': [5],
        },
        {
            'hidden_layers': [1000],
            'attention_heads': [5],
        },
        {
            'hidden_layers': [100,100],
            'attention_heads': [10,10],
        },
        {
            'hidden_layers': [500,500],
            'attention_heads': [10,10],
        },
        {
            'hidden_layers': [100,100],
            'attention_heads': [5,5],
        },
        {
            'hidden_layers': [500,500],
            'attention_heads': [5,5],
        },
    ]
    best_loss = 100000000000
    best_metrics = {}
    best_args = {}
    best_model = None
    k = 0
    for margs in model_arglist:
        args = {k:v for k,v in margs.items()}
        for embed_size in [200,400,800]:
            args['embed_size'] = embed_size
            for dropout in [.9,.95]:
                args['dropout'] = dropout
                for input_dropout in [.25,.35,.5]:
                    args['input_dropout'] = input_dropout
                    model,m_loss,m_metrics = train_state(model_args=args,state=state,verbose=False)
                    print('done',k,m_loss)
                    k+=1
                    if m_loss < best_loss:
                        best_loss = m_loss
                        best_metrics  = m_metrics
                        best_model = model
                        best_args = args
                        print('_++++++++++New Best++++____')
                        print(best_loss)
                        print(best_metrics)
                        print(best_args)
                        print('___________')
                        print('++++++++')
                        print()
    print('_________')
    print('+++++++++++')
    print('best stuff',best_loss)
    print(best_metrics)
    print(best_args)
    return best_model
# model = gridsearch_transition_models(1)

In [11]:
# model2 = gridsearch_transition_models(2)

In [12]:
# model3 = gridsearch_transition_models(3)

In [14]:
Const.tuned_transition_models

['../data/models/final_transition1_model_state1_input63_dims500,500_dropout0.25,0.5.pt',
 '../data/models/final_transition2_model_state2_input85_dims100_dropout0.25,0.pt',
 '../data/models/final_outcome_model_state1_input83_dims1000_dropout0,0.pt']

In [16]:
def load_trained_models():
    files = Const.tuned_transition_models
    decision_file = Const.tuned_decision_model
    [model1,model2,model3] = [torch.load(file) for file in files]
    decision_model = torch.load(decision_file)
    return decision_model, model1,model2,model3
_, model1, model2, model3 =load_trained_models()

In [59]:

def shuffle_col(v,col=None):
    if col is None:
        col = np.random.choice([i for i in range(v.shape[1])])
    idx = torch.randperm(v.shape[0])
    vv = torch.clone(v)
    vv[:,col] = vv[idx,col]
    return vv
    
def train_decision_model(
    tmodel1,
    tmodel2,
    tmodel3,
    use_default_split=True,
    use_bagging_split=False,
    lr=.0001,
    epochs=10000,
    patience=100,
    weights=[1,1,1], #realtive weight of survival, feeding tube, and aspiration
    imitation_weight=1,
    shufflecol_chance = 0.1,
    reward_weight=10,
    split=.7,
    resample_training=False,
    save_path='../data/models/',
    file_suffix='',
    use_attention=False,
    verbose=True,
    **model_kwargs,
):
    
    tmodel1.eval()
    tmodel2.eval()
    tmodel3.eval()
    
    train_ids, test_ids = get_tt_split(use_default_split=use_default_split,use_bagging_split=use_bagging_split,resample_training=resample_training)
    
    dataset = DTDataset()
    
    data = dataset.processed_df.copy()
    
    def get_dlt(state):
        if state == 2:
            return data[Const.dlt2].copy()
        d = data[Const.dlt1].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_pd(state):
        if state == 2:
            return data[Const.primary_disease_states2].copy()
        d = data[Const.primary_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_nd(state):
        if state == 2:
            return data[Const.nodal_disease_states2].copy()
        d = data[Const.nodal_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_cc(state):
        res = data[Const.ccs].copy()
        if state == 1:
            res.values[:,:] = np.zeros(res.values.shape)
        return res
    
    def get_mod(state):
        res = data[Const.modifications].copy()
        return res
        
    outcomedf = data[Const.outcomes]
    baseline = dataset.get_state('baseline')
    
    def formatdf(d,dids=train_ids):
        d = df_to_torch(d.loc[dids])
        return d
    
    def makegrad(v):
        if not v.requires_grad:
            v.requires_grad=True
        return v
    
    if use_attention:
        model = DecisionAttentionModel(baseline.shape[1],**model_kwargs)
    else:
        model = DecisionModel(baseline.shape[1],**model_kwargs)

    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    
    save_file = save_path + 'model_' + model.identifier +'_hash' + hashcode + file_suffix + '.tar'
    model.fit_normalizer(df_to_torch(baseline.loc[train_ids]))
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)

    def outcome_loss(ypred):
        #convert survival to death
        loss = torch.mul(torch.mean(-1*(ypred[:,0] - 1)),weights[0])
        for i,weight in enumerate(weights[1:]):
            newloss = torch.mean(ypred[:,i])*weight
            loss = torch.add(loss,torch.mul(newloss,weight))
        return loss
    
    mse = torch.nn.MSELoss()
    nllloss = torch.nn.NLLLoss()
    bce = torch.nn.BCELoss()
    
    def compare_decisions(d1,d2,d3,ids):
#         ypred = np.concatenate([dd.cpu().detach().numpy().reshape(-1,1) for dd in [d1,d2,d3]],axis=1)
        ytrue = df_to_torch(outcomedf.loc[ids])
        dloss = bce(d1.view(-1),ytrue[:,0])
        dloss += bce(d2.view(-1),ytrue[:,1])
        dloss += bce(d3,view(-1),ytrue[:,2])
        return dloss
        
    def remove_decisions(df):
        cols = [c for c in df.columns if c not in Const.decisions ]
        ddf = df[cols]
        return ddf
    
    makeinput = lambda step,dids: df_to_torch(remove_decisions(dataset.get_input_state(step=step,ids=dids)))
    threshold = lambda x: torch.gt(x,.5).type(torch.FloatTensor)
    def step(train=True):
        if train:
            model.train(True)
            optimizer.zero_grad()
            ids = train_ids
        else:
            ids = test_ids
            model.eval()
            
            
        ytrain = df_to_torch(outcomedf.loc[ids])
        #imitation losses and decision 1
        xxtrain = [baseline, get_dlt(0),get_dlt(0),get_pd(0),get_nd(0),get_cc(0),get_mod(0)]
        xxtrain = [formatdf(xx,ids) for xx in xxtrain]
        o1 = model(torch.cat(xxtrain,axis=1),position=0)
        decision1_imitation = o1[:,3]
        decision1 = o1[:,0]
        
#         imitation_loss1 = bce(threshold(decision1_imitation),ytrain[:,0])
        imitation_loss1 = bce(decision1_imitation,ytrain[:,0])
        
        x1_imitation = [baseline, get_dlt(1),get_dlt(0),get_pd(1),get_nd(1),get_cc(1),get_mod(1)]
        x1_imitation = [formatdf(xx1,ids) for xx1 in x1_imitation]
        decision2_imitation = model(torch.cat(x1_imitation,axis=1),position=1)[:,4]
        
#         imitation_loss2 =  bce(threshold(decision2_imitation),ytrain[:,1])
        imitation_loss2 =  bce(decision2_imitation,ytrain[:,1])
        
        x2_imitation = [baseline, get_dlt(1),get_dlt(2),get_pd(2),get_nd(2),get_cc(2),get_mod(2)]
        x2_imitation = [formatdf(xx2,ids) for xx2 in x2_imitation]
        decision3_imitation = model(torch.cat(x2_imitation,axis=1),position=2)[:,5]
        
#         imitation_loss3 = bce(threshold(decision3_imitation),ytrain[:,2])
        imitation_loss3 = bce(decision3_imitation,ytrain[:,2])
        
        #reward decisions
        xx1 = makeinput(1,ids)
        xx2 = makeinput(2,ids)
        xx3 = makeinput(3,ids)

        baseline_train_base = formatdf(baseline,ids)
            
        baseline_train = torch.clone(baseline_train_base)
        if train and shufflecol_chance > 0.0001:
            for col in range(baseline_train_base.shape[1]): 
                if np.random.random() < shufflecol_chance:
                    baseline_train = shuffle_col(baseline_train,col)
                    
        xi1 = torch.cat([xx1,decision1.view(-1,1)],axis=1)
        [ypd1, ynd1, ymod, ydlt1] = tmodel1(xi1)
        #this outputs log likelihoods (except for dlts) -> convert to probability
        ypd1 = torch.exp(ypd1)
        ynd1 = torch.exp(ynd1)
        ymod = torch.exp(ymod)
        x1 = [baseline_train,ydlt1,formatdf(get_dlt(0),ids),ypd1,ynd1,formatdf(get_cc(1),ids),ymod]
        
        decision2 = model(torch.cat(x1,axis=1),position=1)[:,1] 
        
        xi2 = torch.cat([xx2,decision1.view(-1,1),decision2.view(-1,1)],axis=1)
        [ypd2,ynd2,ycc,ydlt2] = tmodel2(xi2)
        ypd2 = torch.exp(ypd2)
        ynd2 = torch.exp(ynd2)
        ycc = torch.exp(ycc)
        x2 = [baseline_train,ydlt1,ydlt2,ypd2,ynd2,ycc,ymod]
            
        decision3 = model(torch.cat(x2,axis=1),position=2)[:,2]
        
        decision1 = threshold(decision1)
        decision2 = threshold(decision2)
        decision3 = threshold(decision3)
        
        xi3 = torch.cat([xx3,decision1.view(-1,1),decision2.view(-1,1),decision3.view(-1,1)],axis=1)
        outcomes = tmodel3(xi3)

        reward_loss = outcome_loss(outcomes)
        loss = torch.add(imitation_loss1,imitation_loss2)
        loss = torch.add(loss,imitation_loss3)
        loss = torch.mul(loss,imitation_weight/3)
        loss = torch.add(loss,torch.mul(reward_loss,reward_weight))
        losses = [imitation_loss1+imitation_loss2+imitation_loss3,reward_loss]
        if train:
            loss.backward()
            optimizer.step()
            return losses
        else:
            scores = []
            for i,decision in enumerate([decision1_imitation,decision2_imitation,decision3_imitation]):
                dec = decision.cpu().detach().numpy()
                dec0 = (dec > .5).astype(int)
                out = ytrain[:,i].cpu().detach().numpy()
                acc = accuracy_score(out,dec > .5)
                auc = roc_auc_score(out,dec)
                scores.append({'decision': i,'accuracy': acc,'auc': auc})
            return losses, scores
        
    best_val_loss = torch.tensor(1000000000.0)
    steps_since_improvement = 0
    best_val_score = {}
    for epoch in range(epochs):
        losses = step(True)
        val_losses,val_metrics = step(False)
        vl = val_losses[0] + val_losses[1]
        if verbose:
            print('______epoch',str(epoch),'_____')
            print('train imitation',losses[0].item(),'reward',losses[1].item())
            print('val imitation',val_losses[0].item(),'reward',val_losses[1].item())
            print('val loss',vl.item(),best_val_loss.item())
            print(val_metrics)
        if vl < best_val_loss:
            best_val_loss = vl
            best_val_score = val_metrics
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        if steps_since_improvement > patience:
            break
    print('++++++++++Final+++++++++++')
    print('best',best_val_loss)
    print(best_val_score)
    model.load_state_dict(torch.load(save_file))
    model.eval()
    return model, best_val_score, best_val_loss

from Models import *
args = {
    'hidden_layers': [600], 
    'attention_heads': [3], 
    'embed_size': 210, 
    'dropout': 0.9, 
    'input_dropout': 0.5, 
    'shufflecol_chance': 0.1,
}
decision_model, _, _ = train_decision_model(model1,model2,model3,lr=.0001,use_attention=True,**args)

  df = pd.read_csv(file)
  df = pd.read_csv(data_file)
  self.means = self.processed_df.mean(axis=0)
  self.stds = self.processed_df.std(axis=0)


______epoch 0 _____
train imitation 2.161153554916382 reward 1.1807951927185059
val imitation 1.9953837394714355 reward 1.184358835220337
val loss 3.1797425746917725 1000000000.0
[{'decision': 0, 'accuracy': 0.6232876712328768, 'auc': 0.36875}, {'decision': 1, 'accuracy': 0.6506849315068494, 'auc': 0.43229166666666663}, {'decision': 2, 'accuracy': 0.678082191780822, 'auc': 0.5198717948717948}]
______epoch 1 _____
train imitation 2.119168996810913 reward 1.1801512241363525
val imitation 1.9548444747924805 reward 1.1844807863235474
val loss 3.1393251419067383 3.1797425746917725
[{'decision': 0, 'accuracy': 0.726027397260274, 'auc': 0.37115384615384617}, {'decision': 1, 'accuracy': 0.7191780821917808, 'auc': 0.4413377192982456}, {'decision': 2, 'accuracy': 0.7328767123287672, 'auc': 0.5266025641025641}]
______epoch 2 _____
train imitation 2.083787679672241 reward 1.1801612377166748
val imitation 1.9162837266921997 reward 1.1845741271972656
val loss 3.100857734680176 3.1393251419067383
[{'

______epoch 21 _____
train imitation 1.6016905307769775 reward 1.1839709281921387
val imitation 1.4559180736541748 reward 1.184873342514038
val loss 2.640791416168213 2.6547584533691406
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.39278846153846153}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6014254385964912}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6580128205128205}]
______epoch 22 _____
train imitation 1.588714361190796 reward 1.1800607442855835
val imitation 1.4429160356521606 reward 1.1848335266113281
val loss 2.627749443054199 2.640791416168213
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.39278846153846153}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6052631578947368}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.6673076923076924}]
______epoch 23 _____
train imitation 1.5808918476104736 reward 1.1814935207366943
val imitation 1.430832862854004 reward 1.1847141981124878
val loss 2.6155471801757812 2.6

______epoch 42 _____
train imitation 1.4334405660629272 reward 1.1813918352127075
val imitation 1.3269062042236328 reward 1.1866122484207153
val loss 2.5135183334350586 2.515374183654785
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.44759615384615387}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6674890350877193}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7528846153846154}]
______epoch 43 _____
train imitation 1.4425970315933228 reward 1.182062029838562
val imitation 1.325310468673706 reward 1.1866122484207153
val loss 2.511922836303711 2.5135183334350586
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.451923076923077}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.6702302631578947}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.757051282051282}]
______epoch 44 _____
train imitation 1.4762611389160156 reward 1.1827021837234497
val imitation 1.3238191604614258 reward 1.1866122484207153
val loss 2.5104312896728516 2.51

______epoch 62 _____
train imitation 1.3845112323760986 reward 1.180848479270935
val imitation 1.3032461404800415 reward 1.1889135837554932
val loss 2.492159843444824 2.4931397438049316
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5278846153846154}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7064144736842105}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7826923076923078}]
______epoch 63 _____
train imitation 1.4348039627075195 reward 1.1819158792495728
val imitation 1.3021972179412842 reward 1.1884677410125732
val loss 2.4906649589538574 2.492159843444824
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5322115384615385}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7077850877192983}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7817307692307692}]
______epoch 64 _____
train imitation 1.381913185119629 reward 1.1810784339904785
val imitation 1.3010834455490112 reward 1.188529133796692
val loss 2.489612579345703 2.490

______epoch 82 _____
train imitation 1.392754077911377 reward 1.1798515319824219
val imitation 1.2845208644866943 reward 1.1903281211853027
val loss 2.474848985671997 2.475585699081421
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5716346153846154}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7203947368421053}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7794871794871795}]
______epoch 83 _____
train imitation 1.3786760568618774 reward 1.1827833652496338
val imitation 1.283811330795288 reward 1.190372347831726
val loss 2.4741835594177246 2.474848985671997
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.5716346153846154}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7201206140350878}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7807692307692308}]
______epoch 84 _____
train imitation 1.4199111461639404 reward 1.1807448863983154
val imitation 1.2831017971038818 reward 1.1902741193771362
val loss 2.4733757972717285 2.474

______epoch 102 _____
train imitation 1.3492770195007324 reward 1.1804249286651611
val imitation 1.2734383344650269 reward 1.189482569694519
val loss 2.462920904159546 2.46356201171875
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6067307692307693}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7239583333333333}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7746794871794872}]
______epoch 103 _____
train imitation 1.3360636234283447 reward 1.184448480606079
val imitation 1.2727577686309814 reward 1.1893690824508667
val loss 2.4621267318725586 2.462920904159546
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6086538461538462}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7242324561403508}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7746794871794872}]
______epoch 104 _____
train imitation 1.3587414026260376 reward 1.182149052619934
val imitation 1.2720887660980225 reward 1.1893690824508667
val loss 2.4614577293395996 2.4

______epoch 122 _____
train imitation 1.3353101015090942 reward 1.1837143898010254
val imitation 1.2614336013793945 reward 1.1888949871063232
val loss 2.4503285884857178 2.4512290954589844
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6375000000000001}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7225877192982456}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.771474358974359}]
______epoch 123 _____
train imitation 1.3172979354858398 reward 1.183728814125061
val imitation 1.2610281705856323 reward 1.1888731718063354
val loss 2.4499013423919678 2.4503285884857178
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6394230769230769}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7225877192982457}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7708333333333334}]
______epoch 124 _____
train imitation 1.2882364988327026 reward 1.1834731101989746
val imitation 1.2606478929519653 reward 1.1888731718063354
val loss 2.449521064758301

______epoch 142 _____
train imitation 1.308988094329834 reward 1.1827898025512695
val imitation 1.2518820762634277 reward 1.188625693321228
val loss 2.4405078887939453 2.441009759902954
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6423076923076922}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7198464912280702}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7647435897435897}]
______epoch 143 _____
train imitation 1.2897448539733887 reward 1.1842409372329712
val imitation 1.2514846324920654 reward 1.188625693321228
val loss 2.440110206604004 2.4405078887939453
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6442307692307692}, {'decision': 1, 'accuracy': 0.7808219178082192, 'auc': 0.7190241228070176}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7650641025641026}]
______epoch 144 _____
train imitation 1.2968952655792236 reward 1.1830050945281982
val imitation 1.2510011196136475 reward 1.1885663270950317
val loss 2.4395675659179688 2

______epoch 163 _____
train imitation 1.254044771194458 reward 1.1838613748550415
val imitation 1.2435396909713745 reward 1.1884181499481201
val loss 2.431957721710205 2.4328200817108154
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6427884615384616}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.7160087719298246}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7628205128205128}]
______epoch 164 _____
train imitation 1.2578409910202026 reward 1.1831963062286377
val imitation 1.2432926893234253 reward 1.1884151697158813
val loss 2.4317078590393066 2.431957721710205
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.64375}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.7168311403508771}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7615384615384615}]
______epoch 165 _____
train imitation 1.2932919263839722 reward 1.1822419166564941
val imitation 1.2429471015930176 reward 1.188717246055603
val loss 2.43166446685791 2.43170785903

______epoch 184 _____
train imitation 1.253989577293396 reward 1.1804355382919312
val imitation 1.2441534996032715 reward 1.1887258291244507
val loss 2.4328794479370117 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6461538461538461}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.715734649122807}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.760576923076923}]
______epoch 185 _____
train imitation 1.2083740234375 reward 1.1802603006362915
val imitation 1.2442569732666016 reward 1.188940405845642
val loss 2.433197498321533 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6466346153846153}, {'decision': 1, 'accuracy': 0.773972602739726, 'auc': 0.7154605263157894}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7602564102564102}]
______epoch 186 _____
train imitation 1.2683528661727905 reward 1.1815779209136963
val imitation 1.2444853782653809 reward 1.1891231536865234
val loss 2.4336085319519043 2.430473

______epoch 204 _____
train imitation 1.2465823888778687 reward 1.1798620223999023
val imitation 1.2495278120040894 reward 1.1898938417434692
val loss 2.4394216537475586 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6475961538461539}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7132675438596491}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7583333333333333}]
______epoch 205 _____
train imitation 1.220320701599121 reward 1.1817084550857544
val imitation 1.2501115798950195 reward 1.1898938417434692
val loss 2.440005302429199 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6495192307692308}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7127192982456141}, {'decision': 2, 'accuracy': 0.821917808219178, 'auc': 0.7580128205128205}]
______epoch 206 _____
train imitation 1.243523120880127 reward 1.1805429458618164
val imitation 1.2499876022338867 reward 1.1898938417434692
val loss 2.4398813247680664 2

______epoch 224 _____
train imitation 1.2258397340774536 reward 1.182205319404602
val imitation 1.2443815469741821 reward 1.188377857208252
val loss 2.4327592849731445 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6543269230769231}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7143640350877193}, {'decision': 2, 'accuracy': 0.8356164383561644, 'auc': 0.7608974358974359}]
______epoch 225 _____
train imitation 1.2110233306884766 reward 1.1800861358642578
val imitation 1.2449017763137817 reward 1.188420057296753
val loss 2.433321952819824 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6552884615384615}, {'decision': 1, 'accuracy': 0.7671232876712328, 'auc': 0.7143640350877193}, {'decision': 2, 'accuracy': 0.8356164383561644, 'auc': 0.760576923076923}]
______epoch 226 _____
train imitation 1.1963790655136108 reward 1.18160080909729
val imitation 1.2460083961486816 reward 1.1888593435287476
val loss 2.4348678588867188 2.4

______epoch 244 _____
train imitation 1.2084933519363403 reward 1.1809087991714478
val imitation 1.2454630136489868 reward 1.1881095170974731
val loss 2.43357253074646 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6572115384615385}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7129934210526315}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7583333333333333}]
______epoch 245 _____
train imitation 1.1809507608413696 reward 1.1809451580047607
val imitation 1.2456588745117188 reward 1.1881095170974731
val loss 2.4337682723999023 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6572115384615385}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7129934210526315}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7583333333333333}]
______epoch 246 _____
train imitation 1.223212480545044 reward 1.180373191833496
val imitation 1.2459386587142944 reward 1.1881095170974731
val loss 2.4340481758117676 

______epoch 264 _____
train imitation 1.2176625728607178 reward 1.1809096336364746
val imitation 1.2542054653167725 reward 1.1884838342666626
val loss 2.4426894187927246 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6581730769230769}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7118969298245614}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7583333333333333}]
______epoch 265 _____
train imitation 1.2210216522216797 reward 1.181593894958496
val imitation 1.254934310913086 reward 1.1889853477478027
val loss 2.4439196586608887 2.430473804473877
[{'decision': 0, 'accuracy': 0.8904109589041096, 'auc': 0.6591346153846154}, {'decision': 1, 'accuracy': 0.7602739726027398, 'auc': 0.7118969298245614}, {'decision': 2, 'accuracy': 0.8287671232876712, 'auc': 0.7573717948717948}]
______epoch 266 _____
train imitation 1.1644319295883179 reward 1.180776834487915
val imitation 1.2555088996887207 reward 1.189017653465271
val loss 2.4445266723632812 

In [None]:
def gridsearch_decision_model(m1,m2,m3):
#     model_arglist = [
#         {
#             'hidden_layers': [100],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [500],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [1000],
#             'attention_heads': [1],
#         },
#         {
#             'hidden_layers': [100],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [500],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [1000],
#             'attention_heads': [5],
#         },
#         {
#             'hidden_layers': [100,100],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [500,500],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [1000,1000],
#             'attention_heads': [1,1],
#         },
#         {
#             'hidden_layers': [100,100],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [500,500],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [1000,1000],
#             'attention_heads': [5,5],
#         },
#         {
#             'hidden_layers': [500,500,500],
#             'attention_heads': [5,5,5]
#         }
#     ]
    model_arglist = [

        {
            'hidden_layers': [100,100],
            'attention_heads': [1,1],
        },
        {
            'hidden_layers': [50,50],
            'attention_heads': [1,1],
        },
        {
            'hidden_layers': [100,100],
            'attention_heads': [2,2],
        },
        {
            'hidden_layers': [300],
            'attention_heads': [3],
        },
        {
            'hidden_layers': [300,300],
            'attention_heads': [3,3],
        },
    ]
    best_loss = 100000000000
    best_metrics = {}
    best_args = {}
    best_model = None
    k = 0
    for margs in model_arglist:
        args = {k:v for k,v in margs.items()}
        for embed_size in [0,100,200]:
            #embed_size = 0 skips the firt layer that makes the sizes right
            if embed_size == 0 and args['attention_heads'][0] != 1:
                continue
            args['embed_size'] = embed_size
            for dropout in [.5,.9,.95]:
                args['dropout'] = dropout
                for input_dropout in [.35,.45,.55]:
                    args['input_dropout'] = input_dropout
                    for shufflecol_chance in [.5,.75,.9]:
                        args['shufflecol_chance'] = shufflecol_chance
                        model,m_metrics,m_loss = train_decision_model(m1,m2,m3,use_attention=True,verbose=False,**args)
                        print('done',k,m_loss)
                        print('curr best',best_loss)
                        k+=1
                        if m_loss < best_loss:
                            best_loss = m_loss
                            best_metrics  = m_metrics
                            best_model = model
                            best_args = args
                            print('_++++++++++New Best++++____')
                            print(best_loss)
                            print(best_metrics)
                            print(best_args)
                            print('___________')
                            print('++++++++')
                            print()
    print('_________')
    print('+++++++++++')
    print('best stuff',best_loss)
    print(best_metrics)
    print(best_args)
    return best_model

from Models import *
decision_model = gridsearch_decision_model(model,model2,model3)
decision_model

In [None]:
from captum.attr import IntegratedGradients

ds = DTDataset()
states = DTDataset().get_states()
xdf = [states['baseline'],states['dlt1'],states['dlt2'],states['pd_states2'],states['nd_states2'],states['ccs'],states['modifications']]
x = tuple([df_to_torch(xx) for xx in xdf])
attributions = decision_model.get_attributions(x)
attributions[0]

In [60]:
torch.save(decision_model,'../data/models/final_decision_model_' + decision_model.identifier + '.pt')
print('../data/models/final_decision_model_' + decision_model.identifier + '.pt')

../data/models/final_decision_model_statedecisions_input132_dims600_dropout0.5,0.9.pt


In [None]:
torch.save(model,'../data/models/final_transition1_model_' + model.identifier + '.pt')
torch.save(model2,'../data/models/final_transition2_model_' + model2.identifier + '.pt')
torch.save(model3,'../data/models/final_outcome_model_' + model3.identifier + '.pt')
print('../data/models/final_transition1_model_' + model.identifier + '.pt')
print('../data/models/final_transition2_model_' + model2.identifier + '.pt')
print('../data/models/final_outcome_model_' + model3.identifier + '.pt')

In [None]:
xatt = []
for att,xxdf in zip(list(attributions),xdf):
    new = pd.DataFrame(att.cpu().detach().numpy(),columns=xxdf.columns,index=xxdf.index)
    xatt.append(new)
attributions = pd.concat(xatt,axis=1)
attributions

In [None]:
attributions.sum().sort_values()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
fig = plt.subplots(1,1,figsize=(100,100))
sns.heatmap(data=attributions.T,ax=fig[1])

In [None]:
def breakup_state_models(state_model):
    #for state 1 and 2
    models = {}
    models['pd'] = lambda x: state_model(x)[0]
    models['nd'] = lambda x: state_model(x)[1]
    models['chemo'] = lambda x: state_model(x)[2]
    for i,dlt in enumerate(Const.dlt1):
        models[dlt] = lambda x: state_model(x)[3][:,i]
    return models

def breakup_outcome_models(omodel):
    models = {}
    for i,name in enumerate(Const.outcomes):
        models[name] = lambda x: omodel(x)[:,i].reshape(-1,1)
    return models

def get_all_models(m1,m2,m3):
    state1_models = breakup_state_models(m1)
    state2_models = breakup_state_models(m2)
    state3_models = breakup_outcome_models(m3)
    all_models = {}
    for i,sm in enumerate([state1_models,state2_models,state3_models]):
        for ii,m in sm.items():
            all_models[ii +  '_state' + str(i+1)] = m
    return all_models

all_models = get_all_models(model,model2,model3)
all_models

In [None]:

def get_ytrue(name,df):
    outcomes=None
    value = None
    if name == 'pd_state1':
        outcomes = df[Const.primary_disease_states]
    elif name == 'pd_state2':
        outcomes = df[Const.primary_disease_states2]
    elif name == 'nd_state1':
        outcomes = df[Const.nodal_disease_states]
    elif name == 'nd_state2':
        outcomes = df[Const.nodal_disease_states2]
    elif name == 'chemo_state1':
        outcomes = df[Const.modifications]
    elif name == 'chemo_state2':
        outcomes = df[Const.ccs]   
    if outcomes is not None:
        value = outcomes.idxmax(axis=1)
    if 'DLT' in name:
        newname = name.replace('_state', ' ').replace('1','').strip()
        value = df[newname]
    if name.replace('_state3','') in Const.outcomes:
        value = df[name.replace('_state3','')]
    if value is None:
        print(name,df.columns)
    return value

def check_impact_of_decisions(model_dict,data):
    results = []
    #todo: this is wrong fix it
    ids = []
    df = data.get_data()
    outcomedict = {step: pd.concat(data.get_intermediate_outcomes(step=step),axis=1) for step in [1,2,3]}
    for decision in Const.decisions:
        for name, model in model_dict.items():
            step = int(name[-1])
            subset0 = dataset.get_input_state(step=step,fixed={decision: 0})
            subset1 = dataset.get_input_state(step=step,fixed={decision: 1})
            outcomes = outcomedict[step]
            ids = subset0.index.values
            x0 = df_to_torch(subset0)
            x1 = df_to_torch(subset1)
            y0 = model(x0).detach().cpu().numpy()
            y1 = model(x1).detach().cpu().numpy()
            original = data.get_input_state(step=step)
            xx = df_to_torch(original)
            yy = model(xx).detach().cpu().numpy()
            ytrue = get_ytrue(name,outcomes)
            if "DLT" in name:
                y0 = y0.argmax(axis=1).reshape(-1,1)
                y1 = y1.argmax(axis=1).reshape(-1,1)
                yy = yy.argmax(axis=1).reshape(-1,1)
                change = y0 - y1
                decision_change = (y0 != y1).astype(int)
            elif y0.shape[1] == 1:
                change = y1 - y0
                decision_change = np.abs((y0 > .5).astype(int) - (y1 > .5).astype(int))
            else:
                index = np.unravel_index(np.argmax(yy, axis=1), yy.shape)
                change = (y0[index] - y1[index]).reshape(-1,1)
                decision_change =  (y0.argmax(axis=1).reshape(-1,1) != y1.argmax(axis=1).reshape(-1,1)).astype(int)
                yy = yy.argmax(axis=1).reshape(-1,1)
                y1 = y1.argmax(axis=1).reshape(-1,1)
                y0 = y0.argmax(axis=1).reshape(-1,1)
            outcome = name.replace('_state','')
            for ii,pid in enumerate(ids):
                oo = ytrue.loc[pid]
                onew = y0[ii][0]
                original_decision = df.loc[pid,decision]
                if original_decision > 0:
                    onew = y0[ii][0]
                oname = Const.name_dict.get(name)
                if oname is not None:
                    onew = oname[onew]
                entry = {'id': pid, 'decision': decision,'outcome': outcome,'original_choice': original_decision, 'original_result': oo, 'alt_result': onew, 'change': change[ii][0], 'decision_change': decision_change[ii][0]}
                results.append(entry)
    return pd.DataFrame(results)

test = check_impact_of_decisions(all_models,dataset)
test

In [None]:
data.get_data()['SD Primary 2'].sum()

In [None]:
(test[test.outcome == 'pd2'].original_result == 'SD Primary 2').sum()

In [None]:
check_impact_of_decisions(all_models,dataset).to_csv('../data/decision_impacts.csv')