In [1]:
import torch
import torch.nn as nn
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
sns.set_style('whitegrid')

In [2]:
# N = number of samples
# T = sequence lenght
# D = number of input features
# M = number of hidden units
# K = number of output units

In [3]:
# Make some data
N = 1
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [10]:
# Make a RNN
# DEFINE SimpleRNN
class SimpleRNN(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_outputs):
        super(SimpleRNN, self).__init__()
        self.D = n_inputs
        self.M = n_hidden
        self.K = n_outputs
        
        # note: batch_first=True
        # applies the convention that our data will be of shape: (num_samples, sequence_length, num_features)
        # rather than: (sequence_length, num_samples, num_features)
        self.rnn = nn.RNN(
            input_size = self.D,
            hidden_size = self.M, 
            nonlinearity = 'tanh',
            batch_first = True)
        self.fc = nn.Linear(self.M, self.K)
            
    def forward(self, X):
        # initial hidden states
        h0 = torch.zeros(1, X.size(0), self.M)
        
        # get RNN unit output
        # out is of size (N, T, M)
        # 2nd return value is hidden states at each hidden layer, we don't need those now
        out, _ = self.rnn(X, h0)
        
        # we only want h(T) at the final time step
        # N x M => N x K
        out = self.fc(out)
        return out

In [11]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M,  n_outputs=K)

In [12]:
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[-0.6770,  0.1178],
         [-0.5568,  0.1199],
         [-0.7185,  0.1907],
         [-0.3239, -0.0437],
         [-0.5515,  0.3637],
         [-0.7874,  0.1112],
         [-0.5464,  0.0395],
         [-0.4709,  0.0393],
         [-0.5123,  0.3385],
         [-0.6788,  0.2819]]], grad_fn=<AddBackward0>)

In [13]:
out.shape

torch.Size([1, 10, 2])

In [14]:
# save for later
Yhats_batch = out.detach().numpy()

In [41]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [42]:
W_xh.shape

torch.Size([5, 3])

In [43]:
W_xh

Parameter containing:
tensor([[-0.2837,  0.0281, -0.0205],
        [ 0.1065,  0.2600, -0.2244],
        [-0.4189,  0.4109, -0.3886],
        [ 0.1018, -0.3114,  0.0292],
        [-0.2252,  0.3979, -0.4233]], requires_grad=True)

In [44]:
W_xh = W_xh.detach().numpy()
W_hh = W_hh.detach().numpy()
b_xh = b_xh.detach().numpy()
b_hh = b_hh.detach().numpy()

In [45]:
Wo, bo = model.fc.parameters()

In [46]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()

In [47]:
Wo.shape, bo.shape

((2, 5), (2,))

In [48]:
h_last = np.zeros(M)
x = X[0]
Yhats = np.zeros((T, K))

In [49]:
for t in range(T):
    h = np.tanh(x[t].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)
    y = h.dot(Wo.T)
    Yhats[t] = y
    
    h_last = h
    
print(Yhats)

[[-0.25046709  0.11572695]
 [-0.13028103  0.11788034]
 [-0.29200551  0.18869281]
 [ 0.10264479 -0.04577346]
 [-0.12499333  0.36167691]
 [-0.36083769  0.10918696]
 [-0.11992737  0.037405  ]
 [-0.04438106  0.03722227]
 [-0.0858161   0.33643836]
 [-0.25231524  0.27986977]]


In [50]:
# Check
np.allclose(Yhats, Yhats_batch)

False