In [1]:
import os
import torch
import torchvision

import torchvision.transforms as transforms

from torch.utils.data import DataLoader

from cl_explain.encoders.simclr.resnet_wider import resnet50x1, resnet50x2, resnet50x4

In [2]:
device = "cuda:6"
model_path = "/projects/leelab/cl-explainability/models/simclr/"

In [3]:
def load_model(resnet_version, model_path, device):
    if resnet_version == 1:
        model = resnet50x1()
        sd_path = os.path.join(model_path, "resnet50-1x.pth")
    elif resnet_version == 2:
        model = resnet50x2()
        sd_path = os.path.join(model_path, "resnet50-2x.pth")
    elif resnet_version == 4:
        model = resnet50x4()
        sd_path = os.path.join(model_path, "resnet50-4x.pth")
    else:
        raise NotImplementedError(
            f"ResNet50({resnet_version}x) is not implemented!"
        )
    sd = torch.load(sd_path, map_location="cpu")
    model.load_state_dict(sd["state_dict"])
    model.to(device)
    return model

In [4]:
def test_model(model, img, device):
    img = img.to(device)
    output = model(img, apply_eval_head=True)
    rep = model(img, apply_eval_head=False)
    
    print(f"Output dim = {output.shape[-1]}")
    print(f"Representation dim = {rep.shape[-1]}")

In [5]:
# Load in data.
imagenette_path = "/projects/leelab/data/image/imagenette2/val"

dataset = torchvision.datasets.ImageFolder(
    imagenette_path,
    transform=transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            # Original model does not have normalization.
        ]
    ),
)

data_loader = DataLoader(dataset, batch_size=1)
for img, target in data_loader:
    break

In [6]:
for resnet_version in [1, 2, 4]:
    print(f"ResNet50({resnet_version}x)")
    model = load_model(resnet_version, model_path, device)
    test_model(model, img, device)
    print("")

ResNet50(1x)
Output dim = 1000
Representation dim = 2048

ResNet50(2x)
Output dim = 1000
Representation dim = 4096

ResNet50(4x)
Output dim = 1000
Representation dim = 8192

