In [1]:
import os
import torch
import torch.cuda as cuda
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from ImageNetKaggle import ImageNetKaggle
from eval_linear import RegLog
import src.resnet50 as resnet_models
import torchvision.models as models
from tqdm import tqdm

Loading validation data

In [2]:
import json

samples = []
targets = []
syn_to_class = {}

root = "/scratch/sl636/"
split = "val"
samples_dir = os.path.join(root, "ILSVRC/Data/CLS-LOC", split, "val_subdir")


with open(os.path.join(root, "imagenet_class_index.json"), "rb") as f:
    json_file = json.load(f)
    for class_id, v in json_file.items():
        syn_to_class[v[0]] = int(class_id)
with open(os.path.join(root, "ILSVRC2012_val_labels.json"), "rb") as f:
        val_to_syn = json.load(f)


with open(os.path.join(root, "ILSVRC2012_val_labels.json"), "rb") as f:
    val_to_syn = json.load(f)

for entry in os.listdir(samples_dir):
    syn_id = val_to_syn[entry]
    target = syn_to_class[syn_id]
    sample_path = os.path.join(samples_dir, entry)
    samples.append(sample_path)
    targets.append(target)

print(len(samples))

50000


In [3]:
normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])
dataset = ImageNetKaggle("/scratch/sl636/", "val", val_transform)

dataloader = DataLoader(
        dataset,
        batch_size=256,
        num_workers=10, 
        shuffle=False,
        drop_last=False,
        pin_memory=True
    )
device = "cuda:0" if cuda.is_available() else "cpu"


Defining the models as recommended here: https://github.com/facebookresearch/swav/issues/74#issuecomment-883325920

In [4]:
print(len(dataset))

50000


In [5]:
#model = torch.hub.load("facebookresearch/swav", "resnet50", pretrained=True)
model = resnet_models.__dict__["resnet50"](output_dim=0, eval_mode=True)
linear_classifier = RegLog(1000, global_avg=True, use_bn=False)


In [6]:
state_dict = torch.load("/home/sl636/swav/swav_800ep_eval_linear.pth.tar")["state_dict"]
# remove prefixe "module."
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
for k, v in linear_classifier.state_dict().items():
    if k not in list(state_dict):
        print('key "{}" could not be found in provided state dict'.format(k))
    elif state_dict[k].shape != v.shape:
        print('key "{}" is of different shape in model and provided state dict'.format(k))
        state_dict[k] = v
msg = linear_classifier.load_state_dict(state_dict, strict=False)
print(msg)
print(state_dict.keys())


<All keys matched successfully>
dict_keys(['linear.weight', 'linear.bias'])


I was hoping that in the code block below, I could replace state_dict with torch.load(PATH_TO_MY_OWN_TRAINED_RESNET) and get the accuracy, but it only work with the 800ep checkpoint

In [7]:
model = resnet_models.__dict__["resnet50"](output_dim=0, eval_mode=True)
state_dict = torch.load("/home/sl636/swav/swav_800ep_pretrain.pth.tar")
#state_dict = torch.load("/home/sl636/swav/experiments/indep/swav/imagenet_from_scratch_400_wrongniters/checkpoints/ckp-399.pth")["state_dict"]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
for k, v in model.state_dict().items():
    if k not in list(state_dict):
        print('key "{}" could not be found in provided state dict'.format(k))
    elif state_dict[k].shape != v.shape:
        print('key "{}" is of different shape in model and provided state dict'.format(k))
        state_dict[k] = v
msg = model.load_state_dict(state_dict, strict=False)
print(msg)
print(state_dict.keys())

_IncompatibleKeys(missing_keys=[], unexpected_keys=['projection_head.0.weight', 'projection_head.0.bias', 'projection_head.1.weight', 'projection_head.1.bias', 'projection_head.1.running_mean', 'projection_head.1.running_var', 'projection_head.1.num_batches_tracked', 'projection_head.3.weight', 'projection_head.3.bias', 'prototypes.weight'])
dict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsampl

In [8]:
model = model.to(device)
linear_classifier = linear_classifier.to(device)

model.eval()
linear_classifier.eval()
correct = 0
total = 0

In [9]:
invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

We get 75.3% as shown here https://github.com/facebookresearch/swav/issues/74#issuecomment-883325920

In [10]:
with torch.no_grad():
    for x, y in tqdm(dataloader):
        y_pred = linear_classifier(model(x.to(device)))
        correct += (y_pred.argmax(axis=1) == y.to(device)).sum().item()
        total += len(y)
print(correct / total)

100%|██████████| 196/196 [00:59<00:00,  3.29it/s]

0.75268





when loading my own model and/or another facebook-released checkpoint from that same table, the accuracy is < 1%

need to train a linear classifier for this specific backbone

In [11]:
model = resnet_models.__dict__["resnet50"](output_dim=0, eval_mode=True)
#state_dict = torch.load("/home/sl636/swav/swav_800ep_pretrain.pth.tar")
state_dict = torch.load("/home/sl636/swav/experiments/indep/swav/imagenet_from_scratch_400_wrongniters/checkpoints/ckp-399.pth")["state_dict"]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
for k, v in model.state_dict().items():
    if k not in list(state_dict):
        print('key "{}" could not be found in provided state dict'.format(k))
    elif state_dict[k].shape != v.shape:
        print('key "{}" is of different shape in model and provided state dict'.format(k))
        state_dict[k] = v
msg = model.load_state_dict(state_dict, strict=False)
#model.fc = torch.nn.Linear(2048, 1000)
print(msg)
print(state_dict.keys())

_IncompatibleKeys(missing_keys=[], unexpected_keys=['projection_head.0.weight', 'projection_head.0.bias', 'projection_head.1.weight', 'projection_head.1.bias', 'projection_head.1.running_mean', 'projection_head.1.running_var', 'projection_head.1.num_batches_tracked', 'projection_head.3.weight', 'projection_head.3.bias', 'prototypes.weight'])
dict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsampl

load new linear classifier

In [None]:
linear_classifier = RegLog(1000, global_avg=True, use_bn=False)
state_dict = torch.load("PATH_TO_PTH_FILE")["state_dict"]
# remove prefixe "module."
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
for k, v in linear_classifier.state_dict().items():
    if k not in list(state_dict):
        print('key "{}" could not be found in provided state dict'.format(k))
    elif state_dict[k].shape != v.shape:
        print('key "{}" is of different shape in model and provided state dict'.format(k))
        state_dict[k] = v
msg = linear_classifier.load_state_dict(state_dict, strict=False)
print(msg)
print(state_dict.keys())



In [25]:
model = model.to(device)
linear_classifier = linear_classifier.to(device)

model.eval()
linear_classifier.eval()
correct = 0
total = 0

In [29]:
sample_input = torch.randn(1, 3, 224, 224)  # Example: 1 batch, 3 channels, 224x224 image

# Step 3: Pass the input through the model
output = model(sample_input)

# Print the output shape
print("Output shape:", output.shape)

Output shape: torch.Size([1, 2048, 7, 7])


In [12]:
import torch.nn as nn

In [32]:
decoder = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                        nn.Flatten(),
                        nn.Linear(in_features = 2048,out_features=1000,bias=True))
decoder.to(device)

output = decoder(output)
print(output.shape)

torch.Size([1, 1000])


In [33]:
correct, total = 0, 0

In [38]:
with torch.no_grad():
    for x, y in tqdm(dataloader):
        outputs = decoder(model(x.to(device)))
        y_pred = torch.argmax(outputs, dim=1)
        correct += (y_pred == y.to(device)).sum().item()
        total += len(y)
print(correct / total)

100%|██████████| 196/196 [01:20<00:00,  2.43it/s]

0.001187160811196801





: 