# Goal of the study

This notebook explores two aspects of transformer model training: optimizer performance and the role of the self-attention mechanism. We will:
1. Compare the `AdamW` and `SGD` optimizers on a small fine-tuning task.
2. Implement and evaluate an "attention ablation" to understand the impact of learned attention on model performance. This we do using `hooks` in pytorch

## Setup

First, let's import the necessary libraries and functions from our `src/finetune.py` script.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

from src.finetune import (
    build_tiny_dataset,
    collate_pad,
    train_short_run,
    eval_perplexity,
)

## Attention Ablation

### Hooks and their types

| Hook Function                 | Target        | Execution Time                 | Signature                   | Can Modify                                   |
| ----------------------------- | ------------- | ------------------------------ | --------------------------- | -------------------------------------------- |
| `register_forward_pre_hook`   | `nn.Module`   | Before `forward()`             | `(module, input)`           | Module's input                               |
| `register_forward_hook`       | `nn.Module`   | After `forward()`              | `(module, input, output)`   | Module's output                              |
| `register_full_backward_hook` | `nn.Module`   | During backward pass           | `(module, grad_in, grad_out)` | Gradient w.r.t. module's input (`grad_in`)   |
| `register_hook`               | `torch.Tensor`| When grad for that tensor is computed | `(grad)`                    | The tensor's gradient (`grad`)                |

### Usefulness of Hooks
1. Inspecting code
2. Logging internals of a model
3.  Modify internal state (like activations and gradients) without change source code of the model

In [None]:
from src.attention_ablation import (
    attention_ablation_hook,
    apply_attention_ablation,
    train_ablation_run
)

In [None]:
# --- Main Execution ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

tiny_dataset, _ = build_tiny_dataset(tokenizer, split="train")
val_dataset, _ = build_tiny_dataset(tokenizer, split="validation")
val_loader = DataLoader(val_dataset, batch_size=4, collate_fn=lambda b: collate_pad(b, pad_id=tokenizer.pad_token_id))

# --- Optimizer Comparison ---
print("Running optimizer comparison...")
adamw_res = train_short_run(
    base_model_name=model_name,
    tokenizer=tokenizer,
    device=device,
    tiny_dataset=tiny_dataset,
    val_loader=val_loader,
    run_name="adamw_notebook",
    epochs=5,
    optimizer_name="AdamW",
)

sgd_res = train_short_run(
    base_model_name=model_name,
    tokenizer=tokenizer,
    device=device,
    tiny_dataset=tiny_dataset,
    val_loader=val_loader,
    run_name="sgd_notebook",
    epochs=5,
    optimizer_name="SGD",
)

# Plotting optimizer results
plt.figure(figsize=(10, 5))
plt.plot(adamw_res["loss_values"], label="AdamW")
plt.plot(sgd_res["loss_values"], label="SGD")
plt.xlabel("Iteration")
plt.ylabel("Training Loss")
plt.title("Optimizer Comparison")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# --- Attention Ablation ---
print("\nRunning attention ablation...")
ablation_res = train_ablation_run(
    base_model_name=model_name,
    tokenizer=tokenizer,
    device=device,
    tiny_dataset=tiny_dataset,
    val_loader=val_loader,
)

# Compare with a normal run
print("\nRunning baseline for ablation comparison...")
baseline_res = train_short_run(
    base_model_name=model_name,
    tokenizer=tokenizer,
    device=device,
    tiny_dataset=tiny_dataset,
    val_loader=val_loader,
    run_name="baseline_notebook",
    epochs=5,
)


# Plotting ablation results
plt.figure(figsize=(10, 5))
plt.plot(baseline_res["perplexities"], label="Baseline (with attention)")
plt.plot(ablation_res["perplexities"], label="Attention Ablated")
plt.xlabel("Epoch")
plt.ylabel("Validation Perplexity")
plt.title("Attention Ablation Study")
plt.legend()
plt.grid(True)
plt.show()

## Analysis and Conclusion

**Optimizer Comparison:**
- *Observations:* In our short run, AdamW converges faster and more stably than SGD. SGD might show more fluctuations in the loss.
- *Conclusion:* AdamW is generally a better choice for training large language models due to its adaptive learning rate.

**Attention Ablation:**
- *Observations:* The model with ablated attention has a significantly higher perplexity. This is because it cannot learn context-dependent relationships between tokens. However, even without self-attention, the token representation do get better over time
- *Conclusion:* This experiment demonstrates the critical role of the self-attention mechanism in language modeling. Without it, the model's performance degrades substantially.