# Live Coding LSTM TensorFlow 2.x

In [1]:
import tensorflow as tf

In [50]:
batch_size = 4
sequence_length = 5
input_size = 30
output_size = 20

x = tf.random.uniform((batch_size, sequence_length, input_size))

In [51]:
x.shape

TensorShape([4, 5, 30])

In [52]:
# LSTM's input: [batch_size, sequence_length, input_size]
# LSTM's output 1 : [batch_size, sequence_length, output_size]
#        output 2 : [batch_size, output_size]

In [54]:
xt.shape

TensorShape([4, 30])

In [55]:
wf = tf.random.uniform((input_size, output_size))
wi = tf.random.uniform((input_size, output_size))
wo = tf.random.uniform((input_size, output_size))
wc = tf.random.uniform((input_size, output_size))

uf = tf.random.uniform((output_size, output_size))
ui = tf.random.uniform((output_size, output_size))
uo = tf.random.uniform((output_size, output_size))
uc = tf.random.uniform((output_size, output_size))

bf = tf.random.uniform((1, output_size))
bi = tf.random.uniform((1, output_size))
bo = tf.random.uniform((1, output_size))
bc = tf.random.uniform((1, output_size))

In [62]:
sequence_outputs = []
for i in range(sequence_length):
    if i == 0:
        xt = x[:, 0, :]
        ft = tf.sigmoid(tf.matmul(xt, wf) + bf)
        it = tf.sigmoid(tf.matmul(xt, wi) + bi)
        ot = tf.sigmoid(tf.matmul(xt, wo) + bo)
        cht = tf.tanh(tf.matmul(xt, wc) + bc)
        ct = it * cht
        ht = ot * tf.tanh(ct)
    else:
        xt = x[:, i, :]
        ft = tf.sigmoid(tf.matmul(xt, wf) + tf.matmul(ht, uf) + bf)
        it = tf.sigmoid(tf.matmul(xt, wi) + tf.matmul(ht, ui) + bi)
        ot = tf.sigmoid(tf.matmul(xt, wo) + tf.matmul(ht, uo) + bo)
        cht = tf.tanh(tf.matmul(xt, wc) + tf.matmul(ht, uc) + bc)
        ct = ft * ct + it * cht
        ht = ot * tf.tanh(ct)
    sequence_outputs.append(ht)

sequence_outputs = tf.stack(sequence_outputs)
sequence_outputs = tf.transpose(sequence_outputs, (1, 0, 2))

In [76]:
class CustomLSTM(tf.keras.layers.Layer):
    """
    # LSTM's input: [batch_size, sequence_length, input_size]
    # LSTM's output 1 : [batch_size, sequence_length, output_size]
    #        output 2 : [batch_size, output_size]
    """

    def __init__(self, output_size, return_sequences=False):
        super(CustomLSTM, self).__init__()
        self.output_size = output_size
        self.return_sequences = return_sequences

    def build(self, input_shape):
        super(CustomLSTM, self).build(input_shape)
        input_size = int(input_shape[-1])

        self.wf = self.add_weight('wf', shape=(input_size, self.output_size))
        self.wi = self.add_weight('wi', shape=(input_size, self.output_size))
        self.wo = self.add_weight('wo', shape=(input_size, self.output_size))
        self.wc = self.add_weight('wc', shape=(input_size, self.output_size))

        self.uf = self.add_weight('uf', shape=(self.output_size, self.output_size))
        self.ui = self.add_weight('ui', shape=(self.output_size, self.output_size))
        self.uo = self.add_weight('uo', shape=(self.output_size, self.output_size))
        self.uc = self.add_weight('uc', shape=(self.output_size, self.output_size))

        self.bf = self.add_weight('bf', shape=(1, self.output_size))
        self.bi = self.add_weight('bi', shape=(1, self.output_size))
        self.bo = self.add_weight('bo', shape=(1, self.output_size))
        self.bc = self.add_weight('bc', shape=(1, self.output_size))

    def call(self, x):
        sequence_outputs = []
        for i in range(sequence_length):
            if i == 0:
                xt = x[:, 0, :]
                ft = tf.sigmoid(tf.matmul(xt, self.wf) + self.bf)
                it = tf.sigmoid(tf.matmul(xt, self.wi) + self.bi)
                ot = tf.sigmoid(tf.matmul(xt, self.wo) + self.bo)
                cht = tf.tanh(tf.matmul(xt, self.wc) + self.bc)
                ct = it * cht
                ht = ot * tf.tanh(ct)
            else:
                xt = x[:, i, :]
                ft = tf.sigmoid(tf.matmul(xt, self.wf) + tf.matmul(ht, self.uf) + self.bf)
                it = tf.sigmoid(tf.matmul(xt, self.wi) + tf.matmul(ht, self.ui) + self.bi)
                ot = tf.sigmoid(tf.matmul(xt, self.wo) + tf.matmul(ht, self.uo) + self.bo)
                cht = tf.tanh(tf.matmul(xt, self.wc) + tf.matmul(ht, self.uc) + self.bc)
                ct = ft * ct + it * cht
                ht = ot * tf.tanh(ct)
            sequence_outputs.append(ht)

        sequence_outputs = tf.stack(sequence_outputs)
        sequence_outputs = tf.transpose(sequence_outputs, (1, 0, 2))
        if self.return_sequences:
            return sequence_outputs
        return sequence_outputs[:, -1, :]

In [73]:
x = tf.random.uniform((batch_size, sequence_length, input_size))

In [77]:
lstm = CustomLSTM(output_size=output_size)

In [79]:
lstm(x)

<tf.Tensor: shape=(4, 20), dtype=float32, numpy=
array([[ 0.30852655, -0.21745056,  0.10692555,  0.11597968,  0.04500321,
        -0.37364006, -0.00705841, -0.03053965,  0.14263612,  0.04566151,
        -0.09337032,  0.00208508, -0.22490782,  0.0173543 ,  0.01354474,
         0.07829211,  0.07962556,  0.34995866, -0.27240828, -0.09687887],
       [ 0.1493266 , -0.36569503,  0.01223969,  0.00075458, -0.01858684,
        -0.33774638,  0.04054861, -0.10191229,  0.23654637,  0.04943463,
         0.12264866,  0.01052341, -0.20794514,  0.12941153,  0.09601486,
         0.17277044,  0.17709456,  0.27750972, -0.24403755, -0.10510138],
       [ 0.16262454, -0.06469484, -0.02235889, -0.01085542, -0.27742815,
        -0.3049222 , -0.02507741, -0.11945379,  0.19139498,  0.04464089,
         0.04118394,  0.17346984, -0.19876944,  0.09176239,  0.0059094 ,
         0.21667516,  0.20115265,  0.2884535 , -0.21941884, -0.15643002],
       [ 0.21765962, -0.4575573 ,  0.14168835, -0.19112496,  0.059621  ,

In [80]:
model = tf.keras.Sequential([
    CustomLSTM(output_size=32),
    tf.keras.layers.Dense(2, activation='softmax')
])
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    optimizer=tf.keras.optimizers.Adam()
)

In [81]:
x_batch = tf.random.uniform((batch_size, sequence_length, input_size))
y_batch = tf.random.uniform((batch_size,), maxval=2, dtype=tf.int32)

In [85]:
model.train_on_batch(x_batch, y_batch)

0.614313542842865

In [87]:
model.variables[0]

<tf.Variable 'custom_lstm_3/wf:0' shape=(30, 32) dtype=float32, numpy=
array([[-0.25377417,  0.1692809 , -0.11199328,  0.20707953, -0.13842864,
        -0.11177021, -0.21409385,  0.09249356,  0.11510239, -0.03803512,
         0.24835561,  0.22213687,  0.03385825,  0.23015162, -0.00721113,
         0.09036292, -0.08538435, -0.30503562, -0.19663627,  0.25471726,
        -0.04732769,  0.10959946,  0.03748435,  0.05714989, -0.01627726,
        -0.0712833 , -0.10199174, -0.24439244,  0.18501475,  0.1931884 ,
        -0.14674157,  0.2976991 ],
       [ 0.05983765, -0.30675378, -0.15261371,  0.29464987, -0.30049977,
        -0.03310706, -0.06632823, -0.21576302, -0.2265486 ,  0.07686362,
        -0.21351254,  0.10314063, -0.05508944, -0.12001041, -0.30482033,
         0.1528233 , -0.17421727,  0.18781576, -0.02743551, -0.12909791,
        -0.25794446,  0.09066619, -0.02472133,  0.21310787, -0.09038031,
         0.09551661, -0.06014815, -0.15544894, -0.06312753, -0.27809623,
         0.0925460

In [88]:
x_data = tf.random.uniform((batch_size * 1000, sequence_length, input_size))
y_data = tf.random.uniform((batch_size * 1000,), maxval=2, dtype=tf.int32)

In [89]:
model.fit(x_data, y_data, batch_size=4)



<tensorflow.python.keras.callbacks.History at 0x7f6bfcc757f0>

In [90]:
model.fit(x_data, y_data, batch_size=4)



<tensorflow.python.keras.callbacks.History at 0x7f6bfc42d1f0>

In [91]:
model.fit(x_data, y_data, batch_size=4)



<tensorflow.python.keras.callbacks.History at 0x7f6bfcc777c0>

In [92]:
model.variables[0]

<tf.Variable 'custom_lstm_3/wf:0' shape=(30, 32) dtype=float32, numpy=
array([[-0.3787613 ,  0.13099906, -0.19011065,  0.20315205, -0.17168456,
        -0.15743671, -0.2487079 ,  0.06384558,  0.08235257, -0.11096092,
         0.08081412,  0.20376322, -0.0493223 ,  0.18933591, -0.0403847 ,
         0.10564984, -0.17890741, -0.35125497, -0.2018877 ,  0.22142582,
        -0.10761657,  0.07634663,  0.05353941,  0.0397033 , -0.05405599,
        -0.06989753, -0.11654466, -0.34448513,  0.10717814,  0.14334589,
        -0.16894601,  0.21672042],
       [ 0.00998725, -0.32890332, -0.24561387,  0.29237375, -0.3823997 ,
        -0.04255546, -0.12086585, -0.21965371, -0.17488421,  0.00110312,
        -0.39397702,  0.1439992 , -0.12985045, -0.14595012, -0.3079283 ,
         0.22138529, -0.28265324,  0.20770982, -0.08039226, -0.18642919,
        -0.34764668,  0.02575475, -0.01740637,  0.2522215 , -0.13832994,
         0.0817381 , -0.08812083, -0.2853523 , -0.13098139, -0.3534873 ,
         0.0628147