In [1]:
import csv
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
import pickle
import random

#from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

import tensorflow as tf
from tensorflow import nn
from tensorflow.keras import Model
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import Input, Dense, Lambda
from tensorflow.keras.layers import Embedding, Layer, Input, Dense, Lambda, Concatenate, Add
from tensorflow.keras.models import Model
from tensorflow.python.ops import array_ops

# this file is needed for dropout operations, as the code used for STraTS is based on a very old Keras version
import smart_cond_mod as sc

from pathlib import Path
from tensorflow.keras.callbacks import CSVLogger

tf.keras.utils.set_random_seed(1)
random.seed(100)
np.random.seed(100)

In [2]:
with open("./data/mimic_and_sbert_for_thesis.pkl", "rb") as pfile:
    raw_data = pickle.load(pfile)
mimic = raw_data[0]
meta = raw_data[1]
train_ind = raw_data[2]
valid_ind = raw_data[3]
test_ind = raw_data[4]
data = mimic
oc = meta

In [3]:
ids = oc['SUBJECT_ID'].tolist()
labels = oc['in_hospital_sepsis'].tolist()

new_patient_ids = []
new_labels = []

for i in range(len(labels)):
  # print(i)
  if ids[i] in new_patient_ids:
    continue
  else:
    new_patient_ids.append(ids[i])
    new_labels.append(labels[i])

In [4]:
from collections import Counter
# data ratio
Counter(new_labels)

Counter({0: 33592, 1: 3263})

In [5]:

x, x_test, y, y_test = train_test_split(new_patient_ids, new_labels, test_size=0.2, random_state=1)
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=1)
# train
train_ind = []

ts_ind = oc['ts_ind'].tolist()
# ids = ids

for i in range(len(ts_ind)):
  if ids[i] in x_train:
    train_ind.append(ts_ind[i])

# number of train patients
print(len(x_train))
# number of train instances
print(len(train_ind))
# to np.array
train_ind = np.array(train_ind)

test_ind = []

for i in range(len(ts_ind)):
  if ids[i] in x_test:
    test_ind.append(ts_ind[i])

# number of test patients
print(len(x_test))
# number of test instances
print(len(test_ind))
# to np.array
test_ind = np.array(test_ind)

valid_ind = []

for i in range(len(ts_ind)):
  if ids[i] in x_val:
    valid_ind.append(ts_ind[i])

# number of test patients
print(len(x_val))
# number of test instances
print(len(valid_ind))
# to np.array
valid_ind = np.array(valid_ind)

23587
31708
7371
9894
5897
7803


In [6]:
pred_window = 1 # hours that the output vector represents. 1 because i want to learn to predict 1 hour many times
obs_windows = range(20, 124, 4)

In [7]:
# Remove test patients.
data = data.merge(oc[['ts_ind', 'SUBJECT_ID']], on='ts_ind', how='left')
test_sub = oc.loc[oc.ts_ind.isin(test_ind)].SUBJECT_ID.unique()
data = data.loc[~data.SUBJECT_ID.isin(test_sub)]
oc = oc.loc[~oc.SUBJECT_ID.isin(test_sub)]
data.drop(columns=['SUBJECT_ID'], inplace=True)
# Fix age.
data.loc[(data.variable=='Age')&(data.value>200), 'value'] = 91.4
# Get static data with mean fill and missingness indicator.
static_varis = ['Age', 'Gender']
ii = data.variable.isin(static_varis)
static_data = data.loc[ii]
data = data.loc[~ii]
def inv_list(l, start=0):
    d = {}
    for i in range(len(l)):
        d[l[i]] = i+start
    return d
static_var_to_ind = inv_list(static_varis)
D = len(static_varis)
N = data.ts_ind.max()+1
demo = np.zeros((N, D))
for row in tqdm(static_data.itertuples()):
    demo[row.ts_ind, static_var_to_ind[row.variable]] = row.value
# Normalize static data.
means = demo.mean(axis=0, keepdims=True)
stds = demo.std(axis=0, keepdims=True)
stds = (stds==0)*1 + (stds!=0)*stds
demo = (demo-means)/stds
# Get variable indices.
varis = sorted(list(set(data.variable)))
V = len(varis)
var_to_ind = inv_list(varis, start=1)
data['vind'] = data.variable.map(var_to_ind)
data = data[['ts_ind', 'vind', 'hour', 'value']].sort_values(by=['ts_ind', 'vind', 'hour'])
# Find max_len.
fore_max_len = 880
# Get forecast inputs and outputs.
fore_times_ip = []
fore_values_ip = []
fore_varis_ip = []
fore_op = []
fore_inds = []
def f(x):
    mask = [0 for i in range(V)]
    values = [0 for i in range(V)]
    for vv in x:
        v = int(vv[0])-1
        mask[v] = 1
        values[v] = vv[1]
    return values+mask
def pad(x):
    return x+[0]*(fore_max_len-len(x))
for w in tqdm(obs_windows):
    pred_data = data.loc[(data.hour>=w)&(data.hour<=w+pred_window)]
    pred_data = pred_data.groupby(['ts_ind', 'vind']).agg({'value':'first'}).reset_index()
    pred_data['vind_value'] = pred_data[['vind', 'value']].values.tolist()
    pred_data = pred_data.groupby('ts_ind').agg({'vind_value':list}).reset_index()
    pred_data['vind_value'] = pred_data['vind_value'].apply(f)
    obs_data = data.loc[(data.hour<w)&(data.hour>=w-24)]
    obs_data = obs_data.loc[obs_data.ts_ind.isin(pred_data.ts_ind)]
    obs_data = obs_data.groupby('ts_ind').head(fore_max_len)
    obs_data = obs_data.groupby('ts_ind').agg({'vind':list, 'hour':list, 'value':list}).reset_index()
    obs_data = obs_data.merge(pred_data, on='ts_ind')
    for col in ['vind', 'hour', 'value']:
        obs_data[col] = obs_data[col].apply(pad)
    fore_op.append(np.array(list(obs_data.vind_value)))
    fore_inds.append(np.array(list(obs_data.ts_ind)))
    fore_times_ip.append(np.array(list(obs_data.hour)))
    fore_values_ip.append(np.array(list(obs_data.value)))
    fore_varis_ip.append(np.array(list(obs_data.vind)))
del data
fore_times_ip = np.concatenate(fore_times_ip, axis=0)
fore_values_ip = np.concatenate(fore_values_ip, axis=0)
fore_varis_ip = np.concatenate(fore_varis_ip, axis=0)
fore_op = np.concatenate(fore_op, axis=0)
fore_inds = np.concatenate(fore_inds, axis=0)
fore_demo = demo[fore_inds]
# Get train and valid ts_ind for forecast task.
train_sub = oc.loc[oc.ts_ind.isin(train_ind)].SUBJECT_ID.unique()
valid_sub = oc.loc[oc.ts_ind.isin(valid_ind)].SUBJECT_ID.unique()
rem_sub = oc.loc[~oc.SUBJECT_ID.isin(np.concatenate((train_ind, valid_ind)))].SUBJECT_ID.unique()
bp = int(0.8*len(rem_sub))
train_sub = np.concatenate((train_sub, rem_sub[:bp]))
valid_sub = np.concatenate((valid_sub, rem_sub[bp:]))
train_ind = oc.loc[oc.SUBJECT_ID.isin(train_sub)].ts_ind.unique() # Add remaining ts_ind s of train subjects.
valid_ind = oc.loc[oc.SUBJECT_ID.isin(valid_sub)].ts_ind.unique() # Add remaining ts_ind s of train subjects.
# Generate 3 sets of inputs and outputs.
train_ind = np.argwhere(np.in1d(fore_inds, train_ind)).flatten()
valid_ind = np.argwhere(np.in1d(fore_inds, valid_ind)).flatten()

# INPUT
# 4*sequencenumber*880
# (demographics, hours, values, feature names) * (observation sequences) * 880
fore_train_ip = [ip[train_ind] for ip in [fore_demo, fore_times_ip, fore_values_ip, fore_varis_ip]]
fore_valid_ip = [ip[valid_ind] for ip in [fore_demo, fore_times_ip, fore_values_ip, fore_varis_ip]]
del fore_times_ip, fore_values_ip, fore_varis_ip, demo, fore_demo

# OUTPUT
# sequencenumber * (2 * # of features)
# sequencenumber * (feature values + feature mask)
# for SBERT and TF-IDF, noise: sequencenumber * (183*2)
# for BASE and TF-IDF: sequencenumber * (133*2)
fore_train_op = fore_op[train_ind]
fore_valid_op = fore_op[valid_ind]
del fore_op

79022it [00:00, 389272.63it/s]
100%|██████████| 26/26 [03:16<00:00,  7.56s/it]


In [8]:
# number of 24 hour sequences
len(fore_train_ip[0])

395850

In [9]:
fore_train_ip[3][19999]

array([  8,  10,  13,  21,  24,  24,  24,  26,  29,  29,  29,  29,  29,
        29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,  29,
        30,  43,  43,  43,  43,  43,  43,  43,  43,  43,  43,  43,  43,
        43,  43,  43,  43,  43,  43,  43,  43,  43,  44,  44,  44,  44,
        44,  44,  44,  44,  44,  44,  44,  44,  44,  44,  44,  44,  44,
        44,  44,  44,  44,  45,  45,  45,  45,  45,  45,  45,  45,  45,
        45,  45,  45,  45,  45,  45,  45,  45,  45,  45,  45,  45,  49,
        49,  49,  49,  50,  50,  50,  50,  52,  52,  52,  52,  52,  52,
        52,  52,  52,  52,  52,  52,  52,  52,  52,  52,  52,  52,  52,
        52,  54,  57,  58,  60,  61,  76,  76,  76,  76,  76,  76,  76,
        76,  76,  76,  76,  76,  76,  76,  76,  76,  76,  76,  77,  78,
        79,  80,  88,  94,  94,  94,  94,  94,  94,  94,  94,  95,  95,
        95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
        95,  95,  95,  95, 100, 101, 104, 105, 105, 107, 108, 10

In [53]:
#print("Output length Full:", len(fore_train_op), "Output length Single", len(fore_train_op[0]))
#print("One Output Full:\n", fore_train_op[1])

#print("features length Full:", len(fore_train_ip[3]), "features length Single", len(fore_train_ip[3][0]))

print("One feature ID Full:", Counter(fore_train_ip[3][2]))
print("One feature ID Full:", fore_train_ip[3][2])

#print("One feature value Full:", fore_train_ip[2][0])

One feature ID Full: Counter({95: 25, 114: 25, 9: 24, 29: 24, 76: 24, 116: 24, 52: 23, 0: 20, 28: 16, 37: 16, 39: 16, 71: 16, 93: 16, 94: 16, 111: 16, 126: 16, 127: 16, 123: 12, 38: 9, 134: 9, 135: 9, 136: 9, 137: 9, 138: 9, 139: 9, 140: 9, 141: 9, 142: 9, 143: 9, 144: 9, 145: 9, 146: 9, 147: 9, 148: 9, 149: 9, 150: 9, 151: 9, 152: 9, 153: 9, 154: 9, 155: 9, 156: 9, 157: 9, 158: 9, 159: 9, 160: 9, 161: 9, 162: 9, 163: 9, 164: 9, 165: 9, 166: 9, 167: 9, 168: 9, 169: 9, 170: 9, 171: 9, 172: 9, 173: 9, 174: 9, 175: 9, 176: 9, 177: 9, 178: 9, 179: 9, 180: 9, 181: 9, 182: 9, 183: 9, 43: 7, 44: 7, 45: 7, 11: 4, 49: 4, 97: 4, 99: 4, 124: 4, 132: 4, 48: 3, 65: 3, 8: 1, 10: 1, 13: 1, 19: 1, 21: 1, 24: 1, 26: 1, 50: 1, 54: 1, 57: 1, 60: 1, 70: 1, 77: 1, 78: 1, 79: 1, 80: 1, 100: 1, 101: 1, 104: 1, 107: 1, 108: 1, 112: 1, 113: 1, 118: 1, 130: 1})
One feature ID Full: [  8   9   9   9   9   9   9   9   9   9   9   9   9   9   9   9   9   9
   9   9   9   9   9   9   9  10  11  11  11  11  13  19  

In [13]:
text_ids = set(range(134,184,1))
missing_text_list = []
all_seq = len(fore_train_ip[3])
# for each seq look at feature names
for i in range(len(fore_train_ip[3])):
    # count which feature names are in sequence
    counts = Counter(fore_train_ip[3][i])

    # append length of list containing all missing text features in sequence
    missing_text_list.append(len(list(text_ids - counts.keys())))

# for all 24 hour sequences:
# all text features present: no text features missing = Counter(missing_text_list)[0]
full_text_features = Counter(missing_text_list)[0]
# all sequences with some text missing: all seq - sequences with full text features
missing_some_text = all_seq-full_text_features
# full text missing: all 50 features missing
missing_50 = Counter(missing_text_list)[50]
# at least 1 text feature: all sequences - sequences where all 50 features missing
min_1_text =  all_seq - missing_50
d = {"All_Text_Features": [full_text_features / all_seq*100], "Min1_Text_Features": [min_1_text / all_seq*100], "SomeMiss_Text_Features": [missing_some_text / all_seq*100], "all_text_miss": [missing_50 / all_seq*100]}
text_percentages = pd.DataFrame.from_dict(d)
text_percentages

Unnamed: 0,All_Text_Features,Min1_Text_Features,SomeMiss_Text_Features,all_text_miss
0,74.199571,87.928508,25.800429,12.071492


In [122]:
# all text
full_text_features = Counter(missing_text_list)[0]
# all sequences with some text missing
missing_some_text = len(fore_train_ip[3])-Counter(missing_text_list)[0]
# full text missing
missing_50 = Counter(missing_text_list)[50]
# of sequences with out all text sequences where at least 1 text feature is in
min_1_text =  full_text_features - missing_50


In [123]:
missing_some_text

5326

In [124]:
missing_50

3241

In [125]:
min_1_text

13882

In [126]:
full_text_features

17123

In [127]:
all_seq = len(fore_train_ip[3])

22449

In [130]:
50 / 200 *100

25.0

In [69]:
counts = Counter(fore_train_ip[3][0])
if counts.keys().__contains__(100000):
    print("true")
else:
    print("false")


false


In [10]:
def get_res(y_true, y_pred):
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
    pr_auc = auc(recall, precision)
    minrp = np.minimum(precision, recall).max()
    roc_auc = roc_auc_score(y_true, y_pred)
    return [roc_auc, pr_auc, minrp]


#class_weights = compute_class_weight(class_weight='balanced', classes=[0,1], y=train_op)
#def mortality_loss(y_true, y_pred):
    #sample_weights = (1-y_true)*class_weights[0] + y_true*class_weights[1]
    #bce = K.binary_crossentropy(y_true, y_pred)
    #return K.mean(sample_weights*bce, axis=-1)


# var_weights = np.sum(fore_train_op[:, V:], axis=0)
# var_weights[var_weights==0] = var_weights.max()
# var_weights = var_weights.max()/var_weights
# var_weights = var_weights.reshape((1, V))
def forecast_loss(y_true, y_pred):
    return K.sum(y_true[:,V:]*(y_true[:,:V]-y_pred)**2, axis=-1)
                                          
def get_min_loss(weight):
    def min_loss(y_true, y_pred):
        return weight*y_pred
    return min_loss

class CustomCallback(Callback):
    def __init__(self, validation_data, batch_size):
        self.val_x, self.val_y = validation_data
        self.batch_size = batch_size
        super(Callback, self).__init__()

    def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.val_x, verbose=0, batch_size=self.batch_size)
        if type(y_pred)==type([]):
            y_pred = y_pred[0]
        precision, recall, thresholds = precision_recall_curve(self.val_y, y_pred)
        pr_auc = auc(recall, precision)
        roc_auc = roc_auc_score(self.val_y, y_pred)
        logs['custom_metric'] = pr_auc + roc_auc
        print ('val_aucs:', pr_auc, roc_auc)

In [11]:
class CVE(Layer):
    def __init__(self, hid_units, output_dim):
        self.hid_units = hid_units
        self.output_dim = output_dim
        super(CVE, self).__init__()

    def build(self, input_shape):
        self.W1 = self.add_weight(name='CVE_W1',
                            shape=(1, self.hid_units),
                            initializer='glorot_uniform',
                            trainable=True)
        self.b1 = self.add_weight(name='CVE_b1',
                            shape=(self.hid_units,),
                            initializer='zeros',
                            trainable=True)
        self.W2 = self.add_weight(name='CVE_W2',
                            shape=(self.hid_units, self.output_dim),
                            initializer='glorot_uniform',
                            trainable=True)
        super(CVE, self).build(input_shape)

    def call(self, x):
        x = K.expand_dims(x, axis=-1)
        x = K.dot(K.tanh(K.bias_add(K.dot(x, self.W1), self.b1)), self.W2)
        return x

    def compute_output_shape(self, input_shape):
        return input_shape + (self.output_dim,)


class Attention(Layer):

    def __init__(self, hid_dim):
        self.hid_dim = hid_dim
        super(Attention, self).__init__()

    def build(self, input_shape):
        d = input_shape.as_list()[-1]
        self.W = self.add_weight(shape=(d, self.hid_dim), name='Att_W',
                                 initializer='glorot_uniform',
                                 trainable=True)
        self.b = self.add_weight(shape=(self.hid_dim,), name='Att_b',
                                 initializer='zeros',
                                 trainable=True)
        self.u = self.add_weight(shape=(self.hid_dim,1), name='Att_u',
                                 initializer='glorot_uniform',
                                 trainable=True)
        super(Attention, self).build(input_shape)

    def call(self, x, mask, mask_value=-1e30):
        attn_weights = K.dot(K.tanh(K.bias_add(K.dot(x,self.W), self.b)), self.u)
        mask = K.expand_dims(mask, axis=-1)
        attn_weights = mask*attn_weights + (1-mask)*mask_value
        attn_weights = K.softmax(attn_weights, axis=-2)
        return attn_weights

    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (1,)


class Transformer(Layer):

    def __init__(self, N=2, h=8, dk=None, dv=None, dff=None, dropout=0):
        self.N, self.h, self.dk, self.dv, self.dff, self.dropout = N, h, dk, dv, dff, dropout
        self.epsilon = K.epsilon() * K.epsilon()
        super(Transformer, self).__init__()

    def build(self, input_shape):
        d = input_shape.as_list()[-1]
        if self.dk==None:
            self.dk = d//self.h
        if self.dv==None:
            self.dv = d//self.h
        if self.dff==None:
            self.dff = 2*d
        self.Wq = self.add_weight(shape=(self.N, self.h, d, self.dk), name='Wq',
                                 initializer='glorot_uniform', trainable=True)
        self.Wk = self.add_weight(shape=(self.N, self.h, d, self.dk), name='Wk',
                                 initializer='glorot_uniform', trainable=True)
        self.Wv = self.add_weight(shape=(self.N, self.h, d, self.dv), name='Wv',
                                 initializer='glorot_uniform', trainable=True)
        self.Wo = self.add_weight(shape=(self.N, self.dv*self.h, d), name='Wo',
                                 initializer='glorot_uniform', trainable=True)
        self.W1 = self.add_weight(shape=(self.N, d, self.dff), name='W1',
                                 initializer='glorot_uniform', trainable=True)
        self.b1 = self.add_weight(shape=(self.N, self.dff), name='b1',
                                 initializer='zeros', trainable=True)
        self.W2 = self.add_weight(shape=(self.N, self.dff, d), name='W2',
                                 initializer='glorot_uniform', trainable=True)
        self.b2 = self.add_weight(shape=(self.N, d), name='b2',
                                 initializer='zeros', trainable=True)
        self.gamma = self.add_weight(shape=(2*self.N,), name='gamma',
                                 initializer='ones', trainable=True)
        self.beta = self.add_weight(shape=(2*self.N,), name='beta',
                                 initializer='zeros', trainable=True)
        super(Transformer, self).build(input_shape)

    def call(self, x, mask, mask_value=-1e-30):
        mask = K.expand_dims(mask, axis=-2)
        for i in range(self.N):
            # MHA
            mha_ops = []
            for j in range(self.h):
                q = K.dot(x, self.Wq[i,j,:,:])
                k = K.permute_dimensions(K.dot(x, self.Wk[i,j,:,:]), (0,2,1))
                v = K.dot(x, self.Wv[i,j,:,:])
                A = K.batch_dot(q,k)
                # Mask unobserved steps.
                A = mask*A + (1-mask)*mask_value
                # Mask for attention dropout.
                def dropped_A():
                    dp_mask = K.cast((K.random_uniform(shape=array_ops.shape(A))>=self.dropout), K.floatx())
                    return A*dp_mask + (1-dp_mask)*mask_value
                A = sc.smart_cond(K.learning_phase(), dropped_A, lambda: array_ops.identity(A))
                A = K.softmax(A, axis=-1)
                mha_ops.append(K.batch_dot(A,v))
            conc = K.concatenate(mha_ops, axis=-1)
            proj = K.dot(conc, self.Wo[i,:,:])
            # Dropout.
            proj = sc.smart_cond(K.learning_phase(), lambda: array_ops.identity(nn.dropout(proj, rate=self.dropout)),\
                                       lambda: array_ops.identity(proj))
            # Add & LN
            x = x+proj
            mean = K.mean(x, axis=-1, keepdims=True)
            variance = K.mean(K.square(x - mean), axis=-1, keepdims=True)
            std = K.sqrt(variance + self.epsilon)
            x = (x - mean) / std
            x = x*self.gamma[2*i] + self.beta[2*i]
            # FFN
            ffn_op = K.bias_add(K.dot(K.relu(K.bias_add(K.dot(x, self.W1[i,:,:]), self.b1[i,:])),
                           self.W2[i,:,:]), self.b2[i,:,])
            # Dropout.
            ffn_op = sc.smart_cond(K.learning_phase(), lambda: array_ops.identity(nn.dropout(ffn_op, rate=self.dropout)),\
                                       lambda: array_ops.identity(ffn_op))
            # Add & LN
            x = x+ffn_op
            mean = K.mean(x, axis=-1, keepdims=True)
            variance = K.mean(K.square(x - mean), axis=-1, keepdims=True)
            std = K.sqrt(variance + self.epsilon)
            x = (x - mean) / std
            x = x*self.gamma[2*i+1] + self.beta[2*i+1]
        return x

    def compute_output_shape(self, input_shape):
        return input_shape


def build_strats(D, max_len, V, d, N, he, dropout, forecast=False):
    demo = Input(shape=(D,))
    demo_enc = Dense(2*d, activation='tanh')(demo)
    demo_enc = Dense(d, activation='tanh')(demo_enc)
    varis = Input(shape=(max_len,))
    values = Input(shape=(max_len,))
    times = Input(shape=(max_len,))
    varis_emb = Embedding(V+1, d)(varis)
    cve_units = int(np.sqrt(d))
    values_emb = CVE(cve_units, d)(values)
    times_emb = CVE(cve_units, d)(times)
    comb_emb = Add()([varis_emb, values_emb, times_emb]) # b, L, d
#     demo_enc = Lambda(lambda x:K.expand_dims(x, axis=-2))(demo_enc) # b, 1, d
#     comb_emb = Concatenate(axis=-2)([demo_enc, comb_emb]) # b, L+1, d
    mask = Lambda(lambda x:K.clip(x,0,1))(varis) # b, L
#     mask = Lambda(lambda x:K.concatenate((K.ones_like(x)[:,0:1], x), axis=-1))(mask) # b, L+1
    cont_emb = Transformer(N, he, dk=None, dv=None, dff=None, dropout=dropout)(comb_emb, mask=mask)
    attn_weights = Attention(2*d)(cont_emb, mask=mask)
    fused_emb = Lambda(lambda x:K.sum(x[0]*x[1], axis=-2))([cont_emb, attn_weights])
    conc = Concatenate(axis=-1)([fused_emb, demo_enc])
    fore_op = Dense(V)(conc)
    op = Dense(1, activation='sigmoid')(fore_op)
    model = Model([demo, times, values, varis], op)
    if forecast:
        fore_model = Model([demo, times, values, varis], fore_op)
        return [model, fore_model]
    return model

In [13]:
#surely can be done more elegant...
fore_savepath = './models/EXP3_STraTS_SBERT'

train_FILE_PATH = Path(f'{fore_savepath}/train_losses.csv')
val_FILE_PATH = Path(f'{fore_savepath}/val_losses.csv')

# initialize model parameters
lr, batch_size, samples_per_epoch, patience = 0.0005, 32, len(fore_train_op), 5
d, N, he, dropout = 50, 2, 4, 0.2
model, fore_model =  build_strats(D, fore_max_len, V, d, N, he, dropout, forecast=True)
print(fore_model.summary())
lossfunction = forecast_loss
opt = tf.keras.optimizers.Adam(lr)
fore_model.compile(loss=lossfunction, optimizer=opt)

# initialize checkpoint manager
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=fore_model)
manager = tf.train.CheckpointManager(ckpt, f'{fore_savepath}', max_to_keep=3)

# define training procedure
def train_and_checkpoint(net, manager):
  # initialize loss, etc
  best_val_loss = np.inf
  N_fore = len(fore_train_op)
  train_losses = []
  val_losses = []

  # load or create model
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  # training 
  for e in range(1000):
    np.random.seed(100)
    e_indices = np.random.choice(range(N_fore), size=samples_per_epoch, replace=False)
    e_loss = 0
    pbar = tqdm(range(0, len(e_indices), batch_size))
    for start in pbar:
        ind = e_indices[start:start+batch_size]
        # pre-train data
        e_loss += net.train_on_batch([ip[ind] for ip in fore_train_ip], fore_train_op[ind])
        pbar.set_description('%f'%(e_loss/(start+1)))
    
    # validate at end of epoch
    val_loss = net.evaluate(fore_valid_ip, fore_valid_op, batch_size=batch_size, verbose=1)
    print ('Epoch', e, 'loss', e_loss*batch_size/samples_per_epoch, 'val loss', val_loss)
    #train_losses.append(e_loss*batch_size/samples_per_epoch)
    #val_losses.append(val_loss)
    
    # save best checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_path = manager.save()
        print("Saved new best checkpoint for step {}: {}".format(int(ckpt.step), save_path))
        best_epoch = e
    
      # save train and val losses for visualization
    if train_FILE_PATH.exists():
      with open(train_FILE_PATH, 'a') as lo:
          reader = csv.writer(lo)
          reader.writerow([str(e_loss*batch_size/samples_per_epoch)])
      with open(val_FILE_PATH, 'a') as val_lo:
              reader = csv.writer(val_lo)
              reader.writerow([val_loss])

    if not train_FILE_PATH.exists():
        with open(train_FILE_PATH, 'w') as lo:
            reader = csv.writer(lo)
            reader.writerow([e_loss*batch_size/samples_per_epoch])
        with open(val_FILE_PATH, 'w') as val_lo:
            reader = csv.writer(val_lo)
            reader.writerow([val_loss])  
    
    ckpt.step.assign_add(1)

    if (e-best_epoch)>patience:
        break

2024-03-23 15:16:57.274924: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31133 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:b2:00.0, compute capability: 7.0


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 880)]                0         []                            
                                                                                                  
 input_3 (InputLayer)        [(None, 880)]                0         []                            
                                                                                                  
 input_4 (InputLayer)        [(None, 880)]                0         []                            
                                                                                                  
 embedding (Embedding)       (None, 880, 50)              9200      ['input_2[0][0]']             
                                                                                            

In [None]:
# training
train_and_checkpoint(fore_model, manager)

Initializing from scratch.


  0%|          | 0/12371 [00:00<?, ?it/s]2024-03-23 15:17:03.330634: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x1535e80d5330 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-03-23 15:17:03.330669: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla V100-SXM2-32GB, Compute Capability 7.0
2024-03-23 15:17:03.335738: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-03-23 15:17:03.703032: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:442] Loaded cuDNN version 8700
2024-03-23 15:17:03.796295: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
0.209229:  93%|█████████▎| 11523/12371 [08:57<00:48, 17.48it/s]

Epoch 0 loss 6.63968187382987 val loss 6.2129902839660645
Saved new best checkpoint for step 1: ./sentence_bert_STraTS_20-124_masked_transformer_predsize_1V2/ckpt-1


0.179497: 100%|██████████| 12371/12371 [09:22<00:00, 21.98it/s]


Epoch 1 loss 5.743760733344842 val loss 5.862745761871338
Saved new best checkpoint for step 2: ./sentence_bert_STraTS_20-124_masked_transformer_predsize_1V2/ckpt-2


0.169943: 100%|██████████| 12371/12371 [09:21<00:00, 22.03it/s]


Epoch 2 loss 5.438046396613738 val loss 5.591403961181641
Saved new best checkpoint for step 3: ./sentence_bert_STraTS_20-124_masked_transformer_predsize_1V2/ckpt-3


0.164330: 100%|██████████| 12371/12371 [09:21<00:00, 22.05it/s]


Epoch 3 loss 5.2584352629886695 val loss 5.4886393547058105
Saved new best checkpoint for step 4: ./sentence_bert_STraTS_20-124_masked_transformer_predsize_1V2/ckpt-4


0.160931: 100%|██████████| 12371/12371 [09:21<00:00, 22.04it/s]


Epoch 4 loss 5.149690664305115 val loss 5.380105018615723
Saved new best checkpoint for step 5: ./sentence_bert_STraTS_20-124_masked_transformer_predsize_1V2/ckpt-5


0.158213: 100%|██████████| 12371/12371 [09:21<00:00, 22.02it/s]


Epoch 5 loss 5.0627011319057225 val loss 5.298466682434082
Saved new best checkpoint for step 6: ./sentence_bert_STraTS_20-124_masked_transformer_predsize_1V2/ckpt-6


0.155947: 100%|██████████| 12371/12371 [09:30<00:00, 21.68it/s]


Epoch 6 loss 4.990175031541769 val loss 5.238565921783447
Saved new best checkpoint for step 7: ./sentence_bert_STraTS_20-124_masked_transformer_predsize_1V2/ckpt-7


0.153996: 100%|██████████| 12371/12371 [09:21<00:00, 22.02it/s]


Epoch 7 loss 4.927766759201962 val loss 5.163273334503174
Saved new best checkpoint for step 8: ./sentence_bert_STraTS_20-124_masked_transformer_predsize_1V2/ckpt-8


0.152718: 100%|██████████| 12371/12371 [09:20<00:00, 22.06it/s]


Epoch 8 loss 4.886874525930991 val loss 5.137864589691162
Saved new best checkpoint for step 9: ./sentence_bert_STraTS_20-124_masked_transformer_predsize_1V2/ckpt-9


0.151546: 100%|██████████| 12371/12371 [09:26<00:00, 21.84it/s]


