
# Faster R-CNN Fine-Tuning on VOC-style Annotations (PyTorch)

This notebook provides an **end-to-end pipeline**:
- Load VOC-style data (JPEGImages/ + Annotations/*.xml + ImageSets/Main/train/val.txt).
- Build and modify Faster R-CNN head for your classes.
- Train, validate (loss-based), and save weights.
- Run inference and visualize detections.


In [None]:

from pathlib import Path
import os, xml.etree.ElementTree as ET
import torch, torchvision
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

# ==== User config ====
DATA_ROOT = Path("dataset")           # VOC-style root
CLASSES   = ["person","car","dog"]    # change to your classes
BATCH_SIZE = 2
NUM_EPOCHS = 10
LR = 0.005
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ==== End config ====


## Dataset (VOC XML)

In [None]:

class VOCDataset(Dataset):
    def __init__(self, root: Path, image_set="train", classes=None, transforms=None):
        self.root = root
        self.transforms = transforms
        self.img_dir = root/"JPEGImages"
        self.ann_dir = root/"Annotations"
        ids_file = root/"ImageSets"/"Main"/f"{image_set}.txt"
        with open(ids_file) as f:
            self.ids = [x.strip() for x in f.readlines() if x.strip()]
        self.class_to_idx = {c:i+1 for i,c in enumerate(classes)} if classes else None

    def __len__(self): return len(self.ids)

    def _parse_xml(self, xml_path: Path):
        tree = ET.parse(xml_path)
        root = tree.getroot()
        boxes, labels = [], []
        for obj in root.findall("object"):
            name = obj.find("name").text
            bnd = obj.find("bndbox")
            x1 = float(bnd.find("xmin").text)
            y1 = float(bnd.find("ymin").text)
            x2 = float(bnd.find("xmax").text)
            y2 = float(bnd.find("ymax").text)
            boxes.append([x1,y1,x2,y2])
            labels.append(self.class_to_idx.get(name, 1) if self.class_to_idx else 1)
        return boxes, labels

    def __getitem__(self, i):
        img_id = self.ids[i]
        img_path = self.img_dir/f"{img_id}.jpg"
        ann_path = self.ann_dir/f"{img_id}.xml"
        img = Image.open(img_path).convert("RGB")
        boxes, labels = self._parse_xml(ann_path)
        import torch
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        target = {"boxes": boxes, "labels": labels, "image_id": torch.tensor([i])}
        if self.transforms:
            img = self.transforms(img)
        else:
            img = torchvision.transforms.ToTensor()(img)
        return img, target

def collate_fn(batch):
    return tuple(zip(*batch))


In [None]:

train_set = VOCDataset(DATA_ROOT, "train", classes=CLASSES)
val_set   = VOCDataset(DATA_ROOT, "val",   classes=CLASSES)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
len(train_set), len(val_set)


## Model: Faster R-CNN + FPN (modify head)

In [None]:

model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")
# Replace head
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, len(CLASSES)+1)
model = model.to(DEVICE)

# (Optional) freeze backbone for warmup
for name,p in model.backbone.body.named_parameters():
    p.requires_grad = True  # set False to freeze for first few epochs


## Train

In [None]:

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=LR, momentum=0.9, weight_decay=1e-4)
lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

def to_device(batch):
    imgs, tgts = batch
    imgs = [im.to(DEVICE) for im in imgs]
    tgts = [{k:v.to(DEVICE) for k,v in t.items()} for t in tgts]
    return imgs, tgts

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        imgs, tgts = to_device(batch)
        loss_dict = model(imgs, tgts)
        loss = sum(loss_dict.values())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    lr_sched.step()

    # simple val loss
    model.train()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            imgs, tgts = to_device(batch)
            ld = model(imgs, tgts)
            val_loss += sum(ld.values()).item()

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | train_loss={total_loss:.3f} | val_loss={val_loss:.3f}")

torch.save(model.state_dict(), OUTPUT_DIR/"frcnn_voc.pt")
print("Saved:", OUTPUT_DIR/"frcnn_voc.pt")


## Inference & Visualization

In [None]:

# Pick one image id from val set file
val_ids = (DATA_ROOT/"ImageSets"/"Main"/"val.txt").read_text().strip().splitlines()
if len(val_ids) == 0:
    print("No val ids. Please populate ImageSets/Main/val.txt")
else:
    sample_id = val_ids[0]
    img_path = DATA_ROOT/"JPEGImages"/f"{sample_id}.jpg"
    img = Image.open(img_path).convert("RGB")
    tensor = torchvision.transforms.ToTensor()(img).to(DEVICE)
    model.eval()
    with torch.no_grad():
        out = model([tensor])[0]

    draw = img.copy()
    d = ImageDraw.Draw(draw)
    boxes = out["boxes"].detach().cpu().numpy()
    labels = out["labels"].detach().cpu().numpy()
    scores = out["scores"].detach().cpu().numpy()

    for b,l,s in zip(boxes, labels, scores):
        if s < 0.5: 
            continue
        x1,y1,x2,y2 = b
        d.rectangle([x1,y1,x2,y2], outline=1, width=3)  # outline=1 uses default color
        name = CLASSES[l-1] if 1 <= l <= len(CLASSES) else str(l)
        d.text((x1,y1), f"{name}:{s:.2f}")
    display(draw)
