# objdetect

In [None]:
!pip3 install git+https://github.com/rpmcruz/objdetect.git

In [None]:
import torchvision
import torch
import numpy as np
import objdetect as od
import albumentations as A
from albumentations.pytorch import ToTensorV2
from time import time
import matplotlib.pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'
download = False
data_path = '/data'

## Data

Let's use PASCAL VOC, which already comes with `torchvision`.

In [None]:
class VOC(torch.utils.data.Dataset):
    def __init__(self, root, fold, transform=None, download=True):
        super().__init__()
        fold = 'test' if fold == 'val' else fold
        self.ds = torchvision.datasets.VOCDetection(root, image_set=fold, download=download)
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, i):
        image, xml = self.ds[i]
        image = np.array(image)
        objs = xml['annotation']['object']
        labels = [o['name'] for o in objs]
        bboxes = [(
            float(o['bndbox']['xmin']), float(o['bndbox']['ymin']),
            float(o['bndbox']['xmax']), float(o['bndbox']['ymax']),
            ) for o in objs]
        d = {'image': image, 'bboxes': bboxes, 'labels': labels}
        if self.transform:
            d = self.transform(**d)
        return d

Let's detect only certain classes, such as animals.

In [None]:
class FilterClass(torch.utils.data.Dataset):
    def __init__(self, ds, whitelist):
        super().__init__()
        self.ds = ds
        self.ix = [i for i in range(len(ds)) if any(label in whitelist for label in ds[i]['labels'])]
        self.whitelist = whitelist

    def __len__(self):
        return len(self.ix)

    def __getitem__(self, i):
        d = self.ds[self.ix[i]]
        d['bboxes'] = [bbox for label, bbox in zip(d['labels'], d['bboxes']) if label in self.whitelist]
        d['labels'] = [self.whitelist.index(label) for label in d['labels'] if label in self.whitelist]
        return d

Testing...

In [None]:
import matplotlib.pyplot as plt
animals = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep']
voc = VOC(data_path, 'train', download=download)
voc = FilterClass(voc, animals)
data = voc[0]
plt.imshow(data['image'])
od.draw.bboxes(data['bboxes'], labels=data['labels'])
plt.show()

## Model

We will implement the following model which is based on [FCOS](https://arxiv.org/abs/1904.01355). Although we are not going to support multi-scale grids or anchors here, we will separate the model into a `Grid` and `Model` class so that, if needed, you may more easily add grid multi-scale or anchors. (You may find implementations using multi-scale and anchors under the folder `implementations`.)

![Model diagram](model.svg)

In [None]:
bboxes_loss = torchvision.ops.generalized_box_iou_loss
centerness_loss = torch.nn.BCEWithLogitsLoss()
labels_loss = torchvision.ops.sigmoid_focal_loss

In [None]:
class Grid(torch.nn.Module):
    def __init__(self, in_channels, nclasses, img_size):
        super().__init__()
        self.img_size = img_size
        # like FCOS, we do not have a dedicated 'scores' prediction. it's just
        # the argmax of the classes.
        self.classes = torch.nn.Conv2d(in_channels, nclasses, 1)
        self.bboxes = torch.nn.Conv2d(in_channels, 4, 1)
        self.centerness = torch.nn.Conv2d(in_channels, 1, 1)

    def forward(self, x):
        # like FCOS, the network is predicting bboxes in relative terms, we need
        # to convert to absolute bboxes because the loss requires so.
        bboxes = torch.exp(self.bboxes(x))
        bboxes = od.transforms.rel_bboxes(bboxes, self.img_size)
        return {'labels': self.classes(x), 'bboxes': bboxes,
            'centerness': self.centerness(x)}

    def post_process(self, preds, threshold=0.05):
        scores, labels = torch.sigmoid(preds['labels']).max(1, keepdim=True)
        bboxes = preds['bboxes']
        centerness = torch.sigmoid(preds['centerness'])
        mask = scores[:, 0] >= threshold
        # like FCOS, centerness will help NMS choose the best bbox.
        scores = scores * centerness
        return {
            'scores': od.grid.mask_select(mask, scores, True),
            'bboxes': od.grid.mask_select(mask, bboxes, True),
            'labels': od.grid.mask_select(mask, labels, True),
        }

    def compute_loss(self, preds, targets):
        grid_size = preds['bboxes'].shape[2:]
        mask, indices = od.grid.where(od.grid.slice_all_center, targets['bboxes'], grid_size, self.img_size)
        # preds grid -> list
        pred_bboxes = od.grid.mask_select(mask, preds['bboxes'])
        pred_labels = od.grid.mask_select(mask, preds['labels'])
        pred_centerness = od.grid.mask_select(mask, preds['centerness'])
        # targets list -> list
        target_bboxes = od.grid.indices_select(indices, targets['bboxes'])
        target_labels = od.grid.indices_select(indices, targets['labels'])
        # labels: must be one-hot since we use independent classifiers
        target_labels = torch.nn.functional.one_hot(target_labels.long(),
            preds['labels'].shape[1]).float()
        # compute centerness: requires doing the transformation in grid-space
        target_bboxes_grid = od.grid.to_grid(mask, indices, targets['bboxes'])
        target_rel_bboxes = od.transforms.rel_bboxes(target_bboxes_grid, self.img_size)
        target_centerness = od.transforms.centerness(target_rel_bboxes)
        target_centerness = od.grid.mask_select(mask, target_centerness)
        # compute losses
        return bboxes_loss(pred_bboxes, target_bboxes).mean() + \
            labels_loss(pred_labels, target_labels).mean() + \
            centerness_loss(pred_centerness, target_centerness)

In [None]:
class Model(torch.nn.Module):
    def __init__(self, nclasses, img_size):
        super().__init__()
        resnet = torchvision.models.resnet50(weights='DEFAULT')
        self.backbone = torch.nn.Sequential(*list(resnet.children())[:-2])
        self.grid = Grid(2048, nclasses, img_size)

    def forward(self, x):
        x = self.backbone(x)
        return self.grid(x)

    def post_process(self, x):
        return self.grid.post_process(x)

    def compute_loss(self, preds, targets):
        return self.grid.compute_loss(preds, targets)

## Data Augmentation

In [None]:
img_size = (256, 256)
transform = A.Compose([
    A.Resize(int(img_size[0]*1.1), int(img_size[1]*1.1)),
    A.RandomCrop(*img_size),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=1),
    A.Normalize(),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))

## Training loop

In [None]:
animals = ['bird', 'cat', 'cow', 'dog', 'horse', 'sheep']
tr = VOC(data_path, 'train', transform, download)
tr = FilterClass(tr, animals)
tr = torch.utils.data.DataLoader(tr, 4, True, collate_fn=od.utils.collate_fn, num_workers=2, pin_memory=True)

In [None]:
K = len(animals)
model = Model(K, img_size).to(device)
opt = torch.optim.Adam(model.parameters(), 1e-4)

In [None]:
model.train()
epochs = 100
for epoch in range(epochs):
    tic = time()
    avg_loss = 0
    for images, targets in tr:
        targets['bboxes'] = [bb.float() for bb in targets['bboxes']]
        targets = {k: [v.to(device) for v in l] for k, l in targets.items()}
        preds = model(images.to(device))
        loss_value = model.compute_loss(preds, targets)
        opt.zero_grad()
        loss_value.backward()
        opt.step()
        avg_loss += float(loss_value) / len(tr)
    toc = time()
    print(f'Epoch {epoch+1}/{epochs} - {toc-tic:.0f}s - Avg loss: {avg_loss}')

If you wish to evaluate the results, you may do so after the model is trained, or even inside the training loop...

In [None]:
model.eval()
images, targets = next(iter(tr))
preds_grid = model(images.to(device))
preds = model.post_process(preds_grid)
preds = od.post.NMS(preds)
i = 0
mean = torch.tensor([0.485, 0.456, 0.406])[None, None]
std = torch.tensor([0.229, 0.224, 0.225])[None, None]
plt.clf()
plt.imshow(images[i].permute(1, 2, 0)*std+mean)
od.draw.bboxes(preds['bboxes'][i].detach().cpu(), labels=[f'{int(l)} ({int(s*100)})' for l, s in zip(preds['labels'][i], preds['scores'][i])], color='cyan')
od.draw.bboxes(targets['bboxes'][i], labels=[int(l) for l in targets['labels'][i]])
plt.suptitle(f'Epoch {epoch+1} - Avg loss: {avg_loss}')
plt.show()

If you like, you can move the previous code block to inside the training loop itself. In such a case, we recommend either saving each image `plt.savefig()` or replace `plt.show()` with the following code to display it in a non-blocking fashion:

In [None]:
plt.show(block=False)
plt.pause(0.1)