In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.losses import MeanSquaredError
# from tensorflow.keras import models
import pickle
import numpy as np
from tqdm import tqdm
tqdm.pandas()
# from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback, EarlyStopping
import pandas as pd
from collections import Counter

In [2]:
pickle.format_version

'4.0'

## Load Data

### no_text_lab

In [3]:
lab = pd.read_csv("patients_no_text_sepsis_labels.csv")

In [4]:
hours = pd.read_csv("lab_hours.csv")

In [5]:
def missing_values_table(df):
    mis_val = df.isnull().sum()
    mis_val_percent = 100 * df.isnull().sum() / len(df)
    mis_val_table = pd.concat([mis_val, mis_val_percent], axis=1)
    mis_val_table_ren_columns = mis_val_table.rename(
    columns = {0 : 'Missing Values', 1 : '% of Total Values'})
    mis_val_table_ren_columns = mis_val_table_ren_columns[
        mis_val_table_ren_columns.iloc[:,1] != 0].sort_values(
    '% of Total Values', ascending=False).round(1)
    print ("Your selected dataframe has " + str(df.shape[1]) + " columns.\n"      
        "There are " + str(mis_val_table_ren_columns.shape[0]) +
            " columns that have missing values.")
    return mis_val_table_ren_columns

In [6]:
missing_values_table(lab)

Your selected dataframe has 57 columns.
There are 3 columns that have missing values.


Unnamed: 0,Missing Values,% of Total Values
o:output_4hourly,1102706,55.4
o:output_total,1102706,55.4
o:PaO2_FiO2,889233,44.6


### transform into strats format + charttime -> hour (computed once on laptop to save time)

In [7]:
lab["hour"] = hours["hour"]

In [8]:
rows = lab.drop(columns=["m:charttime", "traj", "step", "r:reward", "a:action","Unnamed: 0", "Unnamed: 0.1", "m:presumed_onset", "o:cumulated_balance", "o:re_admission", "o:output_4hourly", "o:output_total", "o:PaO2_FiO2"])
rows

Unnamed: 0,m:icustayid,o:Arterial_BE,o:Arterial_lactate,o:Arterial_pH,o:BUN,o:Calcium,o:Chloride,o:Creatinine,o:DiaBP,o:FiO2_1,...,o:age,o:gender,o:input_4hourly,o:input_total,o:max_dose_vaso,o:mechvent,o:paCO2,o:paO2,sepsis_label,hour
0,200003.0,1.071661,-0.560359,1.069870,-0.380498,0.094949,-0.115152,-0.151959,-0.098663,0.270372,...,-0.963438,-0.5,0.347576,0.692126,-2.302585,-0.5,-0.374712,-0.689066,1,0.0
1,200003.0,1.071661,-0.560359,1.069870,-0.380498,0.094949,-0.115152,-0.151959,-0.093255,0.270372,...,-0.963438,-0.5,0.347576,0.694322,-2.302585,-0.5,-0.374712,-0.689066,1,1.0
2,200003.0,1.071661,-0.560359,1.069870,-0.380498,0.094949,-0.115152,-0.151959,0.014916,0.270372,...,-0.963438,-0.5,0.347576,0.696511,-2.302585,-0.5,-0.374712,-0.689066,1,2.0
3,200003.0,1.071661,-0.560359,1.069870,-0.380498,0.094949,-0.115152,-0.151959,0.112270,0.270372,...,-0.963438,-0.5,0.347576,0.698694,-2.302585,-0.5,-0.374712,-0.689066,1,3.0
4,200003.0,1.071661,-0.560359,1.069870,-0.380498,0.094949,-0.115152,-0.151959,0.004099,0.270372,...,-0.963438,-0.5,0.347576,0.700870,-2.302585,-0.5,-0.374712,-0.689066,1,4.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1991934,299999.0,-0.019540,-0.744383,-0.112773,-0.519833,0.095193,0.191409,-0.065567,-0.062706,-0.340459,...,-0.849940,-0.5,-0.329454,0.983842,-2.302585,-0.5,-0.076793,-0.483898,0,21.0
1991935,299999.0,-0.019540,-0.744383,-0.112773,-0.519833,0.095193,0.191409,-0.065567,-0.033424,-0.340459,...,-0.849940,-0.5,-1.056430,0.983842,-2.302585,-0.5,-0.076793,-0.483898,0,22.0
1991936,299999.0,-0.019540,-0.744383,-0.112773,-0.519833,0.095193,0.191409,-0.065567,-0.040745,-0.340459,...,-0.849940,-0.5,-1.056430,0.983842,-2.302585,-0.5,-0.076793,-0.483898,0,23.0
1991937,299999.0,-0.019540,-0.744383,-0.112773,-0.519833,0.095193,0.191409,-0.065567,-0.037084,-0.340459,...,-0.849940,-0.5,-1.056430,0.983842,-2.302585,-0.5,-0.076793,-0.483898,0,24.0


In [9]:
strats_data = (rows.melt(id_vars=['hour', 'm:icustayid', 'sepsis_label'], var_name='variable',value_name='value', ignore_index=False)
       .sort_values(['m:icustayid', 'hour'])
       .reset_index(drop=True))


In [10]:
rows.describe()

Unnamed: 0,m:icustayid,o:Arterial_BE,o:Arterial_lactate,o:Arterial_pH,o:BUN,o:Calcium,o:Chloride,o:Creatinine,o:DiaBP,o:FiO2_1,...,o:age,o:gender,o:input_4hourly,o:input_total,o:max_dose_vaso,o:mechvent,o:paCO2,o:paO2,sepsis_label,hour
count,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,...,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0,1991939.0
mean,249996.4,-2.397085e-18,1.410463e-16,4.104135e-14,-1.618603e-16,-6.130402e-16,-6.76466e-15,1.9619e-17,1.221929e-16,1.678851e-16,...,-1.332836e-15,-0.06825335,-1.2684580000000002e-17,2.321748e-16,-2.170558,-0.2202595,-2.29284e-15,4.179489e-16,0.4464158,31.24133
std,28905.76,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,0.4953197,1.0,2.000001,0.4301755,0.4488718,1.0,1.0,0.4971206,20.14708
min,200003.0,-10.90518,-1.550087,-71.80108,-7.955582,-11.45663,-16.79163,-4.16345,-2.689363,-2.47969,...,-2.775468,-0.5,-1.417028,-7.225433,-2.302585,-0.5,-4.374527,-1.491731,0.0,0.0
25%,224907.0,-0.5379041,-0.5935605,-0.4847101,-0.6566401,-0.5843257,-0.5962046,-0.6097886,-0.06636667,-0.565261,...,-0.6533086,-0.5,-1.05643,-0.1454999,-2.302585,-0.5,-0.6009071,-0.729459,0.0,15.0
50%,250007.0,-0.01954019,-0.2304499,0.09825751,-0.07960458,-0.02944955,0.01226456,-0.2284475,-0.007801794,-0.2796406,...,0.09459865,-0.5,0.2880607,0.6013366,-2.302585,-0.5,-0.1816161,-0.4063115,0.0,28.0
75%,275036.0,0.4988237,0.3113721,0.5840638,0.6441497,0.5925442,0.5496967,0.3465666,0.05442339,0.2703717,...,0.7930249,0.5,0.8958044,1.076243,-2.302585,0.5,0.3429908,0.4829659,1.0,46.0
max,299999.0,25.89865,20.37071,6.553564,4.095757,15.45232,7.715458,8.335615,659.0247,3.481176,...,1.616612,0.5,2.433977,3.366127,5.242678,0.5,16.38038,4.153773,1.0,78.0


In [11]:
mean_stds = strats_data.groupby('variable').agg({'value':['mean', 'std']})

In [12]:
mean_stds.columns = [col[1] for col in mean_stds.columns]
mean_stds

Unnamed: 0_level_0,mean,std
variable,Unnamed: 1_level_1,Unnamed: 2_level_1
o:Arterial_BE,5.458763999999999e-19,1.0
o:Arterial_lactate,1.483356e-16,1.0
o:Arterial_pH,4.10403e-14,1.0
o:BUN,-1.568973e-16,1.0
o:Calcium,-6.190385e-16,1.0
o:Chloride,-6.769658e-15,1.0
o:Creatinine,1.7620400000000002e-17,1.0
o:DiaBP,1.228216e-16,1.0
o:FiO2_1,1.673816e-16,1.0
o:GCS,-2.877988e-16,1.0


In [13]:
strats_data = strats_data.merge(mean_stds.reset_index(), on='variable', how='left')
strats_data['ts_ind'] = strats_data.groupby(['m:icustayid']).ngroup()
strats_data

Unnamed: 0,hour,m:icustayid,sepsis_label,variable,value,mean,std,ts_ind
0,0.0,200003.0,1,o:Arterial_BE,1.071661,5.458764e-19,1.000000,0
1,0.0,200003.0,1,o:Arterial_lactate,-0.560359,1.483356e-16,1.000000,0
2,0.0,200003.0,1,o:Arterial_pH,1.069870,4.104030e-14,1.000000,0
3,0.0,200003.0,1,o:BUN,-0.380498,-1.568973e-16,1.000000,0
4,0.0,200003.0,1,o:Calcium,0.094949,-6.190385e-16,1.000000,0
...,...,...,...,...,...,...,...,...
83661433,38.0,299999.0,0,o:input_total,0.983842,2.268117e-16,2.000001,43387
83661434,38.0,299999.0,0,o:max_dose_vaso,-2.302585,-2.170558e+00,0.430175,43387
83661435,38.0,299999.0,0,o:mechvent,-0.500000,-2.202595e-01,0.448872,43387
83661436,38.0,299999.0,0,o:paCO2,-0.705730,-2.292463e-15,1.000000,43387


In [14]:
IDs = pd.read_csv("IDs.csv")
train_IDs = np.unique(IDs["train_ids"][~np.isnan(IDs["train_ids"])])
test_IDs = np.unique(IDs["test_ids"][~np.isnan(IDs["test_ids"])])
val_IDs = np.unique(IDs["val_ids"][~np.isnan(IDs["val_ids"])])
print(np.intersect1d(train_IDs,test_IDs))
print(np.intersect1d(train_IDs,val_IDs))
print(np.intersect1d(test_IDs,val_IDs))

[]
[]
[]


In [15]:
targets = strats_data[["m:icustayid", "ts_ind", "sepsis_label"]]
strats_data = strats_data.drop(columns="sepsis_label")

In [16]:
train_val_IDs = np.concatenate((train_IDs,val_IDs))
print((len(train_IDs)+len(val_IDs)), len(train_val_IDs))

25418 25418


In [17]:
len(val_IDs)

2542

train and val data for training

In [18]:
train_data = strats_data.loc[strats_data["m:icustayid"].isin(train_val_IDs)]
train_data = train_data.drop(columns="m:icustayid")
train_target = targets.loc[targets["m:icustayid"].isin(train_val_IDs)]
test_data = strats_data.loc[strats_data["m:icustayid"].isin(test_IDs)]
test_data = test_data.drop(columns="m:icustayid")
test_target = targets.loc[targets["m:icustayid"].isin(test_IDs)]
#val_data = strats_data.loc[strats_data["m:icustayid"].isin(val_IDs)]
#val_data = val_data.drop(columns="m:icustayid")
#val_target = targets.loc[targets["m:icustayid"].isin(val_IDs)]

In [19]:
print(len(train_data["ts_ind"]))

49159824


## Load and preprocess data for forecasting

In [20]:
train_ind = np.unique(strats_data["ts_ind"].loc[strats_data["m:icustayid"].isin(train_IDs)])
test_ind = np.unique(strats_data["ts_ind"].loc[strats_data["m:icustayid"].isin(test_IDs)])
valid_ind = np.unique(strats_data["ts_ind"].loc[strats_data["m:icustayid"].isin(val_IDs)])


train_ind

array([    1,     4,     8, ..., 43379, 43382, 43386])

In [21]:
data = train_data[["ts_ind", "hour","variable", "value", "mean", "std"]]
data

Unnamed: 0,ts_ind,hour,variable,value,mean,std
924,1,0.0,o:Arterial_BE,-0.019540,5.458764e-19,1.000000
925,1,0.0,o:Arterial_lactate,0.839249,1.483356e-16,1.000000
926,1,0.0,o:Arterial_pH,0.177068,4.104030e-14,1.000000
927,1,0.0,o:BUN,-0.519833,-1.568973e-16,1.000000
928,1,0.0,o:Calcium,-0.856133,-6.190385e-16,1.000000
...,...,...,...,...,...,...
83660383,43386,50.0,o:input_total,-0.110171,2.268117e-16,2.000001
83660384,43386,50.0,o:max_dose_vaso,-2.302585,-2.170558e+00,0.430175
83660385,43386,50.0,o:mechvent,-0.500000,-2.202595e-01,0.448872
83660386,43386,50.0,o:paCO2,0.163565,-2.292463e-15,1.000000


In [22]:
pred_window = 4 # hours
obs_windows = [4,8,12]

In [23]:
# Fix age.
data.loc[(data.variable=='Age')&(data.value>200), 'value'] = 91.4
# Get static data with mean fill and missingness indicator.
static_varis = ['o:age', 'o:gender']
ii = data.variable.isin(static_varis)
static_data = data.loc[ii]
data = data.loc[~ii]
def inv_list(l, start=0):
    d = {}
    for i in range(len(l)):
        d[l[i]] = i+start
    return d
static_var_to_ind = inv_list(static_varis)
D = len(static_varis)
N = data.ts_ind.max()+1
demo = np.zeros((N, D))
for row in tqdm(static_data.itertuples()):
    demo[row.ts_ind, static_var_to_ind[row.variable]] = row.value
# Normalize static data.
means = demo.mean(axis=0, keepdims=True)
stds = demo.std(axis=0, keepdims=True)
stds = (stds==0)*1 + (stds!=0)*stds
demo = (demo-means)/stds
# Get variable indices.
varis = sorted(list(set(data.variable)))
V = len(varis)
var_to_ind = inv_list(varis, start=1)
data['vind'] = data.variable.map(var_to_ind)
data = data[['ts_ind', 'vind', 'hour', 'value']].sort_values(by=['ts_ind', 'vind', 'hour'])
# Find max_len.
fore_max_len = 880
# Get forecast inputs and outputs.
fore_times_ip = []
fore_values_ip = []
fore_varis_ip = []
fore_op = []
fore_inds = []
def f(x):
    mask = [0 for i in range(V)]
    values = [0 for i in range(V)]
    for vv in x:
        v = int(vv[0])-1
        mask[v] = 1
        values[v] = vv[1]
    return values+mask
def pad(x):
    return x+[0]*(fore_max_len-len(x))
for w in tqdm(obs_windows):
    pred_data = data.loc[(data.hour>=w)&(data.hour<=w+pred_window)]
    pred_data = pred_data.groupby(['ts_ind', 'vind']).agg({'value':'first'}).reset_index()
    pred_data['vind_value'] = pred_data[['vind', 'value']].values.tolist()
    pred_data = pred_data.groupby('ts_ind').agg({'vind_value':list}).reset_index()
    pred_data['vind_value'] = pred_data['vind_value'].apply(f)    
    obs_data = data.loc[(data.hour<w)&(data.hour>=w-24)]
    obs_data = obs_data.loc[obs_data.ts_ind.isin(pred_data.ts_ind)]
    obs_data = obs_data.groupby('ts_ind').head(fore_max_len)
    obs_data = obs_data.groupby('ts_ind').agg({'vind':list, 'hour':list, 'value':list}).reset_index()
    obs_data = obs_data.merge(pred_data, on='ts_ind')
    for col in ['vind', 'hour', 'value']:
        obs_data[col] = obs_data[col].apply(pad)
    fore_op.append(np.array(list(obs_data.vind_value)))
    fore_inds.append(np.array(list(obs_data.ts_ind)))
    fore_times_ip.append(np.array(list(obs_data.hour)))
    fore_values_ip.append(np.array(list(obs_data.value)))
    fore_varis_ip.append(np.array(list(obs_data.vind)))
del data
fore_times_ip = np.concatenate(fore_times_ip, axis=0)
fore_values_ip = np.concatenate(fore_values_ip, axis=0)
fore_varis_ip = np.concatenate(fore_varis_ip, axis=0)
fore_op = np.concatenate(fore_op, axis=0)
fore_inds = np.concatenate(fore_inds, axis=0)
fore_demo = demo[fore_inds]

2340944it [00:01, 1731856.54it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data['vind'] = data.variable.map(var_to_ind)
100%|██████████| 3/3 [00:19<00:00,  6.40s/it]


In [24]:
len(fore_times_ip)

69451

In [25]:
# data
targets["SUBJECT_ID"] = targets["m:icustayid"]


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  targets["SUBJECT_ID"] = targets["m:icustayid"]


In [26]:
# Get train and valid ts_ind for forecast task.
train_sub = targets.loc[targets.ts_ind.isin(train_ind)].SUBJECT_ID.unique()
valid_sub = targets.loc[targets.ts_ind.isin(valid_ind)].SUBJECT_ID.unique()
rem_sub = targets.loc[~targets.SUBJECT_ID.isin(np.concatenate((train_ind, valid_ind)))].SUBJECT_ID.unique()
bp = int(0.8*len(rem_sub))
train_sub = np.concatenate((train_sub, rem_sub[:bp]))
valid_sub = np.concatenate((valid_sub, rem_sub[bp:]))
train_ind = targets.loc[targets.SUBJECT_ID.isin(train_sub)].ts_ind.unique() # Add remaining ts_ind s of train subjects.
valid_ind = targets.loc[targets.SUBJECT_ID.isin(valid_sub)].ts_ind.unique() # Add remaining ts_ind s of train subjects.
# Generate 3 sets of inputs and outputs.^
train_ind = np.argwhere(np.in1d(fore_inds, train_ind)).flatten()
valid_ind = np.argwhere(np.in1d(fore_inds, valid_ind)).flatten()
fore_train_ip = [ip[train_ind] for ip in [fore_demo, fore_times_ip, fore_values_ip, fore_varis_ip]]
fore_valid_ip = [ip[valid_ind] for ip in [fore_demo, fore_times_ip, fore_values_ip, fore_varis_ip]]
del fore_times_ip, fore_values_ip, fore_varis_ip, demo, fore_demo
fore_train_op = fore_op[train_ind]
fore_valid_op = fore_op[valid_ind]
del fore_op

In [27]:
fore_train_op[0]

array([-1.95401924e-02,  8.39249405e-01,  3.21988390e-01, -5.19833204e-01,
        7.74712400e-01, -2.31660792e+00,  8.31308583e-02,  1.02007349e-01,
        1.09145199e-01,  5.09087999e-01,  4.73788556e-01,  1.69643842e+00,
       -1.23758231e-01,  6.58007841e-01, -2.61916044e-01,  8.19715611e-01,
        9.60098312e-02, -2.61171730e-01, -6.31365300e-02,  1.25626717e+00,
       -1.58956946e-02,  1.01326390e-01, -1.49042417e-01, -5.42738701e-01,
       -3.28233379e-01, -2.38325849e-01, -8.96031829e-02, -8.48922257e-01,
        1.09403131e+00, -1.68473530e-01,  3.46548285e-01, -5.96517147e-01,
        5.69720331e-01, -2.23258552e-03, -1.05643043e+00,  3.67967868e-01,
       -2.30258509e+00,  5.00000000e-01, -6.00907129e-01, -2.95441734e-01,
        1.00000000e+00,  1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
        1.00000000e+00,  1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
        1.00000000e+00,  1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
        1.00000000e+00,  

## Matrics, Losses, Model Architecture

In [28]:
from sklearn.metrics import mean_squared_error
mse = MeanSquaredError()
def get_res(y_true, y_pred):
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
    pr_auc = auc(recall, precision)
    minrp = np.minimum(precision, recall).max()
    roc_auc = roc_auc_score(y_true, y_pred)
    return [roc_auc, pr_auc, minrp]

# class_weights = compute_class_weight(class_weight='balanced', classes=[0,1], y=train_op)
# def mortality_loss(y_true, y_pred):
#     sample_weights = (1-y_true)*class_weights[0] + y_true*class_weights[1]
#     bce = K.binary_crossentropy(y_true, y_pred)
#     return K.mean(sample_weights*bce, axis=-1)

# var_weights = np.sum(fore_train_op[:, V:], axis=0)
# var_weights[var_weights==0] = var_weights.max()
# var_weights = var_weights.max()/var_weights
# var_weights = var_weights.reshape((1, V))
def forecast_loss(y_true, y_pred):
    return K.sum(y_true[:,V:]*(y_true[:,:V]-y_pred)**2, axis=-1)

def forecast_mse(y_true, y_pred):
    if type(y_pred)==type([]):
            y_pred = y_pred[0]
    y_ = (y_true[:,V:])*(y_true[:,:V]) #y_true[:,V:] = actual values, y_true[:,:V] = masked values (always 1).
    return mse(y_, y_pred)

def get_min_loss(weight):
    def min_loss(y_true, y_pred):
        return weight*y_pred
    return min_loss

class CustomCallback(Callback):
    def __init__(self, validation_data, batch_size):
        self.val_x, self.val_y = validation_data
        self.batch_size = batch_size
        super(Callback, self).__init__()

    """def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.val_x, verbose=0, batch_size=self.batch_size)
        if type(y_pred)==type([]):
            y_pred = y_pred[0]
        precision, recall, thresholds = precision_recall_curve(self.val_y, y_pred)
        pr_auc = auc(recall, precision)
        roc_auc = roc_auc_score(self.val_y, y_pred)
        logs['custom_metric'] = pr_auc + roc_auc
        print ('val_aucs:', pr_auc, roc_auc)"""

    def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.val_x, verbose=1, batch_size=self.batch_size)
        if type(y_pred)==type([]):
            y_pred = y_pred[0]
        mse_raw = mean_squared_error(self.val_y, y_pred, multioutput='raw_values')
        mse_loss = mse(self.val_y, y_pred)
        logs['mse'] = mse_loss
        logs['mse_raw'] = mse_raw
        print('val_mse:', mse_loss)

In [29]:
import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Embedding, Activation, Dropout, Softmax, Layer, InputSpec, Input, Dense, Lambda, TimeDistributed, Concatenate, Add
from tensorflow.keras import initializers, regularizers, constraints, Model
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow import nn
import smart_cond_mod as sc



class CVE(Layer):
    def __init__(self, hid_units, output_dim):
        self.hid_units = hid_units
        self.output_dim = output_dim
        super(CVE, self).__init__()

    def build(self, input_shape):
        self.W1 = self.add_weight(name='CVE_W1',
                            shape=(1, self.hid_units),
                            initializer='glorot_uniform',
                            trainable=True)
        self.b1 = self.add_weight(name='CVE_b1',
                            shape=(self.hid_units,),
                            initializer='zeros',
                            trainable=True)
        self.W2 = self.add_weight(name='CVE_W2',
                            shape=(self.hid_units, self.output_dim),
                            initializer='glorot_uniform',
                            trainable=True)
        super(CVE, self).build(input_shape)

    def call(self, x):
        x = K.expand_dims(x, axis=-1)
        x = K.dot(K.tanh(K.bias_add(K.dot(x, self.W1), self.b1)), self.W2)
        return x

    def compute_output_shape(self, input_shape):
        return input_shape + (self.output_dim,)


class Attention(Layer):

    def __init__(self, hid_dim):
        self.hid_dim = hid_dim
        super(Attention, self).__init__()

    def build(self, input_shape):
        d = input_shape.as_list()[-1]
        self.W = self.add_weight(shape=(d, self.hid_dim), name='Att_W',
                                 initializer='glorot_uniform',
                                 trainable=True)
        self.b = self.add_weight(shape=(self.hid_dim,), name='Att_b',
                                 initializer='zeros',
                                 trainable=True)
        self.u = self.add_weight(shape=(self.hid_dim,1), name='Att_u',
                                 initializer='glorot_uniform',
                                 trainable=True)
        super(Attention, self).build(input_shape)

    def call(self, x, mask, mask_value=-1e30):
        attn_weights = K.dot(K.tanh(K.bias_add(K.dot(x,self.W), self.b)), self.u)
        mask = K.expand_dims(mask, axis=-1)
        attn_weights = mask*attn_weights + (1-mask)*mask_value
        attn_weights = K.softmax(attn_weights, axis=-2)
        return attn_weights

    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (1,)


class Transformer(Layer):

    def __init__(self, N=2, h=8, dk=None, dv=None, dff=None, dropout=0):
        self.N, self.h, self.dk, self.dv, self.dff, self.dropout = N, h, dk, dv, dff, dropout
        self.epsilon = K.epsilon() * K.epsilon()
        super(Transformer, self).__init__()

    def build(self, input_shape):
        d = input_shape.as_list()[-1]
        if self.dk==None:
            self.dk = d//self.h
        if self.dv==None:
            self.dv = d//self.h
        if self.dff==None:
            self.dff = 2*d
        self.Wq = self.add_weight(shape=(self.N, self.h, d, self.dk), name='Wq',
                                 initializer='glorot_uniform', trainable=True)
        self.Wk = self.add_weight(shape=(self.N, self.h, d, self.dk), name='Wk',
                                 initializer='glorot_uniform', trainable=True)
        self.Wv = self.add_weight(shape=(self.N, self.h, d, self.dv), name='Wv',
                                 initializer='glorot_uniform', trainable=True)
        self.Wo = self.add_weight(shape=(self.N, self.dv*self.h, d), name='Wo',
                                 initializer='glorot_uniform', trainable=True)
        self.W1 = self.add_weight(shape=(self.N, d, self.dff), name='W1',
                                 initializer='glorot_uniform', trainable=True)
        self.b1 = self.add_weight(shape=(self.N, self.dff), name='b1',
                                 initializer='zeros', trainable=True)
        self.W2 = self.add_weight(shape=(self.N, self.dff, d), name='W2',
                                 initializer='glorot_uniform', trainable=True)
        self.b2 = self.add_weight(shape=(self.N, d), name='b2',
                                 initializer='zeros', trainable=True)
        self.gamma = self.add_weight(shape=(2*self.N,), name='gamma',
                                 initializer='ones', trainable=True)
        self.beta = self.add_weight(shape=(2*self.N,), name='beta',
                                 initializer='zeros', trainable=True)
        super(Transformer, self).build(input_shape)

    def call(self, x, mask, mask_value=-1e-30):
        mask = K.expand_dims(mask, axis=-2)
        for i in range(self.N):
            # MHA
            mha_ops = []
            for j in range(self.h):
                q = K.dot(x, self.Wq[i,j,:,:])
                k = K.permute_dimensions(K.dot(x, self.Wk[i,j,:,:]), (0,2,1))
                v = K.dot(x, self.Wv[i,j,:,:])
                A = K.batch_dot(q,k)
                # Mask unobserved steps.
                A = mask*A + (1-mask)*mask_value
                # Mask for attention dropout.
                def dropped_A():
                    dp_mask = K.cast((K.random_uniform(shape=array_ops.shape(A))>=self.dropout), K.floatx())
                    return A*dp_mask + (1-dp_mask)*mask_value
                A = sc.smart_cond(K.learning_phase(), dropped_A, lambda: array_ops.identity(A))
                A = K.softmax(A, axis=-1)
                mha_ops.append(K.batch_dot(A,v))
            conc = K.concatenate(mha_ops, axis=-1)
            proj = K.dot(conc, self.Wo[i,:,:])
            # Dropout.
            proj = sc.smart_cond(K.learning_phase(), lambda: array_ops.identity(nn.dropout(proj, rate=self.dropout)),\
                                       lambda: array_ops.identity(proj))
            # Add & LN
            x = x+proj
            mean = K.mean(x, axis=-1, keepdims=True)
            variance = K.mean(K.square(x - mean), axis=-1, keepdims=True)
            std = K.sqrt(variance + self.epsilon)
            x = (x - mean) / std
            x = x*self.gamma[2*i] + self.beta[2*i]
            # FFN
            ffn_op = K.bias_add(K.dot(K.relu(K.bias_add(K.dot(x, self.W1[i,:,:]), self.b1[i,:])),
                           self.W2[i,:,:]), self.b2[i,:,])
            # Dropout.
            ffn_op = sc.smart_cond(K.learning_phase(), lambda: array_ops.identity(nn.dropout(ffn_op, rate=self.dropout)),\
                                       lambda: array_ops.identity(ffn_op))
            # Add & LN
            x = x+ffn_op
            mean = K.mean(x, axis=-1, keepdims=True)
            variance = K.mean(K.square(x - mean), axis=-1, keepdims=True)
            std = K.sqrt(variance + self.epsilon)
            x = (x - mean) / std
            x = x*self.gamma[2*i+1] + self.beta[2*i+1]
        return x

    def compute_output_shape(self, input_shape):
        return input_shape


def build_strats(D, max_len, V, d, N, he, dropout, forecast=False):
    demo = Input(shape=(D,))
    demo_enc = Dense(2*d, activation='tanh')(demo)
    demo_enc = Dense(d, activation='tanh')(demo_enc)
    varis = Input(shape=(max_len,))
    values = Input(shape=(max_len,))
    times = Input(shape=(max_len,))
    varis_emb = Embedding(V+1, d)(varis)
    cve_units = int(np.sqrt(d))
    values_emb = CVE(cve_units, d)(values)
    times_emb = CVE(cve_units, d)(times)
    comb_emb = Add()([varis_emb, values_emb, times_emb]) # b, L, d
#     demo_enc = Lambda(lambda x:K.expand_dims(x, axis=-2))(demo_enc) # b, 1, d
#     comb_emb = Concatenate(axis=-2)([demo_enc, comb_emb]) # b, L+1, d
    mask = Lambda(lambda x:K.clip(x,0,1))(varis) # b, L
#     mask = Lambda(lambda x:K.concatenate((K.ones_like(x)[:,0:1], x), axis=-1))(mask) # b, L+1
    cont_emb = Transformer(N, he, dk=None, dv=None, dff=None, dropout=dropout)(comb_emb, mask=mask)
    attn_weights = Attention(2*d)(cont_emb, mask=mask)
    fused_emb = Lambda(lambda x:K.sum(x[0]*x[1], axis=-2))([cont_emb, attn_weights])
    conc = Concatenate(axis=-1)([fused_emb, demo_enc])
    fore_op = Dense(V)(conc)
    op = Dense(1, activation='sigmoid')(fore_op)
    model = Model([demo, times, values, varis], op)
    if forecast:
        fore_model = Model([demo, times, values, varis], fore_op)
        return [model, fore_model]
    return model

# To tune:
# 1. Transformer parameters. (N, h, dropout)
# 2. Normalization

In [30]:
# lr, batch_size, samples_per_epoch, patience = 0.0005, 32, 102400, 5
lr, batch_size, samples_per_epoch, patience = 0.0005, 128, 1024, 5
d, N, he, dropout = 50, 2, 4, 0.5
model, fore_model =  build_strats(D, fore_max_len, V, d, N, he, dropout, forecast=True)
print (fore_model.summary())
# fore_model.compile(loss=forecast_loss, optimizer=Adam(lr))

fore_model.compile(loss=forecast_loss, optimizer=Adam(lr))

# Pretrain fore_model.
best_val_loss = np.inf
N_fore = len(fore_train_op)
fore_savepath = 'test.h5'

for e in range(10):
    e_indices = np.random.choice(range(N_fore), size=samples_per_epoch, replace=False)
    e_loss = 0
    pbar = tqdm(range(0, len(e_indices), batch_size))
    for start in pbar:
        ind = e_indices[start:start+batch_size]
        # pre-train data
        e_loss += fore_model.train_on_batch([ip[ind] for ip in fore_train_ip], fore_train_op[ind])
        pbar.set_description('%f'%(e_loss/(start+1)))
    val_loss = fore_model.evaluate(fore_valid_ip, fore_valid_op, batch_size=batch_size, verbose=1)
    print ('Epoch', e, 'loss', e_loss*batch_size/samples_per_epoch, 'val loss', val_loss)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        fore_model.save_weights(fore_savepath)
        best_epoch = e
    if (e-best_epoch)>patience:
        break

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 880)]                0         []                            
                                                                                                  
 input_3 (InputLayer)        [(None, 880)]                0         []                            
                                                                                                  
 input_4 (InputLayer)        [(None, 880)]                0         []                            
                                                                                                  
 embedding (Embedding)       (None, 880, 50)              2050      ['input_2[0][0]']             
                                                                                            



None


0.490349: 100%|██████████| 8/8 [00:31<00:00,  3.89s/it] 


Epoch 0 loss 54.980403423309326 val loss 93.58240509033203


0.431659: 100%|██████████| 8/8 [00:31<00:00,  3.95s/it] 


Epoch 1 loss 48.399733543395996 val loss 90.29300689697266


0.581912: 100%|██████████| 8/8 [00:30<00:00,  3.87s/it] 


Epoch 2 loss 65.24684810638428 val loss 88.76493835449219


0.555476: 100%|██████████| 8/8 [00:32<00:00,  4.07s/it] 


Epoch 3 loss 62.28277540206909 val loss 88.06761932373047


0.408634: 100%|██████████| 8/8 [00:33<00:00,  4.18s/it] 


Epoch 4 loss 45.81804275512695 val loss 87.7301254272461


0.431091: 100%|██████████| 8/8 [00:31<00:00,  3.96s/it] 


Epoch 5 loss 48.3360276222229 val loss 87.48186492919922


0.568964: 100%|██████████| 8/8 [00:31<00:00,  3.90s/it] 


Epoch 6 loss 63.79509401321411 val loss 87.41021728515625


0.401457: 100%|██████████| 8/8 [00:30<00:00,  3.84s/it] 


Epoch 7 loss 45.01338243484497 val loss 87.22779083251953


0.398299: 100%|██████████| 8/8 [00:30<00:00,  3.82s/it] 


Epoch 8 loss 44.65925741195679 val loss 87.08069610595703


0.400158: 100%|██████████| 8/8 [00:33<00:00,  4.20s/it] 


  4/153 [..............................] - ETA: 3:22 - loss: 52.8826

KeyboardInterrupt: 

## Pretraining for forecasting

In [None]:
lr, batch_size, samples_per_epoch, patience = 0.00003, 32, len(fore_train_op), 5
d, N, he, dropout = 75, 4, 4, 0.2
model, fore_model =  build_strats(D, fore_max_len, V, d, N, he, dropout, forecast=True)
print (fore_model.summary())
loss_function = forecast_mse
fore_model.compile(loss=loss_function, optimizer=Adam(learning_rate=lr))

# Pretrain fore_model.
best_val_loss = np.inf
N_fore = len(fore_train_op)
fore_savepath = 'mimic_iii_strats_all_feats.h5'
train_losses = []
val_losses = []
for e in range(10):
    e_indices = np.random.choice(range(N_fore), size=samples_per_epoch, replace=False)
    e_loss = 0
    pbar = tqdm(range(0, len(e_indices), batch_size))
    for start in pbar:
        ind = e_indices[start:start+batch_size]
        # pre-train data
        e_loss += fore_model.train_on_batch([ip[ind] for ip in fore_train_ip], fore_train_op[ind])
        pbar.set_description('%f'%(e_loss/(start+1)))
       # else:
        #    e_loss += fore_model.train_on_batch([ip[ind] for ip in fore_train_ip], fore_train_op[ind])
         #   pbar.set_description('%f'%(e_loss/(start+1*batch_size)))
    val_loss = fore_model.evaluate(fore_valid_ip, fore_valid_op, batch_size=batch_size, verbose=1)
    if loss_function == forecast_loss:
        print ('Epoch', e, 'loss', e_loss*batch_size/samples_per_epoch, 'val loss', val_loss)
        train_losses.append(e_loss*batch_size/samples_per_epoch)
    #else:
     #   print ('Epoch', e, 'mse_loss', e_loss/(samples_per_epoch/batch_size), 'val mse_loss', val_loss)
      #  train_losses.append(e_loss/(samples_per_epoch/batch_size))
    val_losses.append(val_loss)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        fore_model.save_weights(fore_savepath)
        best_epoch = e
    if (e-best_epoch)>patience:
        losses = pd.DataFrame({'train_loss': train_losses,'val_loss': val_losses})
        losses.to_csv(f"losses_{fore_savepath}.csv")
        break

Model: "model_5"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_10 (InputLayer)       [(None, 880)]                0         []                            
                                                                                                  
 input_11 (InputLayer)       [(None, 880)]                0         []                            
                                                                                                  
 input_12 (InputLayer)       [(None, 880)]                0         []                            
                                                                                                  
 embedding_2 (Embedding)     (None, 880, 75)              3000      ['input_10[0][0]']            
                                                                                            



None


nan:   1%|          | 5/673 [00:12<27:58,  2.51s/it]


KeyboardInterrupt: 