In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transforms for the dataset
transform = transforms.Compose([
    transforms.Resize((600, 600)),
    transforms.ToTensor(),
])

# Load PASCAL VOC dataset
train_dataset = torchvision.datasets.VOCDetection(
    root="./data/pascal",
    year="2012",
    image_set="train",
    download=True,
    transform=transform
)
val_dataset = torchvision.datasets.VOCDetection(
    root="./data/pascal",
    year="2012",
    image_set="val",
    download=True,
    transform=transform
)

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))



100%|██████████| 2.00G/2.00G [15:05<00:00, 2.21MB/s] 


In [None]:
# pretrained rcnn model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.to(device)

# Define loss functions
criterion_cls = nn.CrossEntropyLoss()
criterion_bbox = nn.SmoothL1Loss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def process_targets(targets):
    bbox_list, label_list = [], []
    for target in targets:
        objects = target["annotation"]["object"]
        if not isinstance(objects, list):
            objects = [objects] 
        
        boxes = []
        labels = []
        for obj in objects:
            bbox = [
                float(obj["bndbox"]["xmin"]),
                float(obj["bndbox"]["ymin"]),
                float(obj["bndbox"]["xmax"]),
                float(obj["bndbox"]["ymax"]),
            ]
            boxes.append(bbox)
            labels.append(int(obj["name"]) if obj["name"].isdigit() else 1)
        
        bbox_list.append(torch.tensor(boxes, dtype=torch.float32))
        label_list.append(torch.tensor(labels, dtype=torch.long))
    
    return bbox_list, label_list



In [None]:
# Training loop
for epoch in range(5):
    model.train()
    train_loss = 0.0    
    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        bbox_targets, label_targets = process_targets(targets)

        # Move targets to device
        bbox_targets = [b.to(device) for b in bbox_targets]
        label_targets = [l.to(device) for l in label_targets]

        # Format targets properly
        targets = [{"boxes": b, "labels": l} for b, l in zip(bbox_targets, label_targets)]

        # Forward pass
        loss_dict = model(images, targets)  # ✅ Returns a dictionary of losses

        # Compute total loss
        loss = sum(loss_dict.values())  # ✅ Sum all components

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/5] - Training Loss: {train_loss / len(train_loader):.4f}")
    
    # Validation Phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, targets in val_loader:
            images = [img.to(device) for img in images]
            bbox_targets, label_targets = process_targets(targets)

            # Move targets to device
            bbox_targets = [b.to(device) for b in bbox_targets]
            label_targets = [l.to(device) for l in label_targets]

            # Format targets properly
            targets = [{"boxes": b, "labels": l} for b, l in zip(bbox_targets, label_targets)]

            # Forward pass
            loss_dict = model(images, targets)  # ✅ Model returns losses directly

            # Compute total loss
            loss = sum(loss_dict.values())  # ✅ Sum all loss components

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

    
        print(f"Epoch [{epoch+1}/5] - Validation Loss: {val_loss / len(val_loader):.4f}")