### Imports

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
# remember
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of input units
# K = number of output units

### Functions

In [3]:
class SimpleRNN(nn.Module):
    def __init__(self,n_inputs, n_hidden, n_rnnlayers, n_outputs):
        super(SimpleRNN,self).__init__()
        self.D = n_inputs
        self.M = n_hidden
        self.K = n_outputs
        self.L = n_rnnlayers
        
        self.rnn = nn.RNN(
            input_size = self.D,
            hidden_size = self.M,
            num_layers = self.L,
            nonlinearity = 'tanh',
            batch_first=True
        )
        self.fc = nn.Linear(self.M, self.K)
        
    def forward(self, X):
        h0 = torch.zeros(self.L, X.size(0), self.M)
        
        out,_ = self.rnn(X, h0)
        
#         out = self.fc(out[:, -1, :])
        out = self.fc(out)
        return out

### Create data

In [21]:
N=1
T=10
D=7
M=30
K=7
L=1

X = np.random.randn(N,T,D)

In [22]:
X.shape

(1, 10, 7)

### Create and train model

In [23]:
model = SimpleRNN(n_inputs=D, n_hidden=M, n_rnnlayers=L, n_outputs=K)
    
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)

In [24]:
out.shape

torch.Size([1, 10, 7])

In [25]:
# out

In [26]:
Yhats_torch = out.detach().numpy()

In [27]:
Yhats_torch

array([[[-0.23914573, -0.00561276,  0.27054527,  0.6078329 ,
          0.255951  , -0.02782328,  0.39205444],
        [ 0.05487784,  0.431114  ,  0.34497124,  0.3886513 ,
          0.16479775, -0.24054877, -0.01547538],
        [-0.02050099,  0.30601448,  0.34104246,  0.30659038,
          0.29468   ,  0.13449825,  0.09009729],
        [ 0.1727648 ,  0.30750796,  0.02667528,  0.00740698,
          0.04949115,  0.32272577, -0.17799291],
        [-0.4581695 ,  0.16674955,  0.35903767,  0.38640803,
          0.4121445 ,  0.3173287 ,  0.23487648],
        [-0.24155822,  0.2632517 ,  0.2728587 ,  0.2666041 ,
          0.27923316,  0.37510353,  0.14208375],
        [-0.05837795,  0.1799213 ,  0.3014036 , -0.12022686,
          0.21474952,  0.23753487,  0.2382418 ],
        [ 0.31464928, -0.36120874,  0.04347937, -0.20549104,
         -0.19457723,  0.07116429,  0.48618156],
        [ 0.18523552,  0.04194547,  0.11283469, -0.08910662,
         -0.12034521, -0.22207288,  0.31435058],
        [ 

In [28]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [29]:
Wo, bo = model.fc.parameters()

In [30]:
W_xh.shape

torch.Size([30, 7])

In [31]:
W_xh

Parameter containing:
tensor([[-0.1693, -0.0584,  0.1315,  0.1492,  0.1774,  0.0179, -0.0825],
        [ 0.0838,  0.1285, -0.0256, -0.1429, -0.1723, -0.0243,  0.1590],
        [ 0.0921,  0.1646, -0.1057,  0.1039, -0.0035,  0.1150, -0.0543],
        [-0.1292, -0.0429, -0.0068,  0.0958, -0.1379,  0.1303, -0.0438],
        [-0.1154, -0.1209,  0.0018, -0.0026, -0.1512, -0.0026, -0.0537],
        [-0.1702,  0.1213,  0.1310,  0.1497,  0.0071, -0.1064, -0.1695],
        [-0.1366, -0.1477,  0.1425, -0.0587,  0.1581,  0.1487, -0.0217],
        [-0.1035, -0.0943,  0.1173, -0.1525,  0.0705,  0.0579,  0.1400],
        [ 0.1338, -0.0316, -0.1282,  0.1580, -0.0074,  0.1405, -0.1400],
        [-0.1004,  0.0520,  0.0720,  0.0084,  0.0132,  0.0737,  0.1262],
        [-0.0939,  0.0527,  0.0434, -0.1265,  0.1596,  0.0971,  0.1022],
        [ 0.1368, -0.1394, -0.0703, -0.0129,  0.1570,  0.0897, -0.0517],
        [-0.1013, -0.1281,  0.0026,  0.1707,  0.0760,  0.1045, -0.1293],
        [ 0.0103, -0.1284, -0

In [32]:
W_xh = W_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_xh = b_xh.data.numpy()
b_hh = b_hh.data.numpy()

Wo = Wo.data.numpy()
bo = bo.data.numpy()


In [33]:
W_xh.shape, W_hh.shape, b_xh.shape, b_hh.shape, Wo.shape, bo.shape

((30, 7), (30, 30), (30,), (30,), (7, 30), (7,))

In [34]:
h_last = np.zeros(M)

x = X[0]
Yhats = np.zeros((T,K))

for t in range(T):
    h = np.tanh(x[t].dot(W_xh.T)+ b_xh + h_last.dot(W_hh.T)+ b_hh)
    y = h.dot(Wo.T) + bo
    Yhats[t] = y
    
    h_last = h
print(Yhats)

[[-0.23914574 -0.00561275  0.27054528  0.60783289  0.25595098 -0.02782329
   0.39205441]
 [ 0.05487785  0.43111399  0.34497122  0.3886513   0.16479776 -0.24054875
  -0.01547537]
 [-0.020501    0.30601449  0.34104243  0.30659039  0.29468006  0.13449826
   0.09009729]
 [ 0.17276479  0.30750795  0.0266753   0.007407    0.04949116  0.32272579
  -0.17799288]
 [-0.45816947  0.16674956  0.3590377   0.38640803  0.41214451  0.31732871
   0.23487646]
 [-0.24155823  0.26325167  0.27285873  0.26660407  0.27923314  0.37510352
   0.14208372]
 [-0.05837795  0.1799213   0.30140358 -0.12022682  0.2147495   0.23753485
   0.23824182]
 [ 0.31464929 -0.36120873  0.0434794  -0.20549106 -0.19457724  0.07116428
   0.48618155]
 [ 0.18523554  0.04194546  0.11283467 -0.08910665 -0.1203452  -0.22207288
   0.31435056]
 [ 0.20495101  0.00465598  0.02589001  0.01049995  0.09764157  0.14536115
   0.09077454]]


In [35]:
np.allclose(Yhats,Yhats_torch)

True