In [None]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.datasets import VOCDetection
from torch.utils.data import DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt
import matplotlib.patches as patches

classes = ["__background__", "apple", "banana", "orange"]

def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def collate_fn(batch):
    return tuple(zip(*batch))

transform = T.Compose([T.ToTensor()])

def target_transform(target):
    objs = target["annotation"]["object"]
    if not isinstance(objs, list):
        objs = [objs]
    boxes = []
    labels = []
    for obj in objs:
        name = obj["name"]
        bbox = obj["bndbox"]
        xmin, ymin, xmax, ymax = int(bbox["xmin"]), int(bbox["ymin"]), int(bbox["xmax"]), int(bbox["ymax"])
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(classes.index(name) if name in classes else 0)
    return {"boxes": torch.tensor(boxes, dtype=torch.float32), "labels": torch.tensor(labels, dtype=torch.int64)}

train_dataset = VOCDetection("./", year="2012", image_set="train", download=True, transform=transform, target_transform=target_transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

model = get_model(len(classes))
device = torch.device("cpu")
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {losses.item():.4f}")

model.eval()
images, _ = next(iter(train_loader))
img = images[0].to(device)
with torch.no_grad():
    prediction = model([img])

img_np = img.permute(1, 2, 0).numpy()
fig, ax = plt.subplots(1)
ax.imshow(img_np)
for box, label, score in zip(prediction[0]["boxes"], prediction[0]["labels"], prediction[0]["scores"]):
    if score > 0.5:
        xmin, ymin, xmax, ymax = box
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(xmin, ymin, classes[label], bbox=dict(facecolor='yellow', alpha=0.5))
plt.show()
