<a href="https://colab.research.google.com/github/pndang/Everything-PyTorch/blob/master/Understanding_RNN_Shapes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of hidden units
# K = number of output units

In [3]:
# Make some data
N = 1
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [4]:
X

array([[[ 0.53520312,  0.20970803, -0.50321226],
        [-1.09008526, -0.00332943, -1.04472737],
        [-2.31402019, -0.82112938,  0.7578737 ],
        [-0.50152503, -0.04526435, -0.10851617],
        [ 0.10997366, -1.15137782,  0.12738908],
        [-0.14816413, -1.01722317, -1.32630286],
        [-0.20901049,  0.11876847,  0.95649862],
        [ 0.74308443, -0.00485867, -0.2736849 ],
        [ 0.36260891,  1.14192522, -0.40897029],
        [ 0.43905177,  0.65212037,  0.31803129]]])

In [11]:
## Define simple RNN

class SimpleRNN(nn.Module):
  def __init__(self, n_inputs, n_hidden, n_outputs):
    super(SimpleRNN, self).__init__()
    self.D = n_inputs
    self.M = n_hidden
    self.K = n_outputs

    # Note: batch_first=True
    # applies the convention that our data will be of shape:
    # (num_samples, sequence_lenght, num_features)
    # rather than:
    # (sequence_length, num_samples, num_features)
    self.rnn = nn.RNN(
        input_size=self.D,
        hidden_size=self.M,
        nonlinearity='tanh',
        batch_first=True
    )

    self.fc = nn.Linear(self.M, self.K)

  def forward(self, X):
    # Initial hidden state
    h0 = torch.zeros(1, X.size(0), self.M)

    # Get RNN unit output
    # out is of size (N, T, M)
    # 2nd return val is hidden state at each hidden layer (no need now)
    # 1st return val is hidden state at each time instance
    out, _ = self.rnn(X, h0)

    out = self.fc(out)

    return out  # shape: N x T x K

In [12]:
# Instantiate model

model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [14]:
# Get output

inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[ 0.1429, -0.0677],
         [ 0.1511, -0.0783],
         [ 0.0952,  0.0529],
         [-0.0661,  0.2883],
         [ 0.0557, -0.1118],
         [ 0.0827, -0.2274],
         [-0.0364,  0.2808],
         [ 0.1148, -0.0902],
         [ 0.1539,  0.1009],
         [ 0.1964,  0.0308]]], grad_fn=<ViewBackward0>)

In [15]:
out.shape

torch.Size([1, 10, 2])

In [16]:
# Save for later

Yhats_torch = out.detach().numpy()

In [19]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [20]:
W_xh.shape

torch.Size([5, 3])

In [21]:
W_xh

Parameter containing:
tensor([[ 0.4272,  0.3821,  0.0461],
        [ 0.1543, -0.4468, -0.0435],
        [-0.2653, -0.1527,  0.1049],
        [ 0.1487,  0.0510,  0.1613],
        [ 0.1439,  0.3714,  0.3141]], requires_grad=True)

In [22]:
W_xh = W_xh.data.numpy()
W_xh

array([[ 0.42720553,  0.38207307,  0.04611127],
       [ 0.15429308, -0.44676226, -0.04346886],
       [-0.26530138, -0.15267031,  0.10485897],
       [ 0.14873274,  0.05099438,  0.16125834],
       [ 0.1439435 ,  0.3714183 ,  0.3140993 ]], dtype=float32)

In [23]:
b_xh = b_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_hh = b_hh.data.numpy()

In [24]:
W_xh.shape, b_xh.shape, W_hh.shape, b_hh.shape

((5, 3), (5,), (5, 5), (5,))

In [25]:
# Now get the FC layer weights
Wo, bo = model.fc.parameters()

In [26]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))

In [27]:
# See if we can replicate the output
h_last = np.zeros(M) # initial hidden state
x = X[0] # the one and only sample
Yhats = np.zeros((T, K)) # where we store the outputs

for t in range(T):
  h = np.tanh(x[t].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)
  y = h.dot(Wo.T) + bo # we only care about this value on the last iteration
  Yhats[t] = y

  # important: assign h to h_last
  h_last = h

# print the final output
print(Yhats)

[[ 0.14285763 -0.06773254]
 [ 0.15106135 -0.07828531]
 [ 0.09524156  0.05289371]
 [-0.06609863  0.28833235]
 [ 0.0557301  -0.11175401]
 [ 0.08265598 -0.22739134]
 [-0.03637174  0.28077986]
 [ 0.11481806 -0.09021091]
 [ 0.15394555  0.10088641]
 [ 0.19636961  0.03084321]]


In [28]:
# Check
np.allclose(Yhats, Yhats_torch)

True

In [30]:
# Bonus: calculate output for multiple samples at once (N > 1)