# Day 16: Attention Mechanisms

Learn attention mechanisms that revolutionized NLP and paved the way for Transformers.

## Learning Objectives
- Understand the information bottleneck problem in sequence models
- Implement attention mechanisms (Bahdanau and Luong)
- Apply attention to seq2seq models
- Visualize and analyze attention patterns

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## Part 1: Understanding Attention

### The Problem: Information Bottleneck
In seq2seq models, the encoder compresses the entire input sequence into a single context vector.
This becomes problematic for long sequences.

In [None]:
# Visualize the bottleneck problem
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Without attention
ax = axes[0]
ax.set_xlim(-1, 10)
ax.set_ylim(-1, 5)
# Draw encoder boxes
for i in range(5):
    ax.add_patch(plt.Rectangle((i, 3), 0.8, 0.8, fill=True, color='lightblue', edgecolor='black'))
    ax.text(i+0.4, 3.4, f'x{i}', ha='center', va='center')
# Draw bottleneck
ax.add_patch(plt.Rectangle((2, 1.5), 1, 0.8, fill=True, color='red', edgecolor='black', linewidth=2))
ax.text(2.5, 1.9, 'c', ha='center', va='center', fontsize=12, fontweight='bold')
# Draw decoder boxes
for i in range(3):
    ax.add_patch(plt.Rectangle((2+i*0.5, 0), 0.8, 0.8, fill=True, color='lightgreen', edgecolor='black'))
    ax.text(2.4+i*0.5, 0.4, f'y{i}', ha='center', va='center')
ax.set_title('Without Attention (Bottleneck)', fontweight='bold')
ax.axis('off')

# With attention
ax = axes[1]
ax.set_xlim(-1, 10)
ax.set_ylim(-1, 5)
# Draw encoder boxes
for i in range(5):
    ax.add_patch(plt.Rectangle((i, 3), 0.8, 0.8, fill=True, color='lightblue', edgecolor='black'))
    ax.text(i+0.4, 3.4, f'x{i}', ha='center', va='center')
# Draw attention arrows
for i in range(5):
    ax.arrow(i+0.4, 2.9, 2, -1.5, head_width=0.2, head_length=0.1, fc='orange', ec='orange', alpha=0.6)
# Draw decoder boxes
for i in range(3):
    ax.add_patch(plt.Rectangle((2+i*0.5, 0), 0.8, 0.8, fill=True, color='lightgreen', edgecolor='black'))
    ax.text(2.4+i*0.5, 0.4, f'y{i}', ha='center', va='center')
ax.set_title('With Attention (Dynamic Context)', fontweight='bold')
ax.axis('off')

plt.tight_layout()
plt.savefig('attention_bottleneck.png', dpi=100, bbox_inches='tight')
plt.show()

print('Attention allows the decoder to focus on different parts of the input at each step!')

## Part 2: Implement Attention Mechanisms

In [None]:
class BahdanauAttention(nn.Module):
    """Additive (Bahdanau) Attention"""
    
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.hidden_size = hidden_size
        
        # Linear layers for attention
        self.W_q = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_k = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v = nn.Linear(hidden_size, 1, bias=False)
    
    def forward(self, query, keys, values, mask=None):
        """
        Args:
            query: (batch_size, 1, hidden_size) - decoder hidden state
            keys: (batch_size, seq_len, hidden_size) - encoder hidden states
            values: (batch_size, seq_len, hidden_size) - encoder hidden states
            mask: (batch_size, seq_len) - padding mask
        
        Returns:
            context: (batch_size, 1, hidden_size) - context vector
            attention_weights: (batch_size, 1, seq_len) - attention weights
        """
        # Compute attention scores
        Q = self.W_q(query)  # (batch_size, 1, hidden_size)
        K = self.W_k(keys)   # (batch_size, seq_len, hidden_size)
        
        # Score = v^T * tanh(Q + K)
        scores = self.v(torch.tanh(Q + K))  # (batch_size, seq_len, 1)
        scores = scores.squeeze(-1)  # (batch_size, seq_len)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Apply softmax to get attention weights
        attention_weights = torch.softmax(scores, dim=1)  # (batch_size, seq_len)
        
        # Compute weighted sum of values
        context = torch.bmm(attention_weights.unsqueeze(1), values)  # (batch_size, 1, hidden_size)
        
        return context, attention_weights

print('✓ BahdanauAttention (Additive) implemented')

In [None]:
class LuongAttention(nn.Module):
    """Multiplicative (Luong) Attention"""
    
    def __init__(self, hidden_size):
        super(LuongAttention, self).__init__()
        self.hidden_size = hidden_size
        self.W = nn.Linear(hidden_size, hidden_size, bias=False)
    
    def forward(self, query, keys, values, mask=None):
        """
        Args:
            query: (batch_size, 1, hidden_size)
            keys: (batch_size, seq_len, hidden_size)
            values: (batch_size, seq_len, hidden_size)
        
        Returns:
            context: (batch_size, 1, hidden_size)
            attention_weights: (batch_size, 1, seq_len)
        """
        # Score = Q * W * K^T
        Q = self.W(query)  # (batch_size, 1, hidden_size)
        scores = torch.bmm(Q, keys.transpose(1, 2))  # (batch_size, 1, seq_len)
        scores = scores.squeeze(1)  # (batch_size, seq_len)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Apply softmax
        attention_weights = torch.softmax(scores, dim=1)  # (batch_size, seq_len)
        
        # Compute context
        context = torch.bmm(attention_weights.unsqueeze(1), values)  # (batch_size, 1, hidden_size)
        
        return context, attention_weights

print('✓ LuongAttention (Multiplicative) implemented')

In [ ]:
class ScaledDotProductAttention(nn.Module):
    """Scaled Dot-Product Attention (used in Transformers)"""
    
    def __init__(self, hidden_size):
        super(ScaledDotProductAttention, self).__init__()
        self.hidden_size = hidden_size
        self.scale = np.sqrt(hidden_size)
    
    def forward(self, query, keys, values, mask=None):
        """
        Args:
            query: (batch_size, ..., hidden_size)
            keys: (batch_size, seq_len, hidden_size)
            values: (batch_size, seq_len, hidden_size)
        """
        # Compute scaled dot product
        scores = torch.matmul(query, keys.transpose(-2, -1)) / self.scale
        
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Apply softmax
        attention_weights = torch.softmax(scores, dim=-1)
        
        # Compute context
        context = torch.matmul(attention_weights, values)
        
        return context, attention_weights

print('✓ ScaledDotProductAttention implemented')

## Part 3: Test Attention Mechanisms

In [None]:
# Create sample inputs
batch_size = 4
seq_len = 10
hidden_size = 64

query = torch.randn(batch_size, 1, hidden_size).to(device)
keys = torch.randn(batch_size, seq_len, hidden_size).to(device)
values = torch.randn(batch_size, seq_len, hidden_size).to(device)

# Test each attention mechanism
bahdanau = BahdanauAttention(hidden_size).to(device)
luong = LuongAttention(hidden_size).to(device)
scaled = ScaledDotProductAttention(hidden_size).to(device)

# Forward pass
context_b, weights_b = bahdanau(query, keys, values)
context_l, weights_l = luong(query, keys, values)
context_s, weights_s = scaled(query, keys, values)

print(f'Context shape: {context_b.shape}')
print(f'Attention weights shape: {weights_b.shape}')
print(f'Attention weights sum: {weights_b.sum(dim=1)}')
print('✓ All attention mechanisms working!')

## Part 4: Visualize Attention Patterns

In [None]:
# Visualize attention weights from different mechanisms
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Bahdanau
sns.heatmap(weights_b[0].detach().cpu().numpy().reshape(1, -1), 
            ax=axes[0], cbar=True, cmap='Blues')
axes[0].set_title('Bahdanau Attention')
axes[0].set_xlabel('Source Sequence Position')

# Luong
sns.heatmap(weights_l[0].detach().cpu().numpy().reshape(1, -1), 
            ax=axes[1], cbar=True, cmap='Greens')
axes[1].set_title('Luong Attention')
axes[1].set_xlabel('Source Sequence Position')

# Scaled Dot-Product
sns.heatmap(weights_s[0].detach().cpu().numpy().reshape(1, -1), 
            ax=axes[2], cbar=True, cmap='Reds')
axes[2].set_title('Scaled Dot-Product Attention')
axes[2].set_xlabel('Source Sequence Position')

plt.tight_layout()
plt.savefig('attention_comparison.png', dpi=100, bbox_inches='tight')
plt.show()

print('✓ Attention patterns visualized')

## Part 5: Analysis and Conclusions

In [None]:
# Summary of attention mechanisms
print("=" * 60)
print("ATTENTION MECHANISMS COMPARISON")
print("=" * 60)

print("\n1. BAHDANAU (ADDITIVE) ATTENTION")
print("   - Score = v^T * tanh(W_q*q + W_k*k)")
print("   - Computational cost: O(seq_len * hidden_size)")
print("   - Good for shorter sequences")

print("\n2. LUONG (MULTIPLICATIVE) ATTENTION")
print("   - Score = q^T * W * k")
print("   - Computational cost: O(seq_len * hidden_size)")
print("   - More efficient matrix operations")

print("\n3. SCALED DOT-PRODUCT ATTENTION")
print("   - Score = (q * k^T) / sqrt(d_k)")
print("   - Scaling prevents gradient vanishing")
print("   - Foundation of Transformer attention")

print("\n" + "=" * 60)
print("KEY INSIGHTS")
print("=" * 60)
print("✓ Attention solves the information bottleneck problem")
print("✓ Decoder can focus on relevant input parts")
print("✓ Different mechanisms have different computational costs")
print("✓ Scaled dot-product is most efficient (matrix operations)")
print("✓ Attention weights are interpretable and visualizable")