# Module 08: Bidirectional & Deep RNNs

**Building More Powerful Sequence Models**

---

## 1. Objectives

- âœ… Understand bidirectional RNNs
- âœ… Know when to use (and NOT use) bidirectional
- âœ… Build deep/stacked RNNs
- âœ… Implement residual connections and layer norm

## 2. Prerequisites

- [Module 06: LSTM](../06_lstm/06_lstm.ipynb)
- [Module 07: GRU](../07_gru/07_gru.ipynb)

## 3. Intuition & Motivation

### Why Bidirectional?

For classification, we want context from **both** sides:

```
"The movie was not very good but the acting was _____"

Forward only â†’  Limited context about "acting"
Bidirectional â†’ Knows "not very good" AND what comes after
```

### When You CAN'T Use Bidirectional

- **Streaming/Online**: Can't wait for future tokens
- **Generation**: Don't have future tokens yet
- **Real-time**: Latency constraints

In [None]:
import torch
import torch.nn as nn
import numpy as np

print("Setup complete!")

## 4. Bidirectional RNNs

In [None]:
# PyTorch Bidirectional LSTM
bilstm = nn.LSTM(
    input_size=128,
    hidden_size=256,
    num_layers=1,
    batch_first=True,
    bidirectional=True  # <-- Key parameter
)

x = torch.randn(32, 50, 128)  # (batch, seq, features)
output, (h_n, c_n) = bilstm(x)

print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print("  â†’ Note: hidden_size * 2 = 512 (forward + backward concat)")
print(f"h_n: {h_n.shape}")
print("  â†’ Note: 2 states (forward and backward)")

In [None]:
# Manual Bidirectional Implementation
class BidirectionalLSTM(nn.Module):
    """Manual bidirectional LSTM."""
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.forward_lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.backward_lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
    
    def forward(self, x):
        # Forward pass
        out_fwd, (h_fwd, c_fwd) = self.forward_lstm(x)
        
        # Backward pass (flip, process, flip back)
        x_rev = torch.flip(x, [1])  # Reverse sequence
        out_bwd, (h_bwd, c_bwd) = self.backward_lstm(x_rev)
        out_bwd = torch.flip(out_bwd, [1])  # Flip back
        
        # Concatenate
        output = torch.cat([out_fwd, out_bwd], dim=-1)
        h_n = torch.cat([h_fwd, h_bwd], dim=0)
        
        return output, h_n

# Test
model = BidirectionalLSTM(128, 256)
out, h = model(x)
print(f"Manual BiLSTM Output: {out.shape}")

## 5. Combining Bidirectional Outputs

In [None]:
class BiLSTMClassifier(nn.Module):
    """Bidirectional LSTM for classification."""
    
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes, combine='concat'):
        super().__init__()
        self.combine = combine
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        
        # Output dimensions depend on combine method
        if combine == 'concat':
            fc_input = hidden_dim * 2
        else:  # 'sum' or 'mean'
            fc_input = hidden_dim
        
        self.fc = nn.Linear(fc_input, num_classes)
    
    def forward(self, x):
        embedded = self.embedding(x)
        _, (h_n, _) = self.lstm(embedded)
        
        # h_n: (2, batch, hidden) - forward and backward
        h_fwd = h_n[0]  # (batch, hidden)
        h_bwd = h_n[1]  # (batch, hidden)
        
        if self.combine == 'concat':
            combined = torch.cat([h_fwd, h_bwd], dim=-1)
        elif self.combine == 'sum':
            combined = h_fwd + h_bwd
        elif self.combine == 'mean':
            combined = (h_fwd + h_bwd) / 2
        
        return self.fc(combined)

# Test different combine methods
for method in ['concat', 'sum', 'mean']:
    m = BiLSTMClassifier(5000, 100, 128, 2, combine=method)
    x = torch.randint(0, 5000, (32, 50))
    out = m(x)
    print(f"{method}: {out.shape}")

## 6. Deep/Stacked RNNs

In [None]:
# Stacked LSTM
stacked_lstm = nn.LSTM(
    input_size=128,
    hidden_size=256,
    num_layers=3,  # <-- 3 layers stacked
    batch_first=True,
    dropout=0.3  # Dropout between layers
)

x = torch.randn(32, 50, 128)
output, (h_n, c_n) = stacked_lstm(x)

print(f"Output: {output.shape}")
print(f"h_n: {h_n.shape} (one per layer)")
print(f"\nTotal parameters: {sum(p.numel() for p in stacked_lstm.parameters()):,}")

In [None]:
# Deep LSTM with Residual Connections
class ResidualLSTM(nn.Module):
    """LSTM layer with residual connection."""
    
    def __init__(self, input_size, hidden_size, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)
        
        # Project input if sizes don't match
        self.proj = nn.Linear(input_size, hidden_size) if input_size != hidden_size else nn.Identity()
    
    def forward(self, x):
        residual = self.proj(x)
        out, _ = self.lstm(x)
        out = self.dropout(out)
        return self.layer_norm(out + residual)  # Residual connection

class DeepResidualLSTM(nn.Module):
    """Stack of LSTM layers with residual connections."""
    
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            ResidualLSTM(
                input_size if i == 0 else hidden_size,
                hidden_size,
                dropout
            ) for i in range(num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Test
deep_lstm = DeepResidualLSTM(128, 256, num_layers=4)
out = deep_lstm(torch.randn(32, 50, 128))
print(f"Deep Residual LSTM Output: {out.shape}")

## 7. ðŸ”¥ Real-World Usage

### BiLSTM-CRF for NER

Classic and still competitive architecture:
```
Embeddings â†’ BiLSTM â†’ CRF â†’ Entity Tags
```

### Stack Depth Guidelines

| Layers | Use Case |
|--------|----------|
| 1 | Simple tasks, limited data |
| 2-3 | Most NLP tasks |
| 4+ | Large data, add residuals |

### Tips
- Use **dropout between layers** (not within)
- Add **residual connections** for deep stacks
- **Layer normalization** helps stability

## 8. Interview Questions

**Q1: When would you use a bidirectional RNN?**
<details><summary>Answer</summary>

- Classification tasks (have full sequence)
- Sequence labeling (NER, POS)
- NOT for generation (don't have future tokens)
- NOT for streaming (can't wait for full input)
</details>

**Q2: Why use residual connections in deep RNNs?**
<details><summary>Answer</summary>

- Helps gradient flow in deep networks
- Allows identity mapping if layer is not useful
- Makes optimization easier
</details>

## 9. Summary

- **Bidirectional**: Forward + backward, concat outputs
- **Use bidirectional**: Classification, labeling, understanding
- **Don't use bidirectional**: Generation, streaming
- **Deep RNNs**: Stack layers with dropout
- **Residual connections**: Essential for 4+ layers

## 10. References

- [Bidirectional RNNs Paper](https://www.cs.toronto.edu/~graves/asru_2013.pdf)
- [Deep Residual Learning](https://arxiv.org/abs/1512.03385)

---
**Next:** [Module 09: Text Classification with RNNs](../09_text_classification_rnns/09_text_classification_rnns.ipynb)