[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shang-vikas/series1-coding-exercises/blob/main/exercises/blog-05/exercise-02.ipynb)

# Visualizing Vanishing Gradients

Good. Now we make vanishing gradients visible, not theoretical.

We'll instrument your IMDB RNN so you can see how gradient strength decays through time.

This turns "vanishing gradient" from a sentence into a measurement.

## ðŸ”¬ Goal

We want to answer:

**When training on a long review, does the gradient reaching early words become smaller than the gradient for recent words?**

If yes â†’ RNN forgets early information.

## ðŸ§  Strategy

We will:

1. Take one long batch
2. Run forward pass
3. Backprop once
4. Measure gradient magnitude of hidden states at each timestep

Because gradients must flow backward through time.

If they shrink â†’ vanishing gradient.

## ðŸ§ª Step 1 â€” Modify Forward to Return All Hidden States

Change your model slightly:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
import matplotlib.pyplot as plt

# Reload data (assuming you've run exercise-01)
train_iter, test_iter = IMDB(split=('train', 'test'))
train_data = list(train_iter)
test_data = list(test_iter)

tokenizer = get_tokenizer("basic_english")

counter = Counter()
for label, text in train_data:
    tokens = tokenizer(text)
    counter.update(tokens)

vocab_size = 20000
most_common = counter.most_common(vocab_size - 2)
vocab = {word: idx+2 for idx, (word, _) in enumerate(most_common)}
vocab["<pad>"] = 0
vocab["<unk>"] = 1

def encode(text):
    tokens = tokenizer(text)
    return [vocab.get(token, vocab["<unk>"]) for token in tokens]

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class VanillaRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        x = self.embedding(x)
        out, hidden = self.rnn(x)
        
        # Enable gradient tracking on intermediate outputs
        out.retain_grad()
        self.saved_outputs = out  # save all timestep outputs
        
        final_hidden = hidden.squeeze(0)
        return self.fc(final_hidden)

Now `model.saved_outputs` contains:

- Shape: `(batch_size, sequence_length, hidden_dim)`
- All hidden states at every timestep

## ðŸ§ª Step 2 â€” Gradient Inspection Function

In [None]:
def collate_batch(batch):
    texts, labels = [], []
    
    for label, text in batch:
        encoded = torch.tensor(encode(text))
        texts.append(encoded)
        labels.append(1 if label == "pos" else 0)
    
    texts = pad_sequence(texts, batch_first=True)
    labels = torch.tensor(labels)
    
    return texts, labels

train_loader = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=collate_batch)
criterion = nn.BCEWithLogitsLoss()

In [None]:
def visualize_gradient_decay(model, loader):
    model.train()
    
    texts, labels = next(iter(loader))
    texts, labels = texts.to(device), labels.to(device).float()
    
    outputs = model(texts).squeeze()
    loss = criterion(outputs, labels)
    
    model.zero_grad()
    loss.backward()
    
    # Get gradients w.r.t. hidden outputs
    grads = model.saved_outputs.grad  # shape: batch, seq_len, hidden
    
    if grads is None:
        print("Enable requires_grad for saved_outputs")
        return
    
    # Average gradient magnitude per timestep
    grad_magnitudes = grads.abs().mean(dim=(0,2)).detach().cpu().numpy()
    
    plt.figure(figsize=(10, 6))
    plt.plot(grad_magnitudes)
    plt.title("Gradient Magnitude Across Time Steps")
    plt.xlabel("Time Step")
    plt.ylabel("Average Gradient Magnitude")
    plt.grid(True)
    plt.show()
    
    return grad_magnitudes

## ðŸ§ª Step 3 â€” Run Visualization

In [None]:
model = VanillaRNN(vocab_size, embed_dim=100, hidden_dim=128).to(device)
grad_magnitudes = visualize_gradient_decay(model, train_loader)

## ðŸ“‰ What You'll See

A plot like:

```
|\
| \
|  \
|   \
|    \
|     \______
|
+------------------>
```

- Large gradients near the end.
- Tiny gradients at early timesteps.

That's vanishing gradient.

Backprop must multiply through many Jacobians:

```
grad_t = grad_t+1 Ã— local_derivative
```

If `local_derivative < 1` on average â†’ exponential decay.

## ðŸ§ª Make It Worse (Increase Sequence Length)

Add truncation to enforce long sequences:

In [None]:
def collate_batch_long(batch):
    texts, labels = [], []
    
    for label, text in batch:
        encoded = torch.tensor(encode(text)[:400])  # force longer sequences
        texts.append(encoded)
        labels.append(1 if label == "pos" else 0)
    
    texts = pad_sequence(texts, batch_first=True)
    labels = torch.tensor(labels)
    
    return texts, labels

train_loader_long = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=collate_batch_long)

In [None]:
# Reinitialize model
model_long = VanillaRNN(vocab_size, embed_dim=100, hidden_dim=128).to(device)

# Visualize with longer sequences
grad_magnitudes_long = visualize_gradient_decay(model_long, train_loader_long)

You'll see:

- Steeper decay.
- Even smaller gradients at early timesteps.

## ðŸ§  What This Proves

- Early tokens barely receive gradient updates.
- Model struggles to learn dependencies from far past.
- Memory fragility is structural.

This is not tuning.

This is architecture.