# Image registration
## Introduction
Image registration is the process of aligning two images.  It involves moving (or deforming) one image, i.e. the floating image, such that it matches the best to a reference or fixed image.  In the this notebook, we will see how to perform image registration using the insight toolkit (ITK).  We will use a simplified version, SimpleITK.

# Preparatory steps
Let's start setting up the notebook:

In [None]:
## install simpleITK [if needed]
!pip install simpleitk

## download data and extract to cloud machine
import zipfile
!wget -nc https://github.com/rrr-uom-projects/ManchesterBioinformaticsCourse_colab_student/tree/main/Day4/data.zip -O ./data.zip
%matplotlib inline

import zipfile
zfile = zipfile.ZipFile('./data.zip', 'r')   # unzip the directory 
zfile.extractall()
for d in zfile.namelist():   # d = directory
    print('directory and files: ', d)


# Let's python!

Let's import all libraries we need:

In [None]:
import SimpleITK as sitk        # for the registration
import os                       # for file management
import os.path                  # for file management
import matplotlib.pyplot as plt # for plotting the metric
from ipywidgets import interact, fixed  #for plotting the registered image
from numpy import sign, zeros, max
from IPython.display import clear_output

## Write some auxiliary functions

We will be plotting the metric of the function as the registration is doing its job. For that we need to set up some functions:

In [None]:
def start_plot():
    global metric_values, optimiser_iterations
    optimiser_iterations = []
    metric_values = []
    if runningInColab == False:
        global fig, ax
        fig = plt.figure()
        ax = fig.add_subplot(111)
        plt.ion()
        fig.show()
        fig.canvas.draw()
    
def plot_values(registration_method):
    global metric_values, optimiser_iterations
    optimiser_iterations.append(registration_method.GetOptimizerIteration())
    metric_values.append(registration_method.GetMetricValue())                                       
    # Clear and plot the similarity metric values
    if runningInColab == False:
        global fig, ax
    else:
        fig = plt.figure()
        ax = fig.add_subplot(111)
        clear_output()
        
    ax.clear()
    #ax.plot(optimiser_iterations, metric_values, 'b.')
    ax.plot( metric_values, 'b.')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    ax.set_ylim([-1, 0])
    if( runningInColab ):
        plt.show();
    else:
        fig.canvas.draw()
    
def command_multires_iterations():
    print("    > ---- Resolution change ----")


Let's also define a function to visualize the images after registration.

In [None]:
# Callback invoked by the IPython interact method for scrolling and modifying the alpha blending
# of an image stack of two images that occupy the same physical space. 
def display_images( referenceImage, floatingAfterResample):
    ref = sitk.GetArrayFromImage(referenceImage)
    flo = sitk.GetArrayFromImage(floatingAfterResample)

    rgbimg = zeros((*ref.shape,3))
    rgbimg[...,0]=ref/max(ref);
    rgbimg[...,1]=flo/max(ref);
    rgbimg[...,2]=ref/max(ref);
    
    figi = plt.figure(num=None, figsize=(12, 4), dpi=80)
    axi = figi.add_subplot(131)
    axi.imshow(rgbimg[:,64,:,:]) # change 64 to show a different coronal slice
    axi.axis('off'); axi.invert_yaxis()
    axi = figi.add_subplot(132)
    axi.imshow(rgbimg[50,:,:,:]) # change 50 to show a different axial slice
    axi.axis('off')
    axi = figi.add_subplot(133)
    axi.imshow(rgbimg[:,:,64,:]) # change 64 to show a different saggittal slice
    axi.axis('off'); axi.invert_yaxis()
    if runningInColab == False:
        plt.ion()
    figi.show()

Last, but not least, let's define the function that applies a given tranformation (affine or non-rigid) to an image. 

In [None]:
def resample_image_with_Tx(referenceImage, Tx, iimg):
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(referenceImage);
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(Tx)
    oimg = resampler.Execute(iimg)
    return oimg

# Let's set the ball rolling.  

Let's read the reference image.  In this case, it is the CT of the reference patient.  We chose 0002 and made a mask for it.

In [None]:
imgfolder = "./data/DownsampledImages/"  
dataDir = imgfolder
nrrOutput = imgfolder + "../warpedImgs/"

referenceImage = sitk.ReadImage(imgfolder + "../0002.nii", sitk.sitkFloat32)
referenceMask = sitk.ReadImage(imgfolder + "../0002_mask_ds.nii")

# visualise the reference image! 
display_images(referenceImage,referenceImage)

Now, we need to read the floaging image. If the image is acquired in a different direction than the reference image, the image is flipped.  

In [None]:
img = "0004.nii"

In [None]:
floating = sitk.ReadImage(os.path.join(dataDir, img), sitk.sitkFloat32)

# Flip if needed HFS vs FFS
flipped = False
if( sign(floating.GetDirection()[-1]) != sign(referenceImage.GetDirection()[-1]) ):
    print("(i) Floating image was flipped")
    floating = sitk.Flip(floating, [False, False, True])
    flipped = True
display_images(floating,floating)

## Affine registration
The first step to register any pair of images is to globally align them. This is achieved using Rigid or Affine registration.

In this function, all elements for an affine registration will be set-up. The elements include:
-  The **metric**:  This element is used to determine how similar two images are.  In this example, we will use the (negative) normalized cross corelation.
-  The **interpolator**: we use a linear interpolator.  A good balance between speed and performance. 
-  The **reference mask**: to filter the voxels in the image that matter for the registation.  The mask we will use was semi-automatically created, and focuses on the head and neck area, ignoring the shoulders and thorax.
-  The **optimiser**: we chose regular step grandient descent.  Change the values of the arguments to see how fast/slow the optimiser converges (if it does!)
-  The **transform**:  we chose the 'Similarity 3D Transform'.  This transforms allows translation, rotation and scaling. The choice of the transform defines whether the registation is affine, rigid or non-rigid!  Notice we are optimising 7 parameters, 3 translations, 3 rotations and scaling.

This function returns the transform parameters that minimised the metric the most!

Check here for extra info: https://simpleitk.readthedocs.io/en/master/registrationOverview.html


In [None]:
def run_affine_registration(referenceImage, referenceMask, floatingImage, printInfo=True):
    R = sitk.ImageRegistrationMethod()
    R.SetMetricAsCorrelation()
    R.SetInterpolator(sitk.sitkLinear)
    R.SetMetricFixedMask(referenceMask)
    
    R.SetOptimizerAsRegularStepGradientDescent(learningRate=2.0,
                                               minStep=1e-2,
                                               numberOfIterations=200,
                                               gradientMagnitudeTolerance=1e-2,
                                               maximumStepSizeInPhysicalUnits = 10)
    R.SetOptimizerScalesFromIndexShift()
    tx = sitk.CenteredTransformInitializer(referenceImage, floating, sitk.Similarity3DTransform() )
    print("Initial Number of Parameters: {0}".format(tx.GetNumberOfParameters()))
    R.SetInitialTransform(tx)
    R.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(R))
    outTx = R.Execute(referenceImage, floatingImage)
    if printInfo:
        print("    >> Optimizer stop condition: {0}".format(R.GetOptimizerStopConditionDescription()))
        print("    >> Iteration: {0}".format(R.GetOptimizerIteration()))
        print("    >> Metric value: {0}".format(R.GetMetricValue()))
    return outTx

Let's run the function with the floating and reference images, and then resample the floating image in the same grid as the reference image (floatingAfterAffine).

In [None]:
# STEP 2: register the images allowing only rotations, translations and scaling
print("---- Affine part ----")
start_plot()
affineTx = run_affine_registration(referenceImage, referenceMask, floating )

In [None]:
print("---- Resampling image ----")
floatingAfterAffine = resample_image_with_Tx(referenceImage, affineTx, floating)
# visualize the images
display_images( referenceImage, floatingAfterAffine)

## Non-rigid Registration
Non-rigid registration, also known as deformable registration, helps to fine-tune the image alignment by allowing the image to be deformed. 

Similarly to the affine registration, all elements are defined in this function.  Additionally, we allow for a multi-resolution approach to speed-up and improve the registration results.

The elements include:
-  The **metric** [same as in affine]: This element is used to determine how similar two images are.  We will use the (negative) normalized cross corelation again.
-  The **interpolator** [same as in affine]: we use a linear interpolator.  A good balance between speed and performance. 
-  The **reference mask** [same as in affine]: to filter the voxels in the image that matter for the registation.  The mask we will use was semi-automatically created, and focuses on the head and neck area, ignoring the shoulders and thorax.

-  The **optimiser**: we chose a different optimiser to allow the many parameters be optimised in a decent time.  Try the Limited memory Broyden Fletcher Goldfarb Shannon minimization (LBFGSB).
-  The **transform**:  we chose the BSpline transform. It uses a sparse set of control points to control a free form deformation. This choice of the transform defines the registation to be non-rigid!  Notice we are optimising many more parameters than the affine registration... How many?
- The **multi-resolution scheme**: Multiresolution constructs an 'image pyramid', were each level is smaller than the next.  This helps the registration evade local minima, and allows it to run faster (at least in theory) by allowing the optimiser converge faster. It is defined by the ShrinkFactorPerLevel and SmoothingSigmaPerLevel.  The size of the argument array defines how many levels are to be used. Note that this only happens if the argument UseMultiResolution is set to true.

This function returns the transform parameters that minimised the metric the most!  

Check here for extra info: https://simpleitk.readthedocs.io/en/master/registrationOverview.html

In [None]:
def run_nonrigid_registration(referenceImage, referenceMask, floatingImage, useMultiResolution=True, printInfo=True):
    R = sitk.ImageRegistrationMethod()
    R.SetMetricAsCorrelation()
    R.SetInterpolator(sitk.sitkLinear)
    R.SetMetricFixedMask(referenceMask)
    

    R.SetOptimizerAsGradientDescentLineSearch(5.0, 100,
                                          convergenceMinimumValue=1e-4,
                                          convergenceWindowSize=5)
    # try a simpler optimiser?
    #R.SetOptimizerAsLBFGSB(gradientConvergenceTolerance=1e-3,
    #                   numberOfIterations=100,
    #                   maximumNumberOfCorrections=5,
    #                   maximumNumberOfFunctionEvaluations=1000,
    #                   costFunctionConvergenceFactor=1e+7)

    
    transformDomainMeshSize=[8]*floatingImage.GetDimension()
    tx = sitk.BSplineTransformInitializer(referenceImage, transformDomainMeshSize )
    print("Initial Number of Parameters: {0}".format(tx.GetNumberOfParameters()))
    R.SetInitialTransform(tx, True)
    R.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(R))

    if useMultiResolution:
        R.SetShrinkFactorsPerLevel([8,4,2,1])
        R.SetSmoothingSigmasPerLevel([4,2,1,0])
        R.AddCommand(sitk.sitkMultiResolutionIterationEvent, lambda: command_multires_iterations() )
    
    outTx = R.Execute(referenceImage, floatingImage)

    if printInfo:
        print("    >> Optimizer stop condition: {0}".format(R.GetOptimizerStopConditionDescription()))
        print("    >> Iteration: {0}".format(R.GetOptimizerIteration()))
        print("    >> Metric value: {0}".format(R.GetMetricValue()))

    return outTx

Let's chose whether we want to do multiresolution and run the function with the reference image and the floating image (after affine registration, floatingAfterAffine).  Then, let's resample it with the new pairs of parameters (nrrTx), and display it.

In [None]:
# STEP 3: register the images non-rigidly
useMultiRes = False
print("---- Non-Rigid part ----")
start_plot()
nrrTx = run_nonrigid_registration(referenceImage, referenceMask, floatingAfterAffine, useMultiRes )
# print(nrrTx)

In [None]:
# Let's resample the image with the B-Spline found
print("---- Resampling image ----")
out = resample_image_with_Tx(referenceImage, nrrTx, floatingAfterAffine)
display_images( referenceImage, out)

# Try:
1. Try using multi-resolution in the non-rigid registration part.  Why does the metric 'jump' up? Are the results better?

2. The shoulders are not looking great.  Why?  Try the non-rigid part without the mask. Why is the initial metric not the same as the final of the affine?

3. Could you add multi-resolution to the affine registration?