In [1]:
import torch
from PIL import Image
from torchvision import transforms
from torchvision.models import efficientnet_v2_s, inception_v3, resnet50, wide_resnet101_2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = resnet50(weights="DEFAULT")

In [3]:
img = Image.open("/home/vislab-001/Documents/algonauts_2023_challenge_data/subj01/training_split/training_images/train-0013_nsd-00140.png")

In [4]:
tsfms = transforms.Compose([
    transforms.Resize(384),
    transforms.ToTensor()
])

In [6]:
model.l

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [26]:
torch.flatten(finalConvLayer(extracted), 1).shape

torch.Size([1, 184320])

In [5]:
class EfficientNetExtractor(torch.nn.Module):
    def __init__(self):
        super(EfficientNetExtractor, self).__init__()
        self.efficientNet = efficientnet_v2_s()
        self.features = self.efficientNet.features[:-1]
        self.finalConvLayer = list(self.efficientNet.features[-1].children())[0]

    def forward(self, img):
        features = self.features(img)
        features = self.finalConvLayer(features)
        return features
    
class InceptionEXtractor(torch.nn.Module):
    def __init__(self):
        super(InceptionEXtractor, self).__init__()
        self.inception = inception_v3(weights="DEFAULT")

    def forward(self, x):
        # N x 3 x 299 x 299
        x = self.inception.Conv2d_1a_3x3(x)
        # N x 32 x 149 x 149
        x = self.inception.Conv2d_2a_3x3(x)
        # N x 32 x 147 x 147
        x = self.inception.Conv2d_2b_3x3(x)
        # N x 64 x 147 x 147
        x = self.inception.maxpool1(x)
        # N x 64 x 73 x 73
        x = self.inception.Conv2d_3b_1x1(x)
        # N x 80 x 73 x 73
        x = self.inception.Conv2d_4a_3x3(x)
        # N x 192 x 71 x 71
        x = self.inception.maxpool2(x)
        # N x 192 x 35 x 35
        x = self.inception.Mixed_5b(x)
        # N x 256 x 35 x 35
        x = self.inception.Mixed_5c(x)
        # N x 288 x 35 x 35
        x = self.inception.Mixed_5d(x)
        # N x 288 x 35 x 35
        x = self.inception.Mixed_6a(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6b(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6c(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6d(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_6e(x)
        # N x 768 x 17 x 17
        x = self.inception.Mixed_7a(x)
        # N x 1280 x 8 x 8
        x = self.inception.Mixed_7b(x)
        # N x 2048 x 8 x 8
        x = self.inception.Mixed_7c(x)
        # N x 2048 x 8 x 8
        # # Adaptive average pooling
        # x = self.inception.avgpool(x)

        return x
    

class Resnet50Extractor(torch.nn.Module):
    def __init__(self):
        super(Resnet50Extractor, self).__init__()
        self.resnet = resnet50(weights="DEFAULT")

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        return x

class WideResnet1012(torch.nn.Module):
    def __init__(self):
        super(WideResnet1012, self).__init__()
        self.resnet = wide_resnet101_2(weights="DEFAULT")

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        return x




In [6]:
model = Resnet50Extractor()

In [7]:
model(tsfms(img)[None, :, :, :]).shape

torch.Size([1, 2048, 12, 12])

In [8]:
torch.flatten(model(tsfms(img)[None, :, :, :]), 1).shape

torch.Size([1, 294912])