## Model

In [1]:
import torch
import torch.nn as nn

""" 
Information about architecture config:
Tuple is structured by (filters, kernel_size, stride) 
Every conv is a same convolution. 
List is structured by "B" indicating a residual block followed by the number of repeats
"S" is for scale prediction block and computing the yolo loss
"U" is for upsampling the feature map and concatenating with a previous layer
"""
config = [
    (32, 3, 1),
    (64, 3, 2),
    ["B", 1],
    (128, 3, 2),
    ["B", 2],
    (256, 3, 2),
    ["B", 8],
    (512, 3, 2),
    ["B", 8],
    (1024, 3, 2),
    ["B", 4],  # To this point is Darknet-53
    (512, 1, 1),
    (1024, 3, 1),
    "S",
    (256, 1, 1),
    "U",
    (256, 1, 1),
    (512, 3, 1),
    "S",
    (128, 1, 1),
    "U",
    (128, 1, 1),
    (256, 3, 1),
    "S",
]


class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = nn.LeakyReLU(0.1)
        self.use_bn_act = bn_act

    def forward(self, x):
        if self.use_bn_act:
            return self.leaky(self.bn(self.conv(x)))
        else:
            return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels, use_residual=True, num_repeats=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for repeat in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    CNNBlock(channels, channels // 2, kernel_size=1),
                    CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
                )
            ]

        self.use_residual = use_residual
        self.num_repeats = num_repeats

    def forward(self, x):
        for layer in self.layers:
            if self.use_residual:
                x = x + layer(x)
            else:
                x = layer(x)

        return x


class ScalePrediction(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pred = nn.Sequential(
            CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
            CNNBlock(
                2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
            ),
        )
        self.num_classes = num_classes

    def forward(self, x):
        return (
            self.pred(x)
            .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
            .permute(0, 1, 3, 4, 2)
        )


class YOLOv3(nn.Module):
    def __init__(self, in_channels=3, num_classes=80):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.layers = self._create_conv_layers()

    def forward(self, x):
        outputs = []  # for each scale
        route_connections = []
        for layer in self.layers:
            if isinstance(layer, ScalePrediction):
                outputs.append(layer(x))
                continue

            x = layer(x)

            if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
                route_connections.append(x)

            elif isinstance(layer, nn.Upsample):
                x = torch.cat([x, route_connections[-1]], dim=1)
                route_connections.pop()

        return outputs

    def _create_conv_layers(self):
        layers = nn.ModuleList()
        in_channels = self.in_channels

        for module in config:
            if isinstance(module, tuple):
                out_channels, kernel_size, stride = module
                layers.append(
                    CNNBlock(
                        in_channels,
                        out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=1 if kernel_size == 3 else 0,
                    )
                )
                in_channels = out_channels

            elif isinstance(module, list):
                num_repeats = module[1]
                layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))

            elif isinstance(module, str):
                if module == "S":
                    layers += [
                        ResidualBlock(in_channels, use_residual=False, num_repeats=1),
                        CNNBlock(in_channels, in_channels // 2, kernel_size=1),
                        ScalePrediction(in_channels // 2, num_classes=self.num_classes),
                    ]
                    in_channels = in_channels // 2

                elif module == "U":
                    layers.append(nn.Upsample(scale_factor=2),)
                    in_channels = in_channels * 3

        return layers

In [2]:
# test model
model = YOLOv3(num_classes=20)

dummy_tensor = torch.randn((2, 3, 416, 416))

outputs = model(dummy_tensor)

assert outputs[0].shape == (2, 3, 416 // 32, 416 // 32, 20 + 5)
assert outputs[1].shape == (2, 3, 416 // 16, 416 // 16, 20 + 5)
assert outputs[2].shape == (2, 3, 416 // 8, 416 // 8, 20 + 5)
print("Success!")

Success!


# Inference

In [3]:
import cv2
import torch
import numpy as np

def _resize(image: np.ndarray, imsize=416) -> np.ndarray:
    ratio = imsize / max(image.shape)
    image = cv2.resize(image, (0, 0), fx=ratio, fy=ratio)
    return image

def _pad_to_square(image: np.ndarray) -> np.ndarray:
    height, width = image.shape[:2]
    max_size = max(height, width)
    image = np.pad(image, ((0, max_size - height), (0, max_size - width), (0, 0)))
    return image

def preprocess(images, imsize=416, mean=[0, 0, 0], std=[1, 1, 1], device='cpu'):
    mean = torch.tensor(mean, dtype=torch.float, device=device).view(1, 3, 1, 1)
    std = torch.tensor(std, dtype=torch.float, device=device).view(1, 3, 1, 1)

    samples = [_resize(image, imsize=imsize) for image in images]
    samples = [_pad_to_square(sample) for sample in samples]
    samples = [cv2.cvtColor(sample, cv2.COLOR_BGR2RGB) for sample in samples]
    samples = [torch.from_numpy(sample) for sample in samples]
    samples = torch.stack(samples, dim=0).to(device)
    samples = samples.permute(0, 3, 1, 2).contiguous()
    samples = (samples.float().div(255.) - mean) / std

    scales = [max(image.shape[:2]) / imsize for image in images]

    return images, scales, samples

In [14]:
from torchvision import ops

def postprocess(preds, anchors, imsize=416, iou_threshold=0.5, score_threshold=0.05):
    '''get all boxes at gird S x S (grid_size = imsize / S)
    Args:
        preds: Tuple[[N x 3 x S x S x (tp, tx, ty, tw, th, n_classes)]] with S = [13, 26, 52]
        anchors: [3 x 3 x 2] (pw, ph with size in [0, 1])  (relative to imsize)
    Outputs:
        scores: [N x (3 * (S1 * S1 + S2 * S2 + S3 * S3))]
        labels: [N x (3 * (S1 * S1 + S2 * S2 + S3 * S3))]
        bboxes: [N x (3 * (S1 * S1 + S2 * S2 + S3 * S3)) x 4], with [x1 y1 x2 y2].
    '''
    predictions = []
    for i in range(3):
        pred, anchor = preds[i], anchors[i]  # pred: N x 3 x S x S x (5 + C), anchor: 3 x 2
        N, S, device = pred.shape[0], pred.shape[2], pred.device
        
        # anchor: 1 x 3 x 1 x 1 x 2
        anchor = torch.tensor(anchor, device=device, dtype=torch.float)
        anchor = anchor.reshape(1, 3, 1, 1, 2)

        # N x num_anchors x S x S x 1
        x_indices = torch.arange(S).repeat(N, 3, S, 1)
        x_indices = x_indices.unsqueeze(dim=-1).to(device)
        y_indices = x_indices.permute(0, 1, 3, 2, 4)

        # N x num_anchors x S x S -> N x (num_anchors * S * S)
        scores = torch.sigmoid(pred[..., 0]).reshape(N, 3 * S * S)

        # N x num_anchors x S x S -> N x (num_anchors * S * S)
        labels = torch.argmax(pred[..., 5:], dim=-1)
        labels = labels.reshape(N, 3 * S * S)

        # N x num_anchors x S x S x 2 -> N x (num_anchors * S * S) x 2
        x = (torch.sigmoid(pred[..., 1:2]) + x_indices) * (imsize / S)
        x = x.reshape(N, 3 * S * S, 1)
        
        y = (torch.sigmoid(pred[..., 2:3]) + y_indices) * (imsize / S)
        y = y.reshape(N, 3 * S * S, 1)

        xy = torch.cat([x, y], dim=-1)

        # N x num_anchors x S x S x 2 -> N x (num_anchors * S * S) x 2
        wh = anchor * torch.exp(pred[..., 3:5]) * imsize
        wh = wh.reshape(N, 3 * S * S, 2)

        # N x (num_anchors * S * S) x 4
        boxes = torch.cat([xy - wh / 2, xy + wh / 2], dim=-1)
        boxes[:, :, 0:1] = torch.clamp(boxes[:, :, 0:1], min=0)
        boxes[:, :, 2:3] = torch.clamp(boxes[:, :, 2:3], max=imsize)

        predictions.append([labels, scores, boxes])
    
    batch_labels = torch.cat([pred[0] for pred in predictions], dim=1)  # [N x (3 * (S1 * S1 + S2 * S2 + S3 * S3))]
    batch_scores = torch.cat([pred[1] for pred in predictions], dim=1)  # [N x (3 * (S1 * S1 + S2 * S2 + S3 * S3))]
    batch_boxes = torch.cat([pred[2] for pred in predictions], dim=1)  # [N x (3 * (S1 * S1 + S2 * S2 + S3 * S3)) x 4]

    predictions = []
    for i in range(preds[0].shape[0]):
        scores_over_thresh = batch_scores[i, :] > score_threshold
        if scores_over_thresh.sum() == 0:
            predictions.append({'boxes': torch.FloatTensor([[0, 0, 1, 1]]),
                                'labels': torch.FloatTensor([-1]),
                                'scores': torch.FloatTensor([0])})
            continue

        labels = batch_labels[i, scores_over_thresh]
        boxes = batch_boxes[i, scores_over_thresh, :]
        scores = batch_scores[i, scores_over_thresh]

        nms_idx = ops.boxes.batched_nms(boxes, scores, labels, iou_threshold=iou_threshold)

        if nms_idx.shape[0] != 0:
            labels = labels[nms_idx]
            scores = scores[nms_idx]
            boxes = boxes[nms_idx, :]

            predictions.append({'boxes': boxes, 'labels': labels, 'scores': scores})
        else:
            predictions.append({'boxes': torch.FloatTensor([[0, 0, 1, 1]]),
                                'labels': torch.FloatTensor([-1]),
                                'scores': torch.FloatTensor([0])})


    return predictions

In [15]:
# load pretrained weight
model = YOLOv3(num_classes=20)
state_dict = torch.load(f='checkpoint/pretrained_weight/78.1map_0.2threshold_PASCAL.tar', map_location='cpu')
model.load_state_dict(state_dict=state_dict['state_dict'])
model = model.eval()

In [45]:
# preprocessing
image_paths = [
    '../efficient_det_pytorch/dataset/PASCALVOC2007/VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007/JPEGImages/000005.jpg',
    '../efficient_det_pytorch/dataset/PASCALVOC2007/VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007/JPEGImages/000012.jpg',
    '../efficient_det_pytorch/dataset/PASCALVOC2007/VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007/JPEGImages/000016.jpg',
]

images = [cv2.imread(image_path) for image_path in image_paths]

images, scales, samples = preprocess(images)

In [46]:
# prediction
with torch.no_grad():
    preds = model(samples)

print(f'Input Shape: {samples.shape}')    
print(f'Output Shape at S={preds[0].shape[2]}: {preds[0].shape}')
print(f'Output Shape at S={preds[1].shape[2]}: {preds[1].shape}')
print(f'Output Shape at S={preds[2].shape[2]}: {preds[2].shape}')

Input Shape: torch.Size([3, 3, 416, 416])
Output Shape at S=13: torch.Size([3, 3, 13, 13, 25])
Output Shape at S=26: torch.Size([3, 3, 26, 26, 25])
Output Shape at S=52: torch.Size([3, 3, 52, 52, 25])


In [47]:
# postprocessing
iou_threshold = 0.45
score_threshold = 0.05

imsize = 416
S = [13, 26, 52]
anchors = [[(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],    # S = 13
           [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],   # S = 26
           [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],]  # S = 52

In [48]:
predictions = postprocess(preds=preds, anchors=anchors, imsize=imsize, iou_threshold=0.5, score_threshold=0.5)

In [49]:
predictions

[{'boxes': tensor([[135.5854, 210.2635, 222.2740, 310.1468],
          [204.4330, 161.3284, 261.8168, 271.7595],
          [  4.1186, 195.3734,  63.7083, 307.3726],
          [229.5404, 154.1949, 261.5659, 182.4224],
          [261.1333, 149.7752, 289.5691, 179.2450],
          [194.2008, 159.0178, 245.7994, 249.7052],
          [179.5570, 156.9499, 226.7502, 219.2694],
          [216.1679, 159.5498, 269.2601, 233.0432],
          [231.1824, 154.0324, 254.3342, 173.7911],
          [279.5289, 152.8962, 305.2201, 177.5389],
          [377.7064, 150.8492, 413.7950, 320.2241],
          [253.7750, 155.2566, 282.9843, 186.1714],
          [378.3737, 141.9657, 413.5199, 313.8051],
          [220.0476, 154.6475, 266.1848, 203.0758]]),
  'labels': tensor([ 8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8, 14,  8]),
  'scores': tensor([0.8458, 0.8184, 0.8114, 0.7521, 0.7256, 0.6656, 0.6551, 0.6264, 0.5843,
          0.5616, 0.5572, 0.5501, 0.5443, 0.5120])},
 {'boxes': tensor([[126.4862,  70.2691

In [50]:
classes2idx = {'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3, 'bottle': 4,
               'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9, 'diningtable': 10,
               'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14, 'pottedplant': 15,
               'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19}
classes = list(classes2idx.keys())

In [51]:
for image, scale, pred in zip(images, scales, predictions):
    thickness = max(image.shape) // 600
    fontscale = max(image.shape) / 700
    boxes = pred['boxes'].cpu().numpy()
    labels = pred['labels'].cpu().numpy()
    scores = pred['scores'].cpu().numpy()
    class_names = [classes[label] for label in labels]
    boxes[:, [0, 2]] = boxes[:, [0, 2]] * scale
    boxes[:, [1, 3]] = boxes[:, [1, 3]] * scale
    boxes = boxes.astype(np.int32)
    for box, score, class_name in zip(boxes, scores, class_names):
        color = (np.random.randint(200, 255),
                 np.random.randint(50, 200),
                 np.random.randint(0, 150))
#         if label != -1:
        cv2.rectangle(
            img=image,
            pt1=tuple(box[:2]),
            pt2=tuple(box[2:]),    
            color=color,
            thickness=thickness
        )

        cv2.putText(
            img=image,
            text=f'{class_name}: {score: .4f}',
            org=tuple(box[:2]),
            fontFace=cv2.FONT_HERSHEY_PLAIN,
            fontScale=fontscale,
            color=color,
            thickness=thickness,
            lineType=cv2.LINE_AA)

        cv2.imshow(class_name, image)
        cv2.waitKey()
        cv2.destroyAllWindows()