### Imports

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
# remember
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of input units
# K = number of output units

### Functions

In [38]:
class SimpleRNN(nn.Module):
    def __init__(self,n_inputs, n_hidden, n_rnnlayers, n_outputs):
        super(SimpleRNN,self).__init__()
        self.D = n_inputs
        self.M = n_hidden
        self.K = n_outputs
        self.L = n_rnnlayers
        
        self.rnn = nn.RNN(
            input_size = self.D,
            hidden_size = self.M,
            num_layers = self.L,
            nonlinearity = 'tanh',
            batch_first=True
        )
        self.fc = nn.Linear(self.M, self.K)
        
    def forward(self, X):
        h0 = torch.zeros(self.L, X.size(0), self.M)
        
        out,_ = self.rnn(X, h0)
        
#         out = self.fc(out[:, -1, :])
        out = self.fc(out)
        return out

### Create data

In [39]:
N=1
T=10
D=3
M=5
K=2
L=1

X = np.random.randn(N,T,D)

In [40]:
X.shape

(1, 10, 3)

### Create and train model

In [41]:
model = SimpleRNN(n_inputs=D, n_hidden=M, n_rnnlayers=L, n_outputs=K)
    
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)

In [42]:
out.shape

torch.Size([1, 10, 2])

In [43]:
# out

In [44]:
Yhats_torch = out.detach().numpy()

In [45]:
Yhats_torch

array([[[ 0.37914371,  0.45375326],
        [ 0.24274841, -0.24644494],
        [ 0.2960476 ,  0.0516105 ],
        [ 0.48041022,  0.34107348],
        [ 0.3452308 ,  0.34777388],
        [ 0.2709691 , -0.17157316],
        [ 0.39632678,  0.18019316],
        [ 0.21948442, -0.12511605],
        [ 0.18743178, -0.35153332],
        [ 0.26475424,  0.21748504]]], dtype=float32)

In [46]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [47]:
Wo, bo = model.fc.parameters()

In [48]:
W_xh.shape

torch.Size([5, 3])

In [49]:
W_xh

Parameter containing:
tensor([[-0.3127, -0.1878, -0.3873],
        [-0.1836, -0.3284,  0.3695],
        [ 0.3084,  0.0173,  0.1247],
        [-0.1339, -0.3704, -0.3500],
        [ 0.2168, -0.1162, -0.1928]], requires_grad=True)

In [50]:
W_xh = W_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_xh = b_xh.data.numpy()
b_hh = b_hh.data.numpy()

Wo = Wo.data.numpy()
bo = bo.data.numpy()


In [51]:
W_xh.shape, W_hh.shape, b_xh.shape, b_hh.shape, Wo.shape, bo.shape

((5, 3), (5, 5), (5,), (5,), (2, 5), (2,))

In [52]:
h_last = np.zeros(M)

x = X[0]
Yhats = np.zeros((T,K))

for t in range(T):
    h = np.tanh(x[t].dot(W_xh.T)+ b_xh + h_last.dot(W_hh.T)+ b_hh)
    y = h.dot(Wo.T) + bo
    Yhats[t] = y
    
    h_last = h
print(Yhats)

[[ 0.37914368  0.45375327]
 [ 0.24274842 -0.24644494]
 [ 0.29604762  0.05161049]
 [ 0.48041021  0.3410735 ]
 [ 0.3452308   0.34777389]
 [ 0.2709691  -0.17157314]
 [ 0.39632677  0.18019319]
 [ 0.21948441 -0.12511606]
 [ 0.18743178 -0.35153334]
 [ 0.26475423  0.21748502]]


In [53]:
np.allclose(Yhats,Yhats_torch)

True