## Brief theory about LSTM (Long Short Term Memory), GRU (Gated Recurrent Unit) and simple RNN and their from scratch implementation in PyTorch

### LSTM

LSTMs are a special kind of RNN, capable of learning long-term dependencies. They were introduced by Hochreiter & Schmidhuber (1997), and were refined and popularized by many people in following work. Simple RNNs have exploding or vanishing gradients when dealing with long sequences, due to the nature of their recurrent neural network architecture. LSTMs are explicitly designed to avoid these long-term dependency problems. Remembering information for long periods of time is practically their default behavior.

LSTMs are a kind of RNN and function similarly to traditional RNNs, its Gating mechanism is what sets it apart. In LSTMS, there are four interacting layers:

- Input gate: The input gate adds information to the cell state. It decides what new information we are going to store in the cell state
- Forget gate: The forget gate removes the information that is no longer required by the model
- Output gate: Output Gate at LSTM selects the information to be shown as output
- Cell state: It carries the information across the different steps of the LSTM

While in the normal RNN cell, the input at a time-step and the hidden state from the previous time step is passed through a tanh activation function to obtain a new hidden state and output.

On the other hand, at each time step, the LSTM cell takes in 3 different pieces of information -- the current input data, the short-term memory from the previous cell (similar to hidden states in RNNs) and lastly the long-term memory.
The short-term memory is commonly referred to as the hidden state, and the long-term memory is usually known as the cell state.

**Input gate**:
Input gate functionality is achieved using 2 layers. The first layer can be seen as the filter which selects what information can pass through it and what information to be discarded. To create this layer, we pass the short-term memory and current input into a sigmoid function. The sigmoid function will transform the values to be between 0 and 1, with 0 indicating that part of the information is unimportant, whereas 1 indicates that the information will be used. 


The second layer takes the short term memory and current input and uses tanh activation which will transform the values to be between -1 and 1. This layer will be used to scale the information that is passed through the input gate. The output of the input gate is the product of the two layers.

Mathematically, the input gate is represented as:
first layer:
$i_t = \sigma(W_{xi}x_t + W_{hi}h_{t-1} + b_i)$
second layer with tanh activation:
$g_t = tanh(W_{xg}x_t + W_{hg}h_{t-1} + b_g)$

The output of the input gate is the product of the two layers:
$c_t = i_t * g_t$


**Forget gate**:
The forget gate is used to decide what information from the cell state we are going to throw away. The forget gate is similar to the input gate, except that it uses a sigmoid function to decide what information to throw away. The sigmoid function will transform the values to be between 0 and 1, with 0 indicating that part of the information is unimportant, whereas 1 indicates that the information will be used.

Forget gate functionality is acheived by multiplying the incoming long-term memory by a forget vector generated by the current input and incoming short-term memory.

Mathematical representation of the forget gate:
$f_t = \sigma(W_{xf}x_t + W_{hf}h_{t-1} + b_f)$

The output of the forget gate is the product of the incoming long-term memory and the forget vector:
$c_t = f_t * c_{t-1} + i_{input}$


**Output gate**:
The output gate is used to decide what information from the cell state we are going to output. The output gate will take the current input, the previous short-term memory, and the newly computed long-term memory to produce the new short-term memory/hidden state which will be passed on to the cell in the next time step.

Mathematical representation of the output gate:
$o_t = \sigma(W_{xo}x_t + W_{ho}h_{t-1} + b_o)$

The output of the output gate is the product of the output vector and the tanh activation of the long-term memory:
$h_t = o_t * tanh(c_t)$



In [1]:
from typing import Optional, Tuple

import torch
from torch import nn
from torch.nn import functional as F

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        """   
        Args:
            input_size (int): The number of expected features in the input x. 
                    For example, if the input is a one-hot vector, the input_size is the size of the vocabulary.
            hidden_size (int): The number of features in the hidden state h
        """
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.W_ih = nn.Linear(input_size, hidden_size * 4) # 4 for i, f, g, o
        self.W_hh = nn.Linear(hidden_size, hidden_size * 4)

    def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        gates = self.W_ih(x) + self.W_hh(h)
        i, f, g, o = gates.chunk(4, 1) # split along the 1st dimension into 4 tensors
        i = torch.sigmoid(i) # input gate
        f = torch.sigmoid(f) # forget gate
        g = torch.tanh(g) # cell gate
        o = torch.sigmoid(o) # output gate
        c = f * c + i * g # cell state is a weighted sum of the previous cell state and the current input
        h = o * torch.tanh(c)
        return h, c

class LSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers # number of layers
        self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)]) # the first layer
        for _ in range(num_layers - 1): # the rest of the layers
            self.cells.append(LSTMCell(hidden_size, hidden_size))

    def forward(self, x: torch.Tensor, h: Optional[torch.Tensor] = None, c: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if h is None: # if the hidden state is not provided, initialize it to 0
            h = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        if c is None: # if the cell state is not provided, initialize it to 0
            c = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        for i, cell in enumerate(self.cells):
            h[i], c[i] = cell(x, h[i], c[i])
            x = h[i]
        return h, c, x

#### BiLSTM: Bidirectional LSTM

BiLSTMs are a type of LSTM that has two LSTM layers, one that processes the input sequence in the forward direction and one that processes the input sequence in the backward direction. This allows BiLSTMs to capture long-term dependencies in the input sequence, regardless of whether the dependencies are present in the forward or backward direction.


In [None]:
# BI-LSTM
class BiLSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.forward_lstm = LSTM(input_size, hidden_size, num_layers) # forward LSTM
        self.backward_lstm = LSTM(input_size, hidden_size, num_layers) # backward LSTM

    def forward(self, x: torch.Tensor, h: Optional[torch.Tensor] = None, c: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if h is None:
            h = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size)
        if c is None:
            c = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size)
        forward_h, forward_c, forward_x = self.forward_lstm(x, h[:self.num_layers], c[:self.num_layers])
        backward_h, backward_c, backward_x = self.backward_lstm(x.flip(0), h[self.num_layers:], c[self.num_layers:])
        h = torch.cat([forward_h, backward_h], dim=0) # concatenate forward and backward hidden states along the 0th dimension
        c = torch.cat([forward_c, backward_c], dim=0) # concatenate forward and backward cell states along the 0th dimension
        x = torch.cat([forward_x, backward_x], dim=1) # concatenate forward and backward outputs along the 1st dimension
        return h, c, x

### GRU

GRU stands for Gate Recurrent Unit. It is a variant of RNN, introduced by Cho et al. in 2014. It is simpler than LSTM as it has fewer gates and is easier to train. GRU is also faster than LSTM as it has fewer parameters to learn.

It has two gates:

- Update gate: It decides how much past information to remember. It is similar to the input gate in LSTM. The update gate is a sigmoid function that takes as input the current input and the previous state of the GRU. The output of the update gate is a number between 0 and 1. A value of 0 means that the GRU should forget the past information and a value of 1 means that the GRU should remember all the past information.
- Reset gate: It decides how much past information to throw away. The reset gate is also a sigmoid function that takes as input the current input and the previous state of the GRU. The output of the reset gate is a number between 0 and 1. If the output of the reset gate is 0, then no new information is combined with the previous state. If the output of the reset gate is 1, then all of the new information is combined with the previous state.


The GRU cell is represented as:

$z_t = \sigma(W_{xz}x_t + W_{hz}h_{t-1} + b_z)$

$r_t = \sigma(W_{xr}x_t + W_{hr}h_{t-1} + b_r)$

$g_t = tanh(W_{xg}x_t + W_{hg}(r_t * h_{t-1}) + b_g)$

$h_t = (1 - z_t) * h_{t-1} + z_t * g_t$

These equations allow the GRU to selectively update its hidden state by learning to gate the flow of information through its reset and update gates.


In [6]:
# GRU Cell (Gated Recurrent Unit) from scratch

class GRUcell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.W_ih = nn.Linear(input_size, hidden_size * 3) # 3 for z, r, n
        self.W_hh = nn.Linear(hidden_size, hidden_size * 3)

    def forward(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        gates = self.W_ih(x) + self.W_hh(h)
        z, r, n = gates.chunk(3, 1) # split along the 1st dimension into 3 tensors
        z = torch.sigmoid(z) # update gate
        r = torch.sigmoid(r) # reset gate
        n = torch.tanh(r * n) # new gate
        h = (1 - z) * n + z * h # hidden state is a weighted sum of the previous hidden state and the new gate
        return h


# another way to implement GRU Cell
# 
class GRUCell2(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.i2h = nn.Linear(input_size, hidden_size * 2)
        self.r = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, input, hidden):
        x = self.i2h(input)
        r = self.r(input)
        h = self.h2h(hidden)
        # Calculate the update gate
        update_gate = self.sigmoid(x[:, :hidden_size])
        # Calculate the reset gate
        reset_gate = self.sigmoid(x[:, hidden_size:])
        # Calculate the candidate state
        candidate_state = update_gate * self.tanh(r * h)
        # Calculate the new hidden state
        new_hidden = (1 - reset_gate) * hidden + reset_gate * candidate_state
        return new_hidden

In [None]:
class GRU(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cells = nn.ModuleList([GRUcell(input_size, hidden_size)])
        for _ in range(num_layers - 1):
            self.cells.append(GRUcell(hidden_size, hidden_size))

    def forward(self, x: torch.Tensor, h: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        if h is None:
            h = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        for i, cell in enumerate(self.cells):
            h[i] = cell(x, h[i])
            x = h[i]
        return h, x

# Simple RNN

RNNs are able to learn long-term dependencies in sequential data. This is because the outputs of the neurons in the RNN are fed back into the network as inputs. This allows the RNN to remember information from previous inputs, which is important for tasks such as machine translation and speech recognition.

The RNN cell is represented as:

$h_t = tanh(W_{xh}x_t + W_{hh}h_{t-1} + b_h)$

$x_t$ is the input at time step t, $h_t$ is the output at time step t, $W_{xh}$ and $W_{hh}$ are the weights and $b_h$ is the bias. $h_{t-1}$ is the output at the previous time step.

Output is calculated as follows:

$y_t = W_{hy}h_t + b_y$


In [None]:
# RNN from scratch

class RNNcell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.W_ih = nn.Linear(input_size, hidden_size)
        self.W_hh = nn.Linear(hidden_size, hidden_size)

    def forward(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        h = torch.tanh(self.W_ih(x) + self.W_hh(h))
        return h

class RNN(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cells = nn.ModuleList([RNNcell(input_size, hidden_size)])
        for _ in range(num_layers - 1):
            self.cells.append(RNNcell(hidden_size, hidden_size))

    def forward(self, x: torch.Tensor, h: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        if h is None:
            h = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        for i, cell in enumerate(self.cells):
            h[i] = cell(x, h[i])
            x = h[i]
        return h, x