# RNN Implementation testing

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
class BasicRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(BasicRNNCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        # weight input -> hidden layer
        self.W_ih = nn.Parameter(torch.randn(hidden_size, input_size)) #hidden_size X input_size
        self.b_ih = nn.Parameter(torch.zeros(hidden_size)) # hidden_size 

        self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size)) # hidden_size X hidden_size 
        self.b_hh = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, x, h_prev): # X -> (num_batch X input_size), h_prev -> hidden_size
        hx = torch.matmul(x, self.W_ih.T) + self.b_ih # dim -> 
        hh = torch.matmul(h_prev, self.W_hh.T) + self.b_hh
        opt_add = hx + hh
        h_new = torch.tanh(opt_add)
        return h_new

# Simple RNN

In [36]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.rnn_cell = BasicRNNCell(input_size, hidden_size)

        # fcc
        self.W_o = nn.Parameter(torch.randn(self.hidden_size, self.output_size))
        self.b_o = nn.Parameter(torch.zeros(self.output_size))
        
    
    def forward(self, input_seq): #input x -> (batch_size, seq_length, input_length)
        outputs = [] #all intermediate outputs

        self.batch_size = input_seq.size(0)
        self.seq_length = input_seq.size(1)

        # initialize hidden state for the first time step
        h_prev = torch.zeros(self.batch_size, self.hidden_size).to(input_seq.device)

        for t in range(0, self.seq_length):
            x = input_seq[:, t, :]
            h_prev = self.rnn_cell(x, h_prev)
            y = h_prev @ self.W_o + self.b_o #batch X 1
            outputs.append(y.unsqueeze(1))

        return torch.cat(outputs, axis=1), y

## Train simple RNN to predict next number of the sequence

We’ll train it on a toy task: predict the next number in a sequence.
Example:

Input sequence: [0, 1, 2, 3]

Target sequence: [1, 2, 3, 4]

In [37]:
seq_len = 5
batch_size = 16
input_size = 1
hidden_size = 16
epochs = 2000

In [38]:
# seq = torch.arange(0, seq_len + 1).float().unsqueeze(0).repeat(batch_size, 1)  # (batch, seq_len+1)
# x = seq[:, :-1].unsqueeze(-1) # input: (batch, seq_len, 1)
# y = seq[:, 1:].unsqueeze(-1)   # target: (batch, seq_len, 1)
# x.shape

In [39]:
torch.arange(5, dtype=torch.float).view(1, 5, 1)

tensor([[[0.],
         [1.],
         [2.],
         [3.],
         [4.]]])

In [40]:
# Pick random starting numbers for each batch (e.g. 0–99)
starts = torch.randint(0, 10, (batch_size, 1, 1), dtype=torch.float)

# Create offset sequence [0, 1, 2, ..., seq_len-1]
offsets = torch.arange(5, dtype=torch.float).view(1, 5, 1)
# Add start + offsets to form sequences

x = starts + offsets
y = x + 1

In [41]:
x.shape

torch.Size([16, 5, 1])

In [42]:
y.shape

torch.Size([16, 5, 1])

In [43]:
print(x[0])

tensor([[5.],
        [6.],
        [7.],
        [8.],
        [9.]])


In [44]:
print(y[0])

tensor([[ 6.],
        [ 7.],
        [ 8.],
        [ 9.],
        [10.]])


In [45]:
for i, j in zip(x[:2].flatten(), y[:2].flatten()):
    print(f" {i} -> {j}")

 5.0 -> 6.0
 6.0 -> 7.0
 7.0 -> 8.0
 8.0 -> 9.0
 9.0 -> 10.0
 8.0 -> 9.0
 9.0 -> 10.0
 10.0 -> 11.0
 11.0 -> 12.0
 12.0 -> 13.0


In [46]:
model = SimpleRNN(input_size, hidden_size, input_size)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [47]:
for epoch in range(epochs):
    # forward pass
    seq_out, _ = model(x)
    
    # calculate loss
    loss = loss_fn(seq_out, y)
    
    if epoch % 100 == 0:
        print(f"epoch: {epoch} - loss: {loss.item()}")
    
    # back propagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    

epoch: 0 - loss: 44.27752685546875
epoch: 100 - loss: 0.6626996397972107
epoch: 200 - loss: 0.11533312499523163
epoch: 300 - loss: 0.04268794506788254
epoch: 400 - loss: 0.01995529979467392
epoch: 500 - loss: 0.009928187355399132
epoch: 600 - loss: 0.005761100444942713
epoch: 700 - loss: 0.003685793373733759
epoch: 800 - loss: 0.002435525180771947
epoch: 900 - loss: 0.0016005728393793106
epoch: 1000 - loss: 0.0010358832078054547
epoch: 1100 - loss: 0.0006946452776901424
epoch: 1200 - loss: 0.0004893009318038821
epoch: 1300 - loss: 0.00037968310061842203
epoch: 1400 - loss: 0.0003187202091794461
epoch: 1500 - loss: 0.0002757919719442725
epoch: 1600 - loss: 0.0006003559101372957
epoch: 1700 - loss: 0.00022667348093818873
epoch: 1800 - loss: 0.0002157936105504632
epoch: 1900 - loss: 0.004146499093621969


In [48]:
test_x = torch.tensor([[[0.], [1], [2], [3], [4]]])
test_x.shape

torch.Size([1, 5, 1])

In [49]:
y_pred, _ = model(test_x)

In [50]:
test_x, y_pred

(tensor([[[0.],
          [1.],
          [2.],
          [3.],
          [4.]]]),
 tensor([[[0.9955],
          [2.0001],
          [3.0002],
          [4.0002],
          [5.0001]]], grad_fn=<CatBackward0>))