In [1]:
import torch
import torchvision
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np

In [2]:
class CIFAR10Detection(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.cifar10 = torchvision.datasets.CIFAR10(root=root, train=train, download=True, transform=transform)
        self.classes = self.cifar10.classes

    def __len__(self):
        return len(self.cifar10)

    def __getitem__(self, idx):
        img, label = self.cifar10[idx]
        # Create a bounding box covering the entire image
        boxes = torch.tensor([[0, 0, img.size(1), img.size(2)]], dtype=torch.float32)
        labels = torch.tensor([label + 1], dtype=torch.int64)  # +1 because 0 is background
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        return img, target

In [3]:
# Data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Datasets
train_dataset = CIFAR10Detection(root='./data', train=True, transform=transform)
test_dataset = CIFAR10Detection(root='./data', train=False, transform=transform)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 30327225.61it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
# Load pre-trained Faster R-CNN
model = fasterrcnn_resnet50_fpn(pretrained=True)

# Get the number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# Replace the pre-trained head with a new one for CIFAR-10 (10 classes + background)
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=11)  # 10 classes + background

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /home/ymj68520/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:04<00:00, 34.0MB/s] 


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [5]:
# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# Learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
# Training function
def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    total_loss = 0
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        total_loss += losses.item()

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

    return total_loss / len(data_loader)

# Training loop
num_epochs = 1  # Reduced for quick test
for epoch in range(num_epochs):
    loss = train_one_epoch(model, optimizer, train_loader, device, epoch)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")
    lr_scheduler.step()

In [None]:
# Evaluation function
def evaluate(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, targets in data_loader:
            images = list(image.to(device) for image in images)
            outputs = model(images)
            for output, target in zip(outputs, targets):
                pred_labels = output['labels'].cpu().numpy()
                true_labels = target['labels'].cpu().numpy()
                if len(pred_labels) > 0:
                    pred = pred_labels[0] - 1  # -1 to match CIFAR classes
                    true = true_labels[0] - 1
                    if pred == true:
                        correct += 1
                total += 1
    accuracy = correct / total
    return accuracy

# Evaluate on test set
accuracy = evaluate(model, test_loader, device)
print(f"Test Accuracy: {accuracy:.4f}")

In [None]:
# Visualize some predictions
model.eval()
with torch.no_grad():
    for images, targets in test_loader:
        images = list(image.to(device) for image in images)
        outputs = model(images)
        break  # Only first batch

# Plot
fig, axes = plt.subplots(1, 4, figsize=(12, 3))
for i in range(4):
    img = images[i].cpu().permute(1, 2, 0).numpy()
    axes[i].imshow(img)
    pred_label = outputs[i]['labels'][0].item() - 1 if len(outputs[i]['labels']) > 0 else -1
    true_label = targets[i]['labels'][0].item() - 1
    axes[i].set_title(f"Pred: {train_dataset.classes[pred_label] if pred_label >= 0 else 'None'}\nTrue: {train_dataset.classes[true_label]}")
    axes[i].axis('off')
plt.show()