Setting up a training loop with validation.

This demo is a jupyter notebook, i.e. intended to be run step by step.

Author: Imraj Singh

First version: 13th of May 2022

CCP SyneRBI Synergistic Image Reconstruction Framework (SIRF).
Copyright 2022 University College London.

This is software developed for the Collaborative Computational Project in Synergistic Reconstruction for Biomedical Imaging (http://www.ccpsynerbi.ac.uk/).

SPDX-License-Identifier: Apache-2.0

# Setting up the training

This is very standard pytorch parlance...

The components:
* Setup acquisition model
* Setup dataset
* Setup model
* Setup training loop with validation

First we import the prerequisite packages, set up the forward operator, dataset and model.

In [1]:
# Import the PET reconstruction engine
import sirf.STIR as pet
# Set the verbosity
pet.set_verbosity(0)
# Store temporary sinograms in RAM
pet.AcquisitionData.set_storage_scheme("memory")
# SIRF STIR message redirector
import sirf
msg = sirf.STIR.MessageRedirector(info=None, warn=None, errr=None)
# Load dataset and model
from odl_funcs.ellipses import EllipsesDataset
from lpd_net import LearnedPrimalDual
# Import standard extra packages
import matplotlib.pyplot as plt
import os
import numpy as np
import time
import torch
from tqdm.notebook import trange, tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

size_xy = 128
mini_batch = 1
n_samples = 10

from sirf.Utilities import examples_data_path
sinogram_template = pet.AcquisitionData(examples_data_path('PET')\
                                        + '/thorax_single_slice/template_sinogram.hs');
# create acquisition model
acq_model = pet.AcquisitionModelUsingParallelproj();
image_template = sinogram_template.create_uniform_image(1.0,size_xy);
acq_model.set_up(sinogram_template,image_template);
train_dataloader = torch.utils.data.DataLoader( \
    EllipsesDataset(acq_model.forward, image_template, mode="train", n_samples = n_samples) \
    , batch_size=mini_batch, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader( \
    EllipsesDataset(acq_model.forward, image_template, mode="valid") \
    , batch_size=1, shuffle=True)
model = LearnedPrimalDual(image_template, sinogram_template,\
                          acq_model, n_iter = 2, n_primal = 5, n_dual = 5, n_layers = 5, n_feature_channels = 128).to(device)

Get some trained data (trained.torch_model and results.npy)

In [2]:
device

device(type='cpu')

In [3]:
# Trained network hyperparameters:
# lr = 5e-4
# betas = (0.99, 0.999)
# batch = 10
# size_xy = 128
# n_samples = 50
if not (os.path.exists('trained.torch_model') and os.path.exists('results.npy')):
    !curl -L -o "trained.torch_model" "https://zenodo.org/record/7316151/files/trained.torch_model"
    !curl -L -o "results.npy" "https://zenodo.org/record/7316151/files/results.npy"


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 34.4M  100 34.4M    0     0  5228k      0  0:00:06  0:00:06 --:--:-- 6293k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 62.6M  100 62.6M    0     0  5568k      0  0:00:11  0:00:11 --:--:-- 6194k


Let's set up the training loop with validation and a very simply "data logger"

In [4]:
lr = 5e-8
total_epochs = 1000

criterion = torch.nn.MSELoss(reduction='sum').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.99, 0.999))

data_log = {}
data_log["valid_loss"] = []
data_log["valid_image"] = []
data_log["loss"] = []

min_valid_loss = 1e9

x_gt_valid, y_valid = next(iter(valid_dataloader))
x_gt_valid, y_valid = x_gt_valid.float().to(device), y_valid.float().to(device)

if os.path.exists('trained.torch_model'):
    checkpoint = torch.load('trained.torch_model')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
pbar1 = trange(total_epochs, position=0, leave=True, desc='Epochs')
pbar2 = tqdm(train_dataloader, position=1, leave=True, desc='Iterations')
for i in pbar1:
    model.eval()
    x_valid = model(y_valid);
    loss_valid = criterion(x_gt_valid, x_valid)
    pbar1.set_description("Epoch, validation loss {:10.2f}".format(loss_valid.item()))
    data_log["valid_loss"].append(loss_valid.item())
    data_log["valid_image"].append(x_valid[0,0,...].detach().cpu().numpy())
    if min_valid_loss > loss_valid.item():
        best_model = model.state_dict()
        best_optim = optimizer.state_dict()
    pbar2.reset(int(np.ceil(n_samples/mini_batch)))
    for ii, (x_gt, y) in enumerate(train_dataloader):
        x_gt, y = x_gt.float().to(device), y.float().to(device)
        model.train();
        # Forward pass: Compute predicted y by passing x to the model
        x = model(y);
        # Compute and print loss
        loss = criterion(x_gt, x)
        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        data_log["loss"].append(loss.item())
        pbar2.update()
        pbar2.set_description("Batch sample, training loss {:10.2f}".format(loss.item()))

torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, 'trained_extra.torch_model')
np.save('results',data_log)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

Let's look as some of the results

In [None]:
if os.path.exists('results.npy'):
    r = np.load('results.npy',allow_pickle=True).item()
    min_loss = np.argmin(r['valid_loss'])
    print('Minimum validation loss is: {} at epoch {}'.format(r['valid_loss'][min_loss],min_loss))
    fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=(15,5))
    ax1.plot(np.log(r['loss']))
    ax1.set_title('log(Training MSE)')
    ax2.plot(np.log(r['valid_loss']))
    ax2.set_title('log(Validation MSE)')
    ax3.imshow(np.flipud(r['valid_image'][min_loss].T))
    ax3.set_title('Best validation image with MSE: ' + str(r['valid_loss'][min_loss]))
    ax3.set_axis_off()

# Exercises

* Do inference of BrainWeb data (i.e. add a test set)
* See the reconstruction at various points in the network
* Compare with OSEM reconstruction
* Attempt to improve the model (change training parameters?)
* Use "realistic" acquisition model for training data generation, but "simple" acquisition model for reconstruction
* Use more "realistic" acquisition model for reconstruction