# Lesson 6 Exercise: Visualize Attention Patterns in Pretrained Models

**Estimated Time:** 18 minutes

## Overview

Dive deep into what pretrained Transformers have learned! You'll extract attention weights from multiple layers and heads, create visualization tools to interpret them, and discover that **different attention heads specialize in different linguistic phenomena** (syntax, coreference, semantic similarity).

## Scenario

Before deploying BERT for content moderation, your ML team wants to understand what it's actually learning. Visualize attention patterns to verify the model is focusing on contextually relevant words when detecting toxic language. This **interpretability analysis builds trust** in the model's decision-making.

---

## What You'll Do

**Part A:** Extract and Visualize Basic Attention (6 min)
- Load BERT and tokenizer
- Process example text with ambiguous "bank"
- Examine positional encoding
- Extract attention weights from all layers
- Plot attention heatmap for Layer 6, Head 0

**Part B:** Compare Multiple Attention Heads (6 min)
- Visualize 4 different attention heads from same layer
- Create 2Ã—2 subplot grid
- Observe: Do different heads show different patterns?
- Hypothesis: Some focus on nearby words, others on syntax

**Part C:** Analyze Layer-by-Layer Attention (6 min)
- Pick one attention head (Head 0)
- Visualize across Layers 1, 4, 8, 12
- Compare: Do early layers differ from late layers?
- Discovery: Early = syntax, late = semantics

---

## Setup

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModel, AutoTokenizer
import warnings
warnings.filterwarnings('ignore')

# Import utilities
from data import load_jigsaw_dataset
from attention_utils import (
    plot_attention_heatmap,
    plot_multihead_attention,
    plot_layer_progression,
    extract_attention_patterns,
    print_attention_patterns
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("\nâœ“ Setup complete!")

---

# Part A: Extract and Visualize Basic Attention (6 minutes)

## Step 1: Load BERT and Tokenizer

**TODO:** Load the pretrained BERT model and tokenizer.

In [None]:
print("Loading BERT model and tokenizer...\n")

# TODO: Load BERT model and tokenizer
# Hint: Use 'bert-base-uncased'
# Hint: Set output_attentions=True to get attention weights

model_name = 'bert-base-uncased'

# TODO: Load tokenizer
# tokenizer = ???

# TODO: Load model with output_attentions=True
# model = ???

# TODO: Move model to device and set to eval mode
# model.to(device)
# model.eval()

print("âœ“ BERT loaded successfully!")

## Step 2: Process Example Text

We'll use the classic example: **"The bank can refuse to lend money to the person by the river bank."**

**TODO:** Tokenize the example and prepare inputs.

In [None]:
print("Loading example text for attention visualization...\n")

# TODO: Load the toxic comment dataset
# Hint: Use load_jigsaw_dataset(n_samples=1000)
# comments = ???

# TODO: Select a good example for visualization
# Hint: Find comments with good length (10-30 tokens) for attention visualization
# good_examples = [c for c in comments if 10 < c['length'] < 30]
# OR use this fallback example:
example_text = "The bank can refuse to lend money to the person by the river bank."

print("Example Text:")
print(f"  '{example_text}'\n")

# TODO: Tokenize the text
# Hint: Use tokenizer(text, return_tensors='pt', add_special_tokens=True)
# inputs = ???

# TODO: Move inputs to device
# inputs = {k: v.to(device) for k, v in inputs.items()}

# TODO: Get token strings
# Hint: Use tokenizer.convert_ids_to_tokens()
# tokens = ???

print("\nTokenization:")
print(f"  Tokens: {tokens}")
print(f"  Number of tokens: {len(tokens)}")

## Step 3: Examine Position IDs

**TODO:** Display the positional encoding to understand how Transformers know word order.

In [None]:
print("Positional Encoding:")

# TODO: Extract or create position_ids
# Hint: If 'position_ids' not in inputs, create with torch.arange()
if 'position_ids' in inputs:
    position_ids = inputs['position_ids'][0].tolist()
else:
    # TODO: Create position_ids
    seq_len = inputs['input_ids'].shape[1]
    # position_ids = ???
    pass

print(f"  Position IDs: {position_ids}")
print("\nðŸ’¡ Key Insight:")
print("   These position IDs tell BERT the order of tokens.")
print("   Unlike RNNs, Transformers process all tokens in parallel!")

## Step 4: Extract Attention Weights

**TODO:** Run the model and extract attention weights from all layers.

In [None]:
print("Running model to extract attention weights...\n")

# TODO: Run model forward pass
# Hint: Use torch.no_grad() and model(**inputs)
with torch.no_grad():
    # outputs = ???
    pass

# TODO: Extract attention weights
# Hint: outputs.attentions is a tuple of attention tensors
# attention_weights = ???

print("Attention Structure:")
print(f"  Number of layers: {len(attention_weights)}")
print(f"  Shape per layer: {attention_weights[0].shape}")
print(f"  Format: (batch_size, n_heads, seq_len, seq_len)")
print("\nâœ“ Attention weights extracted!")

## Step 5: Plot Attention Heatmap for Layer 6, Head 0

**TODO:** Visualize attention weights for a specific layer and head.

In [None]:
# Select Layer 6, Head 0
layer_idx = 5  # 0-indexed, so layer 6 = index 5
head_idx = 0

print(f"Visualizing Layer {layer_idx+1}, Head {head_idx}...\n")

# TODO: Extract attention matrix for this layer and head
# Hint: attention_weights[layer_idx][0] gives (n_heads, seq_len, seq_len)
# Hint: Then select head_idx and convert to numpy
# attn_layer = ???
# attn_head = ???

# TODO: Plot attention heatmap
# Hint: Use plot_attention_heatmap() from attention_utils
# fig = plot_attention_heatmap(
#     attn_head,
#     tokens,
#     title="Your Attention Visualization",
#     layer=layer_idx+1,
#     head=head_idx,
#     figsize=(12, 10)
# )
# plt.show()

### Analysis Questions (Part A)

Look at your attention heatmap and answer:

1. **What patterns do you see?**
   - Is there a diagonal pattern (self-attention)?
   - Which tokens attend strongly to each other?

2. **"Bank" disambiguation:**
   - Where does the first "bank" (financial) attend?
   - Where does the second "bank" (river) attend?
   - Are they attending to different words?

**TODO:** Write your observations (2-3 sentences):

**YOUR ANALYSIS HERE:**

I observe that...




---

# Part B: Compare Multiple Attention Heads (6 minutes)

Now let's see if different attention heads learn different patterns!

## TODO: Visualize 4 Different Heads

Create a 2Ã—2 grid comparing Heads 0, 3, 7, and 11 from Layer 6.

In [None]:
layer_idx = 5  # Layer 6
heads_to_compare = [0, 3, 7, 11]

print(f"Comparing {len(heads_to_compare)} attention heads from Layer {layer_idx+1}...\n")

# TODO: Extract attention for all heads in this layer
# Hint: attention_weights[layer_idx][0].cpu().numpy() gives (n_heads, seq_len, seq_len)
# attn_layer = ???

# TODO: Plot multi-head comparison
# Hint: Use plot_multihead_attention() from attention_utils
# fig = ???
# plt.show()

### Analysis Questions (Part B)

Compare the 4 attention heads and answer:

1. **Do all heads show the same pattern?**
   - Or do you see diversity?

2. **Can you characterize each head?**
   - Head 0: Local or long-range?
   - Head 3: What relationships does it capture?
   - Head 7: Semantic or syntactic?
   - Head 11: Dense or sparse connections?

3. **Hypothesis:**
   - Do you think different heads specialize in different linguistic phenomena?
   - What evidence supports this?

**TODO:** Write your observations (3-4 sentences):

**YOUR ANALYSIS HERE:**

Comparing the four heads, I notice...




---

# Part C: Analyze Layer-by-Layer Attention (6 minutes)

Do early layers learn different patterns than late layers?

**Hypothesis:** Early layers capture syntax and local patterns, late layers capture semantics and long-range dependencies.

## TODO: Visualize Head 0 Across Multiple Layers

Compare Layers 1, 4, 8, and 12 (all using Head 0).

In [None]:
head_idx = 0
layers_to_compare = [0, 3, 7, 11]  # Layers 1, 4, 8, 12 (0-indexed)

print(f"Comparing Head {head_idx} across multiple layers...\n")

# TODO: Extract attention weights for each layer (Head 0 only)
# Hint: Create a list of attention matrices, one per layer
# attention_by_layer = []
# for layer_idx in range(len(attention_weights)):
#     attn_layer = ??? 
#     attention_by_layer.append(attn_layer)

# TODO: Plot layer progression
# Hint: Use plot_layer_progression() from attention_utils
# fig = ???
# plt.show()

### Analysis Questions (Part C)

Compare attention patterns across layers and answer:

1. **Early layers (Layer 1, 4):**
   - Do they focus on nearby tokens (local patterns)?
   - Or long-range connections?
   - What might they be capturing? (syntax? word order?)

2. **Late layers (Layer 8, 12):**
   - Do they show more diffuse attention?
   - Do they connect distant tokens?
   - What might they be capturing? (semantics? meaning?)

3. **Hierarchical learning:**
   - Does your observation support the hypothesis that Transformers learn hierarchically?
   - Early = low-level patterns, Late = high-level meaning?

**TODO:** Write your observations (3-4 sentences):

**YOUR ANALYSIS HERE:**

Looking at the layer progression, I observe...




---

# Summary: What Did You Discover?

## Key Findings

From your visualizations, you should have discovered:

### 1. **Multi-Head Diversity**

### 2. **Hierarchical Learning**

### 3. **Contextualized Understanding**

---

## Why This Matters

**For Content Moderation:**
- You can verify the model focuses on relevant words (not arbitrary patterns)
- Interpretability builds trust in the system
- Can identify and fix problematic attention patterns

**For Deep Learning:**
- Attention visualization reveals what models learn
- Different layers capture different levels of abstraction
- Multi-head attention provides multiple perspectives

**For Modern NLP:**
- This architecture powers GPT, BERT, ChatGPT, Claude
- Understanding attention is key to understanding LLMs
- Interpretability is critical for responsible AI