In [None]:
import torch
import torchvision
import objdetect as od
from tqdm import tqdm

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

## Data augmentation

In [None]:
transformations = od.aug.Compose([
    od.aug.Resize(int(256*1.1), int(256*1.1)),
    od.aug.RandomCrop(256, 256),
    od.aug.RandomHflip(),
    od.aug.RandomBrightnessContrast(0.1, 0.1),
    od.aug.Normalize(),
])

## Dataset loader

In [None]:
ds = od.data.VOCDetection('data', 'train', transformations, download=True)

Let's look at the first sample:

In [None]:
d = ds[0]
print(d.keys())

Each sample is composed by an 'image', and respective objects represented by their 'bboxes' and 'classes'. The 'bboxes' are in the format x1y1x2y2 and are 0-1 normalized.

In [None]:
od.plot.image(d['image'])
od.plot.grid_lines(d['image'], 8, 8)
od.plot.bboxes(d['image'], d['bboxes'])
od.plot.classes(d['image'], d['bboxes'], d['classes'], ds.labels)
od.plot.show()

Naturally, the number of bounding boxes varies for each image, therefore they cannot be turned into tensors, so we need to specify a `collate` function for how the batches should be created.

In [None]:
tr = torch.utils.data.DataLoader(ds, 16, True, collate_fn=od.data.collate_fn)

## Model

We will prepare a one-stage model where for each location in the grid predicts: if there is an object (score), and if so, what is the object class and bounding box. Like the object detection models that come with torchvision (see e.g. [FCOS](https://pytorch.org/vision/stable/models/generated/torchvision.models.detection.fcos_resnet50_fpn.html#torchvision.models.detection.fcos_resnet50_fpn)), the behavior changes if in `train` or `eval` mode, but we don't do exactly what they do. In `train` mode, we return the *unprocessed* scores/classes/bboxes grids. In `eval` mode, we return the *processed* classes/bboxes in the form of a list.

![](model.svg)

In [None]:
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = torchvision.models.vgg16(weights='DEFAULT').features
        self.scores = torch.nn.Conv2d(512, 1, 1)
        self.classes = torch.nn.Conv2d(512, 20, 1)
        self.bboxes = torch.nn.Conv2d(512, 4, 1)

    def forward(self, x):
        x = self.backbone(x)
        scores = self.scores(x)
        classes = self.classes(x)
        bboxes = self.bboxes(x)
        if not self.training:
            # when in evaluation mode, convert the output grid into a list of bboxes/classes
            scores = torch.sigmoid(scores)
            hasobjs = scores >= 0.5
            scores = inv_scores(hasobjs, scores)
            bboxes = od.grid.inv_offset_logsize_bboxes(hasobjs, bboxes)
            classes = od.grid.inv_classes(hasobjs, classes)
            bboxes, classes = od.post.NMS(probs, bboxes, classes)
            return bboxes, classes
        return scores, bboxes, classes

## Training

In [None]:
model = MyModel().to(device)
scores_loss = torch.nn.BCEWithLogitsLoss()
bboxes_loss = torch.nn.MSELoss(reduction='none')
classes_loss = torch.nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters())
epochs = 10

In [None]:
model.train()
for epoch in range(epochs):
    avg_loss = 0
    for imgs, targets in tqdm(tr, leave=False):
        imgs = imgs.to(device)
        preds_scores, preds_bboxes, preds_classes = model(imgs)

        slices = od.grid.slices_center_locations(8, 8, targets['bboxes'])
        scores = od.grid.scores(8, 8, slices).to(device)
        bboxes = od.grid.offset_logsize_bboxes(8, 8, slices, targets['bboxes']).to(device)
        classes = od.grid.classes(8, 8, slices, targets['classes']).to(device)

        loss_value = \
            scores_loss(preds_scores, scores) + \
            (scores * bboxes_loss(preds_bboxes, bboxes)).mean() + \
            (scores * classes_loss(preds_classes, classes)).mean()
        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()
        avg_loss += float(loss_value) / len(tr)
    print(f'Epoch {epoch+1}/{epochs} - Avg loss: {avg_loss}')

In [None]:
model.eval()
model(ds[0][0][None].to(device))