# 45‚Äëminute hands‚Äëon: Fine‚Äëtune a Transformer with ü§ó Trainer + peek at Attention

<a target="_blank" href="https://colab.research.google.com/github/uitml/GenerativeAI_course/blob/main/tasks/Train_attention.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

**Goal (what you'll do):**
1. Load a dataset (GLUE MRPC)  
2. Tokenize it, build a data collator  
3. Fine‚Äëtune a pretrained Transformer using the **Trainer API**  
4. Evaluate accuracy/F1  
5. **Optional:** visualize one attention head for a single example

**Assumptions:** You already had a lecture introducing Transformers/attention.  
**Run on:** Colab with GPU (T4/L4/A100).  

> Tip: If something fails, restart runtime and rerun from the top.


## Instructor notes (suggested timing)

- **0‚Äì5 min:** Setup + GPU check  
- **5‚Äì12 min:** Load dataset + create subset (discuss why subsets + epochs matter)  
- **12‚Äì20 min:** Tokenization + padding/collator  
- **20‚Äì35 min:** Trainer fine-tune (students watch loss/metrics)  
- **35‚Äì40 min:** Quick predictions + interpretation  
- **40‚Äì45 min:** Optional attention heatmap (connect back to lecture)

**Common pitfalls**
- Forgot GPU ‚Üí training is slow  
- Library version conflicts ‚Üí restart runtime, rerun install cell  
- If `evaluate.load` downloads slowly, wait once; it caches afterward


## 0) Setup

### (A) Turn on GPU
In Colab: **Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU**

### (B) Install libraries
We pin a *reasonable* minimum set and keep it light.


In [None]:
# Install (quiet-ish). If you see version conflicts, restart runtime and rerun this cell.
!pip -q install -U "transformers>=4.41" "datasets>=2.20" "evaluate>=0.4.2" "accelerate>=0.33" "torch>=2.1" "numpy" "matplotlib"


In [None]:
import os, random, numpy as np, torch
from packaging import version

print("torch:", torch.__version__)
import transformers, datasets, evaluate, accelerate, matplotlib
print("transformers:", transformers.__version__)
print("datasets:", datasets.__version__)
print("evaluate:", evaluate.__version__)
print("accelerate:", accelerate.__version__)

# Sanity check: GPU available?
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)
if device != "cuda":
    print("‚ö†Ô∏è No GPU detected. Training will be slow. Turn on GPU in Colab and restart runtime.")


## 1) Load a dataset (GLUE MRPC)

MRPC: sentence-pair paraphrase classification. Small enough for a short class, but ‚Äúreal‚Äù.


In [None]:
from datasets import load_dataset

raw = load_dataset("glue", "mrpc")
raw


### Keep it fast: use a small subset for training

For a 45‚Äëminute class, we'll train on a subset so **everyone finishes**.
You can increase `train_size` if you have stronger GPUs or more time.


In [None]:
train_size = 2000   # try 5000+ if you have time
eval_size  = 500    # small eval for speed

train_ds = raw["train"].shuffle(seed=42).select(range(min(train_size, len(raw["train"]))))
eval_ds  = raw["validation"].shuffle(seed=42).select(range(min(eval_size, len(raw["validation"]))))

len(train_ds), len(eval_ds), train_ds[0]


## 2) Tokenize + data collator

We‚Äôll use a compact model for speed. `distilbert-base-uncased` is a good default.
(You can swap to `bert-base-uncased` if you want a heavier baseline.)


In [None]:
from transformers import AutoTokenizer, DataCollatorWithPadding

checkpoint = "distilbert-base-uncased"   # fast
# checkpoint = "bert-base-uncased"      # slower, but classic

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_batch(batch):
    return tokenizer(batch["sentence1"], batch["sentence2"], truncation=True)

tok_train = train_ds.map(tokenize_batch, batched=True, remove_columns=train_ds.column_names)
tok_eval  = eval_ds.map(tokenize_batch, batched=True, remove_columns=eval_ds.column_names)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tok_train, tok_eval


## 3) Define model + metrics

We‚Äôll use a sequence classification head (2 labels).
Metrics: **accuracy** and **F1** (standard for MRPC).


In [None]:
from transformers import AutoModelForSequenceClassification
import evaluate

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

# Load metric once (faster + avoids repeated downloads inside compute_metrics)
mrpc_metric = evaluate.load("glue", "mrpc")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return mrpc_metric.compute(predictions=preds, references=labels)

model


## 4) Fine‚Äëtune with the Trainer API

This cell runs the full training loop: batching, forward, loss, backward, optimizer, evaluation.

We set conservative hyperparameters so it finishes quickly.


In [None]:
from transformers import Trainer, TrainingArguments

out_dir = "mrpc-distilbert-trainer"

training_args = TrainingArguments(
    output_dir=out_dir,
    evaluation_strategy="epoch",
    save_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=50,
    report_to="none",
    fp16=torch.cuda.is_available(),  # mixed precision on GPU
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tok_train,
    eval_dataset=tok_eval,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()


## 5) Evaluate and try the model

We‚Äôll compute metrics on the eval subset and run a few predictions.


In [None]:
eval_metrics = trainer.evaluate()
eval_metrics


In [None]:
import torch.nn.functional as F

def predict_pair(s1, s2):
    inputs = tokenizer(s1, s2, return_tensors="pt", truncation=True).to(trainer.model.device)
    with torch.no_grad():
        logits = trainer.model(**inputs).logits
        probs = F.softmax(logits, dim=-1).cpu().numpy().round(4)[0]
    label = int(np.argmax(probs))
    return {"pred_label": label, "p(not paraphrase)": float(probs[0]), "p(paraphrase)": float(probs[1])}

examples = [
    ("The company said it will cut costs.", "The firm announced cost reductions."),
    ("He likes pizza.", "The capital of France is Paris."),
]
for s1, s2 in examples:
    print(s1)
    print(s2)
    print(predict_pair(s1, s2))
    print("‚Äî"*60)


## 6) Optional: Peek at attention (single example)

This is **not** how you ‚Äúexplain attention‚Äù (you already did that in lecture), but it helps students connect the idea
to something tangible: attention matrices per layer/head.

We‚Äôll:
1. Run a forward pass with `output_attentions=True`
2. Pick a layer/head
3. Plot attention weights

> Note: Attention matrices are over **token positions** (including special tokens).


In [None]:
import matplotlib.pyplot as plt

# Pick one example from eval set
ex = eval_ds[0]
s1, s2 = ex["sentence1"], ex["sentence2"]
print("Sentence1:", s1)
print("Sentence2:", s2)

inputs = tokenizer(s1, s2, return_tensors="pt", truncation=True)
inputs = {k: v.to(trainer.model.device) for k, v in inputs.items()}

# Force attentions
with torch.no_grad():
    outputs = trainer.model(**inputs, output_attentions=True)
attentions = outputs.attentions  # tuple: (num_layers, batch, num_heads, seq, seq)

len(attentions), attentions[0].shape


In [None]:
# Choose layer/head
layer_idx = 0
head_idx = 0

att = attentions[layer_idx][0, head_idx].detach().cpu().numpy()  # (seq, seq)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].detach().cpu().tolist())

print("num_tokens:", len(tokens))
print(tokens[:20], "..." if len(tokens) > 20 else "")

plt.figure(figsize=(7, 6))
plt.imshow(att)
plt.title(f"Attention heatmap ‚Äî layer {layer_idx}, head {head_idx}")
plt.xlabel("Key position")
plt.ylabel("Query position")
plt.tight_layout()
plt.show()


## 7) Optional extension (if time): make it ‚Äúmore LLM‚Äù

If you have extra time, you can adapt this notebook to **causal language modeling** (next-token prediction) and train with
TRL‚Äôs SFTTrainer + LoRA (parameter‚Äëefficient fine‚Äëtuning).  
This is a bigger jump in compute + dependencies, so we keep it as an after‚Äëclass exercise.

- TRL notebooks: https://huggingface.co/docs/trl/en/example_overview  
- ‚ÄúSFT with LoRA/QLoRA‚Äù Colab: https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb
