In [2]:
import numpy as np

import keras.backend as K
from keras.engine.topology import Layer
from keras.models import Model
from keras.layers import Input, Dense, LSTM, Bidirectional, Dropout, concatenate, multiply, Lambda, Reshape

Using TensorFlow backend.


In [5]:
WORD_EMBEDDING_DIM = 300
FC_DIM = 128
LSTM_UNITS = 512

In [7]:
premise_input = Input(shape=(None, WORD_EMBEDDING_DIM))
hypothesis_input = Input(shape=(None, WORD_EMBEDDING_DIM))

l_lstm1 = Bidirectional(LSTM(LSTM_UNITS, return_sequences=True))(premise_input)
l_lstm2 = Bidirectional(LSTM(LSTM_UNITS, return_sequences=True))(hypothesis_input)

l_max1 = Lambda(lambda x: K.max(x, axis=0))(l_lstm1)
l_max2 = Lambda(lambda x: K.max(x, axis=0))(l_lstm2)
l_max1 = Reshape((2 * LSTM_UNITS,))(l_max1)
l_max2 = Reshape((2 * LSTM_UNITS,))(l_max2)

l_abssub = Lambda(lambda x: K.abs(x[0] - x[1]))([l_max1, l_max2])
l_mul = multiply([l_max1, l_max2])

x = concatenate([l_max1, l_max2, l_abssub, l_mul])

x = Dropout(0.2)(x)
x = Dense(FC_DIM, activation='relu')(x)
x = Dropout(0.2)(x)
x = Dense(FC_DIM, activation='relu')(x)
x = Dropout(0.2)(x)
pred = Dense(1, activation='sigmoid')(x)

model = Model(inputs=[premise_input, hypothesis_input], outputs=pred)

model.compile(optimizer='adam', loss='binary_crossentropy')

model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            (None, None, 300)    0                                            
__________________________________________________________________________________________________
input_6 (InputLayer)            (None, None, 300)    0                                            
__________________________________________________________________________________________________
bidirectional_5 (Bidirectional) (None, None, 1024)   3330048     input_5[0][0]                    
__________________________________________________________________________________________________
bidirectional_6 (Bidirectional) (None, None, 1024)   3330048     input_6[0][0]                    
__________________________________________________________________________________________________
lambda_7 (

In [8]:
from keras.utils import plot_model
import pydot
plot_model(model, show_shapes=True, to_file='../../model.png')