In [1]:
from model import resnet_block, ConvBnAct, ResNet18
from utils import update_model_wt, set_seed
from collections import OrderedDict
from fastcore.utils import noop

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from torch import nn
from torchvision.models import resnet18

In [3]:
set_seed()

In [4]:
inp = torch.randn(5, 3, 224, 224)

In [5]:
model_gem = ResNet18(n_cls=10)

In [6]:
model_gem(inp)

tensor([[-0.0732, -1.3212,  0.1588, -0.3880,  0.6878,  1.7276, -2.0731,  0.5310,
          0.2525,  0.1851],
        [-0.2888, -1.4897,  0.1330, -0.2777,  0.8336,  1.5032, -2.1334,  0.3064,
          0.5680,  0.3668],
        [-0.3781, -1.5451,  0.1156, -0.2519,  0.6058,  1.6319, -2.0664,  0.3481,
          0.4720,  0.1412],
        [-0.0925, -1.3294,  0.1398, -0.3579,  0.7980,  1.4543, -1.9582,  0.6035,
          0.4494,  0.2037],
        [-0.4258, -1.4692,  0.0358, -0.3029,  0.7881,  1.4945, -1.8997,  0.6500,
          0.3332,  0.2942]], grad_fn=<AddmmBackward0>)

In [7]:
model_gem(inp).shape

torch.Size([5, 10])

Modifying the Resnet Class pooling layer.

In [8]:
class ResNet18(nn.Module):
    """Modified the pooling layer"""
    def __init__(self, n_cls=2):
        super(ResNet18, self).__init__()
        params = dict(in_ch=3, out_ch=64, k=7, s=2, p=3)  # GoogLeNet
        # x = [B, 3, 32, 32]
        self.l1 = nn.Sequential(ConvBnAct(**params),
                                nn.MaxPool2d(kernel_size=3, stride=2, padding=1))  # -> [B, 512, 16, 16]
        self.l2 = nn.Sequential(OrderedDict(resnet_block(64, 64, 2, first_block=True)))  # -> [B, 64, 8, 8]
        self.l3 = nn.Sequential(OrderedDict(resnet_block(64, 128, 2)))  # -> [B, 128, 4, 4]
        self.l4 = nn.Sequential(OrderedDict(resnet_block(128, 256, 2)))  # -> [B, 256, 2, 2]
        self.l5 = nn.Sequential(OrderedDict(resnet_block(256, 512, 2)))  # -> [B, 512, 1, 1]
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.flat = nn.Flatten()
        self.fc = nn.Linear(512, n_cls, bias=True)  # -> [512, n_cls]

    def forward(self, x):
        x = self.l5(self.l4(self.l3(self.l2(self.l1(x)))))
        return self.fc(self.flat(self.pool(x)))

In [9]:
model_adaptive = ResNet18(n_cls=10)

In [10]:
model_adaptive(inp).shape

torch.Size([5, 10])

In [11]:
model_adaptive

ResNet18(
  (l1): Sequential(
    (0): ConvBnAct(
      (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act_fn): ReLU(inplace=True)
    )
    (1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (l2): Sequential(
    (reg_blk0): BasicResBlock(
      (conv1): ConvBnAct(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_fn): ReLU(inplace=True)
      )
      (conv2): ConvBnAct(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (reg_blk1): BasicResBlock(
      (conv1): ConvBnAct(
        (conv): Conv2d(64, 64, kernel_size=