In [None]:
import os
import torch

from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
from torchvision.ops import box_convert
from einops import rearrange

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

# DataLoader

In [None]:
from torch.utils.data import DataLoader
from torchvision.transforms import v2

In [None]:
# from datasets import load_dataset
# train_dataset = load_dataset(import os
import torch
import torchvision.transforms.v2
from torch.utils.data.dataset import Dataset
import xml.etree.ElementTree as ET
from torchvision import tv_tensors
from torchvision.io import read_image


def load_images_and_anns(im_sets, label2idx, ann_fname, split):
    r"""
    Method to get the xml files and for each file
    get all the objects and their ground truth detection
    information for the dataset
    :param im_sets: Sets of images to consider
    :param label2idx: Class Name to index mapping for dataset
    :param ann_fname: txt file containing image names{trainval.txt/test.txt}
    :param split: train/test
    :return:
    """
    im_infos = []
    for im_set in im_sets:
        im_names = []
        # Fetch all image names in txt file for this imageset
        for line in open(os.path.join(
                im_set, 'ImageSets', 'Main', '{}.txt'.format(ann_fname))):
            im_names.append(line.strip())

        # Set annotation and image path
        ann_dir = os.path.join(im_set, 'Annotations')
        im_dir = os.path.join(im_set, 'JPEGImages')

        for im_name in im_names:
            ann_file = os.path.join(ann_dir, '{}.xml'.format(im_name))
            im_info = {}
            ann_info = ET.parse(ann_file)
            root = ann_info.getroot()
            size = root.find('size')
            width = int(size.find('width').text)
            height = int(size.find('height').text)
            im_info['img_id'] = os.path.basename(ann_file).split('.xml')[0]
            im_info['filename'] = os.path.join(
                im_dir, '{}.jpg'.format(im_info['img_id'])
            )
            im_info['width'] = width
            im_info['height'] = height
            detections = []
            for obj in ann_info.findall('object'):
                det = {}
                label = label2idx[obj.find('name').text]
                difficult = int(obj.find('difficult').text)
                bbox_info = obj.find('bndbox')
                bbox = [
                    int(bbox_info.find('xmin').text) - 1,
                    int(bbox_info.find('ymin').text) - 1,
                    int(bbox_info.find('xmax').text) - 1,
                    int(bbox_info.find('ymax').text) - 1
                ]
                det['label'] = label
                det['bbox'] = bbox
                det['difficult'] = difficult
                detections.append(det)
            im_info['detections'] = detections
            # Because we are using 25 as num_queries,
            # so we ignore all images in VOC with greater
            # than 25 target objects.
            # This is okay, since this just means we are
            # ignoring a small number of images(15 to be precise)
            if len(detections) <= 25:
                im_infos.append(im_info)
    print('Total {} images found'.format(len(im_infos)))
    return im_infos


class VOCDataset(Dataset):
    def __init__(self, split, im_sets, im_size=640):
        self.split = split

        # Imagesets for this dataset instance (VOC2007/VOC2007+VOC2012/VOC2007-test)
        self.im_sets = im_sets
        self.fname = 'trainval' if self.split == 'train' else 'test'
        self.im_size = im_size
        self.im_mean = [123.0, 117.0, 104.0]
        self.imagenet_mean = [0.485, 0.456, 0.406]
        self.imagenet_std = [0.229, 0.224, 0.225]

        # Train and test transformations
        self.transforms = {
            'train': torchvision.transforms.v2.Compose([
                torchvision.transforms.v2.RandomHorizontalFlip(p=0.5),
                torchvision.transforms.v2.RandomZoomOut(fill=self.im_mean),
                torchvision.transforms.v2.RandomIoUCrop(),
                torchvision.transforms.v2.RandomPhotometricDistort(),
                torchvision.transforms.v2.Resize(size=(self.im_size, self.im_size)),
                torchvision.transforms.v2.SanitizeBoundingBoxes(
                    labels_getter=lambda transform_input:
                    (transform_input[1]["labels"], transform_input[1]["difficult"])),
                torchvision.transforms.v2.ToPureTensor(),
                torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
                torchvision.transforms.v2.Normalize(mean=self.imagenet_mean,
                                                    std=self.imagenet_std)

            ]),
            'test': torchvision.transforms.v2.Compose([
                torchvision.transforms.v2.Resize(size=(self.im_size, self.im_size)),
                torchvision.transforms.v2.ToPureTensor(),
                torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
                torchvision.transforms.v2.Normalize(mean=self.imagenet_mean,
                                                    std=self.imagenet_std)
            ]),
        }

        classes = [
            'person', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep',
            'aeroplane', 'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train',
            'bottle', 'chair', 'diningtable', 'pottedplant', 'sofa', 'tvmonitor'
        ]
        classes = sorted(classes)
        # We need to add background class as well with 0 index
        classes = ['background'] + classes

        self.label2idx = {classes[idx]: idx for idx in range(len(classes))}
        self.idx2label = {idx: classes[idx] for idx in range(len(classes))}
        print(self.idx2label)
        self.images_info = load_images_and_anns(self.im_sets,
                                                self.label2idx,
                                                self.fname,
                                                self.split)

    def __len__(self):
        return len(self.images_info)

    def __getitem__(self, index):
        im_info = self.images_info[index]
        im = read_image(im_info['filename'])

        # Get annotations for this image
        targets = {}
        targets['boxes'] = tv_tensors.BoundingBoxes(
            [detection['bbox'] for detection in im_info['detections']],
            format='XYXY', canvas_size=im.shape[-2:])
        targets['labels'] = torch.as_tensor(
            [detection['label'] for detection in im_info['detections']])
        targets['difficult'] = torch.as_tensor(
            [detection['difficult']for detection in im_info['detections']])

        # Transform the image and targets
        transformed_info = self.transforms["test"](im, targets)
        im_tensor, targets = transformed_info

        h, w = im_tensor.shape[-2:]

        # Boxes returned are in x1y1x2y2 format normalized from 0-1
        wh_tensor = torch.as_tensor([[w, h, w, h]]).expand_as(targets['boxes'])
        return im_tensor, targets, im_info['filename']
# test_dataset = load_dataset("detection-datasets/coco", split='val', streaming=True)

In [None]:
# dataset

In [None]:
import os
import torch
import torchvision.transforms.v2
from torch.utils.data.dataset import Dataset
import xml.etree.ElementTree as ET
from torchvision import tv_tensors
from torchvision.io import read_image


def load_images_and_anns(im_sets, label2idx, ann_fname, split):
    r"""
    Method to get the xml files and for each file
    get all the objects and their ground truth detection
    information for the dataset
    :param im_sets: Sets of images to consider
    :param label2idx: Class Name to index mapping for dataset
    :param ann_fname: txt file containing image names{trainval.txt/test.txt}
    :param split: train/test
    :return:
    """
    im_infos = []
    for im_set in im_sets:
        im_names = []
        # Fetch all image names in txt file for this imageset
        for line in open(os.path.join(
                im_set, 'ImageSets', 'Main', '{}.txt'.format(ann_fname))):
            im_names.append(line.strip())

        # Set annotation and image path
        ann_dir = os.path.join(im_set, 'Annotations')
        im_dir = os.path.join(im_set, 'JPEGImages')

        for im_name in im_names:
            ann_file = os.path.join(ann_dir, '{}.xml'.format(im_name))
            im_info = {}
            ann_info = ET.parse(ann_file)
            root = ann_info.getroot()
            size = root.find('size')
            width = int(size.find('width').text)
            height = int(size.find('height').text)
            im_info['img_id'] = os.path.basename(ann_file).split('.xml')[0]
            im_info['filename'] = os.path.join(
                im_dir, '{}.jpg'.format(im_info['img_id'])
            )
            im_info['width'] = width
            im_info['height'] = height
            detections = []
            for obj in ann_info.findall('object'):
                det = {}
                label = label2idx[obj.find('name').text]
                difficult = int(obj.find('difficult').text)
                bbox_info = obj.find('bndbox')
                bbox = [
                    int(bbox_info.find('xmin').text) - 1,
                    int(bbox_info.find('ymin').text) - 1,
                    int(bbox_info.find('xmax').text) - 1,
                    int(bbox_info.find('ymax').text) - 1
                ]
                det['label'] = label
                det['bbox'] = bbox
                det['difficult'] = difficult
                detections.append(det)
            im_info['detections'] = detections
            # Because we are using 25 as num_queries,
            # so we ignore all images in VOC with greater
            # than 25 target objects.
            # This is okay, since this just means we are
            # ignoring a small number of images(15 to be precise)
            if len(detections) <= 25:
                im_infos.append(im_info)
    print('Total {} images found'.format(len(im_infos)))
    return im_infos


class VOCDataset(Dataset):
    def __init__(self, split, im_sets, im_size=640):
        self.split = split

        # Imagesets for this dataset instance (VOC2007/VOC2007+VOC2012/VOC2007-test)
        self.im_sets = im_sets
        self.fname = 'trainval' if self.split == 'train' else 'test'
        self.im_size = im_size
        self.im_mean = [123.0, 117.0, 104.0]
        self.imagenet_mean = [0.485, 0.456, 0.406]
        self.imagenet_std = [0.229, 0.224, 0.225]

        # Train and test transformations
        self.transforms = {
            'train': torchvision.transforms.v2.Compose([
                torchvision.transforms.v2.RandomHorizontalFlip(p=0.5),
                torchvision.transforms.v2.RandomZoomOut(fill=self.im_mean),
                torchvision.transforms.v2.RandomIoUCrop(),
                torchvision.transforms.v2.RandomPhotometricDistort(),
                torchvision.transforms.v2.Resize(size=(self.im_size, self.im_size)),
                torchvision.transforms.v2.SanitizeBoundingBoxes(
                    labels_getter=lambda transform_input:
                    (transform_input[1]["labels"], transform_input[1]["difficult"])),
                torchvision.transforms.v2.ToPureTensor(),
                torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
                torchvision.transforms.v2.Normalize(mean=self.imagenet_mean,
                                                    std=self.imagenet_std)

            ]),
            'test': torchvision.transforms.v2.Compose([
                torchvision.transforms.v2.Resize(size=(self.im_size, self.im_size)),
                torchvision.transforms.v2.ToPureTensor(),
                torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
                torchvision.transforms.v2.Normalize(mean=self.imagenet_mean,
                                                    std=self.imagenet_std)
            ]),
        }

        classes = [
            'person', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep',
            'aeroplane', 'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train',
            'bottle', 'chair', 'diningtable', 'pottedplant', 'sofa', 'tvmonitor'
        ]
        classes = sorted(classes)
        # We need to add background class as well with 0 index
        classes = classes + ['background']

        self.label2idx = {classes[idx]: idx for idx in range(len(classes))}
        self.idx2label = {idx: classes[idx] for idx in range(len(classes))}
        print(self.idx2label)
        self.images_info = load_images_and_anns(self.im_sets,
                                                self.label2idx,
                                                self.fname,
                                                self.split)

    def __len__(self):
        return len(self.images_info)

    def __getitem__(self, index):
        im_info = self.images_info[index]
        im = read_image(im_info['filename'])

        # Get annotations for this image
        targets = {}
        targets['boxes'] = tv_tensors.BoundingBoxes(
            [detection['bbox'] for detection in im_info['detections']],
            format='XYXY', canvas_size=im.shape[-2:])
        targets['labels'] = torch.as_tensor(
            [detection['label'] for detection in im_info['detections']])
        targets['difficult'] = torch.as_tensor(
            [detection['difficult']for detection in im_info['detections']])

        # Transform the image and targets
        transformed_info = self.transforms["test"](im, targets)
        im_tensor, targets = transformed_info

        h, w = im_tensor.shape[-2:]

        # Boxes returned are in x1y1x2y2 format normalized from 0-1
        wh_tensor = torch.as_tensor([[w, h, w, h]]).expand_as(targets['boxes'])
        targets["boxes"] = targets["boxes"] / wh_tensor
        return im_tensor, targets, im_info['filename']

In [None]:
transforms = v2.Compose([
    # v2.RGB(),
    # v2.ToImage(),  # Convert to tensor, only needed if you had a PIL image
    # v2.PILToTensor(),
    # v2.ToDtype(torch.uint8, scale=True),  # optional, most input are already uint8 at this point
    # ...
    # v2.Resize(size=(320, 320), antialias=True),  # Or Resize(antialias=True)
    # ...
    v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# def preprocess_img(batch):
#     img = batch['image']
#     img = transforms(img)
#     w = batch['width']
#     h = batch['height']
#     obj = batch['objects']
#     boxes_xywh = torch.tensor(obj['bbox'])
#     boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
#     boxes_xyxy[:, [0, 2]] = boxes_xyxy[:, [0, 2]] / w
#     boxes_xyxy[:, [1, 3]] = boxes_xyxy[:, [1, 3]] / h
#     boxes_xyxy = torch.clamp(boxes_xyxy, 0, 1)
#     batch['new_image'] = img
#     batch['boxes_xyxy'] = boxes_xyxy
#     batch['labels'] = torch.tensor(obj["category"])
#     obj_mask = torch.zeros(320, 320)
#     for obx in boxes_xyxy:
#         x1, y1, x2, y2 = tuple(torch.round(obx * 320).to(torch.int32))
#         # print(x1, y1, x2, y2)
#         obj_mask[x1:x2, y1:y2] = 1
#     batch['mask'] = obj_mask
#     batch['image'] = -1
#     return batch

In [None]:
# import random

# # train_ds = dataset['train'].shuffle().select(range(0, 20000))
# train_ds = list(train_dataset.take(10000))
# # test_ds = dataset['val'].shuffle().select(range(0, 1000))
# test_ds = list(test_dataset.take(1000))
# len(train_ds), len(test_ds)

In [None]:
from torchvision import tv_tensors
def preprocess_ds(batch):
    img = batch[0]
    w = img.shape[1]
    h = img.shape[2]
    obj = batch[1]
    boxes_xyxy = obj['boxes'].to(torch.float32)
    output = {}
    output['new_image'] = img
    output['boxes_xyxy'] = tv_tensors.BoundingBoxes(
        boxes_xyxy,
        format="XYXY", canvas_size=(1, 1)
    )
    output['labels'] = obj["labels"]
    obj_mask = torch.zeros(320, 320)
    for obx in boxes_xyxy:
        x1, y1, x2, y2 = tuple(torch.round(obx * 320).to(torch.int32))
        # print(x1, y1, x2, y2)
        obj_mask[x1:x2, y1:y2] = 1
    output['mask'] = obj_mask
    return output

In [None]:
import torch
import argparse
import os
import numpy as np
import yaml
import random
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('Using mps')


def collate_function(data):
    return tuple(zip(*data))


voc = VOCDataset('train',
                 im_sets=["/kaggle/input/pascal-voc-2007/VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007"],
                 im_size=320)
train_ds = []
for i in tqdm(range(len(voc))):
    train_ds.append(preprocess_ds(voc[i]))
voc2 = VOCDataset('test',
                 im_sets=["/kaggle/input/pascal-voc-2007/VOCtest_06-Nov-2007/VOCdevkit/VOC2007"],
                 im_size=320)
test_ds = []
for i in tqdm(range(len(voc2))):
    test_ds.append(preprocess_ds(voc2[i]))

In [None]:
train_ds[0]

In [None]:
len(train_ds), len(test_ds)

In [None]:
train_augs = v2.Compose([
    v2.RandomRotation(30),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ClampBoundingBoxes(),
])

In [None]:
# Random Resize
r_resize_list = []
for i in range(-2, 3):
    for j in range(-2, 3):
        r_resize_list.append(v2.Resize((320 + i * 32, 320 + j * 32)))
len(r_resize_list)

In [None]:
r_resize_list[12]

In [None]:
r_resize_list[1].size

In [None]:
import random

def collate_fn(in_batch):
    images = []
    targets = []
    # if random.random() < 0.1:
    #     idx = random.randint(0, 24)
    # else:
    #     idx = 12
    for batch in in_batch:
        img = batch['new_image']
        lbl = batch['labels']
        bbox = batch['boxes_xyxy']
        mask = batch['mask']
        img, bbox = train_augs(img, bbox)
        # img, bbox = r_resize_list[idx](img, bbox)
        # bbox[:, [0, 2]] = bbox[:, [0, 2]] / r_resize_list[idx].size[0]
        # bbox[:, [1, 3]] = bbox[:, [1, 3]] / r_resize_list[idx].size[1]
        images.append(img)
        temp = {
            "labels": lbl,
            "boxes": bbox,
            "mask": mask
        }
        targets.append(temp)
    return torch.stack(images), targets

In [None]:
train_loader = DataLoader(
    train_ds,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
)
test_loader = DataLoader(
    test_ds,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
)

In [None]:
%%time
for batch in train_loader:
    print(batch[0].shape)
    print(batch[1][0])
    break

In [None]:
import cProfile

In [None]:
# cProfile.run(
# """
# for batch in train_loader:
#     print(batch[0].shape)
#     print(batch[1][0])
#     break
# """
# , sort='tottime')

# Naive DETR Model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.path1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride)
        self.path2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
            nn.GELU(),
            nn.GroupNorm(1, out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride),
        )

    def forward(self, x):
        x1 = self.path1(x)
        x2 = self.path2(x)
        return x1 + x2

In [None]:
class SimpleBackBone(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.resnet1 = ResNetBlock(in_channels, 16, 2)
        self.resnet2 = ResNetBlock(16, 32, 2)
        self.resnet3 = ResNetBlock(32, 64, 2)
        self.resnet4 = ResNetBlock(64, out_channels, 2)
        self.whole_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=16) # monkeyed yes

    def forward(self, x):
        x1 = self.resnet1(x)
        x1 = self.resnet2(x1)
        x1 = self.resnet3(x1)
        x1 = self.resnet4(x1)
        x = x1 + self.whole_skip(x)
        return x

In [None]:
class BBoxHead(nn.Module):
    # Just alot of MLP layers
    # Adding Residual because I think that's why it's hard to train
    def __init__(self, in_shape, out_shape, middle_dim=256, num_layers=4, act_fn=nn.GELU):
        super().__init__()
        self.in_proj = nn.Linear(in_shape, middle_dim)
        self.out_proj = nn.Linear(middle_dim, out_shape)
        self.middle = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(middle_dim, middle_dim),
                    act_fn(),
                )
             for i in range(num_layers - 2)
            ]
        )
        self.act_fn = act_fn

    def forward(self, x):
        x = self.act_fn()(self.in_proj(x))
        for layer in self.middle:
            x = layer(x) + x
        return nn.Sigmoid()(self.out_proj(x)) # bounding box is scaled between 0 and 1

In [None]:
import torchvision.ops as ops

class DeformConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = ops.DeformConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.offset = nn.Sequential(
            nn.Conv2d(in_channels, 2 * 1 * 3 * 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
        )
        self.offset_scale = nn.Parameter(torch.tensor(3.0))
    def forward(self, x):
        pred_offset = self.offset(x) * self.offset_scale
        x = self.conv(x, pred_offset)
        x = nn.GELU()(x)
        return x

class Projection(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.dconv1 = DeformConvBlock(in_channels, out_channels)
        self.dconv2 = DeformConvBlock(out_channels, out_channels)
        self.dconv3 = DeformConvBlock(out_channels, out_channels)
    def forward(self, x):
        x = self.dconv1(x)
        x = self.dconv2(x) + x
        x = self.dconv3(x) + x
        return x

In [None]:
from torchvision.models import *

class SimpleDETR(nn.Module):
    def __init__(self, hidden_dim, ffn_dim, n_heads, num_layers, num_classes):
        super().__init__()

        # self.backbone = SimpleBackBone(3, 64)
        self.backbone = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.backbone.avgpool = nn.Identity()
        self.backbone.fc = nn.Identity()

        self.proj = Projection(512, hidden_dim)

        # self.proj = nn.Sequential(
        #     nn.Conv2d(512, 128, kernel_size=3, padding=1, stride=1),
        #     nn.GELU(),
        #     nn.Conv2d(128, hidden_dim, kernel_size=1, padding=0, stride=1),
        #     nn.GELU(),
        # )
        sine_embed = torch.zeros(hidden_dim, 32, 32)
        sine_embed = sine_embed + torch.arange(32).expand(hidden_dim, 32, 32)
        sine_embed = sine_embed + torch.arange(32).expand(hidden_dim, 32, 32).permute(0, 2, 1)
        sine_embed = sine_embed + torch.arange(hidden_dim).expand(32, 32, hidden_dim).permute(2, 0, 1)
        sine_embed = torch.sin(sine_embed).unsqueeze(0)
        self.pos_embed = nn.Parameter(sine_embed) # 32 is max image width/height after backbone
        self.pos_embed_weight = nn.Parameter(torch.tensor(0.1))
        self.transformer = nn.Transformer(
            hidden_dim, 
            n_heads, 
            num_layers[0], 
            num_layers[1], 
            ffn_dim, 
            activation = "gelu",
            dropout=0.1,
            batch_first=True
        )

        self.query_1 = nn.Sequential(
            nn.Conv1d(64, hidden_dim, kernel_size=1, stride=1, padding=0),
            nn.GELU(),
            nn.Linear(80 * 80, 10)
        )
        self.query_2 = nn.Sequential(
            nn.Conv1d(128, hidden_dim, kernel_size=1, stride=1, padding=0),
            nn.GELU(),
            nn.Linear(40 * 40, 10)
        )
        self.query_3 = nn.Sequential(
            nn.Conv1d(256, hidden_dim, kernel_size=1, stride=1, padding=0),
            nn.GELU(),
            nn.Linear(20 * 20, 10)
        )
        self.box_query = nn.Linear(10 * 10, 10)
        self.box_embed = nn.Parameter(torch.rand(1, 10, hidden_dim))
        self.query_weights = nn.Parameter(torch.tensor([0.05, 0.05, 0.2, 1.0, 0.1]))

        self.fc_class = nn.Linear(hidden_dim, num_classes + 1) # +1 for the <no object> class
        self.fc_bbox = BBoxHead(hidden_dim, 4, num_layers=3)

        self.impt_pts = nn.Sequential(
            nn.Conv2d(hidden_dim, 128, kernel_size=3, padding=1, stride=1),
            nn.GELU(),
            nn.Conv2d(128, 2, kernel_size=1, padding=0, stride=1),
            nn.Softmax(dim=1),
        ) # custom auxiliary loss to predict whether a point is important

        self.bbox_sizes = BBoxHead(hidden_dim, 2, num_layers=3) # custom auxiliary loss to predict size of bboxes
        # hopefully help with small ones
        
    def forward(self, x):
        batch_size = x.shape[0]
        w = x.shape[2]
        h = x.shape[3]
        
        # BACKBONE PORTION
        # x = self.backbone(x)
        x = self.backbone.relu(self.backbone.bn1(self.backbone.conv1(x)))
        x = self.backbone.maxpool(x)
        # print(x.shape)
        # 80 by 80
        x1 = self.backbone.layer1(x)
        # print(x1.shape)
        # 80 by 80
        x2 = self.backbone.layer2(x1)
        # print(x2.shape)
        # 40 by 40
        x3 = self.backbone.layer3(x2)
        # print(x3.shape)
        # 20 by 20
        x = self.backbone.layer4(x3)
        # print(x.shape)
        # END BACKBONE PORTION
        
        x = x.reshape(batch_size, 512, w//32, h//32)
        x = self.proj(x)
        batch_size = x.shape[0]
        hidden = x.shape[1]
        w = x.shape[2]
        h = x.shape[3]
        x = x + self.pos_embed[:, :, :w, :h] * self.pos_embed_weight
        # [batch size, hidden_dim, width, height]
        impt_pixels = self.impt_pts(x)
        # [batch size, 2, width, height]
        x = x.reshape(batch_size, hidden, -1)
        seq_len = x.shape[2]
        # [batch size, hidden_dim, width * height]
        # Treat it like [batch size, hidden_dim, sequence len] WILL PERMUTE LATER

        box_Q = self.query_weights[3] * self.box_query(x).permute(0, 2, 1)
        # [batch size, 100, hidden_dim]
        box_Q = box_Q + self.query_weights[4] * self.box_embed
        # [batch size, 100, hidden_dim]
        # x1 = model.backbone.layer1[1](x1)
        x1 = rearrange(x1, 'b c w h -> b c (w h)')
        box_Q = box_Q + self.query_weights[0] * self.query_1(x1).permute(0, 2, 1)
        # x2 = model.backbone.layer2[1](x2)
        x2 = rearrange(x2, 'b c w h -> b c (w h)')
        box_Q = box_Q + self.query_weights[1] * self.query_2(x2).permute(0, 2, 1)
        # x3 = model.backbone.layer3[1](x3)
        x3 = rearrange(x3, 'b c w h -> b c (w h)')
        box_Q = box_Q + self.query_weights[2] * self.query_3(x3).permute(0, 2, 1)
        # [batch size, 100, hidden_dim]
        x = x.permute(0, 2, 1)
        # [batch size, width * height, hidden_dim] # w * h = seq len
        x = self.transformer(x, box_Q) # encoder input and decoder input
        # [batch size, 100, hidden_dim] comes from box_Q aka decoder input length

        class_pred = self.fc_class(x)
        # [batch size, 100, num classes + 1]
        box_pred = self.fc_bbox(x)
        # [batch size, 100, 4]
        bbox_size = self.bbox_sizes(x)
        # [batch size, 100, 2]
        return class_pred, box_pred, impt_pixels, bbox_size

# Loss Computation

In [None]:
class_loss_weights = torch.ones(20 + 1).to(device)

In [None]:
from tqdm import tqdm

In [None]:
# cnt = 0
# for _, target in tqdm(train_loader):
#     cnt += 1
#     if cnt >= 100:
#         break
#     for dictionary in target:
#         class_loss_weights[dictionary['labels']] += 1

In [None]:
class_loss_weights[-1] = torch.sum(class_loss_weights) / 20 

In [None]:
class_loss_weights = 1 / class_loss_weights
class_loss_weights[-1] = class_loss_weights[-1] / 20
class_loss_weights = class_loss_weights / torch.sum(class_loss_weights)
class_loss_weights

In [None]:
from scipy.optimize import linear_sum_assignment
# just import package to do bipartite matching
from torchvision.ops import generalized_box_iou, box_iou, generalized_box_iou_loss, box_convert

class_criteria = nn.CrossEntropyLoss(reduction="none", weight=None, label_smoothing=0.05)
no_labels_target_id = 20

pixel_criteria = nn.CrossEntropyLoss()

bbox_size_criteria = nn.SmoothL1Loss()

def match_bbox_class(
    class_pred_in, 
    box_pred_in, 
    pred_pixels, 
    bbox_size_in,
    target_classes_in, 
    target_bbox_in, 
    target_pixels, 
    weights=[2, 5, 0]
):
    # class_pred: [batch size, 100, num_classes + 1]
    # box_pred: [batch size, 100, 4]
    # target_classes: not batched
    # target_bbox: not batched
    batch_size, num_bb, num_class = class_pred_in.shape
    # num_class already contain the +1

    # Stores the total loss
    final_loss = 0
    # Stores total IoU
    total_iou = 0

    # Stores each type of loss
    total_class_loss = 0
    total_bbox_loss = 0
    total_no_label_loss = 0
    total_mask_loss = 0
    total_size_loss = 0

    # For-loop over batch size dimension because input can't really be batched efficiently
    for i in range(batch_size):
        class_pred = class_pred_in[i]
        box_pred = box_pred_in[i]
        bbox_size = bbox_size_in[i]
        target_classes = target_classes_in[i]
        target_bbox = target_bbox_in[i]
        target_len = len(target_classes)
        
        # CLASS LOSS
        class_pred = class_pred.unsqueeze(1).expand(-1, target_len, -1)
        # class_pred: [100, target_len, num_classes + 1]
        target_classes = target_classes.unsqueeze(0).expand(num_bb, -1) 
        # note that this is transposed relative to above to ensure the correct targets lines up
        # and we actually get proper [i, j] pairs
        # might be more useful to imagine the first 100 as N, and second 100 as M
        # target_classes: [100, target_len]
        
        class_loss = class_criteria(class_pred.reshape(-1, num_class), target_classes.reshape(-1))
        class_loss = class_loss.reshape(num_bb, target_len)
        # class_loss: [100, target_len]
    
        # BBOX L1 LOSS
        bbox_loss = 1 * torch.cdist(box_pred, target_bbox, p=1.0) 
        # bbox_loss = bbox_loss + 10 * torch.square(torch.cdist(box_pred, target_bbox, p=2.0))
        # bbox_loss: [100, target_len]
    
        # BBOX G-IOU LOSS
        giou_loss = 1 - generalized_box_iou(
            box_pred, 
            target_bbox
        )
        # giou_loss: [100, target_len]
        # assert torch.min(giou_loss) >= -20.0
    
        # Total Cost
        total_cost = (weights[0] * class_loss + weights[1] * bbox_loss + weights[2] * giou_loss)
        # Positive because linear_sum_assignment minimizers total weight by default
        # Compute the bipartite matching problem and final loss
        row_idx, col_idx = linear_sum_assignment(total_cost.detach().cpu().numpy())
        total_cost = (weights[0] * class_loss + weights[1] * bbox_loss + weights[2] * giou_loss) / (weights[0] + weights[1] + weights[2])
        # giou_loss = generalized_box_iou_loss(
        #     box_pred[row_idx], 
        #     target_bbox[col_idx],
        #     reduction="mean",
        # )
        final_loss += (torch.mean(total_cost[row_idx, col_idx]))
        total_class_loss += torch.mean(weights[0] * class_loss[row_idx, col_idx]) / (weights[0] + weights[1])
        total_bbox_loss += torch.mean(weights[1] * bbox_loss[row_idx, col_idx]) / (weights[0] + weights[1])

        # Force the non-matched classes to be "no labels" class
        mask = torch.ones(num_bb)
        mask[row_idx] = mask[row_idx] - 1
        mask = mask > 0.5 # Select the indicies not matched
        if torch.sum(mask) > 0:
            bad_class = class_pred[mask, :, :].reshape(-1, num_class)
            # reshape to allow loss computation
            no_label_target = torch.ones(bad_class.shape[0], device=device, dtype=torch.long) * no_labels_target_id
            # construct the target that forces everything to "no label" class
            final_loss += class_criteria(bad_class, no_label_target).mean()
            total_no_label_loss += class_criteria(bad_class, no_label_target).mean()

        # # Pixel Loss (similar loss to Segmentation)
        # target_mask = target_pixels[i]
        # # [320, 320]
        # target_mask = target_mask[::32, ::32]
        # # print(target_mask.shape)
        # # [10, 10]
        # pred_mask = pred_pixels[i]
        # # print(pred_mask.shape)
        # # [10, 10, 2]
        # mask_loss = pixel_criteria(pred_mask.reshape(-1, 2), target_mask.reshape(-1))
        # final_loss += 0.02 * mask_loss
        # total_mask_loss += 0.02 * mask_loss

        # BBox sizes loss (hopefully help with smaller objects)
        target_bbox_size = target_bbox[:, [2, 3]] - target_bbox[:, [0, 1]]
        size_loss = bbox_size_criteria(bbox_size[row_idx].reshape(-1), target_bbox_size[col_idx].reshape(-1)) * 10
        final_loss += size_loss
        total_size_loss += size_loss
        
        # Compute IoU
        iou_matrix = box_iou(box_pred, target_bbox)
        total_iou += torch.mean(iou_matrix[row_idx, col_idx])

        # In addition, to encourage matching, we also add all those with IoU > 0.6 or IoU < 0.3
        iou_mask = (iou_matrix > 0.6) | (iou_matrix < 0.3)
        final_loss += torch.mean(iou_matrix[iou_mask]) * 0.5
    exact_losses = torch.tensor([total_class_loss, total_bbox_loss, total_no_label_loss, total_mask_loss, total_size_loss]) / batch_size
    return final_loss, total_iou / batch_size, exact_losses

# Model, Optimizers, Schedulers

In [None]:
model = SimpleDETR(
    32,
    32,
    2,
    (1, 1),
    20
).to(device)

In [None]:
sum(p.numel() for p in model.parameters())

In [None]:
# model

In [None]:
# optimiser = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
optimiser = optim.AdamW(
    [
        {'params': model.backbone.parameters(), 'lr': 1e-5},
        {'params': model.proj.parameters()},
        {'params': model.pos_embed, 'lr': 1e-5},
        {'params': model.pos_embed_weight},
        {'params': model.transformer.parameters()},
        {'params': model.query_1.parameters()},
        {'params': model.query_2.parameters()},
        {'params': model.query_3.parameters()},
        {'params': model.box_query.parameters()},
        {'params': model.box_embed},
        {'params': model.query_weights},
        {'params': model.fc_class.parameters()},
        {'params': model.fc_bbox.parameters()},
        {'params': model.impt_pts.parameters()},
        {'params': model.bbox_sizes.parameters()},
    ], lr=2e-4, weight_decay=1e-4
)
# optimiser = optim.LBFGS(model.parameters(), lr=1e-0, line_search_fn="strong_wolfe", history_size=10)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=10, eta_min=1e-4)

# Training and Evaluation Loops

In [None]:
from tqdm import tqdm

In [None]:
from torchmetrics.detection.mean_ap import MeanAveragePrecision

In [None]:
# for batch in (pbar := tqdm(train_loader)):
#     image, targets = batch
#     # image = image.to(device)        
#     all_labels = []
#     all_boxes  = []
#     all_masks = []
#     for t in targets:
#         t["boxes"] = t["boxes"].resources
#         t["boxes"] = torch.cat([t["boxes"], t["boxes"] * (1 + torch.randn(1) / 100)], dim=0)
#         t["boxes"] = torch.cat([t["boxes"], t["boxes"] * (1 + torch.randn(1) / 100)], dim=0)
#         t["boxes"] = torch.cat([t["boxes"], t["boxes"] * (1 + torch.randn(1) / 100)], dim=0)
#         t["boxes"] = torch.cat([t["boxes"], t["boxes"] * (1 + torch.randn(1) / 100)], dim=0)
#         print(t["boxes"])
#         1/0

In [None]:
def train(weights=[2, 5, 1]):
    model.train()
    total_loss = 0.0
    cnt = 0
    # Stores each type of loss
    exact_loss = torch.tensor([0, 0, 0, 0, 0]).to(torch.float)
    for batch in (pbar := tqdm(train_loader)):
        image, targets = batch
        image = image.to(device)        
        all_labels = []
        all_boxes  = []
        all_masks = []
        for t in targets:
            t["boxes"] = t["boxes"].data
            # t["boxes"] = torch.cat([t["boxes"], t["boxes"] * (1 + torch.randn(1) / 100)], dim=0)
            # t["boxes"] = torch.cat([t["boxes"], t["boxes"] * (1 + torch.randn(1) / 100)], dim=0)
            # t["boxes"] = torch.cat([t["boxes"], t["boxes"] * (1 + torch.randn(1) / 100)], dim=0)
            # t["labels"] = torch.cat([t["labels"], t["labels"]])
            # t["labels"] = torch.cat([t["labels"], t["labels"]])
            # t["labels"] = torch.cat([t["labels"], t["labels"]])
            all_labels.append(t["labels"].to(device))
            all_boxes.append(t["boxes"].to(device))
            all_masks.append(t["mask"].to(torch.long).to(device))
        
        class_pred, box_pred, mask_pred, size_pred = model(image)
        
        loss, iou, exact_losses = match_bbox_class(class_pred, box_pred, mask_pred, size_pred, all_labels, all_boxes, all_masks, weights)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        total_loss += loss.item()
        exact_loss += exact_losses
        cnt += 1
        pbar.set_description(f"Average Loss: {total_loss / cnt:6f} | Average IoU: {iou:6f}")

    exact_loss = exact_loss / cnt
    loss_type_dict = {
        'total_class_loss': exact_loss[0],
        'total_bbox_loss': exact_loss[1],
        'total_no_label_loss': exact_loss[2],
        'total_mask_loss': exact_loss[3],
        'total_size_loss': exact_loss[4],
    }
    print(loss_type_dict)

In [None]:
import torchvision.ops as ops

def train_map():
    model.eval()
    total_loss = 0.0
    cnt = 0
    # 1) Instantiate once, before training
    map_metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox").to(device)
    with torch.no_grad():
        for batch in (pbar := tqdm(train_loader)):
            image, targets = batch
            image = image.to(device)
            all_labels = []
            all_boxes  = []
            all_masks = []
            for t in targets:
                all_labels.append(t["labels"].to(device))
                all_boxes.append(t["boxes"].to(device))
                all_masks.append(t["mask"].to(torch.long).to(device))
            
            class_pred, box_pred, mask_pred, size_pred = model(image)

            # AP Metrics Computation
            for i in range(image.shape[0]):
                has_labels = (torch.argmax(class_pred[i], dim=-1) != 20)
                image_size = (box_pred[i][:, 2] - box_pred[i][:, 0]) * (box_pred[i][:, 3] - box_pred[i][:, 1])
                image_size = image_size * 320 * 320
                size_mask = image_size > 4
                final_mask = torch.logical_and(has_labels, size_mask)
                boxes = box_pred[i][final_mask] * 320
                scores = torch.max(nn.Softmax(dim=-1)(class_pred[i]), dim=-1)[0][final_mask]
                labels = torch.argmax(class_pred[i], dim=-1)[final_mask]
                idx = ops.nms(boxes, scores, 0.5)
                preds = [{
                    "boxes": boxes[idx],
                    "scores": scores[idx],
                    "labels": labels[idx],
                }]
                target = [{
                    "boxes": all_boxes[i] * 320,
                    "labels": all_labels[i],
                }]
                map_metric.update(preds, target)
    print("Train Set Results:")
    print(map_metric.compute())

In [None]:
def test(weights=[2, 5, 1]):
    model.eval()
    total_loss = 0.0
    cnt = 0
    # Stores each type of loss
    exact_loss = torch.tensor([0, 0, 0, 0, 0]).to(torch.float)
    with torch.no_grad():
        for batch in (pbar := tqdm(test_loader)):
            image, targets = batch
            image = image.to(device)
            all_labels = []
            all_boxes  = []
            all_masks = []
            for t in targets:
                all_labels.append(t["labels"].to(device))
                all_boxes.append(t["boxes"].to(device))
                all_masks.append(t["mask"].to(torch.long).to(device))
            
            class_pred, box_pred, mask_pred, size_pred = model(image)
            
            loss, iou, exact_losses = match_bbox_class(class_pred, box_pred, mask_pred, size_pred, all_labels, all_boxes, all_masks, weights)    
            total_loss += loss.item()
            exact_loss += exact_losses
            cnt += 1
            pbar.set_description(f"Testing: Average Loss: {total_loss / cnt:6f} | Average IoU: {iou:6f}")
    exact_loss = exact_loss / cnt
    loss_type_dict = {
        'total_class_loss': exact_loss[0],
        'total_bbox_loss': exact_loss[1],
        'total_no_label_loss': exact_loss[2],
        'total_mask_loss': exact_loss[3],
        'total_size_loss': exact_loss[4],
    }
    print(loss_type_dict)

In [None]:
def test_map():
    model.eval()
    total_loss = 0.0
    cnt = 0
    # 1) Instantiate once, before training
    map_metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox").to(device)
    with torch.no_grad():
        for batch in (pbar := tqdm(test_loader)):
            image, targets = batch
            image = image.to(device)
            all_labels = []
            all_boxes  = []
            all_masks = []
            for t in targets:
                all_labels.append(t["labels"].to(device))
                all_boxes.append(t["boxes"].to(device))
                all_masks.append(t["mask"].to(torch.long).to(device))
            
            class_pred, box_pred, mask_pred, size_pred = model(image)

            # AP Metrics Computation
            for i in range(image.shape[0]):
                has_labels = (torch.argmax(class_pred[i], dim=-1) != 20)
                image_size = (box_pred[i][:, 2] - box_pred[i][:, 0]) * (box_pred[i][:, 3] - box_pred[i][:, 1])
                image_size = image_size * 320 * 320
                size_mask = image_size > 4
                final_mask = torch.logical_and(has_labels, size_mask)
                boxes = box_pred[i][final_mask] * 320
                scores = torch.max(nn.Softmax(dim=-1)(class_pred[i]), dim=-1)[0][final_mask]
                labels = torch.argmax(class_pred[i], dim=-1)[final_mask]
                idx = ops.nms(boxes, scores, 0.5)
                preds = [{
                    "boxes": boxes[idx],
                    "scores": scores[idx],
                    "labels": labels[idx],
                }]
                target = [{
                    "boxes": all_boxes[i] * 320,
                    "labels": all_labels[i],
                }]
                map_metric.update(preds, target)
    # print(map_metric.compute())
    results = map_metric.compute()
    print(results)
    return results

In [None]:
all_results = []
for i in range(10):
    train([2, 6, 0])
    test([2, 6, 0])
    results = test_map()
    all_results.append(results)

In [None]:
train_map()

In [None]:
# test_map()

In [None]:
# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
# cProfile.run("train()", sort='tottime')

# Visualise

In [None]:
import matplotlib.pyplot as plt

In [None]:
with torch.no_grad():
    for batch in (test_loader): # train_loader
        image, targets = batch
        image = image.to(device)
        all_labels = []
        all_boxes  = []
        all_masks = []
        for t in targets:
            all_labels.append(t["labels"].to(device))
            all_boxes.append(t["boxes"].to(device))
            all_masks.append(t["mask"].to(torch.long).to(device))
        
        class_pred, box_pred, _, _ = model(image)
        break

In [None]:
all_labels[0]

In [None]:
# class_pred[0].shape

In [None]:
class_pred[0].argmax(-1)

In [None]:
box_pred[0].shape

In [None]:
box_pred[:, [0,2]] *= 320
box_pred[:, [1,3]] *= 320

In [None]:
torch.round(box_pred[0].cpu()).numpy()

In [None]:
box_pred.shape

In [None]:
targets[0]['boxes'].cpu().numpy()

In [None]:
idx2labels = {0: 'aeroplane', 1: 'bicycle', 2: 'bird', 3: 'boat', 4: 'bottle', 5: 'bus', 6: 'car', 7: 'cat', 8: 'chair', 9: 'cow', 10: 'diningtable', 11: 'dog', 12: 'horse', 13: 'motorbike', 14: 'person', 15: 'pottedplant', 16: 'sheep', 17: 'sofa', 18: 'train', 19: 'tvmonitor', 20: 'background'}

In [None]:
plt.figure(figsize=(12, 8))
plt.imshow(image[0].cpu().permute(1, 2, 0))
ax = plt.gca()

import matplotlib.patches as patches

# Example: keep only top‑K or score‑filtered indices
# for i in keep_inds:
cnt = 0
for xmin, ymin, xmax, ymax in box_pred[0].cpu().numpy():
    width  = xmax - xmin
    height = ymax - ymin
    rect = patches.Rectangle(
        (xmin, ymin), width, height,
        linewidth=2, edgecolor='red', facecolor='none'
    )
    ax.add_patch(rect)
    ax.text(xmin, ymin - 5, f"{idx2labels[int(class_pred[0].argmax(-1)[cnt].detach().cpu().numpy())]}",
        bbox=dict(facecolor='red', alpha=0.5, pad=0),
        color='white', fontsize=8)
    cnt += 1

cnt = 0
for xmin, ymin, xmax, ymax in targets[0]['boxes'].cpu().numpy():
    xmin = xmin * 320
    ymin = ymin * 320
    xmax = xmax * 320
    ymax = ymax * 320
    width  = xmax - xmin
    height = ymax - ymin
    rect = patches.Rectangle(
        (xmin, ymin), width, height,
        linewidth=2, edgecolor='blue', facecolor='none'
    )
    ax.add_patch(rect)
    ax.text(xmin, ymin - 5, f"{idx2labels[int(all_labels[0][cnt].detach().cpu().numpy())]}",
        bbox=dict(facecolor='blue', alpha=0.5, pad=0),
        color='white', fontsize=8)
    cnt += 1

plt.axis('off')   # optional: hide axes
plt.show()