[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](YOUR_COLAB_LINK_HERE)

# 02 Vision Transformers Fundamentals
## Objectives
- Understand patch embeddings, positional encoding, and self-attention.
- Fine-tune a pretrained ViT on Tiny ImageNet or CIFAR-10.
- Visualize attention maps and confusion matrix.


In [None]:
!pip -q install torch torchvision transformers datasets evaluate matplotlib

## Quick refresher: attention in one cell

In [None]:
import torch
attn = torch.randn(1, 4, 8, 8)
attn = attn.softmax(dim=-1)
print('Attention shape:', attn.shape)


## Fine-tune ViT with Hugging Face

In [None]:
from datasets import load_dataset
from transformers import AutoImageProcessor, ViTForImageClassification, Trainer, TrainingArguments
import numpy as np
import matplotlib.pyplot as plt
import torch

dataset = load_dataset('cifar10')
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

def preprocess(example):
    inputs = processor(example['image'], return_tensors='pt')
    example['pixel_values'] = inputs['pixel_values'][0]
    example['labels'] = example['label']
    return example

dataset = dataset.with_transform(preprocess)
labels = dataset['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=len(labels),
    id2label={i: l for i, l in enumerate(labels)},
    label2id={l: i for i, l in enumerate(labels)},
)

def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.tensor([item['labels'] for item in batch])
    return {'pixel_values': pixel_values, 'labels': labels}

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return {'accuracy': (preds == labels).mean()}

args = TrainingArguments(
    output_dir='../outputs/vit_finetune',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset['train'].select(range(2000)),
    eval_dataset=dataset['test'].select(range(500)),
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

trainer.train()


### Visualizations
We plot training curves, sample predictions, a confusion matrix, and an attention overlay.


In [None]:
log_history = [x for x in trainer.state.log_history if 'loss' in x]
losses = [x['loss'] for x in log_history]
plt.figure(figsize=(5, 3))
plt.plot(losses)
plt.title('Training Loss (HF Trainer)')
plt.xlabel('Log step')
plt.ylabel('Loss')
plt.show()

# Sample predictions
sample = dataset['test'].select(range(10))
batch = collate_fn([sample[i] for i in range(10)])
with torch.no_grad():
    logits = model(batch['pixel_values'])
preds = logits.logits.argmax(dim=1).cpu()
fig, axes = plt.subplots(2, 5, figsize=(8, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(sample[i]['image'])
    ax.set_title(f'T:{labels[sample[i]["labels"]]} / P:{labels[preds[i]]}')
    ax.axis('off')
plt.suptitle('ViT Sample Predictions')
plt.tight_layout()
plt.show()

# Confusion matrix on a small subset
y_true, y_pred = [], []
small = dataset['test'].select(range(200))
for i in range(0, len(small), 16):
    batch = collate_fn([small[j] for j in range(i, min(i + 16, len(small)))])
    with torch.no_grad():
        logits = model(batch['pixel_values'])
    preds = logits.logits.argmax(dim=1).cpu().tolist()
    y_pred.extend(preds)
    y_true.extend([small[j]["labels"] for j in range(i, min(i + 16, len(small)))])
import torch as _torch
cm = _torch.zeros(len(labels), len(labels), dtype=_torch.int64)
for t, p in zip(y_true, y_pred):
    cm[t, p] += 1
plt.figure(figsize=(5, 4))
plt.imshow(cm, cmap='Blues')
plt.title('Confusion Matrix (subset)')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.colorbar()
plt.tight_layout()
plt.show()

# Attention overlay
image = sample[0]['image']
inputs = processor(image, return_tensors='pt')
with torch.no_grad():
    outputs = model(**inputs, output_attentions=True)
attn = outputs.attentions[-1].mean(dim=1)[0]
attn_map = attn[0, 1:]
side = int(attn_map.numel() ** 0.5)
attn_map = attn_map.reshape(side, side)
attn_map = torch.nn.functional.interpolate(attn_map.unsqueeze(0).unsqueeze(0), size=image.size[::-1], mode='bilinear', align_corners=False)[0, 0]
plt.figure(figsize=(4, 4))
plt.imshow(image)
plt.imshow(attn_map, cmap='inferno', alpha=0.5)
plt.title('Attention Overlay')
plt.axis('off')
plt.show()


### Scale Up
- Train for 5-20 epochs and unfreeze all layers.
- Try larger image sizes (224) and stronger augmentation.


### Summary
- ViTs tokenize images into patches and use self-attention.
- Positional embeddings encode spatial order.
- Pretrained models adapt quickly with fine-tuning.

### Exercises
1. Compare frozen backbone vs full fine-tune.
2. Try different learning rates and batch sizes.
3. Visualize attention rollout using model outputs.

### Further Reading
- https://arxiv.org/abs/2010.11929 (ViT)
- https://arxiv.org/abs/2012.12877 (Swin)
