## Pytorch LSTM implementation
#### Reference: 
- http://pytorch.org/docs/master/nn.html#torch.nn.LSTM

LSTM math in PyTorch
![pytorch_lstm_math.png](resources/pytorch_lstm_math.png)
These variables in the above equations are represented by 4 Matrices
![keras_lstm_math.png](resources/pytorch_matrices.png)

## Keras LSTM implementation
#### Reference
- https://keras.io/layers/recurrent/#lstm
- https://github.com/keras-team/keras/blob/2.1.3/keras/layers/recurrent.py

LSTM math in Keras
![keras_lstm_math.png](resources/keras_lstm_math.png)

#### Note
* Keras only have one bias for forget gate unlike pytorch which has two biases $b_{ig}$ & $b_{hg}$ 
* The kernel weights are transposed, while PyTorch they aren't
* The default `recurrent_activation='hard_sigmoid'`in Keras, while in pytorch is `sigmoid`, hence the default in keras needs to be overridden
* PyTorch implementation by default masks inputs with values 0.0 (though the behavior is not consitent), But we have Keras implementation to mask padding
* At the time of porting, keras has issue with using Masking along with Bidirectional layer - https://github.com/keras-team/keras/issues/3086 ,a short-cut fix is applied, where the output of the final Bi-LSTM is removed off of prediction for padded field, refer below behaviour before the shortcut fix
![keras_bidirectional_masking_issue.png](resources/keras_bidirectional_masking_issue.png)

In [1]:
import os
import numpy as np
import re

In [2]:
from keras.models import Sequential, Model, Input
from keras.layers import Dense, Activation, Bidirectional
from keras.layers import LSTM, Multiply, Lambda
from keras.layers.core import Masking
from keras import backend as K

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
from cove import MTLSTM

In [4]:
#Loading PyTorch CoVe model
pytorch_network = MTLSTM()

In [5]:
#Building Keras MTLSTM model without short-cut fix for keras masking + Bidirectional issue
keras_model = Sequential()
keras_model.add(Masking(mask_value=0.,input_shape=(None,300)))
keras_model.add(Bidirectional(LSTM(300, return_sequences=True, recurrent_activation='sigmoid', name='lstm1'),name='bidir_1'))
keras_model.add(Bidirectional(LSTM(300, return_sequences=True, recurrent_activation='sigmoid', name='lstm2'),name='bidir_2'))

In [6]:
#Building Keras MTLSTM model with short-cut fix for keras masking + Bidirectional issue
x = Input(shape=(None,300))
y = Masking(mask_value=0.,input_shape=(None,300))(x)
y = Bidirectional(LSTM(300, return_sequences=True, recurrent_activation='sigmoid', name='lstm1'),name='bidir_1')(y)
y = Bidirectional(LSTM(300, return_sequences=True, recurrent_activation='sigmoid', name='lstm2'),name='bidir_2')(y)

# These 2 layer are short-cut fix for the issue - 
y_rev_mask_fix = Lambda(lambda x: K.cast(K.any(K.not_equal(x, 0.), axis=-1, keepdims=True), K.floatx()))(x)
y = Multiply()([y,y_rev_mask_fix])

keras_model = Model(inputs=x,outputs=y)

In [7]:
keras_model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None, 300)    0                                            
__________________________________________________________________________________________________
masking_2 (Masking)             (None, None, 300)    0           input_1[0][0]                    
__________________________________________________________________________________________________
bidir_1 (Bidirectional)         (None, None, 600)    1442400     masking_2[0][0]                  
__________________________________________________________________________________________________
bidir_2 (Bidirectional)         (None, None, 600)    2162400     bidir_1[0][0]                    
__________________________________________________________________________________________________
lambda_1 (

In [8]:
wt_dict = pytorch_network.state_dict()

In [9]:
print('---Pytorch params to be ported---\n')
for param in wt_dict.keys():
    print(param+':'+str(wt_dict[param].numpy().shape))

---Pytorch params to be ported---

rnn.weight_ih_l0:(1200, 300)
rnn.weight_hh_l0:(1200, 300)
rnn.bias_ih_l0:(1200,)
rnn.bias_hh_l0:(1200,)
rnn.weight_ih_l0_reverse:(1200, 300)
rnn.weight_hh_l0_reverse:(1200, 300)
rnn.bias_ih_l0_reverse:(1200,)
rnn.bias_hh_l0_reverse:(1200,)
rnn.weight_ih_l1:(1200, 600)
rnn.weight_hh_l1:(1200, 300)
rnn.bias_ih_l1:(1200,)
rnn.bias_hh_l1:(1200,)
rnn.weight_ih_l1_reverse:(1200, 600)
rnn.weight_hh_l1_reverse:(1200, 300)
rnn.bias_ih_l1_reverse:(1200,)
rnn.bias_hh_l1_reverse:(1200,)


In [10]:
print('---Keras params to be set---\n')
for i in range(0,len(keras_model.layers)):
    for e in zip(keras_model.layers[i].trainable_weights, keras_model.layers[i].get_weights()):
        print('"%s" :%s' % (e[0].name,e[1].shape))

---Keras params to be set---

"bidir_1_1/forward_lstm1/kernel:0" :(300, 1200)
"bidir_1_1/forward_lstm1/recurrent_kernel:0" :(300, 1200)
"bidir_1_1/forward_lstm1/bias:0" :(1200,)
"bidir_1_1/backward_lstm1/kernel:0" :(300, 1200)
"bidir_1_1/backward_lstm1/recurrent_kernel:0" :(300, 1200)
"bidir_1_1/backward_lstm1/bias:0" :(1200,)
"bidir_2_1/forward_lstm2/kernel:0" :(600, 1200)
"bidir_2_1/forward_lstm2/recurrent_kernel:0" :(300, 1200)
"bidir_2_1/forward_lstm2/bias:0" :(1200,)
"bidir_2_1/backward_lstm2/kernel:0" :(600, 1200)
"bidir_2_1/backward_lstm2/recurrent_kernel:0" :(300, 1200)
"bidir_2_1/backward_lstm2/bias:0" :(1200,)


In [11]:
def get_wt(wt_dict,key,no_splits):
    wt_splits = []
    wt = wt_dict[key]
    wt = wt.numpy()
    hidden_size = wt.shape[0]/no_splits
    for i in range(0,no_splits):
        wt_splits.append(wt[int(hidden_size*i):int(hidden_size*(i+1))])
    return tuple(wt_splits)    

In [12]:
def set_keras_wt(model, node_name, new_wt):
    for layer in range(len(model.layers)):
        i = 0
        layer_wts = model.layers[layer].get_weights()
        for e in zip(model.layers[layer].trainable_weights, model.layers[layer].get_weights()):
            if re.compile(node_name).search(e[0].name):
            #if e[0].name == node_name:
                print('setting weigths for:'+e[0].name)
                layer_wts[i] = new_wt
                model.layers[layer].set_weights(layer_wts)
                break
            i+=1

In [13]:
def port_wts(torch_wts,torch_node_name,pytorch_layer,keras_model,keras_node_name, reverse=False):
    torch_reverse = ''
    if reverse:
        torch_reverse = '_reverse'
    W_ii,W_if,W_ig,W_io = get_wt(torch_wts,torch_node_name+'.weight_ih_l'+str(pytorch_layer)+torch_reverse,4)
    W_hi,W_hf,W_hg,W_ho = get_wt(torch_wts,torch_node_name+'.weight_hh_l'+str(pytorch_layer)+torch_reverse,4)
    b_ii,b_if,b_ig,b_io = get_wt(torch_wts,torch_node_name+'.bias_ih_l'+str(pytorch_layer)+torch_reverse,4)
    b_hi,b_hf,b_hg,b_ho = get_wt(torch_wts,torch_node_name+'.bias_hh_l'+str(pytorch_layer)+torch_reverse,4)
    
    b_i = b_ii + b_hi
    b_f = b_if + b_hf
    b_g = b_ig + b_hg
    b_o = b_io + b_ho
    
    kernel = []
    kernel.extend(W_ii)
    kernel.extend(W_if)
    kernel.extend(W_ig)
    kernel.extend(W_io)
    kernel = np.array(kernel)
    kernel = kernel.transpose()
    #print('kernel:'+str(kernel.shape))
    
    recurrent_kernel = []
    recurrent_kernel.extend(W_hi)
    recurrent_kernel.extend(W_hf)
    recurrent_kernel.extend(W_hg)
    recurrent_kernel.extend(W_ho)
    recurrent_kernel = np.array(recurrent_kernel)
    recurrent_kernel = recurrent_kernel.transpose()
    #print('recurrent_kernel:'+str(recurrent_kernel.shape))
    
    bias = []
    bias.extend(b_i)
    bias.extend(b_f)
    bias.extend(b_g)
    bias.extend(b_o)
    bias = np.array(bias)
    bias = bias.transpose()
    #print('bias:'+str(bias.shape))
    
    keras_direction = 'forward'
    if reverse:
        keras_direction = 'backward'
    keras_node_name = keras_node_name.format(keras_direction)
    
    set_keras_wt(keras_model,keras_node_name+'kernel:0',kernel)
    set_keras_wt(keras_model,keras_node_name+'bias:0',bias)
    set_keras_wt(keras_model,keras_node_name+'recurrent_kernel:0',recurrent_kernel)

In [14]:
#Porting Bi-LSTM layer - 1
port_wts(wt_dict,'rnn',0,keras_model,'bidir_1.*/{}_lstm1/',reverse=False)
port_wts(wt_dict,'rnn',0,keras_model,'bidir_1.*/{}_lstm1/',reverse=True)

setting weigths for:bidir_1_1/forward_lstm1/kernel:0
setting weigths for:bidir_1_1/forward_lstm1/bias:0
setting weigths for:bidir_1_1/forward_lstm1/recurrent_kernel:0
setting weigths for:bidir_1_1/backward_lstm1/kernel:0
setting weigths for:bidir_1_1/backward_lstm1/bias:0
setting weigths for:bidir_1_1/backward_lstm1/recurrent_kernel:0


In [15]:
#Porting Bi-LSTM layer - 2
port_wts(wt_dict,'rnn',1,keras_model,'bidir_2.*/{}_lstm2/',reverse=False)
port_wts(wt_dict,'rnn',1,keras_model,'bidir_2.*/{}_lstm2/',reverse=True)

setting weigths for:bidir_2_1/forward_lstm2/kernel:0
setting weigths for:bidir_2_1/forward_lstm2/bias:0
setting weigths for:bidir_2_1/forward_lstm2/recurrent_kernel:0
setting weigths for:bidir_2_1/backward_lstm2/kernel:0
setting weigths for:bidir_2_1/backward_lstm2/bias:0
setting weigths for:bidir_2_1/backward_lstm2/recurrent_kernel:0


In [16]:
#Saving ported model
keras_model.save('Keras_CoVe.h5')