# Pre-trained weight usage

A minimum example of how to load and use the CEM500K pre-trained weights for classification or segmentation tasks.

Before getting started download CEM500K data and models from EMPIAR:
- EMPIAR entry: https://www.ebi.ac.uk/pdbe/emdb/empiar/entry/10592/
- Download help: https://www.ebi.ac.uk/pdbe/emdb/empiar/faq#question_Download


In [1]:
import os

## PyTorch

First, let's consider a simple binary classification model.

In [2]:
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from copy import deepcopy
from torchvision.models import resnet50

In [3]:
path_to_empiar_download = '' #fill this in
state_path = os.path.join(path_to_empiar_download, 'pretrained_models/cem500k_mocov2_resnet50_200ep_pth.tar')
state = torch.load(state_path, map_location='cpu')

In [4]:
#take a look at what's inside the state
print(list(state.keys()))

['epoch', 'arch', 'state_dict', 'optimizer', 'norms']


- Epoch: the training epoch when state was recorded
- Arch: the model architecture: "resnet50"
- State_dict: state dict for the complete pretrained model (both query and key encoders)
- Optimizer: state of the optimizer at save (useful for resuming training)
- Norms: the mean and std pixel values used during training

In [5]:
state_dict = state['state_dict']

In [6]:
#format the parameter names to match torchvision resnet50
resnet50_state_dict = deepcopy(state_dict)
for k in list(resnet50_state_dict.keys()):
    #only keep query encoder parameters; discard the fc projection head
    if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
        resnet50_state_dict[k[len("module.encoder_q."):]] = resnet50_state_dict[k]

    #delete renamed or unused k
    del resnet50_state_dict[k]

In [7]:
#create model and load the pretrained weights
model = resnet50()

#overwrite the first conv layer to accept single channel grayscale image
#overwrite the fc layer for binary classification
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(2048, 1, bias=True)

#loads all parameters but those for the fc head
#those parameters need to be trained
model.load_state_dict(resnet50_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

Now let's load the parameters into a simple binary segmentation UNet.

In [8]:
#as before we need to update parameter names to match the UNet model
#for segmentation_models_pytorch we simply and the prefix "encoder."
#format the parameter names to match torchvision resnet50
unet_state_dict = deepcopy(resnet50_state_dict)
for k in list(unet_state_dict.keys()):
    unet_state_dict['encoder.' + k] = unet_state_dict[k]
    del unet_state_dict[k]

In [9]:
model = smp.Unet('resnet50', in_channels=1, encoder_weights=None, classes=1)
#all encoder parameters are loaded
#parameters in the decoder must be trained on task data
model.load_state_dict(unet_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['decoder.blocks.0.conv1.0.weight', 'decoder.blocks.0.conv1.1.weight', 'decoder.blocks.0.conv1.1.bias', 'decoder.blocks.0.conv1.1.running_mean', 'decoder.blocks.0.conv1.1.running_var', 'decoder.blocks.0.conv2.0.weight', 'decoder.blocks.0.conv2.1.weight', 'decoder.blocks.0.conv2.1.bias', 'decoder.blocks.0.conv2.1.running_mean', 'decoder.blocks.0.conv2.1.running_var', 'decoder.blocks.1.conv1.0.weight', 'decoder.blocks.1.conv1.1.weight', 'decoder.blocks.1.conv1.1.bias', 'decoder.blocks.1.conv1.1.running_mean', 'decoder.blocks.1.conv1.1.running_var', 'decoder.blocks.1.conv2.0.weight', 'decoder.blocks.1.conv2.1.weight', 'decoder.blocks.1.conv2.1.bias', 'decoder.blocks.1.conv2.1.running_mean', 'decoder.blocks.1.conv2.1.running_var', 'decoder.blocks.2.conv1.0.weight', 'decoder.blocks.2.conv1.1.weight', 'decoder.blocks.2.conv1.1.bias', 'decoder.blocks.2.conv1.1.running_mean', 'decoder.blocks.2.conv1.1.running_var', 'decoder.blocks.2.conv2.0.weight', 'decoder.bloc

The segmentation_models_pytorch module comes with a selection of state-of-the-art semantic segmentation models. The weight loading procedure is the same for all of these architectures. For a full list, see https://github.com/qubvel/segmentation_models.pytorch#models.