In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
import numpy as np
import pandas as pd
import torch
import re
from sklearn.metrics import balanced_accuracy_score, roc_auc_score,accuracy_score,precision_recall_fscore_support
from Constants import *
from Preprocessing import *
from Models import *
import copy
from Utils import *
from DeepSurvivalModels import *
pd.set_option('display.max_rows', 200)



In [19]:
data = DTDataset(use_smote=False)
data.processed_df.T

id,3,5,6,7,8,9,10,11,13,14,...,10196,10197,10198,10199,10200,10201,10202,10203,10204,10205
hpv,1,0,1,1,1,1,-1,1,0,1,...,0,1,-1,0,1,1,0,1,0,1
age,55.969444,20.95,69.930556,72.319444,59.730556,60.083333,67.708333,57.858333,51.758333,56.25,...,47.619444,50.163889,70.888889,67.825,56.336111,49.566667,48.705556,77.116667,45.95,49.733333
packs_per_year,0.0,38.0,35.0,0.0,0.0,0.0,40.0,44.0,0.0,40.0,...,5.0,0.0,50.0,0.0,0.0,30.0,30.0,0.0,5.0,0.0
gender,1,1,0,1,1,1,1,1,1,1,...,0,1,0,1,1,1,1,1,1,1
Aspiration rate Pre-therapy,0,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
total_dose,66.0,72.0,70.0,70.0,66.0,66.0,69.96,70.0,70.0,70.0,...,70.0,72.0,66.0,70.0,69.96,70.0,72.0,70.0,69.96,69.96
dose_fraction,2.2,1.8,2.121212,2.121212,2.2,2.2,2.12,2.121212,2.0,2.121212,...,2.121212,1.8,2.2,2.121212,2.12,2.121212,1.714286,2.333333,2.12,2.12
OS (Calculated),6.033333,7.333333,7.466667,7.8,8.066667,8.733333,9.1,9.8,10.033333,10.033333,...,139.033333,139.3,140.6,142.833333,143.033333,143.2,144.366667,148.366667,152.6,155.533333
Locoregional control (Time),4.7,7.333333,7.466667,7.8,8.066667,8.733333,6.7,8.5,10.033333,10.033333,...,139.033333,139.3,140.6,142.833333,143.033333,143.2,144.366667,148.366667,152.6,155.533333
FDM (months),6.033333,7.333333,7.466667,7.8,8.066667,6.633333,9.1,9.8,10.033333,10.033333,...,139.033333,139.3,140.6,142.833333,143.033333,143.2,144.366667,136.033333,152.6,155.533333


In [20]:
model1,model2,model3,smodel3 = load_transition_models()
smodel3

DSM(
  (act): Tanh()
  (shape): ParameterList(
      (0): Parameter containing: [torch.float32 of size 6]
      (1): Parameter containing: [torch.float32 of size 6]
      (2): Parameter containing: [torch.float32 of size 6]
      (3): Parameter containing: [torch.float32 of size 6]
  )
  (scale): ParameterList(
      (0): Parameter containing: [torch.float32 of size 6]
      (1): Parameter containing: [torch.float32 of size 6]
      (2): Parameter containing: [torch.float32 of size 6]
      (3): Parameter containing: [torch.float32 of size 6]
  )
  (gate): ModuleList(
    (0-3): 4 x Sequential(
      (0): Linear(in_features=103, out_features=6, bias=False)
    )
  )
  (scaleg): ModuleList(
    (0-3): 4 x Sequential(
      (0): Linear(in_features=103, out_features=6, bias=True)
    )
  )
  (shapeg): ModuleList(
    (0-3): 4 x Sequential(
      (0): Linear(in_features=103, out_features=6, bias=False)
    )
  )
  (embedding): Sequential(
    (0): Linear(in_features=78, out_features=100, b

In [21]:
def temporal_loss(timestoevents,weights=None,maxtime=48,threshold=True):
    #list of expected times to events, usualy in order of Const.temporal_outcomes
    #basically longer = better, we count > maxtime (weeks) as no event
    if weights is None: 
        weights = [1 for i in range(len(timestoevents))]
    scores =  [(w*maxtime/t)for w,t in zip(weights,timestoevents)]
    if threshold:
        scores = [s*torch.lt(t,maxtime) for s,t in zip(scores,timestoevents)]
    scores = torch.stack(scores).sum(axis=0)
    return scores

def outcome_loss(ypred,weights=None):
    #default weights is bad
    if weights is None: 
        print('using default outcome loss weights, which is probably wrong since bad stuff should be negative')
        weights = [1 for i in range(ypred.shape[1])]
    l = torch.mul(ypred[:,0],weights[0])
    for i,weight in enumerate(weights[1:]):
        #weights with negative values will invert the outcome so e.g. Regional control becomes no regional control
        #so the penaly is correct
        newloss = torch.mul(ypred[:,i+1],weight)
        l = torch.add(l,newloss)
    return l

def calc_optimal_decisions(dataset,ids,m1,m2,m3,sm3,
                           weights=[0,0.5,.5,0], #weight for OS, FT, AS, and LRC as binary probabilities
                           tweights=[1,1,1,1], #weight for OS, LRC, FDM, and event (any + FT or AS at 6m) as time to event in weeks
                           outcome_loss_func=None,
                           threshold_temporal_loss = False,
                           maxtime=48,
                           get_transitions=True):
    m1.eval()
    m2.eval()
    m3.eval()
    sm3.eval()
    device = m1.get_device()
    data = dataset.processed_df.copy().loc[ids]
    
    def get_dlt(state):
        if state == 2:
            return data[Const.dlt2].copy()
        d = data[Const.dlt1].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_pd(state):
        if state == 2:
            return data[Const.primary_disease_states2].copy()
        d = data[Const.primary_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_nd(state):
        if state == 2:
            return data[Const.nodal_disease_states2].copy()
        d = data[Const.nodal_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_cc(state):
        res = data[Const.ccs].copy()
        if state == 1:
            res.values[:,:] = np.zeros(res.values.shape)
        return res
    
    def get_mod(state):
        res = data[Const.modifications].copy()
        #this should have an ic condition but we don't use it anumore anywa
        return res
        
    def formatdf(d):
        d = df_to_torch(d).to(device)
        return d
    
    
    outcomedf = data[Const.outcomes]
    baseline = dataset.get_state('baseline').loc[ids]
    baseline_input = formatdf(baseline)

        
    if outcome_loss_func is None:
        outcome_loss_func = outcome_loss
    
    cat = lambda x: torch.cat([xx.to(device) for xx in x],axis=1).to(device)
    format_transition = lambda x: x.to(device)
    def get_outcome(d1,d2,d3):
        d1 = torch.full((len(ids),1),d1).type(torch.FloatTensor)
        d2 = torch.full((len(ids),1),d2).type(torch.FloatTensor)
        d3 = torch.full((len(ids),1),d3).type(torch.FloatTensor)
        
        tinput1 = cat([baseline_input,d1])
        ytransition = m1(tinput1)
        [ypd1,ynd1,ymod,ydlt1] = [format_transition(xx) for xx in ytransition['predictions']]
        d1_thresh = torch.gt(d1,.5).view(-1,1).to(device)
        ypd1[:,0:2] = ypd1[:,0:2]*d1_thresh
        ynd1[:,0:2] = ynd1[:,0:2]*d1_thresh
        
        tinput2 = cat([baseline_input,ypd1,ynd1,ymod,ydlt1,d1,d2])
        ytransition2 = m2(tinput2)
        [ypd2,ynd2,ycc,ydlt2] = [format_transition(xx) for xx in ytransition2['predictions']]
        
        input3 = cat([baseline_input, ypd2, ynd2, ycc, ydlt2, d1, d2,d3])
        outcome = m3(input3)['predictions']
        temporal_outcomes = sm3.time_to_event(input3,n_samples=1)
        
        transitions = {
            'pd1': ypd1,
            'nd1': ynd1,
            'nd2': ynd2,
            'pd2': ypd2,
            'mod': ymod,
            'cc': ycc,
            'dlt1': ydlt1,
            'dlt2': ydlt2,
        }
        return outcome, temporal_outcomes, transitions

    losses = []
    loss_order = []
    transitions = {}
    for d1 in [0,1]:
        for d2 in [0,1]:
            for d3 in [0,1]:
                outcomes, tte, transition_entry = get_outcome(d1,d2,d3)
                loss = outcome_loss_func(outcomes,weights)
                tloss = temporal_loss(tte,tweights,maxtime=maxtime,threshold=threshold_temporal_loss)
                loss += tloss
                losses.append(loss)
                loss_order.append([d1,d2,d3])
                transitions[str(d1)+str(d2)+str(d3)] = transition_entry
    losses = torch.stack(losses,axis=1)
    optimal_decisions = [loss_order[i] for i in torch.argmin(losses,axis=1)]
    result = torch.tensor(optimal_decisions).type(torch.FloatTensor)
    print(result.sum(axis=0),result.shape[0])
    if get_transitions:
        opt_transitions = {k: torch.zeros(v.shape).type(torch.FloatTensor) for k,v in transitions['000'].items()}
        for i,od in enumerate(optimal_decisions):
            key = ''.join([str(o) for o in od])
            entry = transitions[key]
            for kk,vv in entry.items():
                opt_transitions[kk][i,:] = vv[i,:]
        return result, opt_transitions
    return result

test, testtest = get_tt_split()
calc_optimal_decisions(DTDataset(),
                       testtest,model1,model2,model3,smodel3,
                       threshold_temporal_loss=False,
                       maxtime=48,
                       weights=[0,0,0,0],
                       tweights=[2,0.1,0,0],
                      )

tensor([108.,   1., 137.]) 147


(tensor([[1., 0., 1.],
         [1., 0., 0.],
         [1., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 1.],
         [0., 0., 0.],
         [1., 0., 1.],
         [1., 0., 1.],
         [0., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [0., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [0., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [0., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [1., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 0., 1.],
         [1

In [31]:
def get_unique_sequence(array):
    #converts a row of boolean values to a unique number e.g. [1,1,0] => 11, [0,0,1] => 100
    uniqueify = lambda r: torch.sum(torch.stack([i*(10**ii) for ii,i in enumerate(r)]))
    return torch_apply_along_axis(uniqueify,array)

def train_decision_model_triplet(
    tmodel1,
    tmodel2,
    tmodel3,
    smodel3,
    use_default_split=True,
    use_bagging_split=False,
    use_attention=True,
    lr=.001,
    epochs=10000,
    patience=5,
    weights=[0,.5,.5,0], #realtive weight of survival, feeding tube, aspiration, andl lrc
    tweights=[1,1,1,0], #weight for OS, LRC, FDM, and event (any + FT or AS at 6m) as time to event in weeks
    opt_weights=[1,1,1], #weights for policy model for optimal decisions
    imitation_weights=[.5,1,1],#weights of imitation decisions, because ic overtrains too quickly
    imitation_weight=1,
    reward_weight=1,
    imitation_triplet_weight=2,
    reward_triplet_weight = 2,
    shufflecol_chance = 0.2,
    split=.7,
    resample_training=False,
    save_path='../data/models/',
    file_suffix='',
    verbose=True,
    use_gpu=False,
    **model_kwargs,
):
    tmodel1.eval()
    tmodel2.eval()
    tmodel3.eval()

    train_ids, test_ids = get_tt_split(use_default_split=use_default_split,use_bagging_split=use_bagging_split,resample_training=resample_training)
    true_ids = train_ids + test_ids #for saving memory without upsampling

    dataset = DTDataset()
    data = dataset.processed_df.copy()
    
    def get_dlt(state):
        if state == 2:
            return data[Const.dlt2].copy()
        d = data[Const.dlt1].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_pd(state):
        if state == 2:
            return data[Const.primary_disease_states2].copy()
        d = data[Const.primary_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_nd(state):
        if state == 2:
            return data[Const.nodal_disease_states2].copy()
        d = data[Const.nodal_disease_states].copy()
        if state < 1:
            d.values[:,:] = 0
        return d
    
    def get_cc(state):
        res = data[Const.ccs].copy()
        if state == 1:
            res.values[:,:] = np.zeros(res.values.shape)
        return res
    
    def get_mod(state):
        res = data[Const.modifications].copy()
        #this should have an ic condition but we don't use it anumore anywa
        return res
        
    outcomedf = data[Const.outcomes]
    baseline = dataset.get_state('baseline')
    
    def formatdf(d,dids=train_ids):
        d = df_to_torch(d.loc[dids]).to(model.get_device())
        return d
    
    def makegrad(v):
        if not v.requires_grad:
            v.requires_grad=True
        return v
    
    if use_attention:
        model = DecisionAttentionModel(baseline.shape[1],**model_kwargs)
    else:
        model_kwargs = {k:v for k,v in model_kwargs.items() if 'attention' not in k and 'embed' not in k}
        model = DecisionModel(baseline.shape[1],**model_kwargs)
        
    device = 'cpu'
    if use_gpu and torch.cuda.is_available():
        device = 'cuda'
        
    model.set_device(device)

    tmodel1.set_device(device)
    tmodel2.set_device(device)
    tmodel3.set_device(device)
    smodel3.set_device(device)
    hashcode = str(hash(','.join([str(i) for i in train_ids])))
    
    save_file = save_path + 'model_' + model.identifier +'_hash' + hashcode + file_suffix + '.tar'
    model.fit_normalizer(df_to_torch(baseline.loc[train_ids]))
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    
    
    optimal_train,transitions_train = calc_optimal_decisions(dataset,train_ids,tmodel1,tmodel2,tmodel3,smodel3,
                                           weights=weights,tweights=tweights,
                                          )
    optimal_test,transitions_test = calc_optimal_decisions(dataset,test_ids,tmodel1,tmodel2,tmodel3,smodel3,
                                          weights=weights,tweights=tweights,
                                         )
    optimal_train = optimal_train.to(model.get_device())
    optimal_test = optimal_test.to(model.get_device())
    mse = torch.nn.MSELoss()
    bce = torch.nn.BCELoss()
    
    def compare_decisions(d1,d2,d3,ids):
#         ypred = np.concatenate([dd.cpu().detach().numpy().reshape(-1,1) for dd in [d1,d2,d3]],axis=1)
        ytrue = df_to_torch(outcomedf.loc[ids])
        dloss = bce(d1.view(-1),ytrue[:,0])
        dloss += bce(d2.view(-1),ytrue[:,1])
        dloss += bce(d3,view(-1),ytrue[:,2])
        return dloss
        
    def remove_decisions(df):
        cols = [c for c in df.columns if c not in Const.decisions ]
        ddf = df[cols]
        return ddf
    
    makeinput = lambda step,dids: df_to_torch(remove_decisions(dataset.get_input_state(step=step,ids=dids)))
    threshold = lambda x: torch.gt(x,torch.rand(x.shape[0])).type(torch.FloatTensor)

    randchoice = lambda x: x[torch.randint(len(x),(1,))[0]]
    tloss_func = torch.nn.TripletMarginLoss()
    def get_tloss(row,step,yt,x,imitation=True):
        if yt[:,step].std() < .001:
            return torch.tensor([0]).to(device)
        positive_idx= torch.nonzero(yt[:,step] == yt[row,step])
        if len(positive_idx) <= 1:
            print('no losses','n positive',len(positive_idx),'yt',yt[row,step],'row',row,'step',step,'imitation',imitation,end='\r')
            return torch.tensor([0]).to(device)
        positive_idx = torch.stack([ii for ii in positive_idx if ii != row]).view(-1)
        negative_idx = torch.tensor([ii for ii in range(x.shape[0]) if ii not in positive_idx and ii != row])
        if len(positive_idx) < 1 or len(negative_idx) < 1:
            print('no losses','n positive',len(positive_idx),'n negative',len(negative_idx),'yt',yt[row,step],'row',row,'step',step,'imitation',imitation,end='\r')
            return torch.tensor([0]).to(device)
        positive = x[randchoice(positive_idx)]
        negative = x[randchoice(negative_idx)]
        anchor = x[row]
        if use_attention:
            [anchor_embedding,pos_embedding,neg_embedding] = [model.get_embedding(xx.view(1,-1),position=step,use_saved_memory=True) for xx in [anchor,positive,negative]]
        else:    
            [anchor_embedding,pos_embedding,neg_embedding] = [model.get_embedding(xx.view(1,-1),position=step,concatenate=False)[int(imitation)] for xx in [anchor,positive,negative]]
        tloss = tloss_func(anchor_embedding,pos_embedding,neg_embedding)
        return tloss
    #save the inputs from the whole dataset for future reference
    if use_attention:
        full_data = []
        for mstep in [0,1,2]:
            full_data_step = [baseline, get_dlt(min(mstep,1)),
                         get_dlt(mstep),get_pd(mstep),get_nd(mstep),get_cc(mstep),get_mod(mstep)]
            full_data_step = torch.cat([formatdf(fd,true_ids) for fd in full_data_step],axis=1)
            full_data.append(full_data_step)
        full_data = torch.stack(full_data)
        model.save_memory(full_data)
        print(full_data.shape)
        
    def step(train=True):
        if train:
            model.train(True)
            tmodel1.train(True)
            tmodel2.train(True)
            tmodel3.train(True)
            optimizer.zero_grad()
            ids = train_ids
            y_opt = makegrad(optimal_train)
            transition_dict = {k: torch.clone(v).detach() for k,v in transitions_train.items()}
        else:
            ids = test_ids
            model.eval()
            tmodel1.eval()
            tmodel2.eval()
            tmodel3.eval()
            y_opt = makegrad(optimal_test)
            print(y_opt.mean(axis=0))
            transition_dict = {k: torch.clone(v).detach() for k,v in transitions_test.items()}
        model.set_device(device)
        ytrain = df_to_torch(outcomedf.loc[ids]).to(device)
        #imitation losses and decision 1
        xxtrained = [baseline, get_dlt(0),get_dlt(0),get_pd(0),get_nd(0),get_cc(0),get_mod(0)]
        xxtrain = [formatdf(xx,ids) for xx in xxtrained]
        xxtrain = torch.cat(xxtrain,axis=1).to(device)
        o1 = model(xxtrain,position=0,use_saved_memory= (not train))
        decision1_imitation = o1[:,3]
        decision1_opt = o1[:,0]
    
        imitation_loss1 = bce(decision1_imitation,ytrain[:,0])
        imitation_loss1 = torch.mul(imitation_loss1,imitation_weights[0])
        opt_loss1 = bce(decision1_opt,y_opt[:,0])
        opt_loss1 = torch.mul(opt_loss1,opt_weights[0])
        
        x1_imitation = [baseline, get_dlt(1),get_dlt(0),get_pd(1),get_nd(1),get_cc(1),get_mod(1)]
        x1_imitation = [formatdf(xx1,ids) for xx1 in x1_imitation]
        x1_imitation = torch.cat(x1_imitation,axis=1).to(device)
        
        o2 = model(x1_imitation,position=1,use_saved_memory= (not train))
            
        decision2_imitation = o2[:,4]
            
        imitation_loss2 =  bce(decision2_imitation,ytrain[:,1])
        imitation_loss2 = torch.mul(imitation_loss2,imitation_weights[1])
        
        
        x2_imitation = [baseline, get_dlt(1),get_dlt(2),get_pd(2),get_nd(2),get_cc(2),get_mod(2)]
        
        x2_imitation = [formatdf(xx2,ids) for xx2 in x2_imitation]
        x2_imitation = torch.cat(x2_imitation,axis=1).to(device)
        
        
        o3 = model(x2_imitation,position=2,use_saved_memory= (not train))
        
        decision3_imitation = o3[:,5]
        
        imitation_loss3 = bce(decision3_imitation,ytrain[:,2])
        imitation_loss3 = torch.mul(imitation_loss3,imitation_weights[2])
        
        opt_input2 = [
            formatdf(baseline,ids), 
            transition_dict['dlt1'],
            formatdf(get_dlt(0),ids),
            transition_dict['pd1'],
            transition_dict['nd1'], 
            formatdf(get_cc(0),ids),
            transition_dict['mod']
                 ]
        opt_input2 = [o.to(device) for o in opt_input2]

        opt_input2 = torch.cat(opt_input2,axis=1).to(device)
        decision2_opt = model(opt_input2,position=1,use_saved_memory= (not train))[:,1]
        
        opt_loss2 = bce(decision2_opt,y_opt[:,1])
        opt_loss2 = torch.mul(opt_loss2,opt_weights[1])
        
        opt_input3 = [
            formatdf(baseline,ids),
            transition_dict['dlt1'],
            transition_dict['dlt2'],
            transition_dict['pd2'],
            transition_dict['nd2'],
            transition_dict['cc'],
            transition_dict['mod'],
        ]
        opt_input3 = [o.to(device) for o in opt_input3]
        opt_input3 = torch.cat(opt_input3,axis=1).to(device)
        decision3_opt = model(opt_input3,position=2,use_saved_memory= (not train))[:,2]
        
        opt_loss3 = bce(decision3_opt,y_opt[:,2])
        opt_loss3 = torch.mul(opt_loss3,opt_weights[2])
        
        iloss = torch.add(torch.add(imitation_loss1,imitation_loss2),imitation_loss3)
        iloss = torch.mul(iloss,imitation_weight)
        
        reward_loss = torch.add(torch.add(opt_loss1,opt_loss2),opt_loss3)
        reward_loss =torch.mul(reward_loss,reward_weight)
        
        loss = torch.add(iloss,reward_loss)
        
        imitation_tloss = torch.FloatTensor([0]).to(device)
        opt_tloss = torch.FloatTensor([0]).to(device)
        n_rows = xxtrain.shape[0]
        if reward_triplet_weight + imitation_triplet_weight > 0.0001:
            for i in range(n_rows):
                
                if imitation_triplet_weight > .0001:
                    imitation_tloss += get_tloss(i,0,ytrain,xxtrain,True)
                    imitation_tloss += get_tloss(i,1,ytrain,x1_imitation,True)
                    imitation_tloss += get_tloss(i,2,ytrain,x2_imitation,True)
                if reward_triplet_weight > .0001:
                    opt_tloss += get_tloss(i,0,y_opt,xxtrain,False)
                    opt_tloss += get_tloss(i,1,y_opt,opt_input2,False)
                    opt_tloss += get_tloss(i,2,y_opt,opt_input3,False)
            loss += torch.mul(imitation_tloss[0],imitation_triplet_weight/n_rows)
            loss += torch.mul(opt_tloss[0],reward_triplet_weight/n_rows)
        losses = [iloss,reward_loss,imitation_tloss*imitation_triplet_weight/n_rows,opt_tloss*reward_triplet_weight/n_rows]
        
        if train:
            loss.backward()
            optimizer.step()
            return losses
        else:
            scores = []
            distributions = [decision1_opt.mean().item(),decision2_opt.mean().item(),decision3_opt.mean().item()]
            imitation = [decision1_imitation,decision2_imitation,decision3_imitation]
            optimal = [decision1_opt,decision2_opt,decision3_opt]
            for i,decision_im in enumerate(imitation):
                deci = decision_im.cpu().detach().numpy()
                deci0 = (deci > .5).astype(int)
                iout = ytrain[:,i].cpu().detach().numpy()
                acci = accuracy_score(iout,deci0)
                try:
                    auci = roc_auc_score(iout,deci)
                except:
                    auci = -1
                
                deco = optimal[i].cpu().detach().numpy()
                deci0 = (deco > .5).astype(int)
                oout = y_opt[:,i].cpu().detach().numpy()
                acco = accuracy_score(oout,deci0)
                try:
                    auco = roc_auc_score(oout,deco)
                except:
                    auco=-1
                scores.append({'decision': i,'optimal_auc': auco,'imitation_auc': auci,'optimal_acc': acco,'imitation_acc': acci})
            return losses, scores, distributions
        
    best_val_loss = torch.tensor(1000000000.0)
    steps_since_improvement = 0
    best_val_score = {}
    
    for epoch in range(epochs):
        losses = step(True)
        val_losses,val_metrics,val_distributions = step(False)
        vl = val_losses[0] + val_losses[1]
        for vm in val_metrics:
            vl += (-((vm['optimal_auc']*reward_weight) + (vm['imitation_auc']*imitation_weight)))/10
        if verbose:
            print('______epoch',str(epoch),'_____')
            print('val reward',val_losses[1].item())
            print('imitation reward', val_losses[0].item())
            print('distance losses',val_losses[2].item(),val_losses[-1].item())
            print('distributions',val_distributions)
            print(val_metrics)
        if vl < best_val_loss:
            best_val_loss = vl
            best_val_score = val_metrics
            best_val_distributions = val_distributions
            steps_since_improvement = 0
            torch.save(model.state_dict(),save_file)
        else:
            steps_since_improvement += 1
        if steps_since_improvement > patience:
            break
    print('++++++++++Final+++++++++++')
    print('best',best_val_loss)
    print(best_val_score)
    model.load_state_dict(torch.load(save_file))
    model.eval()
    return model, best_val_score, best_val_loss, best_val_distributions

from Models import *
args = {
    'hidden_layers': [1000,1000], 
    'opt_layer_size': 20, 
    'imitation_layer_size': 20, 
    'dropout': 0.25, 
    'input_dropout': 0.1, 
    'shufflecol_chance': 0.5
}

#1.8424
decision_model, decision_score, decision_loss, _ = train_decision_model_triplet(
    model1,model2,model3,smodel3,
    lr=.01,
    imitation_weight=1,
    reward_weight=1,
    patience=100,
    imitation_triplet_weight=0.1,
    reward_triplet_weight =0.1, 
    verbose=True,
    weights=[1,1,1,1], #realtive weight of survival, feeding tube, aspiration, andl lrc
    tweights=[1,1,1,0], #weight for OS, LRC, FDM, and event (any + FT or AS at 6m) as time to event in weeks
    use_attention=True,
    **args)

decision_model

tensor([201.,  20.,   3.]) 389
tensor([68., 14.,  3.]) 147
torch.Size([3, 536, 86])
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 0 _____
val reward 1.3812651634216309
imitation reward 2.2770237922668457
distance losses 0.3070072829723358 0.26523053646087646
distributions [0.6608559489250183, 0.0013942900113761425, 0.0014365359675139189]
[{'decision': 0, 'optimal_auc': 0.860945644080417, 'imitation_auc': 0.7461832061068702, 'optimal_acc': 0.6326530612244898, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': 0.7400644468313642, 'imitation_auc': 0.6491847826086956, 'optimal_acc': 0.9047619047619048, 'imitation_acc': 0.782312925170068}, {'decision': 2, 'optimal_auc': 0.5694444444444443, 'imitation_auc': 0.7339478703115068, 'optimal_acc': 0.9795918367346939, 'imitation_acc': 0.8231292517006803}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 1 _____
val reward 1.5529539585113525
imitation reward 1.4362828731536865
distance lo

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 11 _____
val reward 0.8791236877441406
imitation reward 1.1967928409576416
distance losses 0.31566545367240906 0.14824198186397552
distributions [0.4095092713832855, 0.023562829941511154, 0.00031847116770222783]
[{'decision': 0, 'optimal_auc': 0.9050632911392404, 'imitation_auc': 0.6183206106870229, 'optimal_acc': 0.8095238095238095, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': 0.9801288936627283, 'imitation_auc': 0.6581521739130436, 'optimal_acc': 0.9047619047619048, 'imitation_acc': 0.782312925170068}, {'decision': 2, 'optimal_auc': 0.9652777777777778, 'imitation_auc': 0.8041958041958042, 'optimal_acc': 0.9795918367346939, 'imitation_acc': 0.8231292517006803}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 12 _____
val reward 0.8810580372810364
imitation reward 1.1940629482269287
distance losses 0.3171918988227844 0.1387118548154831
distributions [0.484589546918869, 0.0

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 22 _____
val reward 0.8904733657836914
imitation reward 1.4018275737762451
distance losses 0.28110724687576294 0.1115868091583252
distributions [0.588032603263855, 0.016389036551117897, 0.006613300181925297]
[{'decision': 0, 'optimal_auc': 0.9195830230826507, 'imitation_auc': 0.4971374045801527, 'optimal_acc': 0.7210884353741497, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': 0.9871106337271751, 'imitation_auc': 0.6902173913043479, 'optimal_acc': 0.9047619047619048, 'imitation_acc': 0.782312925170068}, {'decision': 2, 'optimal_auc': 1.0, 'imitation_auc': 0.8165924984106803, 'optimal_acc': 0.9795918367346939, 'imitation_acc': 0.8231292517006803}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 23 _____
val reward 0.9138354659080505
imitation reward 1.4510201215744019
distance losses 0.30525776743888855 0.12184251099824905
distributions [0.5702221989631653, 0.01417624391615390

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 33 _____
val reward 0.9018863439559937
imitation reward 1.177304744720459
distance losses 0.2850901782512665 0.1240999773144722
distributions [0.5839257836341858, 0.008829777128994465, 0.006210726220160723]
[{'decision': 0, 'optimal_auc': 0.9348473566641846, 'imitation_auc': 0.5033396946564885, 'optimal_acc': 0.8163265306122449, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': 0.9849624060150377, 'imitation_auc': 0.7353260869565217, 'optimal_acc': 0.9047619047619048, 'imitation_acc': 0.782312925170068}, {'decision': 2, 'optimal_auc': 1.0, 'imitation_auc': 0.7870311506675143, 'optimal_acc': 0.9795918367346939, 'imitation_acc': 0.8231292517006803}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 34 _____
val reward 0.8888338208198547
imitation reward 1.1864805221557617
distance losses 0.2901221811771393 0.1295384168624878
distributions [0.577458381652832, 0.00908784568309784, 0.

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 44 _____
val reward 1.1159098148345947
imitation reward 1.1364083290100098
distance losses 0.2695266604423523 0.13182352483272552
distributions [0.7634029388427734, 0.006473764311522245, 0.0029707918874919415]
[{'decision': 0, 'optimal_auc': 0.9143708116157855, 'imitation_auc': 0.5424618320610687, 'optimal_acc': 0.54421768707483, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': 0.9849624060150377, 'imitation_auc': 0.6915760869565217, 'optimal_acc': 0.9047619047619048, 'imitation_acc': 0.782312925170068}, {'decision': 2, 'optimal_auc': 1.0, 'imitation_auc': 0.7781309599491418, 'optimal_acc': 0.9795918367346939, 'imitation_acc': 0.8299319727891157}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 45 _____
val reward 1.0195633172988892
imitation reward 1.1211339235305786
distance losses 0.28122010827064514 0.1242663711309433
distributions [0.7225973010063171, 0.007789247669279575

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 55 _____
val reward 0.8980113863945007
imitation reward 1.1283552646636963
distance losses 0.3042174279689789 0.12132095545530319
distributions [0.705113410949707, 0.02212892472743988, 0.004563449416309595]
[{'decision': 0, 'optimal_auc': 0.8985480268056589, 'imitation_auc': 0.5854007633587786, 'optimal_acc': 0.5918367346938775, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': 0.9876476906552094, 'imitation_auc': 0.689945652173913, 'optimal_acc': 0.8979591836734694, 'imitation_acc': 0.7959183673469388}, {'decision': 2, 'optimal_auc': 1.0, 'imitation_auc': 0.770184361093452, 'optimal_acc': 0.9795918367346939, 'imitation_acc': 0.8367346938775511}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 56 _____
val reward 0.8281864523887634
imitation reward 1.1775517463684082
distance losses 0.30676475167274475 0.12882724404335022
distributions [0.6700541973114014, 0.03263627365231514, 

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 66 _____
val reward 0.7992141246795654
imitation reward 1.1310806274414062
distance losses 0.3061271607875824 0.14379945397377014
distributions [0.6479529142379761, 0.0185040645301342, 0.01696191541850567]
[{'decision': 0, 'optimal_auc': 0.8957557706626955, 'imitation_auc': 0.5834923664122137, 'optimal_acc': 0.7482993197278912, 'imitation_acc': 0.8979591836734694}, {'decision': 1, 'optimal_auc': 0.989795918367347, 'imitation_auc': 0.6894021739130435, 'optimal_acc': 0.9047619047619048, 'imitation_acc': 0.7346938775510204}, {'decision': 2, 'optimal_auc': 1.0, 'imitation_auc': 0.7937062937062938, 'optimal_acc': 0.9931972789115646, 'imitation_acc': 0.8163265306122449}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 67 _____
val reward 0.7096480131149292
imitation reward 1.1401312351226807
distance losses 0.3463682532310486 0.13381615281105042
distributions [0.6141204833984375, 0.041295815259218216,

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 77 _____
val reward 0.6172623634338379
imitation reward 1.2377021312713623
distance losses 0.30820712447166443 0.11522190272808075
distributions [0.5834125876426697, 0.08263914287090302, 0.01528183277696371]
[{'decision': 0, 'optimal_auc': 0.8879374534623976, 'imitation_auc': 0.5844465648854962, 'optimal_acc': 0.7755102040816326, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': 0.9860365198711063, 'imitation_auc': 0.6625, 'optimal_acc': 0.9523809523809523, 'imitation_acc': 0.6938775510204082}, {'decision': 2, 'optimal_auc': 1.0, 'imitation_auc': 0.7577876668785759, 'optimal_acc': 0.9931972789115646, 'imitation_acc': 0.8299319727891157}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 78 _____
val reward 0.6023326516151428
imitation reward 1.23009192943573
distance losses 0.3191601634025574 0.11162585020065308
distributions [0.5555200576782227, 0.07848711311817169, 0.0159342475

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 88 _____
val reward 0.6630216240882874
imitation reward 1.2478861808776855
distance losses 0.27409058809280396 0.12709058821201324
distributions [0.6041780710220337, 0.07407663017511368, 0.014067531563341618]
[{'decision': 0, 'optimal_auc': 0.8884959046909903, 'imitation_auc': 0.6059160305343512, 'optimal_acc': 0.7891156462585034, 'imitation_acc': 0.8979591836734694}, {'decision': 1, 'optimal_auc': 0.9871106337271751, 'imitation_auc': 0.6350543478260869, 'optimal_acc': 0.9455782312925171, 'imitation_acc': 0.782312925170068}, {'decision': 2, 'optimal_auc': 0.9953703703703705, 'imitation_auc': 0.7253655435473617, 'optimal_acc': 0.9863945578231292, 'imitation_acc': 0.8163265306122449}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 89 _____
val reward 0.6345511674880981
imitation reward 1.2493547201156616
distance losses 0.29062482714653015 0.1512180119752884
distributions [0.5849603414535522, 0.0

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 99 _____
val reward 0.6271182894706726
imitation reward 1.3261520862579346
distance losses 0.2937755286693573 0.11962384730577469
distributions [0.4684772491455078, 0.06483998894691467, 0.01654449850320816]
[{'decision': 0, 'optimal_auc': 0.9013402829486225, 'imitation_auc': 0.6063931297709924, 'optimal_acc': 0.8027210884353742, 'imitation_acc': 0.8843537414965986}, {'decision': 1, 'optimal_auc': 0.9785177228786252, 'imitation_auc': 0.6603260869565217, 'optimal_acc': 0.9319727891156463, 'imitation_acc': 0.7891156462585034}, {'decision': 2, 'optimal_auc': 1.0, 'imitation_auc': 0.7418944691671964, 'optimal_acc': 0.9931972789115646, 'imitation_acc': 0.8435374149659864}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 100 _____
val reward 0.6170574426651001
imitation reward 1.2978616952896118
distance losses 0.3059794306755066 0.12351895123720169
distributions [0.4630011320114136, 0.0696127936244010

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 110 _____
val reward 0.6155935525894165
imitation reward 1.40386962890625
distance losses 0.29833221435546875 0.12069926410913467
distributions [0.5363115668296814, 0.07960920035839081, 0.014504714868962765]
[{'decision': 0, 'optimal_auc': 0.8946388682055101, 'imitation_auc': 0.6035305343511451, 'optimal_acc': 0.8095238095238095, 'imitation_acc': 0.8979591836734694}, {'decision': 1, 'optimal_auc': 0.9822771213748658, 'imitation_auc': 0.6350543478260868, 'optimal_acc': 0.9455782312925171, 'imitation_acc': 0.7687074829931972}, {'decision': 2, 'optimal_auc': 0.9976851851851852, 'imitation_auc': 0.742212333121424, 'optimal_acc': 0.9863945578231292, 'imitation_acc': 0.7414965986394558}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 111 _____
val reward 0.6159168481826782
imitation reward 1.500162124633789
distance losses 0.3182598948478699 0.13164684176445007
distributions [0.5565698146820068, 0.07

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 121 _____
val reward 0.5985997319221497
imitation reward 1.525499939918518
distance losses 0.3214562237262726 0.11402206867933273
distributions [0.43890178203582764, 0.10138414800167084, 0.017555363476276398]
[{'decision': 0, 'optimal_auc': 0.8914743112434846, 'imitation_auc': 0.5663167938931297, 'optimal_acc': 0.8163265306122449, 'imitation_acc': 0.8979591836734694}, {'decision': 1, 'optimal_auc': 0.9779806659505907, 'imitation_auc': 0.6320652173913044, 'optimal_acc': 0.9659863945578231, 'imitation_acc': 0.7619047619047619}, {'decision': 2, 'optimal_auc': 0.9953703703703705, 'imitation_auc': 0.7616020343293071, 'optimal_acc': 0.9931972789115646, 'imitation_acc': 0.8231292517006803}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 122 _____
val reward 0.6066659688949585
imitation reward 1.5294181108474731
distance losses 0.2903165817260742 0.10768583416938782
distributions [0.43991410732269287, 

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 132 _____
val reward 0.5973051190376282
imitation reward 1.4189292192459106
distance losses 0.31826677918434143 0.10679791867733002
distributions [0.5151381492614746, 0.07181533426046371, 0.01768389344215393]
[{'decision': 0, 'optimal_auc': 0.8991064780342517, 'imitation_auc': 0.6159351145038168, 'optimal_acc': 0.8163265306122449, 'imitation_acc': 0.8979591836734694}, {'decision': 1, 'optimal_auc': 0.9790547798066596, 'imitation_auc': 0.648641304347826, 'optimal_acc': 0.9523809523809523, 'imitation_acc': 0.782312925170068}, {'decision': 2, 'optimal_auc': 0.9976851851851852, 'imitation_auc': 0.7666878575969486, 'optimal_acc': 0.9931972789115646, 'imitation_acc': 0.8095238095238095}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 133 _____
val reward 0.5644837021827698
imitation reward 1.4795217514038086
distance losses 0.3014083206653595 0.12117301672697067
distributions [0.4863700568675995, 0.0

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 143 _____
val reward 0.5704048275947571
imitation reward 1.5037356615066528
distance losses 0.3293847441673279 0.1083701103925705
distributions [0.49506741762161255, 0.09959698468446732, 0.012629622593522072]
[{'decision': 0, 'optimal_auc': 0.9024571854058079, 'imitation_auc': 0.603530534351145, 'optimal_acc': 0.8163265306122449, 'imitation_acc': 0.891156462585034}, {'decision': 1, 'optimal_auc': 0.981203007518797, 'imitation_auc': 0.6046195652173914, 'optimal_acc': 0.9659863945578231, 'imitation_acc': 0.7619047619047619}, {'decision': 2, 'optimal_auc': 1.0, 'imitation_auc': 0.7225047679593134, 'optimal_acc': 0.9863945578231292, 'imitation_acc': 0.7891156462585034}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 144 _____
val reward 0.5913370251655579
imitation reward 1.447631597518921
distance losses 0.30481791496276855 0.11312922090291977
distributions [0.5331571698188782, 0.07525189220905304

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 154 _____
val reward 0.6525501608848572
imitation reward 1.4187020063400269
distance losses 0.321209192276001 0.14503581821918488
distributions [0.43197667598724365, 0.12321449816226959, 0.03163734823465347]
[{'decision': 0, 'optimal_auc': 0.8845867460908414, 'imitation_auc': 0.5305343511450382, 'optimal_acc': 0.8095238095238095, 'imitation_acc': 0.8843537414965986}, {'decision': 1, 'optimal_auc': 0.9709989258861439, 'imitation_auc': 0.5983695652173914, 'optimal_acc': 0.9523809523809523, 'imitation_acc': 0.7006802721088435}, {'decision': 2, 'optimal_auc': 0.9953703703703705, 'imitation_auc': 0.7774952320406865, 'optimal_acc': 0.9931972789115646, 'imitation_acc': 0.8299319727891157}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 155 _____
val reward 0.6496842503547668
imitation reward 1.3265411853790283
distance losses 0.3342142105102539 0.13281628489494324
distributions [0.4474368989467621, 0.

tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 165 _____
val reward 0.6154537200927734
imitation reward 1.6241118907928467
distance losses 0.33017733693122864 0.1390685886144638
distributions [0.46317410469055176, 0.0951215848326683, 0.024412168189883232]
[{'decision': 0, 'optimal_auc': 0.8959419210722264, 'imitation_auc': 0.5768129770992366, 'optimal_acc': 0.7959183673469388, 'imitation_acc': 0.8843537414965986}, {'decision': 1, 'optimal_auc': 0.9865735767991407, 'imitation_auc': 0.5774456521739131, 'optimal_acc': 0.9523809523809523, 'imitation_acc': 0.6938775510204082}, {'decision': 2, 'optimal_auc': 0.9976851851851852, 'imitation_auc': 0.7727272727272727, 'optimal_acc': 0.9863945578231292, 'imitation_acc': 0.7891156462585034}]
tensor([0.4626, 0.0952, 0.0204], grad_fn=<MeanBackward1>)
______epoch 166 _____
val reward 0.6076890826225281
imitation reward 1.7246925830841064
distance losses 0.36433151364326477 0.16672579944133759
distributions [0.46101903915405273,

DecisionAttentionModel(
  (input_dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): Linear(in_features=100, out_features=1000, bias=True)
  )
  (batchnorm): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.25, inplace=False)
  (relu): Softplus(beta=1, threshold=20)
  (sigmoid): Sigmoid()
  (softmax): Softmax(dim=1)
  (final_opt_layer): Linear(in_features=1000, out_features=100, bias=True)
  (final_imitation_layer): Linear(in_features=1000, out_features=100, bias=True)
  (final_layer): Linear(in_features=1000, out_features=6, bias=True)
  (resize_layer): Linear(in_features=90, out_features=100, bias=True)
  (attentions): ModuleList(
    (0): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
    )
  )
  (norms): ModuleList(
    (0): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
  )
  (activation): ReLU()
)

In [23]:
pd.DataFrame(decision_score)

Unnamed: 0,decision,optimal_auc,imitation_auc,optimal_acc,imitation_acc
0,0,0.9086,0.581107,0.843537,0.891156
1,1,0.964017,0.687228,0.904762,0.768707
2,2,1.0,0.807057,0.979592,0.823129


In [13]:
decision_model.set_device('cpu')
torch.save(decision_model,'../resources/decision_model.pt')
pd.DataFrame(decision_score).to_csv('../results/policy_model_score.csv')

Unnamed: 0,decision,optimal_auc,imitation_auc,optimal_acc,imitation_acc
0,0,0.841137,0.582061,0.789116,0.877551
1,1,0.972805,0.730978,0.92517,0.77551
2,2,0.945078,0.793388,0.897959,0.809524
