<a href="https://colab.research.google.com/github/rhiosutoyo/Teaching-Deep-Learning-and-Its-Applications/blob/main/8_3_lstm_variants.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LSTM Variants: Peephole Connections, GRU, and biLSTM

## Definition
LSTM (Long Short-Term Memory) networks are a type of recurrent neural network (RNN) that are designed to better capture long-term dependencies in sequential data.


In [1]:
!pip install torch

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-

In [2]:
import torch
import torch.nn as nn

In [3]:
# Generate some random data
batch_size = 32
seq_length = 10
input_dim = 8
hidden_dim = 16

x = torch.randn(batch_size, seq_length, input_dim)

In [4]:
# Standard LSTM
class StandardLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(StandardLSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        out, _ = self.lstm(x)
        return out

## LSTM Variants
There are several variants of LSTM, each introducing modifications to improve performance or adapt to specific tasks.
1.   **Peephole Connections**: Enhance LSTM by allowing gates to access cell state, improving learning in timing tasks.
2.   **GRU (Gated Recurrent Unit)**: Simplified LSTM variant with combined cell and hidden states, and only two gates (reset and update).
3. **biLSTM (Bidirectional LSTM)**: Processes sequences in both forward and backward directions, leveraging past and future context.

# 1. Peephole Connections

## Definition
Peephole connections are an enhancement to the traditional LSTM architecture. In a standard LSTM, the gates (input gate, forget gate, and output gate) are controlled by the current input and the previous hidden state. Peephole connections add connections from the cell state to the gates, allowing the gates to also access the cell state. This can improve the LSTM’s ability to learn timing tasks where the precise intervals are important.

In [5]:
# Peephole LSTM
class PeepholeLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(PeepholeLSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.W_i = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.U_i = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.C_i = nn.Parameter(torch.Tensor(hidden_dim))

        self.W_f = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.U_f = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.C_f = nn.Parameter(torch.Tensor(hidden_dim))

        self.W_o = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.U_o = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.C_o = nn.Parameter(torch.Tensor(hidden_dim))

        self.W_c = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.U_c = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W_i)
        nn.init.xavier_uniform_(self.U_i)
        nn.init.zeros_(self.C_i)

        nn.init.xavier_uniform_(self.W_f)
        nn.init.xavier_uniform_(self.U_f)
        nn.init.zeros_(self.C_f)

        nn.init.xavier_uniform_(self.W_o)
        nn.init.xavier_uniform_(self.U_o)
        nn.init.zeros_(self.C_o)

        nn.init.xavier_uniform_(self.W_c)
        nn.init.xavier_uniform_(self.U_c)

    def forward(self, x):
        h_t = torch.zeros(x.size(0), self.hidden_dim).to(x.device)
        c_t = torch.zeros(x.size(0), self.hidden_dim).to(x.device)

        outputs = []
        for t in range(x.size(1)):
            x_t = x[:, t, :]
            i_t = torch.sigmoid(x_t @ self.W_i + h_t @ self.U_i + c_t * self.C_i)
            f_t = torch.sigmoid(x_t @ self.W_f + h_t @ self.U_f + c_t * self.C_f)
            o_t = torch.sigmoid(x_t @ self.W_o + h_t @ self.U_o + c_t * self.C_o)
            c_hat_t = torch.tanh(x_t @ self.W_c + h_t @ self.U_c)
            c_t = f_t * c_t + i_t * c_hat_t
            h_t = o_t * torch.tanh(c_t)
            outputs.append(h_t.unsqueeze(1))
        return torch.cat(outputs, dim=1)

# 2. GRU (Gated Recurrent Unit)

## Definition
The GRU is another type of RNN designed to solve the vanishing gradient problem and efficiently capture long-term dependencies, similar to LSTM but with a simpler architecture. GRUs combine the cell state and hidden state and have only two gates: a reset gate and an update gate.

In [6]:
# GRU
class GRU(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GRU, self).__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        out, _ = self.gru(x)
        return out

# 3. biLSTM (Bidirectional LSTM)

## Definition
A biLSTM processes the input data in both forward and backward directions, allowing the network to have information from both past and future contexts. This is particularly useful for tasks where context from both directions is important, such as language modeling or sequence tagging.

In [7]:
# BiLSTM
class BiLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(BiLSTM, self).__init__()
        self.bilstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)

    def forward(self, x):
        out, _ = self.bilstm(x)
        return out

# Instantiate and run the models

In [8]:
input_dim = 8
hidden_dim = 16

standard_lstm = StandardLSTM(input_dim, hidden_dim)
peephole_lstm = PeepholeLSTM(input_dim, hidden_dim)
gru = GRU(input_dim, hidden_dim)
bilstm = BiLSTM(input_dim, hidden_dim)

x = torch.randn(batch_size, seq_length, input_dim)

print("Standard LSTM output shape:", standard_lstm(x).shape)
print("Peephole LSTM output shape:", peephole_lstm(x).shape)
print("GRU output shape:", gru(x).shape)
print("BiLSTM output shape:", bilstm(x).shape)

Standard LSTM output shape: torch.Size([32, 10, 16])
Peephole LSTM output shape: torch.Size([32, 10, 16])
GRU output shape: torch.Size([32, 10, 16])
BiLSTM output shape: torch.Size([32, 10, 32])
