many thanks to: 
    
    https://medium.com/analytics-vidhya/understanding-rnn-implementation-in-pytorch-eefdfdb4afdb
    
for a very good introduction to RNNs


In [None]:
import torch

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
from torch import nn

from torch import Tensor, dot, matmul

import torch.nn.functional as F

## **Basic Example**

In [None]:
seq = torch.FloatTensor([[3, 4, 5]])  

In [None]:
# Defining a basic RNN layer
rnn= nn.RNN(input_size=1, hidden_size=1, num_layers = 1, bias = False, batch_first=True)

RNN expects input sequences to be in a particular format. By setting batch_first = True, we set the input data format to be 'batch size, sequence length, # input features'

In [None]:
seq = seq.unsqueeze(2)
print(seq.shape)

print(seq)

With the correct input format, we can now pass the input to the RNN layer. The RNN layer provides 2 outputs


1.   All hidden states associated with a sequence, for all sequences in the batch
2.   Just the very last hidden state for a sequence, for all sequences in the batch



In [None]:
out_all,out_last = rnn(seq)

In [None]:
print(f"Out all shape : {out_all.shape}")

print(f"Out last shape : {out_last.shape}")


There are 2 ways that we can acess the weights of the RNN layer.

1.   Accessing individual parameters using their names `weight_hh_10`, `weight_1h_10` and so on.
2.   Using the `state_dict()` parameter to access all weights





In [None]:
rnn.weight_hh_l0

In [None]:
rnn.weight_hh_l0

In [None]:
rnn.state_dict()

### **Computing the output**

RNN layers essentially take in a sequence and compute outputs for each time point in the input sequence. The weights that are used for computation remain the same for all time points.

The basic equation governing the computation is given by :
$h_t = \text{tanh}(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})$

where 
$h_{t}$ represents the hidden state at time $t$








In [None]:
# Output states computed by the RNN layer
out_all,out_last

#### Hidden State 1

Note. Since this is the very first state (time = 1) and we dont have a hidden state preceding it, we assumne it be zero. Therefore, $h_{0}$ is taken to be 0.

In [None]:
seq[0][0]    # # The first input feature of the first sequence

In [None]:
wih = rnn.weight_ih_l0
whh = rnn.weight_hh_l0

x = seq[0][0] 

# Computing thw hidden state for time = 1
h1 = torch.tanh(Tensor(x*wih + whh*0))  
h1

#### Hidden State 2

In [None]:
x = seq[0][1] # The second input feature of the first sequence

h2 = torch.tanh(Tensor(x*wih + whh*h1))  
h2

#### Hidden State 3

In [None]:
x = seq[0][2] # The third and last input feature of the first sequence

h3 = torch.tanh(Tensor(x*wih + whh*h2))  
h3