# Understanding LSTM's Output

Author: Pierre Nugues

## Imports

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

2022-10-05 21:13:10.873225: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


A simple dataset

In [2]:
inputs = tf.random.normal([16, 8, 4]) # 16 sequences of 8 words and 4 categories for each word 
                                    # this is an approximation as it will not be a one-hot vector

2022-10-05 21:13:15.847857: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


And a LSTM with a 2D output

In [3]:
lstm = tf.keras.layers.LSTM(2) # Two units
output = lstm(inputs)

In [4]:
output.shape

TensorShape([16, 2])

We output a 2-D vector for each sentence

In [5]:
output

<tf.Tensor: shape=(16, 2), dtype=float32, numpy=
array([[ 0.04333418,  0.08327197],
       [-0.00737492,  0.33590046],
       [ 0.36525294,  0.28010353],
       [-0.03309057,  0.18950672],
       [ 0.24618253,  0.5596625 ],
       [ 0.08247661, -0.07042927],
       [ 0.05496754,  0.05022008],
       [ 0.13326074, -0.10119801],
       [-0.14637274,  0.23294334],
       [-0.18787374,  0.02777418],
       [ 0.0635227 ,  0.24398568],
       [-0.12099309,  0.17926204],
       [ 0.03342689,  0.49483967],
       [-0.09702206, -0.20789722],
       [-0.08253886, -0.04978706],
       [ 0.21518752,  0.14452478]], dtype=float32)>

Now we want the sequence of outputs (for the 8 words) and the states. We specify it as arguments and we return three variables. See here https://keras.io/api/layers/recurrent_layers/lstm/

In [6]:
lstm = tf.keras.layers.LSTM(2, return_sequences=True, return_state=True)
whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)

In [7]:
whole_seq_output.shape

TensorShape([16, 8, 2])

In [8]:
final_memory_state.shape

TensorShape([16, 2])

In [10]:
whole_seq_output[:, -1]

<tf.Tensor: shape=(16, 2), dtype=float32, numpy=
array([[ 0.13787897, -0.13928951],
       [ 0.01530278,  0.09404946],
       [ 0.39629328,  0.26815152],
       [ 0.06475419,  0.19927594],
       [ 0.47169614, -0.19335218],
       [ 0.63190985,  0.3149604 ],
       [ 0.14613482,  0.0071037 ],
       [ 0.2547831 ,  0.00311853],
       [-0.3286329 , -0.32885313],
       [-0.05864318,  0.33845478],
       [ 0.1869785 , -0.17483842],
       [ 0.06752494, -0.02351512],
       [ 0.36729264,  0.1313507 ],
       [ 0.00150263, -0.1020644 ],
       [ 0.14447618,  0.26301634],
       [ 0.28054118, -0.16965719]], dtype=float32)>

In [11]:
final_memory_state

<tf.Tensor: shape=(16, 2), dtype=float32, numpy=
array([[ 0.13787897, -0.13928951],
       [ 0.01530278,  0.09404946],
       [ 0.39629328,  0.26815152],
       [ 0.06475419,  0.19927594],
       [ 0.47169614, -0.19335218],
       [ 0.63190985,  0.3149604 ],
       [ 0.14613482,  0.0071037 ],
       [ 0.2547831 ,  0.00311853],
       [-0.3286329 , -0.32885313],
       [-0.05864318,  0.33845478],
       [ 0.1869785 , -0.17483842],
       [ 0.06752494, -0.02351512],
       [ 0.36729264,  0.1313507 ],
       [ 0.00150263, -0.1020644 ],
       [ 0.14447618,  0.26301634],
       [ 0.28054118, -0.16965719]], dtype=float32)>

In [12]:
final_carry_state

<tf.Tensor: shape=(16, 2), dtype=float32, numpy=
array([[ 0.2409493 , -0.22636688],
       [ 0.0236536 ,  0.20192736],
       [ 1.1535597 ,  0.44116223],
       [ 0.08893625,  0.29256758],
       [ 1.091116  , -0.29991293],
       [ 1.4106864 ,  0.58626485],
       [ 0.42405617,  0.01437281],
       [ 0.9223573 ,  0.00620121],
       [-0.4871088 , -0.5425641 ],
       [-0.06042434,  0.3887447 ],
       [ 0.34991127, -0.43809798],
       [ 0.15503758, -0.05240416],
       [ 0.8398984 ,  0.27844083],
       [ 0.00580859, -0.36735687],
       [ 0.52805126,  0.8296853 ],
       [ 0.61242664, -0.26442707]], dtype=float32)>