In [44]:
import torch
import torchvision

from collections import OrderedDict
from typing import Tuple, List

from torch import nn
from torchvision.models import mobilenet_v2
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

import torchvision.transforms as transforms
from torchvision.utils import draw_bounding_boxes

In [35]:
backbone = mobilenet_v2(pretrained=True).features
backbone.out_channels = 1280

In [36]:
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                   aspect_ratios=((0.5, 1.0, 2.0),))

In [37]:
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                output_size=7,
                                                sampling_ratio=2)

In [38]:
model = FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)

In [39]:
model.eval()

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): Sequential(
    (0): ConvNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(


In [40]:
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400), torch.rand(3, 350, 350)]

In [41]:
predictions = model(x)

In [48]:
print(type(predictions[0]["boxes"]))

<class 'torch.Tensor'>


In [62]:
torch.manual_seed(17)
batch_size = 64

all_transforms = transforms.Compose([transforms.ToTensor()])

train_dataset = torchvision.datasets.Kitti(
    root="./data", train=True, transform=all_transforms, download=True
)
test_dataset = torchvision.datasets.Kitti(
    root="./data", train=False, transform=all_transforms, download=True
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=True
)

sample_idx = torch.randint(len(train_dataset), size=(1,)).item()
img, labels = train_dataset[sample_idx]
pred = model([img])

#boxes = [label["boxes"] for label in labels]
#boxes = torch.tensor(boxes)
boxes = pred[0]["boxes"]
img = (255 * img).to(torch.uint8)

print(boxes, pred[0]['scores'])

score_threshold = .53
dogs_with_boxes = draw_bounding_boxes(img, boxes=boxes[pred[0]['scores'] > score_threshold], width=4)
dogs_with_boxes = torchvision.transforms.ToPILImage()(dogs_with_boxes)
dogs_with_boxes.show()

img = draw_bounding_boxes(
    img, boxes, width=5, colors="green", fill=False
)
img = torchvision.transforms.ToPILImage()(img)
img.show()

tensor([[8.4520e+01, 5.4986e-01, 2.0325e+02, 5.3996e+01],
        [1.0615e+02, 4.0572e+01, 1.3134e+02, 8.0014e+01],
        [1.9400e+02, 1.6337e+02, 5.2052e+02, 2.9013e+02],
        [1.0000e+02, 1.2962e-01, 1.4296e+02, 5.9150e+01],
        [8.7342e-01, 2.5173e+02, 1.0467e+02, 3.3650e+02],
        [3.2612e+02, 2.0096e+02, 3.7103e+02, 2.6013e+02],
        [7.2778e+01, 5.3747e+01, 9.6161e+01, 8.6613e+01],
        [2.6633e+02, 1.0475e+02, 6.3786e+02, 3.2472e+02],
        [1.0686e+02, 6.4214e+01, 1.2821e+02, 1.0165e+02],
        [7.4210e+01, 1.1022e+01, 1.9231e+02, 8.4872e+01],
        [5.8113e-01, 2.1905e+02, 5.4160e+01, 3.7213e+02],
        [7.3252e-01, 1.1208e+02, 1.5157e+02, 3.7110e+02],
        [3.3499e+01, 2.6042e+02, 1.5693e+02, 3.4248e+02],
        [3.6281e+01, 2.6644e+02, 7.9418e+01, 3.2292e+02],
        [7.3467e+01, 3.1986e+01, 1.2257e+02, 9.6408e+01],
        [2.8327e+02, 0.0000e+00, 4.3762e+02, 3.2873e+02],
        [5.1800e+01, 2.4334e+02, 1.5309e+02, 3.1402e+02],
        [1.680