In [1]:

import numpy as np
import timm
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Subset

from layers import ConditionalBatchNorm2d
from utils import create_subset


  from .autonotebook import tqdm as notebook_tqdm


In [28]:
model_names = timm.list_models("resnet*", pretrained=True)
print("\n".join(model_names))

resnet10t.c3_in1k
resnet14t.c3_in1k
resnet18.a1_in1k
resnet18.a2_in1k
resnet18.a3_in1k
resnet18.fb_ssl_yfcc100m_ft_in1k
resnet18.fb_swsl_ig1b_ft_in1k
resnet18.gluon_in1k
resnet18.tv_in1k
resnet18d.ra2_in1k
resnet26.bt_in1k
resnet26d.bt_in1k
resnet26t.ra2_in1k
resnet32ts.ra2_in1k
resnet33ts.ra2_in1k
resnet34.a1_in1k
resnet34.a2_in1k
resnet34.a3_in1k
resnet34.bt_in1k
resnet34.gluon_in1k
resnet34.tv_in1k
resnet34d.ra2_in1k
resnet50.a1_in1k
resnet50.a1h_in1k
resnet50.a2_in1k
resnet50.a3_in1k
resnet50.am_in1k
resnet50.b1k_in1k
resnet50.b2k_in1k
resnet50.bt_in1k
resnet50.c1_in1k
resnet50.c2_in1k
resnet50.d_in1k
resnet50.fb_ssl_yfcc100m_ft_in1k
resnet50.fb_swsl_ig1b_ft_in1k
resnet50.gluon_in1k
resnet50.ra_in1k
resnet50.ram_in1k
resnet50.tv2_in1k
resnet50.tv_in1k
resnet50_gn.a1h_in1k
resnet50c.gluon_in1k
resnet50d.a1_in1k
resnet50d.a2_in1k
resnet50d.a3_in1k
resnet50d.gluon_in1k
resnet50d.ra2_in1k
resnet50s.gluon_in1k
resnet51q.ra2_in1k
resnet61q.ra2_in1k
resnet101.a1_in1k
resnet101.a1h_in1k
re

In [30]:

# TODO: timm has multiple versions of the resnet50 model, with different
# suffixes (e.g. resnet50.a1_in1k, resnet50.d_in1k, etc.). Figure out what this
# means and how to choose the right model.

model = timm.create_model('resnet18', pretrained=True, num_classes=37)
# model.eval()

In [40]:

# Create model-specific transform
data_config = timm.data.resolve_data_config(model.pretrained_cfg)
data_config.update({
    "is_training": True
})
data_config

{'input_size': (3, 224, 224),
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'crop_pct': 0.95,
 'crop_mode': 'center',
 'is_training': True}

In [42]:
timm.data.create_transform(224)

Compose(
    Resize(size=256, interpolation=bilinear, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

In [39]:

transform = timm.data.create_transform(**data_config)
transform

Compose(
    Resize(size=235, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

In [41]:

transform = timm.data.create_transform(**data_config)
transform

Compose(
    RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic)
    RandomHorizontalFlip(p=0.5)
    ColorJitter(brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=None)
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

{'input_size': (3, 224, 224),
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'crop_pct': 0.95,
 'crop_mode': 'center'}

In [None]:

# Freeze all layers
for p in model.parameters():
    p.requires_grad = False

# Unfreeze the final layer
for param in model.fc.parameters():
    param.requires_grad = True

"""
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        print(m.bias.shape)
"""

# Add conditional batch norm layers
ConditionalBatchNorm2d.replace_bn2d(model)

num_params = sum(p.numel() for p in model.parameters())
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
percent_trainable = num_trainable_params / num_params * 100

print(f"Number of parameters: {num_params:,}")
print(f"Number of trainable parameters: {num_trainable_params:,} ({percent_trainable:.2f}%)")


In [None]:

x = torch.randn(2, 3, 224, 224)
out = model(x)
print(out.shape)

In [6]:
# Set up data
dataset_train = torchvision.datasets.OxfordIIITPet("data", split="trainval", download=True)
dataset_test = torchvision.datasets.OxfordIIITPet("data", split="test")

# from timm.data import create_dataset
# ds = create_dataset("torch/oxford_iiit_pet", root="data", split="trainval", download=True)

# Let's now make a subset of the training dataset with N images per class.
dataset_train = create_subset(dataset_train, n_img_per_class=10, random_seed=0)

In [21]:
model.get_classifier().training

True

In [29]:
dataset_train[0]

(<PIL.Image.Image image mode=RGB size=333x500>, 0)