In [12]:
import torch
from torchvision import models
import os

model = models.resnet50()
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False
# init the fc layer
model.fc.weight.data.normal_(mean=0.0, std=0.01)
model.fc.bias.data.zero_()

# the pretrained weights from SimSiam
pretrained_simsiam = './simsiam/checkpoint_0099.pth.tar'

# transfer pretrained weights from SimSiam
if os.path.isfile(pretrained_simsiam):
    print("=> loading checkpoint '{}'".format(pretrained_simsiam))
    checkpoint = torch.load(pretrained_simsiam, map_location="cpu")

    # rename moco pre-trained keys
    state_dict = checkpoint['state_dict']
    for k in list(state_dict.keys()):
        # retain only encoder up to before the embedding layer
        if k.startswith('module.encoder') and not k.startswith('module.encoder.fc'):
            # remove prefix
            state_dict[k[len("module.encoder."):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]
    msg = model.load_state_dict(state_dict, strict=False)
    assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
    print("=> loaded pre-trained model '{}'".format(pretrained_simsiam))
else:
    print("=> no checkpoint found at '{}'".format(pretrained_simsiam))


=> loading checkpoint './simsiam/checkpoint_0099.pth.tar'
=> loaded pre-trained model './simsiam/checkpoint_0099.pth.tar'


In [35]:
def encode(model, x):
    for i, mod in enumerate(model.modules()):
        if i>0:
            print('{}: {}'.format(i, mod))

import numpy as np

input_size = 224
x = np.random.rand(1, input_size, input_size, 3)
model.forward()
y = encode(model, x)


1: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
2: BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
3: ReLU(inplace=True)
4: MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
5: Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1,