In [2]:
%load_ext autoreload
%autoreload 2

import torch
from torchinfo import summary
from src.architectures.feature_extractors.resnext import ResNext
from src.architectures.feature_extractors.resnet import ResNet


count_params = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)

img = torch.rand(1, 3, 224, 224)

# summary(net, input_data=img)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [57]:
net_local = ResNext(
    3, "resnext50_32x4d", load_from_torch=False, pretrained=False, freeze_extractor=False
)
net_torch = ResNext(
    3, "resnext50_32x4d", load_from_torch=True, pretrained=False, freeze_extractor=False
)


In [58]:
net_local(img).shape, net_torch(img).shape


(torch.Size([1, 2048]), torch.Size([1, 2048]))

In [59]:
count_params(net_local), count_params(net_torch)


(22979904, 22979904)

In [33]:
# summary(net_local, input_data=img, depth=4, col_names=["input_size", "output_size", "num_params", "kernel_size"])


In [34]:
# summary(net_torch, input_data=img, col_names=["input_size", "output_size", "num_params", "kernel_size"])


In [39]:
net_local.net[1].net.stage_1[1]


BottleneckBlock(
  (conv1): CNNBlock(
    (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (batch_norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation_fn): ReLU()
  )
  (conv2): CNNBlock(
    (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
    (batch_norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation_fn): ReLU()
  )
  (conv3): CNNBlock(
    (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (batch_norm): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (relu): ReLU()
)

In [38]:
net_torch.net[0].layer2[0]


Bottleneck(
  (conv1): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
  (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [56]:
net_local = ResNet(3, "resnet34", load_from_torch=False, pretrained=False, freeze_extractor=False)
net_torch = ResNet(3, "resnet34", load_from_torch=True, pretrained=False, freeze_extractor=False)
count_params(net_local), count_params(net_torch)


(21284672, 21284672)