# Train a CNN for anatomy-guided PET denoising and deblurring

import all modules we need in this tutorial

In [None]:
import tensorflow as tf
import nibabel as nib
import numpy as np
import pathlib

import matplotlib.pyplot as plt
import pymirc.viewer as pv

from datetime import datetime

Define a few helper functions to setup the model and load and preprocess the data.

In [None]:
def simple_model(nfeat          = 30,      # number of featuers for Conv3D layers
                 kernel_shape   = (3,3,3), # kernel shapes for Conv3D layers
                 nhidden_layers = 6,       # number of hiddenlayers  
                 batch_norm     = True,    # use batch normalization between Conv3D and activation
                 add_final_relu = True):   # add a final ReLU activation at the end to clip negative values

  """Simple CNN that takes a batch of 3D volumes with 2 channel (e.g. PET and MR) and maps it to a batch
     of 3D volumes with 1 channe (e.g. denoise PET)
  """  
  # setup the input layer for batches of 3D tensors with two channels
  inp = tf.keras.layers.Input(shape = (None, None, None, 2), name = 'input_layer')

  # add a split layer such that we can add the first channel (PET) to the output
  split = tf.keras.layers.Lambda( lambda x: tf.split(x, num_or_size_splits = 2, axis = -1), name = 'split')(inp)

  # add all "hidden" layers
  x   = inp
  for i in range(nhidden_layers):
    x = tf.keras.layers.Conv3D(nfeat, kernel_shape, padding = 'same',
                               kernel_initializer = 'glorot_uniform', name = f'conv3d_{i+1}')(x)
    if batch_norm:
      x = tf.keras.layers.BatchNormalization(name = f'batchnorm_{i+1}')(x)
    x = tf.keras.layers.PReLU(shared_axes=[1,2,3], name = f'prelu_{i+1}')(x)


  # add a (1,1,1) Conv layers with 1 feature to reduce along the feature dimension
  x = tf.keras.layers.Conv3D(1, (1,1,1), padding='same', name = 'conv_final',
                             kernel_initializer = 'glorot_uniform')(x)

  # add first input channel
  x = tf.keras.layers.Add(name = 'add')([x] + [split[0]])

  # add a final ReLU to clip negative values
  if add_final_relu:
    x = tf.keras.layers.ReLU(name = 'final_relu')(x)

  model  = tf.keras.Model(inputs = inp, outputs = x)

  return model


In [None]:
def load_nii_in_lps(fname):
  """ function that loads nifti file and returns the volume and affine in 
      LPS orientation
  """
  nii = nib.load(fname)
  nii = nib.as_closest_canonical(nii)
  vol = np.flip(nii.get_fdata(), (0,1))

  return vol, nii.affine

In [None]:
def robust_max(volume, n = 7):
    """ function that return the max of a heavily smoothed version of the input volume
        
        for the smoothing we use tensorflows strided average pooling (which is fast) 
    """
    # to use tf's average pooling we first have to convert the numpy array to a tf tensor
    # for the pooling layers, the shape of the input need to be [1,n0,n1,n2,1]
    t = tf.convert_to_tensor(np.expand_dims(np.expand_dims(volume,0),-1).astype(np.float32))
    
    return tf.nn.avg_pool(t,2*n + 1,n,'SAME').numpy().max()

In [None]:
def load_data_set(subject_path,    # subject path
                  sim = 0,         # acquisition number (0,1,2) 
                  counts = 1e7):   # count level of PET (1e7 or 5e8)
  """ function that loads and MR, PET and target nifti volumes from simulated brainweb volumes
  
      intensity scaling based on a robust max is also applied
  """
  # get the subject number from the path
  data_id = int(subject_path.parts[-1][-2:])

  # setup the file names
  mr_file   = pathlib.Path(subject_path) / 't1.nii.gz'
  osem_file = pathlib.Path(subject_path) / f'sim_{sim}' / f'osem_psf_counts_{counts:0.1E}.nii.gz'
  target_file = pathlib.Path(subject_path) / f'sim_{sim}' / 'true_pet.nii.gz'

  # load nifti files in RAS orientation
  mr, mr_aff = load_nii_in_lps(mr_file)
  osem, osem_aff = load_nii_in_lps(osem_file)
  target, target_aff = load_nii_in_lps(target_file)

  # normalize the intensities of the MR and PET volumes
  mr_scale   = robust_max(mr)
  osem_scale = robust_max(osem)

  mr     /= mr_scale
  osem   /= osem_scale
  target /= osem_scale

  return osem, mr, target, osem_scale, mr_scale

In [None]:
def train_augmentation(x, y, s0 = 64, s1 = 64, s2 = 64, contrast_aug = True):
  """data augmentation function for training 
     
     the input x has shape (n0,n1,n2,2) and the input y has shape (n0,n1,n2,1)
  """

  # do the same random crop of input and output
  z = tf.concat([x,y], axis = -1)
  z_crop = tf.image.random_crop(z, [s0,s1,s2,z.shape[-1]])

  x_crop = z_crop[...,:2]
  y_crop = z_crop[...,2]

  # random contrast augmentation of the second input channel
  if contrast_aug:
    x_crop    = tf.unstack(x_crop, axis = -1)
    x_crop[1] = tf.image.random_contrast(x_crop[1], 0.1, 1)
    x_crop[1] = tf.image.random_brightness(x_crop[1], 0.5)
    x_crop    = tf.stack(x_crop, axis = -1)

  return x_crop, y_crop

Load all data into host memory and setup the training and validation data loaders.

In [None]:
# adjust this variable to the path where the simulated PET/MR data from zenodo was unzipped
data_dir    = pathlib.Path('brainweb_petmr')
batch_size  = 10
n_train_sub = 16
patch_shape = (45,45,45)


# get all the subjects paths
subject_paths = sorted(list(data_dir.glob('subject??')))

x = np.zeros((3*len(subject_paths),176,196,178,2), dtype = np.float32)
y = np.zeros((3*len(subject_paths),176,196,178,1), dtype = np.float32)

# load all the data sets and sort them into the x and y numpy arrays
for i,subject_path in enumerate(subject_paths):
  for sim in range(3):
    print(f'loading {subject_path} simulation {sim}')
    data = load_data_set(subject_path, sim = sim, counts = 1e7)   
    x[3*i + sim,...,0] = data[0][40:-40,30:-30,40:-40]
    x[3*i + sim,...,1] = data[1][40:-40,30:-30,40:-40]
    y[3*i + sim,...,0] = data[2][40:-40,30:-30,40:-40]

# split the data in training and validation data

x_train = x[:(3*n_train_sub)]
y_train = y[:(3*n_train_sub)]
x_val   = x[(3*n_train_sub):]
y_val   = y[(3*n_train_sub):]

del x
del y

train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_loader.shuffle(x_train.shape[0]).map(lambda x,y: train_augmentation(x,y, s0 = patch_shape[0], s1 = patch_shape[1], s2 = patch_shape[2])).batch(batch_size).prefetch(4)


val_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_loader.shuffle(x_val.shape[0]).map(lambda x,y: train_augmentation(x,y, s0 = patch_shape[0], s1 = patch_shape[1], s2 = patch_shape[2])).batch(x_val.shape[0]).prefetch(2)

xv, yv = list(val_dataset.take(1))[0]

Setup the model and start training. For decent convergence **1000-2000 epochs** should be used, which takes a few hours on a modern GPU. To just check whether everything is working, we only use **10 epochs which is far to less.**

In [None]:
nepochs       = 10                                 # number of training epochs
learning_rate = 1e-3                               # initial learning rate 
loss          = tf.keras.losses.MeanSquaredError() # loss function to use


model = simple_model()

model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate), loss = loss)

# setup a directory where we save the ouput
logdir = pathlib.Path(f'model_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
logdir.mkdir(exist_ok = True)

# setup a few useful callbacls
# save model with best validation loss
checkpoint = tf.keras.callbacks.ModelCheckpoint(logdir / 'trained_model', 
                                                monitor           = 'val_loss', 
                                                verbose           = 1, 
                                                save_best_only    = True, 
                                                save_weights_only = False, 
                                                mode              ='min')

# save a csv log file with the training and validation loss after each epoch
csvlog    = tf.keras.callbacks.CSVLogger(logdir / 'log.csv')

# reduce learning rate by a factor of 2 if validation loss does not improve for 100 epochs
lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss', factor = 0.5, 
                                                 patience = 100, verbose = 1, min_lr = 1e-4)

# tenor board callback to e.g. compute histograms of activations
tb = tf.keras.callbacks.TensorBoard(log_dir = logdir / 'tensor_board', histogram_freq = 5, write_graph = False)

# train the model
history = model.fit(train_dataset, epochs = nepochs, validation_data = (xv,yv),
                    callbacks = [checkpoint, csvlog, lr_reduce, tb])

plot the training and validation loss

In [None]:
fig, ax = plt.subplots()
ax.semilogy(np.arange(1, nepochs + 1), history.history['loss'], label = 'loss')
ax.semilogy(np.arange(1, nepochs + 1), history.history['val_loss'], label = 'validation loss')
ax.set_xlabel('epoch')
ax.legend()
ax.grid(ls=':')

Use the train model to make predictions based on all validation data sets.

In [None]:
# load the best model (the model after the last epoch does not need to have the lowest validation loss)
trained_model = tf.keras.models.load_model(logdir / 'trained_model')

p = trained_model.predict(x_val, batch_size = 1)

Show the predictions. You can click in the plots and use your arrow keys to move through the slices / batch.

In [None]:
# enable interactive plots with the ipympl package
%matplotlib widget

ims = 4*[{'vmin':0,'vmax':1.5}] + [{'vmin':-0.4,'vmax':0.4, 'cmap':plt.cm.bwr}]
vi = pv.ThreeAxisViewer([x_val[...,0].squeeze(),x_val[...,1].squeeze(),p.squeeze(),y_val.squeeze(), 
                         p.squeeze() - y_val.squeeze()], 
                         imshow_kwargs = ims, 
                         rowlabels = ['input PET', 'input MR', 'prediction','target','absolute bias'])