# Tutorial 8-3: The Memory Test â€“ "LSTM vs. The Vanishing Gradient"

**Course:** CSEN 342: Deep Learning  
**Topic:** Long-Term Memory, Vanishing Gradients, and Gating Mechanisms

## The Problem: Vanishing Gradients
In the lecture, we learned that Vanilla RNNs struggle with long sequences. When you backpropagate through time (BPTT) over many steps, the gradient signal is repeatedly multiplied by the weight matrix. 
* If the weights are small, the gradient **vanishes** to zero (the model stops learning).
* If the weights are large, the gradient **explodes** to infinity (the model destabilizes).

This makes it nearly impossible for a standard RNN to connect an input at $t=0$ to an output at $t=100$.

## The Solution: LSTMs
Long Short-Term Memory (LSTM) networks were explicitly designed to solve this. They introduce a **Cell State** ($C_t$) that acts as a superhighway for gradients, allowing information to flow unchanged over long distances.

---

## Part 1: The "Adding Problem" Task

To prove this, we will use a classic benchmark called **The Adding Problem** (Hochreiter & Schmidhuber, 1997).

**The Setup:**
We feed the network a long sequence (length **100**). Each time step has 2 input numbers:
1.  **Signal:** A random number between 0 and 1.
2.  **Marker:** A binary flag (0 or 1). 

**The Rules:**
* In the entire sequence of 100 steps, exactly **two** steps will have the Marker set to `1.0`.
* All other steps have the Marker set to `0.0`.
* The Goal: **Predict the sum of the two marked numbers.**

**Why is this hard?**
The first number might appear at $t=5$ and the second at $t=95$. The model must hold the first number in memory for 90 steps, ignoring all the noise in between, to perform the addition at the end.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

# --- 1. DATA GENERATION ---
class AddingProblemDataset(Dataset):
    def __init__(self, seq_len=100, num_samples=10000):
        self.num_samples = num_samples
        self.seq_len = seq_len
        
        # Features: [Batch, Seq, 2]
        # Channel 0: Random values in [0, 1]
        self.x = torch.rand((num_samples, seq_len, 2))
        
        # Channel 1: The markers. All 0s initially.
        self.x[:, :, 1] = 0.
        
        # Target: [Batch, 1]
        self.y = torch.zeros((num_samples, 1))
        
        for i in range(num_samples):
            # Pick two distinct random indices to mark
            indices = np.random.choice(seq_len, size=2, replace=False)
            
            # Mark them with 1.0
            self.x[i, indices[0], 1] = 1.0
            self.x[i, indices[1], 1] = 1.0
            
            # Calculate sum of the marked values
            val1 = self.x[i, indices[0], 0]
            val2 = self.x[i, indices[1], 0]
            self.y[i] = val1 + val2

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

# Settings
SEQ_LEN = 100
BATCH_SIZE = 64

dataset = AddingProblemDataset(seq_len=SEQ_LEN, num_samples=10000)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Dataset Created. Sequence Length: {SEQ_LEN}")

## Part 2: The Baseline (Random Guessing)

If the model cannot learn to remember, what will it do?

It will try to minimize the MSE (Mean Squared Error) by predicting the **expected value** (average) of the sum.
* Each number is uniform in $[0, 1]$, so the average is $0.5$.
* The sum of two numbers averages to $1.0$.
* The MSE of constantly guessing $1.0$ is approximately **0.167**.

If our model loss gets stuck at **0.167**, it means it has failed to learn the task.

## Part 3: The Model

We use a generic `RecurrentModel` wrapper so we can easily swap between `nn.RNN` and `nn.LSTM`.

### Critical Detail: Initialization
By default, PyTorch initializes the LSTM's "Forget Gate" bias to 0. This means it starts with a 50% chance of forgetting at every step. For a sequence of 100 steps, this causes gradients to vanish just like an RNN.

To give the LSTM a fighting chance, we perform **Forget Gate Bias Initialization**: we set it to `1.0`. This tells the untrained LSTM: *"By default, keep memory intact."*

In [None]:
class RecurrentModel(nn.Module):
    def __init__(self, model_type, input_size=2, hidden_size=64):
        super().__init__()
        self.model_type = model_type
        
        if model_type == 'RNN':
            self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        elif model_type == 'LSTM':
            self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
            self._init_weights()
            
        # Regression output (predicting a single number)
        self.fc = nn.Linear(hidden_size, 1)

    def _init_weights(self):
        # Initialize Forget Gate bias to 1.0 to enable long-term memory
        for name, param in self.rnn.named_parameters():
            if "bias" in name:
                n = param.shape[0]
                start, end = n//4, n//2
                param.data[start:end].fill_(1.0)

    def forward(self, x):
        # x: [Batch, Seq, 2]
        out, _ = self.rnn(x)
        
        # We only care about the prediction at the FINAL time step
        last_step = out[:, -1, :]
        return self.fc(last_step)

## Part 4: The Race

We will now train both models side-by-side.

* **Hyperparameters:** We use a slightly higher learning rate (0.005) to help the LSTM escape the "average guessing" plateau.
* **Clipping:** We clip gradients for the RNN to prevent "exploding gradients" from causing numerical instability. This forces the RNN to rely on standard gradient flow, exposing the vanishing gradient issue.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_model(model_type, epochs=10):
    model = RecurrentModel(model_type).to(device)
    # Learning Rate 0.005 helps punch through the initial plateau
    optimizer = optim.Adam(model.parameters(), lr=0.005)
    criterion = nn.MSELoss()
    
    loss_history = []
    
    print(f"\n--- Training {model_type} ---")
    for epoch in range(epochs):
        epoch_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            preds = model(x)
            loss = criterion(preds, y)
            loss.backward()
            
            # Clip gradients to keep training stable
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            epoch_loss += loss.item()
            
        avg_loss = epoch_loss / len(loader)
        loss_history.append(avg_loss)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")
        
    return loss_history

# Run the Race
rnn_losses = train_model('RNN', epochs=10)
lstm_losses = train_model('LSTM', epochs=10)

## Part 5: Visualization

Let's compare the learning curves.

In [None]:
plt.figure(figsize=(10, 6))

# Plot RNN
plt.plot(rnn_losses, label='Vanilla RNN', color='red', linestyle='--', linewidth=2)

# Plot LSTM
plt.plot(lstm_losses, label='LSTM', color='green', linewidth=3)

# Add baseline line (Random Guessing MSE approx 0.167)
plt.axhline(y=0.167, color='gray', linestyle=':', label='Baseline (Guessing)')

plt.title(f"Adding Problem Loss (Sequence Length = {SEQ_LEN})")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

### Conclusion

**1. The Vanilla RNN (Red Dashed Line):** 
Notice how the RNN loss flattens out around **0.167**. It has failed to find the markers. Because the sequence is so long (100 steps), the gradient signal vanishes before it reaches the start of the sequence. The RNN gives up and simply guesses the average ($1.0$), achieving the baseline loss.

**2. The LSTM (Green Solid Line):** 
The LSTM breaks through the baseline floor!  

This confirms that the LSTM successfully learned to:
1.  Identify the first marker.
2.  Store its value in the **Cell State**.
3.  Maintain that memory for ~50-90 steps.
4.  Add it to the second marker when it appears.

This ability to solve the "Vanishing Gradient" problem is why LSTMs (and Transformers) replaced simple RNNs in modern deep learning.