In [1]:
import tensorflow as tf
from tensorflow import keras
from keras.layers import Dense, Concatenate, Input, Lambda, Reshape
from keras.layers import Embedding, LSTM, TimeDistributed
from keras.utils.vis_utils import plot_model
from keras import backend as K

def rescale(x):
	input_value = x[0]
	embedding = x[1]
	## force cat_code 0 represent missing data
	mask = K.cast(input_value == 0, dtype=K.floatx())
	return embedding * mask + tf.multiply(embedding, input_value) * (1 - mask)

### Embedding + LSTM seq2seq learning (assume a hypothetical data)

In [18]:
categorical_vars = {"admitType": 4, "sex": 3, "race": 10}
time_dim = None
continuous_dim = 4
cont_indct_dim = 2

embedding_size = 4

ins = []
sub_models = []

for cat_var in categorical_vars.keys():
	_input = Input(shape=(time_dim, 1,), name=cat_var)
	ins.append(_input)
	k = categorical_vars[cat_var]
	_cat_embed = TimeDistributed(Embedding(k, embedding_size, input_length=1), name=cat_var + "_embedded")(_input)
	_cat_embed = TimeDistributed(Reshape(target_shape=(embedding_size,)), name=cat_var + "_embedding_reshape")(_cat_embed)
	sub_models.append(_cat_embed)

cont_input = Input(shape=(time_dim, continuous_dim,), name="cont_var")
cont_input_indct = Input(shape=(time_dim, cont_indct_dim,), name="cont_indct_var")
ins.append(cont_input)
ins.append(cont_input_indct)

scaled_embeds = []
for i in range(cont_indct_dim):
	_input = TimeDistributed(Lambda(lambda x: x[:, i+2]), name="slice_cont_" + str(i))(cont_input)
	_input = Lambda(lambda x: tf.expand_dims(x, -1), name="reshape_" + str(i))(_input)
	_input_indct = TimeDistributed(Lambda(lambda x: x[:, i]), name="slice_cont_indct_" + str(i))(cont_input_indct)
	_input_indct = Lambda(lambda x: tf.expand_dims(x, -1), name="reshape_indct_" + str(i))(_input_indct)
	_cont_indct_embed = TimeDistributed(Embedding(2, embedding_size, input_length=1),  name="continuous_indct_embedded_" + str(i))(_input_indct)
	_cont_indct_embed = TimeDistributed(Reshape(target_shape=(embedding_size,)), name="continuous_indct_embedding_reshape_" + str(i))(_cont_indct_embed)
	_scaled_embed = TimeDistributed(Lambda(rescale), name="cont_rescale_" + str(i))([_input, _cont_indct_embed])
	scaled_embeds.append(_scaled_embed)

sub_cont = TimeDistributed(Concatenate(name="concat_scaled_embedding"))(scaled_embeds)

sub_models.append(sub_cont)

x = TimeDistributed(Concatenate(name="concat_embedding"))(sub_models)
x = LSTM(32, return_sequences=True)(x)
x = LSTM(32, return_sequences=True)(x)
outs = TimeDistributed(Dense(1, activation="sigmoid"), name="output")(x)
ehr_model = keras.Model(inputs=ins, outputs=outs)

In [19]:
ehr_model.save("ehr_model.h5")

