In [None]:
import SimpleITK as sitk
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import util
import registration_callbacks
import os
from IPython.display import display, HTML 

In [None]:
dirname_data = './datasets/2019-03-05/'

### Load the images

In [None]:
view0 = sitk.ReadImage(os.path.normpath(dirname_data + 'head_view0_cropped.mhd'), sitk.sitkFloat32)
view1 = sitk.ReadImage(os.path.normpath(dirname_data + 'head_view1_cropped.mhd'), sitk.sitkFloat32)

# The names view0 and view1 have a meaning in some context and they play the role of fixed and moving images
# in the registration, so we use both naming conventions and they alias the relevant images.
fixed_image = view0
moving_image = view1

In [None]:
print(f"Voxel size, um: {view0.GetSpacing()}")

print(f"View0 size, px: {view0.GetSize()}")
print(f"View1 size, px: {view1.GetSize()}")

print(util.get_minmax(view0))
print(util.get_minmax(view1))

img_merge = util.merge_images_rgb(view0, view1)

fig = plt.figure(figsize=(10,10))
util.show_mips(img_merge, "After rigid registration: view0 (fixed, green), view1 (moving, magenta), overlap (white).")

## Registration flow

We register in three steps:
0. Basic initialization, centering the two volumes.
1. Affine registration, global non rigid transformation.
2. FFD transformation, local non rigid transformation.

### Initialization

In [None]:
global_transform = sitk.CenteredTransformInitializer(fixed_image, moving_image, sitk.AffineTransform(3), 
                                                     sitk.CenteredTransformInitializerFilter.GEOMETRY)

### Affine registration

We use a single level, no need for pyramid (Occam's razor - simplest solution is preferred). We also do the registration in place, the transform variable is just updated.

In [None]:
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsCorrelation()
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)

registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=300)
# Scale the step size differently for each parameter, this is critical!!!
registration_method.SetOptimizerScalesFromPhysicalShift() 

registration_method.SetInitialTransform(global_transform, inPlace=True)

registration_method.AddCommand(sitk.sitkStartEvent, registration_callbacks.metric_start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, registration_callbacks.metric_end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, 
                              registration_callbacks.metric_update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, 
                              lambda: registration_callbacks.metric_plot_values(registration_method))

registration_method.Execute(fixed=fixed_image, moving=moving_image)
print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))

In [None]:
view1_registered = sitk.Resample(view1, view0, global_transform, sitk.sitkLinear, 0.0, view0.GetPixelID())
img_merge = util.merge_images_rgb(view0, view1_registered)

fig = plt.figure(figsize=(10,10))
util.show_mips(img_merge, "Affine registration: view0 (fixed, green), view1 (moving, magenta), overlap (white).", scalebar=10)

### FFD

We use the results from the affine registration to initialize the moving tranformation (never modified during optimization) and the FFD transformation is initialized with all parameters set to zero. The FFD transformation is modified in place and then we compose the optimal FFD and the affine transformation into a composite transform, which is the final result of our three step registration.

In [None]:
registration_method = sitk.ImageRegistrationMethod()
# Determine the number of BSpline control points using the physical spacing we want for the control grid. 
grid_physical_spacing = [10.0, 10.0, 10.0] # A control point every 10 um
image_physical_size = [size*spacing for size,spacing in zip(fixed_image.GetSize(), fixed_image.GetSpacing())]
mesh_size = [int(image_size/grid_spacing + 0.5) \
             for image_size,grid_spacing in zip(image_physical_size,grid_physical_spacing)]

local_transform = sitk.BSplineTransformInitializer(image1 = fixed_image, transformDomainMeshSize = mesh_size, order=3)    
registration_method.SetInitialTransform(local_transform, inPlace = True)
registration_method.SetMovingInitialTransform(global_transform)
registration_method.SetMetricAsCorrelation()
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)

registration_method.SetOptimizerAsLBFGSB(gradientConvergenceTolerance=1e-5, numberOfIterations=50)

registration_method.AddCommand(sitk.sitkStartEvent, registration_callbacks.metric_start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, registration_callbacks.metric_end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, 
                              registration_callbacks.metric_update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, 
                               lambda: registration_callbacks.metric_plot_values(registration_method))
registration_method.Execute(fixed_image, moving_image)
# Need to compose the transformations after registration.
final_transform = sitk.CompositeTransform([global_transform,local_transform])
print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))

In [None]:
view1_registered = sitk.Resample(view1, view0, global_transform, sitk.sitkLinear, 0.0, view0.GetPixelID())
img_merge = util.merge_images_rgb(view0, view1_registered)

fig = plt.figure(figsize=(10,10))
util.show_mips(img_merge, "FFD registration: view0 (fixed, green), view1 (moving, magenta), overlap (white).", scalebar=10)