In [1]:
import torch.nn as nn
from torchvision.models import densenet161, resnet152, vgg19
from torchvision.models import VGG19_Weights

net = vgg19(weights=VGG19_Weights.DEFAULT)
net

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [22]:
net = vgg19(weights=VGG19_Weights.DEFAULT)
net = nn.Sequential(*list(net.features.children())[:-1])

# freeze
for param in net.parameters():
    print(param.requires_grad)
    param.requires_grad = False
    
for param in net.parameters():
    print(param.requires_grad)

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False


In [16]:
from prettytable import PrettyTable
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [24]:
net

count_parameters(net)

+---------+------------+
| Modules | Parameters |
+---------+------------+
+---------+------------+
Total Trainable Params: 0


0

In [None]:
# Using features from VGG19, excluding the last pooling layer
net = nn.Sequential(*list(net.features.children())[:-1])
dim = 512  # Dimension of feature vectors for VGG19

+---------------------+------------+
|       Modules       | Parameters |
+---------------------+------------+
|  features.0.weight  |    1728    |
|   features.0.bias   |     64     |
|  features.2.weight  |   36864    |
|   features.2.bias   |     64     |
|  features.5.weight  |   73728    |
|   features.5.bias   |    128     |
|  features.7.weight  |   147456   |
|   features.7.bias   |    128     |
|  features.10.weight |   294912   |
|   features.10.bias  |    256     |
|  features.12.weight |   589824   |
|   features.12.bias  |    256     |
|  features.14.weight |   589824   |
|   features.14.bias  |    256     |
|  features.16.weight |   589824   |
|   features.16.bias  |    256     |
|  features.19.weight |  1179648   |
|   features.19.bias  |    512     |
|  features.21.weight |  2359296   |
|   features.21.bias  |    512     |
|  features.23.weight |  2359296   |
|   features.23.bias  |    512     |
|  features.25.weight |  2359296   |
|   features.25.bias  |    512     |
|

143667240

In [11]:
import torch
import numpy as np

# Load the tensors
preds_scenario1 = torch.load('preds_scenario1_epoch1_att.pt')
preds_scenario2 = torch.load('preds_scenario1_epoch1_noatt.pt')

# Convert to numpy if required
preds_scenario1_np = preds_scenario1.cpu().detach().numpy()
preds_scenario2_np = preds_scenario2.cpu().detach().numpy()

# Compare
comparison = np.allclose(np.abs(preds_scenario1_np), np.abs(preds_scenario2_np), rtol=1e-01)
print(f'Are the predictions close: {comparison}')


Are the predictions close: False


In [13]:
preds_scenario1 - preds_scenario2

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0015,  0.0000, -0.0010],
         [ 0.0000,  0.0026,  0.0000,  ..., -0.0037, -0.0006,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0110,  0.0000],
         ...,
         [ 0.0000, -0.0049, -0.0087,  ..., -0.0102, -0.0125,  0.0000],
         [ 0.0000,  0.0100,  0.0039,  ..., -0.0123,  0.0008,  0.0000],
         [ 0.0000, -0.0140,  0.0063,  ...,  0.0077, -0.0074,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ..., -0.0094,  0.0000,  0.0076],
         [ 0.0000,  0.0000, -0.0056,  ...,  0.0000,  0.0000,  0.0102],
         [ 0.0028,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0019, -0.0301,  ..., -0.0345, -0.0337,  0.0000],
         [ 0.0000,  0.0258,  0.0064,  ..., -0.0299, -0.0355,  0.0000],
         [ 0.0000,  0.0220, -0.0114,  ..., -0.0150,  0.0189,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ..., -0.0012,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0

In [10]:
np.abs(preds_scenario1_np)

array([[[0.        , 0.        , 0.        , ..., 0.03335105,
         0.        , 0.03577995],
        [0.        , 0.34878963, 0.        , ..., 0.28793523,
         0.83124405, 0.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.194206  , 0.        ],
        ...,
        [0.        , 0.13001278, 0.3545614 , ..., 0.7991455 ,
         0.49412656, 0.        ],
        [0.        , 0.04037298, 0.4229029 , ..., 0.90565455,
         0.5475647 , 0.        ],
        [0.        , 0.14127803, 0.4307167 , ..., 0.84165597,
         0.41852698, 0.        ]],

       [[0.        , 0.        , 0.        , ..., 0.03455406,
         0.        , 0.06838391],
        [0.        , 0.        , 0.49620795, ..., 0.        ,
         0.        , 0.12208754],
        [0.37670052, 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ],
        ...,
        [0.        , 0.16886163, 0.41359276, ..., 0.8285238 ,
         0.32362026, 0.        ],
        [0. 