In [2]:
%load_ext autoreload
%autoreload 2

import torch
from torchinfo import summary
from src.architectures.feature_extractors.resnet import ResNet

count_params = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)

img = torch.rand(1, 3, 224, 224)

# summary(net, input_data=img)

In [59]:
net_local = ResNet(3, "resnet152", use_global_pool=True, load_from_torch=False, pretrained=False, freeze_extractor=False)
net_torch = ResNet(3, "resnet152", use_global_pool=True, load_from_torch=True, pretrained=False, freeze_extractor=False)


In [60]:
net_local(img).shape, net_torch(img).shape

(torch.Size([1, 2048]), torch.Size([1, 2048]))

In [61]:
count_params(net_local), count_params(net_torch)

(58143808, 58143808)

In [62]:
summary(net_local, input_data=img, depth=4, col_names=["input_size", "output_size", "num_params", "kernel_size"])

Layer (type:depth-idx)                                  Input Shape               Output Shape              Param #                   Kernel Shape
ResNet152                                               [1, 3, 224, 224]          [1, 2048]                 --                        --
├─Sequential: 1-1                                       --                        --                        --                        --
│    └─CNNBlock: 2-1                                    [1, 3, 224, 224]          [1, 64, 55, 55]           --                        --
│    │    └─Conv2d: 3-1                                 [1, 3, 224, 224]          [1, 64, 112, 112]         9,408                     [7, 7]
│    │    └─BatchNorm2d: 3-2                            [1, 64, 112, 112]         [1, 64, 112, 112]         128                       --
│    │    └─ReLU: 3-3                                   [1, 64, 112, 112]         [1, 64, 112, 112]         --                        --
│    │    └─MaxPool2d: 3-4 

In [63]:
summary(net_torch, input_data=img, col_names=["input_size", "output_size", "num_params", "kernel_size"])

Layer (type:depth-idx)                             Input Shape               Output Shape              Param #                   Kernel Shape
ExternalFeatureExtractor                           [1, 3, 224, 224]          [1, 2048]                 --                        --
├─Sequential: 1-1                                  --                        --                        --                        --
│    └─ResNet: 2-1                                 [1, 3, 224, 224]          [1, 2048]                 --                        --
│    │    └─Conv2d: 3-1                            [1, 3, 224, 224]          [1, 64, 112, 112]         9,408                     [7, 7]
│    │    └─BatchNorm2d: 3-2                       [1, 64, 112, 112]         [1, 64, 112, 112]         128                       --
│    │    └─ReLU: 3-3                              [1, 64, 112, 112]         [1, 64, 112, 112]         --                        --
│    │    └─MaxPool2d: 3-4                         [1, 64, 112