# CESPED training and inference example

This notebook illustrates how to train a model on a CESPED entry and how to use the trained model to infer the poses and save them, once the model is trained. In particular, we will train the model on the test entry named "TEST", a reduced version of an actual entry. 

You can use this notebook to generate results and then, use the other notebook to evaluate the results

First, we import the modules we will be used. As you can see, the `ParticlesDataset` class is all we need from cesped.

In [3]:
import os
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision.models.resnet import resnet18
from cesped.particlesDataset import ParticlesDataset

device = "cuda"


Now we will see the available entries in the CESPED benchmark. 

In [4]:
listOfEntries = ParticlesDataset.getCESPEDEntries()
print(listOfEntries)


[('TEST', 0), ('TEST', 1), ('10786', 0), ('10786', 1), ('11120', 0), ('11120', 1), ('10166', 0), ('10166', 1), ('10280', 0), ('10280', 1), ('10647', 0), ('10647', 1), ('10374', 0), ('10374', 1), ('10409', 0), ('10409', 1)]


Notice that each entry comes twice. This is because we have split each entry into two halves to carry out a gold standard-like evaluation procedure. In the typical cryo-EM gold standard process, each dataset is divided into two halves and each of them is processed independently until two 3D maps, generally called half-maps, are generated. Then, to assess the quality of the solution, the two half-maps are compared. 

Here we will do something similar. We will train a model for the half-set 0, and we will use it to predict the poses of the half-set 1. We will also train a second model using as training data the half-set 1, and we will use it to infer the poses of the half-set 0. Finally, in the next notebook, we will use the predicted poses to generate the 3D maps and evaluate the quality of the predictions.


We will work with the "TEST" entry, and the halfset 0. But change halfset_to_use = 1 and rerun the notebook to get predictions for the other half of the dataset. We will be saving the predicted poses in the /tmp folder, you probably want to change this

In [24]:
targetName = "TEST"
halfset_to_use = 1

The default benchmark directory, the place where the datasets will be downloaded is 

In [25]:
from cesped.constants import default_configs_dir, defaultBenchmarkDir
print(defaultBenchmarkDir)

/home/sanchezg/tmp/cryoSupervisedDataset


Create the folder if it does not exists, or change the config files edditing the .yaml files within cesped/config and rerun the notebook

In [26]:
os.makedirs(defaultBenchmarkDir, exist_ok=True)

Now, we will instantiate the `ParticlesDataset` object. We will use the default benchmark directory and the desired image size. We will enable automatic normalization and ctf_correction. In addition, we will crop the particles, removing 25% of the image side.

In [27]:
dataset = ParticlesDataset(targetName, halfset_to_use,
                           apply_perImg_normalization = True,
                           ctf_correction = "phase_flip",
                           image_size_factor_for_crop = 0.25)

Now we define a toy model and loss function for illustration purposes

In [28]:
#This is a toy model for predicting rotation matrices. You probably want to use something as Gram–Schmidt orthonormalization
model = torch.nn.Sequential(torch.nn.Conv2d(1, 3, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False), resnet18(),
                            torch.nn.Linear(1000, 9))
model = model.to(device)
def loss_function(predR, gtR): #This is a naive loss function
    return torch.nn.functional.mse_loss(predR.flatten(1), gtR.flatten(1))
    

The `ParticlesDataset` objects are compatible with torch `DataLoaders`.

In [29]:
dl = DataLoader(dataset, batch_size=32, num_workers=0) # Change num_workers for better speed


Here you have an illustration of training loop. The dataset behaves as any other dataset, thus there is nothing strange in the loop.

In [30]:
n_batches = len(dl)
optimizer = Adam(model.parameters(), lr=1e-4)
for epoch in range(5):
    for i, batch in enumerate(dl):
        iid, img, (rotMat, xyShiftAngs, confidence), metadata = batch
        
        #iid is the list of ids of the particles (string)
        #img is a batch of Bx1xNxN images
        #rotMat is a batch of rotation matrices Bx3x3
        #xyShiftAngs is a batch of image shifts in Angstroms Bx2
        #confidence is a batch of numbers, between 0 and 1, Bx1
        #metata is a dictionary of names: values for all the information about the particle
        img = img.to(device)
        rotMat = rotMat.to(device)
        predRot = model(img)
        loss = loss_function(predRot, rotMat)
        print(f'Epoch: {epoch + 1} Batch: {i + 1}/{n_batches}, Loss: {loss}', end='\r')
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print()

Epoch: 1 Batch: 157/157, Loss: 0.34019476175308233
Epoch: 2 Batch: 157/157, Loss: 0.10218479484319687
Epoch: 3 Batch: 157/157, Loss: 0.458162993192672734
Epoch: 4 Batch: 157/157, Loss: 0.379527360200881965
Epoch: 5 Batch: 157/157, Loss: 0.104808010160923355


Now we will perform inference on the other half-set of the data

In [31]:
halfset_to_infer = (halfset_to_use + 1) %2 # This just maps 0 -> 1 and 1 -> 0, so that we use the complementary dataset

#Predicted poses for the complementary dataset
outFnamePoses = f"/tmp/predicted_poses_{halfset_to_infer}.star" # They will be saved using RELION star format
toInferDataset = ParticlesDataset(targetName, halfset_to_infer)


The inference loop is quite standard. The only difference is how we update the `ParticlesDataset` object with the predicted poses. To so so, we just use the method `updateMd` providing the particle iids and the predicted angles.

In [32]:
i_dl = DataLoader(toInferDataset, batch_size=32, num_workers=0)
n_batches = len(i_dl)
print(n_batches)
model = model.eval()
for i, batch in enumerate(i_dl):
    iid, img, (rotMat, xyShiftAngs, confidence), metadata = batch
    img = img.to(device)

    #This is because we are predicting the 9 values of the rotation matrix in a single row.
    predRot = model(img).reshape(-1,3,3)
    
    #We are not predicting shifts or confidence scores, so we create dummy values
    shifts=torch.zeros(predRot.shape[0],2, device=predRot.device)
    confidence=torch.ones(predRot.shape[0])
    
    toInferDataset.updateMd(ids=iid, angles=predRot,
                          shifts=torch.zeros(predRot.shape[0],2, device=predRot.device), #Or actual predictions if you have them
                          confidence=torch.ones(predRot.shape[0]),
                          angles_format="rotmat")
    print(f'Batch: {i + 1}/{n_batches}', end='\r')
print()

157
Batch: 157/157


Finally, we will save the predicted poses

In [33]:
toInferDataset.saveMd(outFnamePoses)


That is all. Rerun the notebook with the other half-set and go to the next notebook if you want to see how to evaluate the 