# **A better Detection Head**
* The content is followed using "A Better Detection Head | Essentials of Object Detection".<br>Reference: https://www.youtube.com/watch?v=szu30l7wQ_Y&list=PLivJwLo9VCUJXdO8SiOjZTWr_fXrAy4OQ&index=6
* Extended by **Vigyannveshi** 

In [1]:
from __future__ import annotations

import torch as tr
import torch.nn as nn

from typing import NamedTuple
import einops
from einops.layers.torch import Rearrange

**creating classes to simulate backbone**

In [2]:
class FakeBackboneResult(NamedTuple):
    hl_features:tr.Tensor
    ml_features:tr.Tensor
    ll_features:tr.Tensor

In [3]:
class FakeBackbone(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x:tr.Tensor)-> FakeBackboneResult:
        hl_fm=tr.torch.randn(size=(1,512,13,13))
        ml_fm=tr.torch.randn(size=(1,256,26,26))
        ll_fm=tr.torch.randn(size=(1,128,52,52))

        return FakeBackboneResult(
            hl_features=hl_fm,
            ml_features=ml_fm,
            ll_features=ll_fm) 


In [4]:
backbone=FakeBackbone()
backbone_output=backbone(tr.rand(size=(1,3,416,416)))

In [5]:
backbone_output.hl_features.shape

torch.Size([1, 512, 13, 13])

In [6]:
backbone_output.ml_features.shape

torch.Size([1, 256, 26, 26])

In [7]:
backbone_output.ll_features.shape

torch.Size([1, 128, 52, 52])

**Splitting the Detection Head into BoxHead, ObjectnessHead, DetectionHead**
* Every cell can at max predict num_boxes_per_cell object

In [8]:
class BoxHead(nn.Module):
    def __init__(self,in_channels:int,num_boxes_per_cell:int):
        super().__init__()

        num_predicted_channels=num_boxes_per_cell*4

        self.conv=nn.Conv2d(
            in_channels=in_channels,
            out_channels=num_predicted_channels,
            kernel_size=1,
            stride=1)

        self.pred_reshape=Rearrange(
            "batchsize (num_boxes_per_cell prediction_per_cell) height width -> batchsize num_boxes_per_cell height width prediction_per_cell",num_boxes_per_cell=num_boxes_per_cell
        )

    def forward(self,x:tr.Tensor)-> tr.Tensor:
        x=self.conv(x)
        x=self.pred_reshape(x)
        return x

In [9]:
hl_box_detector = BoxHead(in_channels=512,num_boxes_per_cell=3)
hl_box_detections=hl_box_detector(backbone_output.hl_features)
print(hl_box_detections[0].shape)
hl_box_detections.shape

torch.Size([3, 13, 13, 4])


torch.Size([1, 3, 13, 13, 4])

In [10]:
class ClassificationHead(nn.Module):
    def __init__(self,in_channels:int,num_boxes_per_cell:int,num_classes:int):
        super().__init__()
        num_predicted_channels=num_boxes_per_cell*num_classes

        self.conv=nn.Conv2d(
            in_channels=in_channels,
            out_channels=num_predicted_channels,
            kernel_size=1,
            stride=1
        )

        self.pred_reshape=Rearrange(
            "batchsize (num_boxes_per_cell predictions_per_cell) height width -> batchsize num_boxes_per_cell height width predictions_per_cell", num_boxes_per_cell=num_boxes_per_cell
        )

    def forward(self,x:tr.Tensor)->tr.Tensor:
        x=self.conv(x)
        x=self.pred_reshape(x)
        return x


In [11]:
hl_class_detector = ClassificationHead(in_channels=512, num_boxes_per_cell=3, num_classes=3)

hl_class_detections = hl_class_detector(backbone_output.hl_features)
hl_class_detections.shape

torch.Size([1, 3, 13, 13, 3])

In [12]:
class ObjectnessHead(nn.Module):
    def __init__(self,in_channels,num_boxes_per_cell):
        super().__init__()
        num_predicted_channels=num_boxes_per_cell*1

        self.conv=nn.Conv2d(
            in_channels=in_channels,
            out_channels=num_predicted_channels,
            kernel_size=1,
            stride=1
        )

        self.pred_reshape=Rearrange(
            "batchsize (num_boxes_per_cell predictions_per_cell) height width -> batchsize num_boxes_per_cell height width predictions_per_cell",num_boxes_per_cell=num_boxes_per_cell
        )

    def forward(self,x:tr.Tensor)->tr.Tensor:
        x=self.conv(x)
        x=self.pred_reshape(x)
        return x


In [13]:
hl_obj_detector = ObjectnessHead(in_channels=512, num_boxes_per_cell=3)

hl_obj_detections = hl_obj_detector(backbone_output.hl_features)
hl_obj_detections.shape

torch.Size([1, 3, 13, 13, 1])

In [14]:
class DetectionResult(NamedTuple):
    predictions_box:tr.Tensor
    predictions_obj:tr.Tensor
    predictions_cls:tr.Tensor

In [17]:
class DetectionHead(nn.Module):
    def __init__(self,in_channels:int,num_boxes_per_cell:int,num_classes:int):
        super().__init__()

        self.box_head=BoxHead(
            in_channels=in_channels,num_boxes_per_cell=num_boxes_per_cell
        )
        self.obj_head=ObjectnessHead(
            in_channels=in_channels,num_boxes_per_cell=num_boxes_per_cell
        )

        self.cls_head=ClassificationHead(
            in_channels=in_channels,num_boxes_per_cell=num_boxes_per_cell,
            num_classes=num_classes
        )
    
    def forward(self,x:tr.Tensor)->tr.Tensor:
        predictions_box=self.box_head(x)
        predictions_obj=self.obj_head(x)
        predictions_cls=self.cls_head(x)
        return DetectionResult(predictions_box,predictions_obj,predictions_cls)

In [18]:
hl_detector = DetectionHead(in_channels=512, num_boxes_per_cell=3, num_classes=3)

hl_box_detections, hl_obj_detections, hl_class_detections = hl_detector(backbone_output.hl_features)

hl_box_detections.shape, hl_obj_detections.shape, hl_class_detections.shape

(torch.Size([1, 3, 13, 13, 4]),
 torch.Size([1, 3, 13, 13, 1]),
 torch.Size([1, 3, 13, 13, 3]))