[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](YOUR_COLAB_LINK_HERE)

# 03 Hugging Face Vision Fine-Tuning
## Objectives
- Use datasets + transformers + evaluate.
- Fine-tune ViT and Swin with Trainer.
- Save checkpoints and visualize predictions.


In [None]:
!pip -q install torch torchvision transformers datasets evaluate accelerate matplotlib

## GPU + mixed precision tips

Enable GPU in Colab, then set fp16=True in TrainingArguments for faster training.

In [None]:
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification, Trainer, TrainingArguments
import numpy as np
import torch
import matplotlib.pyplot as plt

dataset = load_dataset('cifar10')
processor = AutoImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224')

def preprocess(example):
    inputs = processor(example['image'], return_tensors='pt')
    example['pixel_values'] = inputs['pixel_values'][0]
    example['labels'] = example['label']
    return example

dataset = dataset.with_transform(preprocess)
labels = dataset['train'].features['label'].names

model = AutoModelForImageClassification.from_pretrained(
    'microsoft/swin-tiny-patch4-window7-224',
    num_labels=len(labels),
    id2label={i: l for i, l in enumerate(labels)},
    label2id={l: i for i, l in enumerate(labels)},
)

def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.tensor([item['labels'] for item in batch])
    return {'pixel_values': pixel_values, 'labels': labels}

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return {'accuracy': (preds == labels).mean()}

args = TrainingArguments(
    output_dir='../outputs/swin_finetune',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    fp16=torch.cuda.is_available(),
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset['train'].select(range(2000)),
    eval_dataset=dataset['test'].select(range(500)),
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model('../outputs/swin_finetune')


## How to adapt to your own dataset

- Use `datasets.load_dataset('imagefolder', data_dir=...)`.
- Update label mappings from dataset features.
- Adjust image processor size or normalization.

## Detection demo: Faster R-CNN (CNN) vs DETR (Transformer)
We'll run inference on a Penn-Fudan Pedestrian sample and draw boxes.


In [None]:
import torchvision
from torchvision.transforms import functional as F
from PIL import Image

dataset_det = torchvision.datasets.PennFudanPed(root='../data', download=True)
img, _ = dataset_det[0]
img_tensor = F.to_tensor(img)

# Faster R-CNN (CNN-based)
frcnn = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).eval()
with torch.no_grad():
    preds = frcnn([img_tensor])[0]

boxes = preds['boxes'][:3].cpu().numpy().tolist()
labels = ['person'] * len(boxes)
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(img)
for box in boxes:
    x0, y0, x1, y1 = box
    rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, color='lime', linewidth=2)
    ax.add_patch(rect)
ax.set_title('Faster R-CNN predictions')
ax.axis('off')
plt.show()

# DETR (Transformer-based) via Hugging Face
from transformers import DetrForObjectDetection, DetrImageProcessor
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
detr = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50').eval()
inputs = processor(images=img, return_tensors='pt')
with torch.no_grad():
    outputs = detr(**inputs)
target_sizes = torch.tensor([img.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(img)
for score, label, box in zip(results['scores'], results['labels'], results['boxes']):
    x0, y0, x1, y1 = box.tolist()
    rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, color='cyan', linewidth=2)
    ax.add_patch(rect)
    ax.text(x0, y0, f'{detr.config.id2label[label.item()]}:{score:.2f}', color='white', fontsize=8, bbox=dict(facecolor='black', alpha=0.6))
ax.set_title('DETR predictions')
ax.axis('off')
plt.show()


## Segmentation demo: DeepLabV3 (CNN) vs SegFormer (Transformer)


In [None]:
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms import functional as F
import numpy as np

dataset_seg = OxfordIIITPet(root='../data', download=True, target_types='segmentation')
img, mask = dataset_seg[0]
img_tensor = F.to_tensor(img)

deeplab = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True).eval()
with torch.no_grad():
    out = deeplab(img_tensor.unsqueeze(0))['out']
pred_mask = out.argmax(dim=1)[0].cpu().numpy()

plt.figure(figsize=(10, 3))
plt.subplot(1, 3, 1)
plt.imshow(img)
plt.title('Image')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(mask, cmap='gray')
plt.title('GT Mask')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(pred_mask, cmap='gray')
plt.title('DeepLabV3 Pred')
plt.axis('off')
plt.tight_layout()
plt.show()

from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
processor = SegformerImageProcessor.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512')
segformer = SegformerForSemanticSegmentation.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512').eval()
inputs = processor(images=img, return_tensors='pt')
with torch.no_grad():
    outputs = segformer(**inputs)
logits = outputs.logits
pred = logits.argmax(dim=1)[0].cpu().numpy()
plt.figure(figsize=(5, 4))
plt.imshow(img)
plt.imshow(pred, alpha=0.5, cmap='jet')
plt.title('SegFormer Overlay')
plt.axis('off')
plt.show()


### Scale Up
- Fine-tune for 5+ epochs on a larger subset.
- Try mixed precision and gradient accumulation.


### Summary
- Hugging Face Trainer handles logging, eval, and checkpoints.
- Swin is hierarchical and scales well to images.

### Exercises
1. Fine-tune ViT vs Swin on the same subset.
2. Try fp16 vs full precision and compare speed.
3. Load SegFormer and run inference on a pet image.

### Further Reading
- https://huggingface.co/docs/transformers
- https://huggingface.co/docs/datasets
