<a href="https://colab.research.google.com/github/raynardj/python4ml/blob/main/experiments/disect_conv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to finetune only a part of the obj detection

In [59]:
import torch
from torch import nn

## Usual last layer of obj detection top

Assume this is the activation before the top layer, batch size 16 x channel_size 32 x 7 x 7 , as 7 is both grid height and grid width

In [60]:
x = torch.rand(16, 32, 7,7,)

Assume this is the top layer(last layer)

In [61]:
last_layer = nn.Conv2d(32, 105, 3, padding=1)

In [62]:
with torch.no_grad():
    # forward pass
    y_ = last_layer(x)

It's like a single dense layer for each grid, transformed the 32 channels to 105

In [63]:
y_.shape

torch.Size([16, 105, 7, 7])

So for this part of the model, the weights is like in  this shape, and let's save the weights to the numpy variable `pretrain_weights`

In [64]:
pretrain_weights = last_layer.weight.data.numpy()
pretrain_weights.shape

(105, 32, 3, 3)

In [65]:
pretrain_bias = last_layer.bias.data.numpy()
pretrain_bias.shape

(105,)

## Segregated weights

Let's bread down the top layer to 3 conv module, to see if it can perform like the same

In [66]:
conv_logit=nn.Conv2d(32, 1, 3, padding=1)
conv_bbox=nn.Conv2d(32, 4, 3, padding=1)
conv_classes=nn.Conv2d(32, 100, 3, padding=1)

Now we load 1 layer of the pretrained weights to 3 conv module.

In [69]:
conv_logit.weight.data = torch.Tensor(pretrain_weights[:1,...])
conv_bbox.weight.data = torch.Tensor(pretrain_weights[1:5,...])
conv_classes.weight.data = torch.Tensor(pretrain_weights[5:,...])

conv_logit.bias.data = torch.Tensor(pretrain_bias[:1])
conv_bbox.bias.data = torch.Tensor(pretrain_bias[1:5])
conv_classes.bias.data = torch.Tensor(pretrain_bias[5:])

In [71]:
with torch.no_grad():
    y_1 = conv_logit(x)
    y_2 = conv_bbox(x)
    y_3 = conv_classes(x)

In [72]:
y_1.shape, y_2.shape, y_3.shape

(torch.Size([16, 1, 7, 7]),
 torch.Size([16, 4, 7, 7]),
 torch.Size([16, 100, 7, 7]))

In [73]:
y_combined = torch.cat([y_1, y_2, y_3], dim=1)
y_combined.shape

torch.Size([16, 105, 7, 7])

We can reproduce the exact results, but as in application, we can train 3 parts of conv separately, or only train 1 part of them

In [75]:
y_==y_combined

tensor([[[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ...