In [24]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import DataLoader

In [25]:
# Load an empty resnet from torchvision
model = models.resnet50(pretrained = False)

# Make the final hidden layer an identity
model.fc = nn.Identity()

In [26]:
# Path to downloaded model
path = "/home/alta/BLTSpeaking/exp-pr450/models/resnet50_byol_imagenet2012.pth.tar"

# Load pretrained byol model
checkpoint = torch.load(path, map_location = torch.device('cpu'))
checkpoint = checkpoint['online_backbone']

In [27]:
# Clean up model state dictionary
state_dict = {key[7:]: value for key, value in checkpoint.items()}

# Load model from pretrained byol
model.load_state_dict(state_dict, strict=True)
model = model.eval()

In [34]:
# Create custom transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.unsqueeze(0)),
    transforms.Lambda(lambda x: nn.functional.interpolate(x, size = 224, mode='bicubic', align_corners=True)),
    transforms.Lambda(lambda x: x.squeeze(0)),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

dataset = torchvision.datasets.CIFAR10(root = "/home/alta/BLTSpeaking/exp-pr450/data", train = True, transform = transform, download = True)

Files already downloaded and verified


In [56]:
N = len(dataset)
B = 128

embeddings = torch.empty(0, 2048)
labels = torch.empty(0)

trainloader = DataLoader(dataset, batch_size=B, drop_last=False, shuffle=False)

with torch.no_grad():
    for i, (x, y) in enumerate(trainloader):
        assert all(y == torch.tensor(dataset.targets[i*B:(i+1)*B]))
        embeddings = torch.cat([embeddings, model(x)])
        labels = torch.cat([labels, y])
        print(f'{i}/{len(trainloader)} batches complete')

0/391 batches complete
1/391 batches complete
2/391 batches complete
3/391 batches complete
4/391 batches complete
5/391 batches complete
6/391 batches complete
7/391 batches complete
8/391 batches complete
9/391 batches complete
10/391 batches complete
11/391 batches complete
12/391 batches complete
13/391 batches complete
14/391 batches complete
15/391 batches complete
16/391 batches complete
17/391 batches complete
18/391 batches complete
19/391 batches complete
20/391 batches complete
21/391 batches complete
22/391 batches complete
23/391 batches complete
24/391 batches complete
25/391 batches complete
26/391 batches complete
27/391 batches complete
28/391 batches complete
29/391 batches complete
30/391 batches complete
31/391 batches complete
32/391 batches complete
33/391 batches complete
34/391 batches complete
35/391 batches complete
36/391 batches complete
37/391 batches complete
38/391 batches complete
39/391 batches complete
40/391 batches complete
41/391 batches complete
42

In [59]:
torch.save(embeddings, '../data/byol_embeddings.pkl')