# **Detection Head**
* The content is followed using "Detection Head | Essentials of Object Detection".<br>Reference: https://www.youtube.com/watch?v=U6rpkdVm21E&list=PLivJwLo9VCUJXdO8SiOjZTWr_fXrAy4OQ&index=5
* Extended by **Vigyannveshi** 

In [1]:
import torch as tr
import torch.nn as nn

from typing import NamedTuple
import einops

**creating classes to simulate backbone**

In [2]:
class FakeBackboneResult(NamedTuple):
    hl_features:tr.Tensor
    ml_features:tr.Tensor
    ll_features:tr.Tensor

In [3]:
class FakeBackbone(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x:tr.Tensor)-> FakeBackboneResult:
        hl_fm=tr.torch.randn(size=(1,512,13,13))
        ml_fm=tr.torch.randn(size=(1,256,26,26))
        ll_fm=tr.torch.randn(size=(1,128,52,52))

        return FakeBackboneResult(
            hl_features=hl_fm,
            ml_features=ml_fm,
            ll_features=ll_fm) 


In [4]:
backbone=FakeBackbone()
backbone_output=backbone(tr.rand(size=(1,3,416,416)))

In [5]:
backbone_output.hl_features.shape

torch.Size([1, 512, 13, 13])

In [6]:
backbone_output.ml_features.shape

torch.Size([1, 256, 26, 26])

In [7]:
backbone_output.ll_features.shape

torch.Size([1, 128, 52, 52])

**creating class to simulate detection head**

In [8]:
class DetectionHead(nn.Module):
    def __init__(self,in_channels:int,num_boxes_per_cell:int,num_classes:int):
        super().__init__()
        num_predicted_channels=num_boxes_per_cell*(4+1+num_classes)

        self.conv=nn.Conv2d(
            in_channels=in_channels,
            out_channels=num_predicted_channels,
            kernel_size=1,
            stride=1
        )

    def forward(self,x:tr.Tensor)->tr.Tensor:
        x=self.conv(x)
        return x

In [9]:
hl_detector=DetectionHead(in_channels=512,num_boxes_per_cell=3,num_classes=3)
hl_detections=hl_detector.forward(backbone_output.hl_features)
print(hl_detections.shape)
# hl_detections

torch.Size([1, 24, 13, 13])


In [10]:
ml_detector=DetectionHead(in_channels=256,num_boxes_per_cell=3,num_classes=3)
ml_detections=ml_detector.forward(backbone_output.ml_features)
print(ml_detections.shape)
# ml_detections

torch.Size([1, 24, 26, 26])


In [11]:
ll_detector=DetectionHead(in_channels=128,num_boxes_per_cell=3,num_classes=3)
ll_detections=ll_detector.forward(backbone_output.ll_features)
print(ll_detections.shape)
# ml_detections

torch.Size([1, 24, 52, 52])


In [12]:
hl_detections_for_training=einops.rearrange(
    hl_detections,
    "batchsize (num_anchors_per_cell prediction_per_object) height width -> batchsize num_anchors_per_cell height width prediction_per_object",num_anchors_per_cell=3
)
hl_detections_for_training.shape

torch.Size([1, 3, 13, 13, 8])

In [13]:
pred_for_box0_at_cell_15 = hl_detections_for_training[0][0][1][5]
pred_for_box1_at_cell_15 = hl_detections_for_training[0][1][1][5]
pred_for_box2_at_cell_15 = hl_detections_for_training[0][2][1][5]

pred_for_box0_at_cell_15.shape

torch.Size([8])

In [14]:
box_coordinates = pred_for_box0_at_cell_15[:4]
box_objectness = pred_for_box0_at_cell_15[4]
box_classes = pred_for_box0_at_cell_15[5:]

box_coordinates, box_objectness, box_classes

(tensor([-1.0813,  0.8398,  0.9141,  1.3997], grad_fn=<SliceBackward0>),
 tensor(-0.3770, grad_fn=<SelectBackward0>),
 tensor([0.2549, 0.3565, 0.5067], grad_fn=<SliceBackward0>))

In [15]:
pred_for_coordinates = hl_detections_for_training[...,:4]

pred_for_coordinates.shape

torch.Size([1, 3, 13, 13, 4])

In [16]:
pred_for_objectness = hl_detections_for_training[..., 4]
pred_for_objectness.shape

torch.Size([1, 3, 13, 13])

In [17]:
pred_for_classes = hl_detections_for_training[..., 5:]

pred_for_classes.shape

torch.Size([1, 3, 13, 13, 3])

In [18]:
hl_detections_for_final_prediction = einops.rearrange(hl_detections, 
                                          "batchsize (num_anchors_per_cell predictions_per_object) height width -> batchsize (num_anchors_per_cell height width) predictions_per_object",
                                          num_anchors_per_cell=3, height=13, width=13)

print(hl_detections_for_final_prediction)
hl_detections_for_final_prediction.shape

tensor([[[-0.6068,  0.3264, -0.0026,  ...,  0.6653,  0.5690, -0.6186],
         [ 0.3252,  0.5246, -0.8246,  ..., -0.9340, -0.8551, -0.2306],
         [-0.1303, -0.6322,  0.2768,  ..., -0.8081,  0.2712, -0.2689],
         ...,
         [-0.1936, -0.5268, -0.1613,  ..., -0.6694,  0.2115,  0.7215],
         [-0.4344,  0.0244, -0.2799,  ...,  0.8163, -0.1796, -0.5253],
         [ 0.8956,  0.3108,  0.0817,  ...,  0.0982, -0.1074, -0.8840]]],
       grad_fn=<UnsafeViewBackward0>)


torch.Size([1, 507, 8])