# Attention Ablation

This notebook explores two aspects of transformer model training: optimizer performance and the role of the self-attention mechanism. We will:
1. Compare the `AdamW` and `SGD` optimizers on a small fine-tuning task.
2. Implement and evaluate an "attention ablation" to understand the impact of learned attention on model performance.

## Setup

First, let's import the necessary libraries and functions from our `finetune.py` script.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# Assuming finetune.py is in the same directory or in the python path
from finetune import (
    build_tiny_dataset,
    collate_pad,
    train_short_run,
    eval_perplexity,
)

In [None]:


# --- Attention Ablation Code ---
def attention_ablation_hook(module, input, output):
    attention_scores = output
    ablated_scores = torch.zeros_like(attention_scores)
    return (ablated_scores,) + output[1:]

def apply_attention_ablation(model):
    for layer in model.transformer.h:
        # The actual attribute name may vary depending on the model architecture.
        # For GPT-2, it's layer.attn.
        # You may need to inspect the model to find the correct attribute.
        layer.attn.register_forward_pre_hook(attention_ablation_hook)

def train_ablation_run(
    base_model_name: str,
    tokenizer: AutoTokenizer,
    device: str,
    tiny_dataset,
    val_loader,
    run_name: str = "ablation",
    epochs: int = 5,
    batch_size: int = 4,
    lr: float = 1e-4,
    save_dir: str = "outputs",
):
    os.makedirs(save_dir, exist_ok=True)
    model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
    apply_attention_ablation(model)
    model.train()

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.resize_token_embeddings(len(tokenizer))

    loader = DataLoader(tiny_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda b: collate_pad(b, pad_id=tokenizer.pad_token_id))
    optim = torch.optim.AdamW(model.parameters(), lr=lr)

    losses = []
    perplexities = []
    for epoch in range(epochs):
        for input_ids, attention_mask in loader:
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            loss.backward()
            optim.step()
            optim.zero_grad()
            losses.append(loss.item())
        
        epoch_perplexity = eval_perplexity(model, val_loader, device, tokenizer.pad_token_id)
        perplexities.append(epoch_perplexity)
        print(f"{run_name} epoch {epoch}: val_perplexity={epoch_perplexity:.4f}")
    
    return {"loss_values": losses, "perplexities": perplexities}


# --- Main Execution ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

tiny_dataset, _ = build_tiny_dataset(tokenizer, split="train")
val_dataset, _ = build_tiny_dataset(tokenizer, split="validation")
val_loader = DataLoader(val_dataset, batch_size=4, collate_fn=lambda b: collate_pad(b, pad_id=tokenizer.pad_token_id))

# --- Optimizer Comparison ---
print("Running optimizer comparison...")
adamw_res = train_short_run(
    base_model_name=model_name,
    tokenizer=tokenizer,
    device=device,
    tiny_dataset=tiny_dataset,
    val_loader=val_loader,
    run_name="adamw_notebook",
    epochs=5,
    optimizer_name="AdamW",
)

sgd_res = train_short_run(
    base_model_name=model_name,
    tokenizer=tokenizer,
    device=device,
    tiny_dataset=tiny_dataset,
    val_loader=val_loader,
    run_name="sgd_notebook",
    epochs=5,
    optimizer_name="SGD",
)

# Plotting optimizer results
plt.figure(figsize=(10, 5))
plt.plot(adamw_res["loss_values"], label="AdamW")
plt.plot(sgd_res["loss_values"], label="SGD")
plt.xlabel("Iteration")
plt.ylabel("Training Loss")
plt.title("Optimizer Comparison")
plt.legend()
plt.grid(True)
plt.show()


# --- Attention Ablation ---
print("\nRunning attention ablation...")
ablation_res = train_ablation_run(
    base_model_name=model_name,
    tokenizer=tokenizer,
    device=device,
    tiny_dataset=tiny_dataset,
    val_loader=val_loader,
)

# Compare with a normal run
print("\nRunning baseline for ablation comparison...")
baseline_res = train_short_run(
    base_model_name=model_name,
    tokenizer=tokenizer,
    device=device,
    tiny_dataset=tiny_dataset,
    val_loader=val_loader,
    run_name="baseline_notebook",
    epochs=5,
)


# Plotting ablation results
plt.figure(figsize=(10, 5))
plt.plot(baseline_res["perplexities"], label="Baseline (with attention)")
plt.plot(ablation_res["perplexities"], label="Attention Ablated")
plt.xlabel("Epoch")
plt.ylabel("Validation Perplexity")
plt.title("Attention Ablation Study")
plt.legend()
plt.grid(True)
plt.show()

## Analysis and Conclusion

**Optimizer Comparison:**
- *Observations:* In our short run, AdamW is expected to converge faster and more stably than SGD. SGD might show more fluctuations in the loss.
- *Conclusion:* AdamW is generally a better choice for training large language models due to its adaptive learning rate.

**Attention Ablation:**
- *Observations:* The model with ablated attention should have a significantly higher perplexity. This is because it cannot learn context-dependent relationships between tokens.
- *Conclusion:* This experiment demonstrates the critical role of the self-attention mechanism in language modeling. Without it, the model's performance degrades substantially.