In [None]:
import pickle
import tempfile
import numpy as np
import tensorflow as tf
from get_data import *

In [None]:
data = read_data()
train_data = data[:3]
test_data = data[3:]

In [None]:
train_data, embed_matrix = preprocess_traindata(train_data)

In [None]:
train_data[0].shape

In [None]:
train_data[1].shape

In [None]:
train_data[2].shape

In [None]:
embedding =  tf.keras.layers.Embedding(embed_matrix.shape[0], output_dim=EMBEDDING_DIM, weights=[embed_matrix], input_length=MAX_SEQ_LEN, trainable=False)

In [None]:
# Define the input layers and its shapes for premise and hypothesis
premise = tf.keras.layers.Input(shape=(MAX_SEQ_LEN,), dtype='int32')
hypothesis = tf.keras.layers.Input(shape=(MAX_SEQ_LEN,), dtype='int32')

In [None]:
premise.shape, hypothesis.shape

In [None]:
# Embed the premise and hypothesis
premise_embedded = embedding(premise)
hypothesis_embedded = embedding(hypothesis)

In [None]:
premise_embedded.shape, hypothesis_embedded.shape

In [None]:
# Add a time distributed translation layer for better performance
# Time distributed layer applies the same Dense layer to each temporal slice of input
translation = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(100, activation='relu'))

In [None]:
# Apply the translation layer
premise_translated = translation(premise_embedded)
hypothesis_translated = translation(hypothesis_embedded)

In [None]:
premise_translated.shape, hypothesis_translated.shape

In [None]:
# Bidirectional LSTM layer
BiLSTM = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(100))

In [None]:
# Apply the bidirectional LSTM layer
premise_BiLSTM = BiLSTM(premise_translated)
hypothesis_BiLSTM = BiLSTM(hypothesis_translated)

In [None]:
premise_BiLSTM.shape, hypothesis_BiLSTM.shape

In [None]:
# Apply Batch normalization
premise_normalized = tf.keras.layers.BatchNormalization()(premise_BiLSTM)
hypothesis_normalized = tf.keras.layers.BatchNormalization()(hypothesis_BiLSTM)

In [None]:
# Concatenate the normalized premise and hypothesis and apply a dropout layer
train_input = tf.keras.layers.concatenate([premise_normalized, hypothesis_normalized])
train_input = tf.keras.layers.Dropout(0.2)(train_input)

In [None]:
train_input.shape

In [None]:
lam = tf.keras.regularizers.l2(l2=0)

train_input = tf.keras.layers.Dense(200, activation='tanh', kernel_regularizer=lam)(train_input)
train_input = tf.keras.layers.Dropout(0.2)(train_input)
train_input = tf.keras.layers.BatchNormalization()(train_input)

train_input = tf.keras.layers.Dense(200, activation='tanh', kernel_regularizer=lam)(train_input)
train_input = tf.keras.layers.Dropout(0.2)(train_input)
train_input = tf.keras.layers.BatchNormalization()(train_input)

train_input = tf.keras.layers.Dense(200, activation='tanh', kernel_regularizer=lam)(train_input)
train_input = tf.keras.layers.Dropout(0.2)(train_input)
train_input = tf.keras.layers.BatchNormalization()(train_input)

In [None]:
# Define the output Dense layer
prediction = tf.keras.layers.Dense(3, activation='softmax')(train_input)

In [None]:
# Define the complete model
model = tf.keras.models.Model(inputs=[premise, hypothesis], outputs=prediction)

# Choosing an optimizer
optimizer = tf.keras.optimizers.RMSprop(lr=0.01)

# Compile the model and print out the model summary
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

In [None]:
learning_rate_reduction = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', patience=4, verbose=1, factor=0.5, min_lr=0.00001)

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=4)

# ModelCheckpoint callback to save the model with best performance
# A temporary file is created to which the intermediate model weights are stored
_, tmpfn = tempfile.mkstemp()
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(tmpfn, save_best_only=True, save_weights_only=True)

callbacks = [early_stopping, model_checkpoint, learning_rate_reduction]

In [None]:
# Train the model
history = model.fit(x=[train_data[0], train_data[1]], y=train_data[2], batch_size=256, epochs=5, validation_split=0.02, callbacks=callbacks)