In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import torch
import re
from sklearn.metrics import balanced_accuracy_score, roc_auc_score,accuracy_score,precision_recall_fscore_support
from Constants import *
from Preprocessing import *
from Models import *
import copy
from Utils import *
from DeepSurvivalModels import *
pd.set_option('display.max_rows', 200)



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data = DTDataset(use_smote=False)
data.processed_df.T
data.get_input_state(1).shape
# data.processed_df#.shape, len(data.processed_df.index.unique())

(536, 62)

In [4]:
from Utils import *
model1,model2,model3,smodel3 = load_transition_models()
smodel3

DSM(
  (act): Tanh()
  (shape): ParameterList(
      (0): Parameter containing: [torch.float32 of size 3]
      (1): Parameter containing: [torch.float32 of size 3]
      (2): Parameter containing: [torch.float32 of size 3]
      (3): Parameter containing: [torch.float32 of size 3]
  )
  (scale): ParameterList(
      (0): Parameter containing: [torch.float32 of size 3]
      (1): Parameter containing: [torch.float32 of size 3]
      (2): Parameter containing: [torch.float32 of size 3]
      (3): Parameter containing: [torch.float32 of size 3]
  )
  (gate): ModuleList(
    (0-3): 4 x Sequential(
      (0): Linear(in_features=103, out_features=3, bias=False)
    )
  )
  (scaleg): ModuleList(
    (0-3): 4 x Sequential(
      (0): Linear(in_features=103, out_features=3, bias=True)
    )
  )
  (shapeg): ModuleList(
    (0-3): 4 x Sequential(
      (0): Linear(in_features=103, out_features=3, bias=False)
    )
  )
  (embedding): Sequential(
    (0): Linear(in_features=79, out_features=100, b

In [5]:
def temporal_loss(timestoevents,weights=None,maxtime=48,threshold=True):
    #list of expected times to events, usualy in order of Const.temporal_outcomes
    #basically longer = better, we count > maxtime (weeks) as no event
    if weights is None: 
        weights = [1 for i in range(len(timestoevents))]
    scores =  [(w*maxtime/t)for w,t in zip(weights,timestoevents)]
    if threshold:
        scores = [s*torch.lt(t,maxtime) for s,t in zip(scores,timestoevents)]
    scores = torch.stack(scores).sum(axis=0)
    return scores

def outcome_loss(ypred,weights=None):
    #default weights is bad
    if weights is None: 
        print('using default outcome loss weights, which is probably wrong since bad stuff should be negative')
        weights = [1 for i in range(ypred.shape[1])]
    l = torch.mul(ypred[:,0],weights[0])
    for i,weight in enumerate(weights[1:]):
        #weights with negative values will invert the outcome so e.g. Regional control becomes no regional control
        #so the penaly is correct
        newloss = torch.mul(ypred[:,i+1],weight)
        l = torch.add(l,newloss)
    return l

def calc_optimal_decisions(dataset,ids,m1,m2,m3,sm3,
                           weights=[0,0.5,.5,0], #weight for OS, FT, AS, and LRC as binary probabilities
                           tweights=[1,1,1,1], #weight for OS, LRC, FDM, and event (any + FT or AS at 6m) as time to event in weeks
                           outcome_loss_func=None,
                           threshold_temporal_loss = False,
                           maxtime=48,
                           get_transitions=True):
    m1.eval()
    m2.eval()
    m3.eval()
    sm3.eval()
    device = m1.get_device()
    data = dataset.processed_df.copy().loc[ids]
    
    def get_dlt(state):
        if state == 2:
            return data[Const.dlt2].copy()
        d = data[Const.dlt1].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_pd(state):
        if state == 2:
            return data[Const.primary_disease_states2].copy()
        d = data[Const.primary_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_nd(state):
        if state == 2:
            return data[Const.nodal_disease_states2].copy()
        d = data[Const.nodal_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_cc(state):
        res = data[Const.ccs].copy()
        if state == 1:
            res.values[:,:] = np.zeros(res.values.shape)
        return res
    
    def get_mod(state):
        res = data[Const.modifications].copy()
        #this should have an ic condition but we don't use it anumore anywa
        return res
        
    def formatdf(d):
        d = df_to_torch(d).to(device)
        return d
    
    
    outcomedf = data[Const.outcomes]
    baseline = dataset.get_state('baseline').loc[ids]
    baseline_input = formatdf(baseline)

        
    if outcome_loss_func is None:
        outcome_loss_func = outcome_loss
    
    cat = lambda x: torch.cat([xx.to(device) for xx in x],axis=1).to(device)
    format_transition = lambda x: x.to(device)
    def get_outcome(d1,d2,d3):
        d1 = torch.full((len(ids),1),d1).type(torch.FloatTensor)
        d2 = torch.full((len(ids),1),d2).type(torch.FloatTensor)
        d3 = torch.full((len(ids),1),d3).type(torch.FloatTensor)
        
        tinput1 = cat([baseline_input,d1])
        ytransition = m1(tinput1)
        [ypd1,ynd1,ymod,ydlt1] = [format_transition(xx) for xx in ytransition['predictions']]
        d1_thresh = torch.gt(d1,.5).view(-1,1).to(device)
        ypd1[:,0:2] = ypd1[:,0:2]*d1_thresh
        ynd1[:,0:2] = ynd1[:,0:2]*d1_thresh
        
        tinput2 = cat([baseline_input,ypd1,ynd1,ymod,ydlt1,d1,d2])
        ytransition2 = m2(tinput2)
        [ypd2,ynd2,ycc,ydlt2] = [format_transition(xx) for xx in ytransition2['predictions']]
        
        input3 = cat([baseline_input, ypd2, ynd2, ycc, ydlt2, d1, d2,d3])
        outcome = m3(input3)['predictions']
        temporal_outcomes = sm3.time_to_event(input3,n_samples=1)
        
        transitions = {
            'pd1': ypd1,
            'nd1': ynd1,
            'nd2': ynd2,
            'pd2': ypd2,
            'mod': ymod,
            'cc': ycc,
            'dlt1': ydlt1,
            'dlt2': ydlt2,
        }
        return outcome, temporal_outcomes, transitions

    losses = []
    loss_order = []
    transitions = {}
    for d1 in [0,1]:
        for d2 in [0,1]:
            for d3 in [0,1]:
                outcomes, tte, transition_entry = get_outcome(d1,d2,d3)
                loss = outcome_loss_func(outcomes,weights)
                tloss = temporal_loss(tte,tweights,maxtime=maxtime,threshold=threshold_temporal_loss)
                loss += tloss
                losses.append(loss)
                loss_order.append([d1,d2,d3])
                transitions[str(d1)+str(d2)+str(d3)] = transition_entry
    losses = torch.stack(losses,axis=1)
    optimal_decisions = [loss_order[i] for i in torch.argmin(losses,axis=1)]
    result = torch.tensor(optimal_decisions).type(torch.FloatTensor)
    print(result.sum(axis=0),result.shape[0])
    if get_transitions:
        opt_transitions = {k: torch.zeros(v.shape).type(torch.FloatTensor) for k,v in transitions['000'].items()}
        for i,od in enumerate(optimal_decisions):
            key = ''.join([str(o) for o in od])
            entry = transitions[key]
            for kk,vv in entry.items():
                opt_transitions[kk][i,:] = vv[i,:]
        return result, opt_transitions
    return result

test, testtest = get_tt_split()
calc_optimal_decisions(DTDataset(),
                       testtest,model1,model2,model3,smodel3,
                       threshold_temporal_loss=False,
                       maxtime=48,
                       weights=[0,0,0,0],
                       tweights=[2,0.1,0,0],
                      )

tensor([8., 0., 0.]) 147


(tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 0., 0.],
         [0

In [7]:
def get_unique_sequence(array):
    #converts a row of boolean values to a unique number e.g. [1,1,0] => 11, [0,0,1] => 100
    uniqueify = lambda r: torch.sum(torch.stack([i*(10**ii) for ii,i in enumerate(r)]))
    return torch_apply_along_axis(uniqueify,array)

def train_decision_model_triplet(
    tmodel1,
    tmodel2,
    tmodel3,
    smodel3,
    use_default_split=True,
    use_bagging_split=False,
    use_attention=True,
    lr=.001,
    epochs=10000,
    patience=5,
    weights=[0,.5,.5,0], #realtive weight of survival, feeding tube, aspiration, andl lrc
    tweights=[1,1,1,0], #weight for OS, LRC, FDM, and event (any + FT or AS at 6m) as time to event in weeks
    opt_weights=[1,1,1], #weights for policy model for optimal decisions
    imitation_weights=[.5,1,1],#weights of imitation decisions, because ic overtrains too quickly
    imitation_weight=1,
    reward_weight=1,
    imitation_triplet_weight=2,
    reward_triplet_weight = 2,
    shufflecol_chance = 0.1,
    split=.7,
    resample_training=False,
    save_path='../data/models/',
    file_suffix='',
    verbose=True,
    use_gpu=False,
    **model_kwargs,
):
    tmodel1.eval()
    tmodel2.eval()
    tmodel3.eval()

    train_ids, test_ids = get_tt_split(use_default_split=use_default_split,use_bagging_split=use_bagging_split,resample_training=resample_training)
    true_ids = train_ids + test_ids #for saving memory without upsampling

    dataset = DTDataset()
    data = dataset.processed_df.copy()
    
    def get_dlt(state):
        if state == 2:
            return data[Const.dlt2].copy()
        d = data[Const.dlt1].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_pd(state):
        if state == 2:
            return data[Const.primary_disease_states2].copy()
        d = data[Const.primary_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_nd(state):
        if state == 2:
            return data[Const.nodal_disease_states2].copy()
        d = data[Const.nodal_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_cc(state):
        res = data[Const.ccs].copy()
        if state == 1:
            res.values[:,:] = np.zeros(res.values.shape)
        return res
    
    def get_mod(state):
        res = data[Const.modifications].copy()
        #this should have an ic condition but we don't use it anumore anywa
        return res
        
    outcomedf = data[Const.outcomes]
    baseline = dataset.get_state('baseline')
    
    def formatdf(d,dids=train_ids):
        d = df_to_torch(d.loc[dids]).to(model.get_device())
        return d
    
    def makegrad(v):
        if not v.requires_grad:
            v.requires_grad=True
        return v
    
    if use_attention:
        model = DecisionAttentionModel(baseline.shape[1],**model_kwargs)
    else:
        model_kwargs = {k:v for k,v in model_kwargs.items() if 'attention' not in k and 'embed' not in k}
        model = DecisionModel(baseline.shape[1],**model_kwargs)
        
    device = 'cpu'
    if use_gpu and torch.cuda.is_available():
        device = 'cuda'
        
    model.set_device(device)

    tmodel1.set_device(device)
    tmodel2.set_device(device)
    tmodel3.set_device(device)
    smodel3.set_device(device)
    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    
    save_file = save_path + 'model_' + model.identifier +'_hash' + hashcode + file_suffix + '.tar'
    model.fit_normalizer(df_to_torch(baseline.loc[train_ids]))
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    
    
    optimal_train,transitions_train = calc_optimal_decisions(dataset,train_ids,tmodel1,tmodel2,tmodel3,smodel3,
                                           weights=weights,tweights=tweights,
                                          )
    optimal_test,transitions_test = calc_optimal_decisions(dataset,test_ids,tmodel1,tmodel2,tmodel3,smodel3,
                                          weights=weights,tweights=tweights,
                                         )
    optimal_train = optimal_train.to(model.get_device())
    optimal_test = optimal_test.to(model.get_device())
    mse = torch.nn.MSELoss()
    bce = torch.nn.BCELoss()
    
    def compare_decisions(d1,d2,d3,ids):
#         ypred = np.concatenate([dd.cpu().detach().numpy().reshape(-1,1) for dd in [d1,d2,d3]],axis=1)
        ytrue = df_to_torch(outcomedf.loc[ids])
        dloss = bce(d1.view(-1),ytrue[:,0])
        dloss += bce(d2.view(-1),ytrue[:,1])
        dloss += bce(d3,view(-1),ytrue[:,2])
        return dloss
        
    def remove_decisions(df):
        cols = [c for c in df.columns if c not in Const.decisions ]
        ddf = df[cols]
        return ddf
    
    makeinput = lambda step,dids: df_to_torch(remove_decisions(dataset.get_input_state(step=step,ids=dids)))
    threshold = lambda x: torch.gt(x,torch.rand(x.shape[0])).type(torch.FloatTensor)

    randchoice = lambda x: x[torch.randint(len(x),(1,))[0]]
    tloss_func = torch.nn.TripletMarginLoss()
    def get_tloss(row,step,yt,x,imitation=True):
        if yt[:,step].std() < .001:
            return torch.tensor([0]).to(device)
        positive_idx= torch.nonzero(yt[:,step] == yt[row,step])
        if len(positive_idx) <= 1:
            print('no losses','n positive',len(positive_idx),'yt',yt[row,step],'row',row,'step',step,'imitation',imitation,end='\r')
            return torch.tensor([0]).to(device)
        positive_idx = torch.stack([ii for ii in positive_idx if ii != row]).view(-1)
        negative_idx = torch.tensor([ii for ii in range(x.shape[0]) if ii not in positive_idx and ii != row])
        if len(positive_idx) < 1 or len(negative_idx) < 1:
            print('no losses','n positive',len(positive_idx),'n negative',len(negative_idx),'yt',yt[row,step],'row',row,'step',step,'imitation',imitation,end='\r')
            return torch.tensor([0]).to(device)
        positive = x[randchoice(positive_idx)]
        negative = x[randchoice(negative_idx)]
        anchor = x[row]
        if use_attention:
            [anchor_embedding,pos_embedding,neg_embedding] = [model.get_embedding(xx.view(1,-1),position=step,use_saved_memory=True) for xx in [anchor,positive,negative]]
        else:    
            [anchor_embedding,pos_embedding,neg_embedding] = [model.get_embedding(xx.view(1,-1),position=step,concatenate=False)[int(imitation)] for xx in [anchor,positive,negative]]
        tloss = tloss_func(anchor_embedding,pos_embedding,neg_embedding)
        return tloss
    #save the inputs from the whole dataset for future reference
    if use_attention:
        full_data = []
        for mstep in [0,1,2]:
            full_data_step = [baseline, get_dlt(min(mstep,1)),
                         get_dlt(mstep),get_pd(mstep),get_nd(mstep),get_cc(mstep),get_mod(mstep)]
            full_data_step = torch.cat([formatdf(fd,true_ids) for fd in full_data_step],axis=1)
            full_data.append(full_data_step)
        full_data = torch.stack(full_data)
        model.save_memory(full_data)
        print(full_data.shape)
        
    def step(train=True):
        if train:
            model.train(True)
            tmodel1.train(True)
            tmodel2.train(True)
            tmodel3.train(True)
            optimizer.zero_grad()
            ids = train_ids
            y_opt = makegrad(optimal_train)
            transition_dict = {k: torch.clone(v).detach() for k,v in transitions_train.items()}
        else:
            ids = test_ids
            model.eval()
            tmodel1.eval()
            tmodel2.eval()
            tmodel3.eval()
            y_opt = makegrad(optimal_test)
            print(y_opt.mean(axis=0))
            transition_dict = {k: torch.clone(v).detach() for k,v in transitions_test.items()}
        model.set_device(device)
        ytrain = df_to_torch(outcomedf.loc[ids]).to(device)
        #imitation losses and decision 1
        xxtrained = [baseline, get_dlt(0),get_dlt(0),get_pd(0),get_nd(0),get_cc(0),get_mod(0)]
        xxtrain = [formatdf(xx,ids) for xx in xxtrained]
        xxtrain = torch.cat(xxtrain,axis=1).to(device)
        o1 = model(xxtrain,position=0,use_saved_memory= (not train))
        decision1_imitation = o1[:,3]
        decision1_opt = o1[:,0]
    
        imitation_loss1 = bce(decision1_imitation,ytrain[:,0])
        imitation_loss1 = torch.mul(imitation_loss1,imitation_weights[0])
        opt_loss1 = bce(decision1_opt,y_opt[:,0])
        opt_loss1 = torch.mul(opt_loss1,opt_weights[0])
        
        x1_imitation = [baseline, get_dlt(1),get_dlt(0),get_pd(1),get_nd(1),get_cc(1),get_mod(1)]
        x1_imitation = [formatdf(xx1,ids) for xx1 in x1_imitation]
        x1_imitation = torch.cat(x1_imitation,axis=1).to(device)
        
        o2 = model(x1_imitation,position=1,use_saved_memory= (not train))
            
        decision2_imitation = o2[:,4]
            
        imitation_loss2 =  bce(decision2_imitation,ytrain[:,1])
        imitation_loss2 = torch.mul(imitation_loss2,imitation_weights[1])
        
        
        x2_imitation = [baseline, get_dlt(1),get_dlt(2),get_pd(2),get_nd(2),get_cc(2),get_mod(2)]
        
        x2_imitation = [formatdf(xx2,ids) for xx2 in x2_imitation]
        x2_imitation = torch.cat(x2_imitation,axis=1).to(device)
        
        
        o3 = model(x2_imitation,position=2,use_saved_memory= (not train))
        
        decision3_imitation = o3[:,5]
        
        imitation_loss3 = bce(decision3_imitation,ytrain[:,2])
        imitation_loss3 = torch.mul(imitation_loss3,imitation_weights[2])
        
        opt_input2 = [
            formatdf(baseline,ids), 
            transition_dict['dlt1'],
            formatdf(get_dlt(0),ids),
            transition_dict['pd1'],
            transition_dict['nd1'], 
            formatdf(get_cc(0),ids),
            transition_dict['mod']
                 ]
        opt_input2 = [o.to(device) for o in opt_input2]

        opt_input2 = torch.cat(opt_input2,axis=1).to(device)
        decision2_opt = model(opt_input2,position=1,use_saved_memory= (not train))[:,1]
        
        opt_loss2 = bce(decision2_opt,y_opt[:,1])
        opt_loss2 = torch.mul(opt_loss2,opt_weights[1])
        
        opt_input3 = [
            formatdf(baseline,ids),
            transition_dict['dlt1'],
            transition_dict['dlt2'],
            transition_dict['pd2'],
            transition_dict['nd2'],
            transition_dict['cc'],
            transition_dict['mod'],
        ]
        opt_input3 = [o.to(device) for o in opt_input3]
        opt_input3 = torch.cat(opt_input3,axis=1).to(device)
        decision3_opt = model(opt_input3,position=2,use_saved_memory= (not train))[:,2]
        
        opt_loss3 = bce(decision3_opt,y_opt[:,2])
        opt_loss3 = torch.mul(opt_loss3,opt_weights[2])
        
        iloss = torch.add(torch.add(imitation_loss1,imitation_loss2),imitation_loss3)
        iloss = torch.mul(iloss,imitation_weight)
        
        reward_loss = torch.add(torch.add(opt_loss1,opt_loss2),opt_loss3)
        reward_loss =torch.mul(reward_loss,reward_weight)
        
        loss = torch.add(iloss,reward_loss)
        
        imitation_tloss = torch.FloatTensor([0]).to(device)
        opt_tloss = torch.FloatTensor([0]).to(device)
        n_rows = xxtrain.shape[0]
        if reward_triplet_weight + imitation_triplet_weight > 0.0001:
            for i in range(n_rows):
                
                if imitation_triplet_weight > .0001:
                    imitation_tloss += get_tloss(i,0,ytrain,xxtrain,True)
                    imitation_tloss += get_tloss(i,1,ytrain,x1_imitation,True)
                    imitation_tloss += get_tloss(i,2,ytrain,x2_imitation,True)
                if reward_triplet_weight > .0001:
                    opt_tloss += get_tloss(i,0,y_opt,xxtrain,False)
                    opt_tloss += get_tloss(i,1,y_opt,opt_input2,False)
                    opt_tloss += get_tloss(i,2,y_opt,opt_input3,False)
            loss += torch.mul(imitation_tloss[0],imitation_triplet_weight/n_rows)
            loss += torch.mul(opt_tloss[0],reward_triplet_weight/n_rows)
        losses = [iloss,reward_loss,imitation_tloss*imitation_triplet_weight/n_rows,opt_tloss*reward_triplet_weight/n_rows]
        
        if train:
            loss.backward()
            optimizer.step()
            return losses
        else:
            scores = []
            distributions = [decision1_opt.mean().item(),decision2_opt.mean().item(),decision3_opt.mean().item()]
            imitation = [decision1_imitation,decision2_imitation,decision3_imitation]
            optimal = [decision1_opt,decision2_opt,decision3_opt]
            for i,decision_im in enumerate(imitation):
                deci = decision_im.cpu().detach().numpy()
                deci0 = (deci > .5).astype(int)
                iout = ytrain[:,i].cpu().detach().numpy()
                acci = accuracy_score(iout,deci0)
                try:
                    auci = roc_auc_score(iout,deci)
                except:
                    auci = -1
                
                deco = optimal[i].cpu().detach().numpy()
                deci0 = (deco > .5).astype(int)
                oout = y_opt[:,i].cpu().detach().numpy()
                acco = accuracy_score(oout,deci0)
                try:
                    auco = roc_auc_score(oout,deco)
                except:
                    auco=-1
                scores.append({'decision': i,'optimal_auc': auco,'imitation_auc': auci,'optimal_acc': acco,'imitation_acc': acci})
            return losses, scores, distributions
        
    best_val_loss = torch.tensor(1000000000.0)
    steps_since_improvement = 0
    best_val_score = {}
    
    for epoch in range(epochs):
        losses = step(True)
        val_losses,val_metrics,val_distributions = step(False)
        vl = val_losses[0] + val_losses[1]
        for vm in val_metrics:
            vl += (-((vm['optimal_auc']*reward_weight) + (vm['imitation_auc']*imitation_weight)))/10
        if verbose:
            print('______epoch',str(epoch),'_____')
            print('val reward',val_losses[1].item())
            print('imitation reward', val_losses[0].item())
            print('distance losses',val_losses[2].item(),val_losses[-1].item())
            print('distributions',val_distributions)
            print(val_metrics)
        if vl < best_val_loss:
            best_val_loss = vl
            best_val_score = val_metrics
            best_val_distributions = val_distributions
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        if steps_since_improvement > patience:
            break
    print('++++++++++Final+++++++++++')
    print('best',best_val_loss)
    print(best_val_score)
    model.load_state_dict(torch.load(save_file))
    model.eval()
    return model, best_val_score, best_val_loss, best_val_distributions

from Models import *
args = {
    'hidden_layers': [1000], 
    'opt_layer_size': 20, 
    'imitation_layer_size': 20, 
    'dropout': 0.25, 
    'input_dropout': 0.25, 
    'shufflecol_chance': 0.5
}

#1.8424
decision_model, decision_score, decision_loss, _ = train_decision_model_triplet(
    model1,model2,model3,smodel3,
    lr=.01,
    imitation_weight=1,
    reward_weight=1,
    patience=10,
    imitation_triplet_weight=1,
    reward_triplet_weight =1,
    verbose=True,
    weights=[0,0,0,0], #realtive weight of survival, feeding tube, aspiration, andl lrc
    tweights=[2,.1,0,0], #weight for OS, LRC, FDM, and event (any + FT or AS at 6m) as time to event in weeks
    use_attention=True,
    **args)
decision_model

tensor([25.,  0.,  0.]) 389
tensor([8., 0., 0.]) 147
torch.Size([3, 536, 87])
tensor([0.0544, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 0 _____
val reward 0.325850248336792
imitation reward 1.7759063243865967
distance losses 2.728210210800171 0.6751828193664551
distributions [0.0030181524343788624, 0.004276015795767307, 0.0045722066424787045]
[{'decision': 0, 'optimal_auc': 0.5620503597122302, 'imitation_auc': 0.6645992366412214, 'optimal_acc': 0.9455782312925171, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': -1, 'imitation_auc': 0.5002717391304348, 'optimal_acc': 1.0, 'imitation_acc': 0.782312925170068}, {'decision': 2, 'optimal_auc': -1, 'imitation_auc': 0.5696122059758424, 'optimal_acc': 1.0, 'imitation_acc': 0.8231292517006803}]
tensor([0.0544, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 1 _____
val reward 0.3722558319568634
imitation reward 1.297990083694458
distance losses 2.4456264972686768 0.5394548177719116
distributions [0.001479243

tensor([0.0544, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 12 _____
val reward 0.21059641242027283
imitation reward 1.0823335647583008
distance losses 2.7554006576538086 0.35311341285705566
distributions [0.10066709667444229, 1.4174489137985802e-07, 2.383086439294857e-07]
[{'decision': 0, 'optimal_auc': 0.7589928057553957, 'imitation_auc': 0.5706106870229007, 'optimal_acc': 0.9455782312925171, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': -1, 'imitation_auc': 0.6980978260869565, 'optimal_acc': 1.0, 'imitation_acc': 0.7891156462585034}, {'decision': 2, 'optimal_auc': -1, 'imitation_auc': 0.819135410044501, 'optimal_acc': 1.0, 'imitation_acc': 0.8231292517006803}]
tensor([0.0544, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 13 _____
val reward 0.2146233320236206
imitation reward 1.0837063789367676
distance losses 2.5318968296051025 0.5357825756072998
distributions [0.11034917831420898, 1.2370111335258116e-07, 2.1511141312657855e-07]
[{'decision':

tensor([0.0544, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 24 _____
val reward 0.18063783645629883
imitation reward 1.1303930282592773
distance losses 2.7410621643066406 0.42590248584747314
distributions [0.09573030471801758, 1.322642447121325e-06, 1.7915821217684424e-06]
[{'decision': 0, 'optimal_auc': 0.8875899280575539, 'imitation_auc': 0.6397900763358778, 'optimal_acc': 0.9455782312925171, 'imitation_acc': 0.8707482993197279}, {'decision': 1, 'optimal_auc': -1, 'imitation_auc': 0.7366847826086956, 'optimal_acc': 1.0, 'imitation_acc': 0.7414965986394558}, {'decision': 2, 'optimal_auc': -1, 'imitation_auc': 0.8296249205340115, 'optimal_acc': 1.0, 'imitation_acc': 0.8163265306122449}]
tensor([0.0544, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 25 _____
val reward 0.17772386968135834
imitation reward 1.1936062574386597
distance losses 2.6197683811187744 0.467324435710907
distributions [0.10334429144859314, 1.898229470498336e-06, 2.4810176455503097e-06]
[{'decision'

tensor([0.0544, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 36 _____
val reward 0.16212324798107147
imitation reward 1.1287782192230225
distance losses 2.800011157989502 0.49350109696388245
distributions [0.0813124030828476, 1.9080920537817292e-05, 2.0029459847137332e-05]
[{'decision': 0, 'optimal_auc': 0.9082733812949639, 'imitation_auc': 0.6698473282442747, 'optimal_acc': 0.9319727891156463, 'imitation_acc': 0.8775510204081632}, {'decision': 1, 'optimal_auc': -1, 'imitation_auc': 0.695108695652174, 'optimal_acc': 1.0, 'imitation_acc': 0.7755102040816326}, {'decision': 2, 'optimal_auc': -1, 'imitation_auc': 0.7949777495232041, 'optimal_acc': 1.0, 'imitation_acc': 0.8299319727891157}]
tensor([0.0544, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 37 _____
val reward 0.15071257948875427
imitation reward 1.1395437717437744
distance losses 2.7426280975341797 0.5628157258033752
distributions [0.0689985454082489, 2.113919072144199e-05, 2.207921352237463e-05]
[{'decision': 0

DecisionAttentionModel(
  (input_dropout): Dropout(p=0.25, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=100, out_features=1000, bias=True)
  )
  (batchnorm): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.25, inplace=False)
  (relu): Softplus(beta=1, threshold=20)
  (sigmoid): Sigmoid()
  (softmax): Softmax(dim=1)
  (final_opt_layer): Linear(in_features=1000, out_features=100, bias=True)
  (final_imitation_layer): Linear(in_features=1000, out_features=100, bias=True)
  (final_layer): Linear(in_features=1000, out_features=6, bias=True)
  (resize_layer): Linear(in_features=91, out_features=100, bias=True)
  (attentions): ModuleList(
    (0): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
    )
  )
  (norms): ModuleList(
    (0): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
  )
  (activation): ReLU()
)

In [12]:
decision_model.set_device('cpu')
torch.save(decision_model,'../resources/decision_model.pt')
pd.DataFrame(decision_score).to_csv('../results/policy_model_score.csv')
pd.DataFrame(decision_score)

Unnamed: 0,decision,optimal_auc,imitation_auc,optimal_acc,imitation_acc
0,0,0.895683,0.620706,0.945578,0.891156
1,1,-1.0,0.757065,1.0,0.77551
2,2,-1.0,0.828353,1.0,0.829932


In [27]:
#1.8424
decision_model2, decision_score2, decision_loss2, _ = train_decision_model_triplet(
    model1,model2,model3,smodel3,
    lr=.01,
    imitation_weight=1,
    reward_weight=1,
    patience=10,
    imitation_triplet_weight=0,
    reward_triplet_weight =0,
    verbose=True,
    weights=[0,0,0,0], #realtive weight of survival, feeding tube, aspiration, andl lrc
    tweights=[1,1,0,0], #weight for OS, LRC, FDM, and event (any + FT or AS at 6m) as time to event in weeks
    use_attention=True,
    **args)

tensor([65.,  0.,  0.]) 389
tensor([26.,  0.,  0.]) 147
torch.Size([3, 536, 87])
tensor([0.1769, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 0 _____
val reward 0.8488877415657043
imitation reward 2.0121195316314697
distance losses 0.0 0.0
distributions [0.008012257516384125, 0.0033295094035565853, 0.0008333181613124907]
[{'decision': 0, 'optimal_auc': 0.6433566433566433, 'imitation_auc': 0.5854007633587787, 'optimal_acc': 0.8231292517006803, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': -1, 'imitation_auc': 0.6309782608695651, 'optimal_acc': 1.0, 'imitation_acc': 0.782312925170068}, {'decision': 2, 'optimal_auc': -1, 'imitation_auc': 0.7094723458359822, 'optimal_acc': 1.0, 'imitation_acc': 0.8231292517006803}]
tensor([0.1769, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 1 _____
val reward 0.5268533825874329
imitation reward 1.3480204343795776
distance losses 0.0 0.0
distributions [0.089411661028862, 0.0002860285749193281, 5.930369297857396e-05]


tensor([0.1769, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 13 _____
val reward 0.39634063839912415
imitation reward 1.1749008893966675
distance losses 0.0 0.0
distributions [0.2865397334098816, 3.9219514746946516e-07, 7.78546151991577e-08]
[{'decision': 0, 'optimal_auc': 0.8302606484424666, 'imitation_auc': 0.6192748091603053, 'optimal_acc': 0.8163265306122449, 'imitation_acc': 0.8775510204081632}, {'decision': 1, 'optimal_auc': -1, 'imitation_auc': 0.7133152173913044, 'optimal_acc': 1.0, 'imitation_acc': 0.7551020408163265}, {'decision': 2, 'optimal_auc': -1, 'imitation_auc': 0.7835346471710107, 'optimal_acc': 1.0, 'imitation_acc': 0.8027210884353742}]
tensor([0.1769, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 14 _____
val reward 0.3510810136795044
imitation reward 1.1551833152770996
distance losses 0.0 0.0
distributions [0.21866436302661896, 3.9559540709888097e-07, 8.445369559240135e-08]
[{'decision': 0, 'optimal_auc': 0.8445645263827082, 'imitation_auc': 0.6230

tensor([0.1769, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 27 _____
val reward 0.3324897885322571
imitation reward 1.2066798210144043
distance losses 0.0 0.0
distributions [0.19053611159324646, 1.1477845163199163e-07, 1.4610249898794336e-08]
[{'decision': 0, 'optimal_auc': 0.8855689764780674, 'imitation_auc': 0.623091603053435, 'optimal_acc': 0.8639455782312925, 'imitation_acc': 0.8571428571428571}, {'decision': 1, 'optimal_auc': -1, 'imitation_auc': 0.752445652173913, 'optimal_acc': 1.0, 'imitation_acc': 0.7755102040816326}, {'decision': 2, 'optimal_auc': -1, 'imitation_auc': 0.7819453273998728, 'optimal_acc': 1.0, 'imitation_acc': 0.8163265306122449}]
tensor([0.1769, 0.0000, 0.0000], grad_fn=<MeanBackward1>)
______epoch 28 _____
val reward 0.336943656206131
imitation reward 1.2627991437911987
distance losses 0.0 0.0
distributions [0.15760163962841034, 6.795051632479954e-08, 1.0446188802859524e-08]
[{'decision': 0, 'optimal_auc': 0.8906547997457088, 'imitation_auc': 0.62452

In [None]:
def train_decision_model(
    tmodel1,
    tmodel2,
    tmodel3,
    smodel3,
    use_default_split=True,
    use_bagging_split=False,
    lr=.0001,
    epochs=10000,
    patience=50,
    weights=[0,.5,.5,0], #realtive weight of survival, feeding tube, aspiration, andl lrc
    tweights=[1,1,1,0]
    imitation_weights=[.5,1,1],#weights of imitation decisions, because ic overtrains too quickly
    imitation_weight=0.1,
    shufflecol_chance = 0.1,
    reward_weight=1,
    imitation_triplet_weight=0,
    reward_triplet_weight = 0,
    split=.7,
    resample_training=False,
    save_path='../data/models/',
    file_suffix='',
    use_gpu=True,
    use_attention=True,
    verbose=True,
    threshold_decisions=True,#convert decisiosn to binary in simulation, usually breaks it
    use_smote=False,
    validate_with_memory=True,
    **model_kwargs):
    #outdated method of doing stuff, haven't updated with new loss functions idk
    tmodel1.eval()
    tmodel2.eval()
    tmodel3.eval()
    smodel3.eval()
    train_ids, test_ids = get_tt_split(use_default_split=use_default_split,use_bagging_split=use_bagging_split,resample_training=resample_training)
    true_ids = train_ids + test_ids #for saving memory without upsampling
    if use_smote:
        dataset = DTDataset(use_smote=True,smote_ids = train_ids)
        train_ids = [i for i in dataset.processed_df.index.values if i not in test_ids]
    else:
        dataset = DTDataset()
    data = dataset.processed_df.copy()
    
    def get_dlt(state):
        if state == 2:
            return data[Const.dlt2].copy()
        d = data[Const.dlt1].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_pd(state):
        if state == 2:
            return data[Const.primary_disease_states2].copy()
        d = data[Const.primary_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_nd(state):
        if state == 2:
            return data[Const.nodal_disease_states2].copy()
        d = data[Const.nodal_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_cc(state):
        res = data[Const.ccs].copy()
        if state == 1:
            res.values[:,:] = np.zeros(res.values.shape)
        return res
    
    def get_mod(state):
        res = data[Const.modifications].copy()
        #this should have an ic condition but we don't use it anumore anywa
        return res
        
    outcomedf = data[Const.outcomes]
    baseline = dataset.get_state('baseline')
    
    def formatdf(d,dids=train_ids):
        d = df_to_torch(d.loc[dids]).to(model.get_device())
        return d
    
    def makegrad(v):
        if not v.requires_grad:
            v.requires_grad=True
        return v
    
    if use_attention:
        model = DecisionAttentionModel(baseline.shape[1],**model_kwargs)
    else:
        model_kwargs = {k:v for k,v in model_kwargs.items() if 'attention' not in k and 'embed' not in k}
        model = DecisionModel(baseline.shape[1],**model_kwargs)

    device = 'cpu'
    if use_gpu and torch.cuda.is_available():
        device = 'cuda'
        
    model.set_device(device)

    tmodel1.set_device(device)
    tmodel2.set_device(device)
    tmodel3.set_device(device)
    smodel3.set_device(device)
    
    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    
    save_file = save_path + 'model_' + model.identifier +'_hash' + hashcode + file_suffix + '.tar'
    model.fit_normalizer(df_to_torch(baseline.loc[train_ids]).to(model.get_device()))
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    
    mse = torch.nn.MSELoss()
    bce = torch.nn.BCELoss()
    device = model.get_device()
    def compare_decisions(d1,d2,d3,ids):
#         ypred = np.concatenate([dd.cpu().detach().numpy().reshape(-1,1) for dd in [d1,d2,d3]],axis=1)
        ytrue = df_to_torch(outcomedf.loc[ids])
        dloss = bce(d1.view(-1),ytrue[:,0])
        dloss += bce(d2.view(-1),ytrue[:,1])
        dloss += bce(d3,view(-1),ytrue[:,2])
        return dloss
        
    def remove_decisions(df):
        cols = [c for c in df.columns if c not in Const.decisions ]
        ddf = df[cols]
        return ddf
    
    makeinput = lambda step,dids: df_to_torch(remove_decisions(dataset.get_input_state(step=step,ids=dids))).to(device)
    thresh = lambda x: torch.sigmoid(100000000*(x - .5))

    optimal_train,transitions_train = calc_optimal_decisions(dataset,train_ids,tmodel1,tmodel2,tmodel3,smodel3,
                                           weights=weights,tweights=tweights,
                                          )
    optimal_test,transitions_test = calc_optimal_decisions(dataset,test_ids,tmodel1,tmodel2,tmodel3,smodel3,
                                          weights=weights,tweights=tweights,
                                         )
    optimal_train = optimal_train.to(model.get_device())
    optimal_test = optimal_test.to(model.get_device())
    
    #save the inputs from the whole dataset for future reference
    if use_attention:
        full_data = []
        for mstep in [0,1,2]:
            full_data_step = [baseline, get_dlt(min(mstep,1)),
                         get_dlt(mstep),get_pd(mstep),get_nd(mstep),get_cc(mstep),get_mod(mstep)]
            full_data_step = torch.cat([formatdf(fd,true_ids) for fd in full_data_step],axis=1)
            full_data.append(full_data_step)
        full_data = torch.stack(full_data)
        model.save_memory(full_data.to(device))
        print(full_data.shape)
        
    randchoice = lambda x: x[torch.randint(len(x),(1,))[0]]
    tloss_func = torch.nn.TripletMarginLoss()
    def get_tloss(row,step,yt,x,imitation=True):
        if yt[:,step].std() < .001:
            return torch.tensor([0]).to(device)
        positive_idx= torch.nonzero(yt[:,step] == yt[row,step])
        positive_idx = torch.stack([ii for ii in positive_idx if ii != row]).view(-1)
        negative_idx = torch.tensor([ii for ii in range(x.shape[0]) if ii not in positive_idx and ii != row])
        if len(positive_idx) < 1 or len(negative_idx) < 1:
            print('no losses','n positive',len(positive_idx),'n negative',len(negative_idx),'yt',yt[row,step],'row',row,'step',step,'imitation',imitation,end='\r')
            return torch.tensor([0]).to(device)
        positive = x[randchoice(positive_idx)]
        negative = x[randchoice(negative_idx)]
        anchor = x[row]
        if use_attention:
            [anchor_embedding,pos_embedding,neg_embedding] = [model.get_embedding(xx.view(1,-1),position=step,use_saved_memory=True) for xx in [anchor,positive,negative]]
        else:    
            [anchor_embedding,pos_embedding,neg_embedding] = [model.get_embedding(xx.view(1,-1),position=step,concatenate=False)[int(imitation)] for xx in [anchor,positive,negative]]
        tloss = tloss_func(anchor_embedding,pos_embedding,neg_embedding)
        return tloss
    
    def step(train=True):
        if train:
            model.train(True)
            tmodel1.train(True)
            tmodel2.train(True)
            tmodel3.train(True)
            optimizer.zero_grad()
            ids = train_ids
            y_opt = makegrad(optimal_train)
        else:
            ids = test_ids
            model.eval()
            tmodel1.eval()
            tmodel2.eval()
            tmodel3.eval()
            y_opt = makegrad(optimal_test)
            
            
        ytrain = df_to_torch(outcomedf.loc[ids]).to(device)
        #imitation losses and decision 1
        xxtrained = [baseline, get_dlt(0),get_dlt(0),get_pd(0),get_nd(0),get_cc(0),get_mod(0)]
        xxtrain = torch.cat([formatdf(xx,ids) for xx in xxtrained],axis=1).to(device)
        
        use_memory = (not train) and validate_with_memory

        o1 = model(xxtrain,position=0,use_saved_memory = use_memory)

        decision1_imitation = o1[:,3]
        
        decision1_opt = o1[:,0]
        if threshold_decisions:
            decision1_opt = thresh(decision1_opt)

        imitation_loss1 = bce(decision1_imitation,ytrain[:,0])
        imitation_loss1 = torch.mul(imitation_loss1,imitation_weights[0])
        
        x1_imitation = [baseline, get_dlt(1),get_dlt(0),get_pd(1),get_nd(1),get_cc(1),get_mod(1)]
        x1_imitation = [formatdf(xx1,ids) for xx1 in x1_imitation]
        x1_imitation = torch.cat(x1_imitation,axis=1).to(device)
        decision2_imitation = model(x1_imitation,position=1,use_saved_memory = use_memory)[:,4]

        imitation_loss2 =  bce(decision2_imitation,ytrain[:,1])
        imitation_loss2 = torch.mul(imitation_loss2,imitation_weights[1])
        
        x2_imitation = [baseline, get_dlt(1),get_dlt(2),get_pd(2),get_nd(2),get_cc(2),get_mod(2)]
        
        x2_imitation = [formatdf(xx2,ids) for xx2 in x2_imitation]
        x2_imitation = torch.cat(x2_imitation,axis=1)
        decision3_imitation = model(x2_imitation,position=2,use_saved_memory = use_memory)[:,5]

        imitation_loss3 = bce(decision3_imitation,ytrain[:,2])
        imitation_loss3 = torch.mul(imitation_loss3,imitation_weights[2])
        
        #reward decisions
        xx1 = makeinput(1,ids)
        xx2 = makeinput(2,ids)
        xx3 = makeinput(3,ids)

        xx1 = makegrad(xx1)
        xx2 = makegrad(xx2)
        xx3 = makegrad(xx3)
        baseline_train_base = formatdf(baseline,ids)
            
        baseline_train = torch.clone(baseline_train_base)

        
        xi1 = torch.cat([xx1,decision1_opt.view(-1,1)],axis=1)
        print(train,tmodel1.training,tmodel1.dropout.training)
        [ypd1, ynd1, ymod, ydlt1] = tmodel1(xi1)['predictions']
        print(train,tmodel1.training,tmodel1.dropout.training)
        d1_thresh = torch.gt(decision1_opt.view(-1,1),.5).to(ypd1.device)
        d1_scale = torch.cat([d1_thresh,d1_thresh,torch.ones(d1_thresh.view(-1,1).shape).to(ypd1.device)],dim=1)
        ypd1= torch.mul(ypd1,d1_scale)
        ynd1= torch.mul(ynd1,d1_scale)
        
        x1 = [baseline_train,ydlt1,formatdf(get_dlt(0),ids),ypd1,ynd1,formatdf(get_cc(1),ids),ymod]
        x1= torch.cat([xx1.to(model.get_device()) for xx1 in x1],axis=1)
        
        decision2_opt = model(x1,position=1,use_saved_memory = use_memory)[:,1] 
        if threshold_decisions:
            decision2_opt = thresh(decision2_opt)
            
        xi2 = torch.cat([xx2,decision1_opt.view(-1,1),decision2_opt.view(-1,1)],axis=1)
        [ypd2,ynd2,ycc,ydlt2] = tmodel2(xi2)['predictions']

        x2 = [baseline_train,ydlt1,ydlt2,ypd2,ynd2,ycc,ymod]
        x2 = torch.cat([xx2.to(model.get_device()) for xx2 in x2],axis=1)
        decision3_opt = model(x2,position=2,use_saved_memory = use_memory)[:,2]
        
        if threshold_decisions:
            decision3_opt = thresh(decision3_opt)
            
        xi3 = torch.cat([xx3,decision1_opt.view(-1,1),decision2_opt.view(-1,1),decision3_opt.view(-1,1)],axis=1)
        
        outcomes = tmodel3(xi3)['predictions']
        survival = smodel3.time_to_event(xi3,n_samples=1)
        if not train and verbose:
            print(torch.mean(outcomes,dim=0))
            
        reward_loss = torch.mean(outcome_loss(outcomes,weights) + temporal_loss(survival,tweights))
        loss = torch.add(imitation_loss1,imitation_loss2)
        loss = torch.add(loss,imitation_loss3)
        loss = torch.mul(loss,imitation_weight/3)
        loss = torch.add(loss,torch.mul(reward_loss,reward_weight))
        
        imitation_tloss = torch.FloatTensor([0]).to(device)
        opt_tloss = torch.FloatTensor([0]).to(device)
        n_rows = x1.shape[0]
        if reward_triplet_weight + imitation_triplet_weight > 0.0001:
            for i in range(n_rows):
                #skip if we're using an attention model idk
                if not use_attention and imitation_triplet_weight > .0001:
                    imitation_tloss += get_tloss(i,0,ytrain,xxtrain,True)
                    imitation_tloss += get_tloss(i,1,ytrain,x1_imitation,True)
                    imitation_tloss += get_tloss(i,2,ytrain,x2_imitation,True)
                if reward_triplet_weight > .0001:
                    opt_tloss += get_tloss(i,0,y_opt,xxtrain,False)
                    opt_tloss += get_tloss(i,1,y_opt,x1,False)
                    opt_tloss += get_tloss(i,2,y_opt,x2,False)
            loss += torch.mul(imitation_tloss[0],imitation_triplet_weight/n_rows)
            loss += torch.mul(opt_tloss[0],reward_triplet_weight/n_rows)
        
        losses = [imitation_loss1+imitation_loss2+imitation_loss3,reward_loss,imitation_tloss,opt_tloss]

        if train:
            loss.backward()
            optimizer.step()
            return losses
        else:
            scores = []
            distributions = [decision1_opt.mean().item(),decision2_opt.mean().item(),decision3_opt.mean().item()]
            imitation = [decision1_imitation,decision2_imitation,decision3_imitation]
            optimal = [decision1_opt,decision2_opt,decision3_opt]
            for i,decision_im in enumerate(imitation):
                deci = decision_im.cpu().detach().numpy()
                deci0 = (deci > .5).astype(int)
                iout = ytrain[:,i].cpu().detach().numpy()
                acci = accuracy_score(iout,deci0)
                try:
                    auci = roc_auc_score(iout,deci)
                except:
                    auci = -1
                
                deco = optimal[i].cpu().detach().numpy()
                deci0 = (deco > .5).astype(int)
                oout = y_opt[:,i].cpu().detach().numpy()
                acco = accuracy_score(oout,deci0)
                try:
                    auco = roc_auc_score(oout,deco)
                except:
                    auco=-1
                scores.append({'decision': i,'optimal_auc': auco,'imitation_auc': auci,'optimal_acc': acco,'imitation_acc': acci})
            return losses, scores, distributions
        
    best_val_loss = torch.tensor(1000000000.0)
    steps_since_improvement = 0
    best_val_score = {}
    
    for epoch in range(epochs):
        losses = step(True)
        val_losses,val_metrics,val_distributions = step(False)
        vl = val_losses[0] + val_losses[1]
        for vm in val_metrics:
            vl += (-((vm['optimal_auc']*reward_weight) + (vm['imitation_auc']*imitation_weight)))/10
        if verbose:
            print('______epoch',str(epoch),'_____')
            print('val reward',val_losses[1].item())
            print('imitation reward', val_losses[0].item())
            if len(val_losses) > 2:
                print('distance losses',val_losses[2].item(),val_losses[-1].item())
            print('distributions',val_distributions)
            print(val_metrics)
        if vl < best_val_loss:
            best_val_loss = vl
            best_val_score = val_metrics
            best_val_distributions = val_distributions
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        if steps_since_improvement > patience:
            break
    print('++++++++++Final+++++++++++')
    print('best',best_val_loss)
    print(best_val_score)
    model.load_state_dict(torch.load(save_file))
    model.eval()
    return model, best_val_score, best_val_loss, best_val_distributions

from Models import *
# args = {
#     'hidden_layers': [50,50], 
#     'attention_heads': [2,2],
#     'embed_size': 120, 
#     'dropout': 0.5, 
#     'input_dropout': 0.2, 
#     'shufflecol_chance':  0.2,
# }
args = {
    'hidden_layers': [500], 
    'opt_layer_size': 20, 
    'imitation_layer_size': 20, 
    'dropout': 0.25, 
    'input_dropout': 0.25, 
    'shufflecol_chance': 0.5
}
from Models import *
decision_model, _, _, _ = train_decision_model(
    model1,model2,model3,smodel3,
    lr=.001,
    use_attention=True,
    imitation_weight=1,
    imitation_triplet_weight=0,
    reward_triplet_weight=0,
    reward_weight=2,
    validate_with_memory=True,
    use_smote=False,
    **args)