In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, GRU, Bidirectional

In [2]:
T = 8
D = 2
M = 3

In [3]:
X = np.random.randn(1, T, D)

In [4]:
print(X)

[[[-1.87171987 -0.08259338]
  [-0.14598867  1.97546405]
  [ 1.09927105 -0.20543904]
  [-0.32963817  0.40001905]
  [ 0.65010311  1.828023  ]
  [-0.02007692  0.2562293 ]
  [-0.93879494 -0.96806222]
  [ 0.62621956 -1.59947435]]]


## With Return Statees

In [5]:
def lstm1():
    input_ = Input(shape=(T,D))
    rnn = Bidirectional(LSTM(M, return_state = True))
    x = rnn(input_)
    model = Model(inputs=input_, outputs=x)
    o,h1,c1, h2, c2 = model.predict(X)
    print("Output State is : ", o)
    print("Hidden State Forward is : ", h1)
    print("Cell State Forward is : ", c1)
    print("Hidden State Backward is : ", h2)
    print("Cell State Backward is : ", c2)

In [6]:
lstm1()

Output State is :  [[ 0.06571475 -0.02487938 -0.09495559 -0.57755727 -0.05710207  0.21104956]]
Hidden State Forward is :  [[ 0.06571475 -0.02487938 -0.09495559]]
Cell State Forward is :  [[ 0.0891399  -0.07864401 -0.28848863]]
Hidden State Backward is :  [[-0.57755727 -0.05710207  0.21104956]]
Cell State Backward is :  [[-1.0757315  -0.17955284  0.32412967]]


## With Return States and Return Sequences

In [7]:
def lstm2():
    input_ = Input(shape=(T,D))
    rnn = Bidirectional(LSTM(M, return_state = True, return_sequences = True))
    x = rnn(input_)
    model = Model(inputs=input_, outputs=x)
    o,h1,c1, h2, c2 = model.predict(X)
    print("Output State is : ", o)
    print("Hidden State Forward is : ", h1)
    print("Cell State Forward is : ", c1)
    print("Hidden State Backward is : ", h2)
    print("Cell State Backward is : ", c2)

In [8]:
lstm2()

Output State is :  [[[ 0.01372506 -0.03622774  0.10887402 -0.1403001   0.07373115
    0.1629204 ]
  [-0.16764866  0.03792358 -0.0562262  -0.10975386 -0.06439384
   -0.12988992]
  [-0.04136924  0.1373438  -0.10176217 -0.04157387 -0.0303855
   -0.06267925]
  [-0.07859652  0.09187915 -0.08307807 -0.10378414  0.01536882
    0.00533444]
  [-0.18750283  0.15367079 -0.23146608 -0.05611564  0.02278277
    0.03266577]
  [-0.11024841  0.20375456 -0.13425054  0.06845403  0.08618236
    0.22069016]
  [ 0.04263048  0.15216765  0.01079943  0.2008963   0.10140317
    0.3885452 ]
  [ 0.12698233  0.13972294  0.0417225   0.16115035  0.02694408
    0.18056981]]]
Hidden State Forward is :  [[0.12698233 0.13972294 0.0417225 ]]
Cell State Forward is :  [[0.42005524 0.20415619 0.1350259 ]]
Hidden State Backward is :  [[-0.1403001   0.07373115  0.1629204 ]]
Cell State Backward is :  [[-0.19727977  0.2600824   0.25173226]]


## With Return States

In [9]:
def gru1():
    input_ = Input(shape=(T,D))
    rnn = Bidirectional(GRU(M, return_state = True))
    x = rnn(input_)
    model = Model(inputs=input_, outputs=x)
    o,h1, h2 = model.predict(X)
    print("Output State is : ", o)
    print("Hidden State Forward is : ", h1)
    print("Hidden State Backward is : ", h2)

In [10]:
gru1()

Output State is :  [[-0.20102169 -0.49340385 -0.2412385  -0.02453041  0.1660212  -0.14241311]]
Hidden State Forward is :  [[-0.20102169 -0.49340385 -0.2412385 ]]
Hidden State Backward is :  [[-0.02453041  0.1660212  -0.14241311]]


## With Return States and Sequences

In [11]:
def gru2():
    input_ = Input(shape=(T,D))
    rnn = Bidirectional(GRU(M, return_state = True, return_sequences = True))
    x = rnn(input_)
    model = Model(inputs=input_, outputs=x)
    o,h1, h2 = model.predict(X)
    print("Output State is : ", o)
    print("Hidden State Forward is : ", h1)
    print("Hidden State Backward is : ", h2)

In [12]:
gru2()

Output State is :  [[[ 0.36779603 -0.05892747  0.37499514  0.00782293 -0.3093549
   -0.13124232]
  [ 0.3498089   0.3648405   0.49861535  0.01299037 -0.7440536
   -0.08855169]
  [ 0.0563151   0.22221264 -0.06354837  0.02157985 -0.12533581
   -0.10200525]
  [ 0.0804134   0.18652844  0.13313028 -0.00210095 -0.38661045
   -0.09765919]
  [ 0.05825904  0.4557331   0.24822682 -0.01069621 -0.650729
   -0.0178827 ]
  [-0.00227847  0.3034046   0.25155288 -0.00326257  0.01776735
    0.04827017]
  [ 0.07764572  0.13250211  0.24791597  0.01268075  0.31943598
    0.03315102]
  [-0.22444445  0.00141366 -0.47074276  0.02759653  0.17409304
    0.00352484]]]
Hidden State Forward is :  [[-0.22444445  0.00141366 -0.47074276]]
Hidden State Backward is :  [[ 0.00782293 -0.3093549  -0.13124232]]
