## Import modules

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import SimpleITK as sitk

# Part I: Basic operation with SimpleITK

## Load an image using SimpleITK and display the image information

In [None]:
fn_img0 = 'img0.png'
img0 = sitk.ReadImage(fn_img0)
print(img0.GetOrigin())
print(img0.GetSpacing())
print(img0.GetDirection())

## Display the image

In [None]:
# %% Get numpy array from image

img0_array = sitk.GetArrayFromImage(img0)

plt.imshow(img0_array, cmap='gray')
plt.show()

# Part II: Image Registration Example

## 1. Read Image for Registration

In [None]:
fixed_image =  sitk.ReadImage('img0.png', sitk.sitkFloat32)
moving_image = sitk.ReadImage('img1.png', sitk.sitkFloat32)

## 2. Initialization

Use `CenteredTransformInitializer` to align the centers of the two volumes and set the center of rotation to the center of the fixed image.

In [None]:
initial_transform = sitk.CenteredTransformInitializer(
    fixed_image, 
    moving_image, 
    sitk.Euler2DTransform(), 
    sitk.CenteredTransformInitializerFilter.GEOMETRY)

print(initial_transform)


# Manual initialization
my_initialization = sitk.Euler2DTransform()
my_initialization.SetTranslation((7.2, 8.4))
print(my_initialization)

## 3. Resample the moving image

In [None]:
moving_resampled = sitk.Resample(moving_image, 
                                 fixed_image,
                                 initial_transform,
                                 sitk.sitkLinear,
                                 0.0,
                                 moving_image.GetPixelID()) # output pixel type

## Exercise I

1. Replace the initial transform with `my_initialization` in the above code and show the resampled image.
2. Change `sitk.sitkLinear` to other interpolators and show the results. Do you see any differences?

## 4. Start Image Registration

In [None]:
registration_method = sitk.ImageRegistrationMethod()

# Similarity metric settings.
#registration_method.SetMetricAsCorrelation()
registration_method.SetMetricAsMeanSquares()

registration_method.SetInterpolator(sitk.sitkLinear)

# Optimizer settings.
registration_method.SetOptimizerAsGradientDescent(
    learningRate=0.1, 
    numberOfIterations=100, 
    convergenceMinimumValue=1e-6, 
    convergenceWindowSize=10)

# The number of iterations involved in computations are defined by 
# the convergence window size

# Estimating scales of transform parameters a step sizes, from the 
# maximum voxel shift in physical space caused by a parameter change. 
registration_method.SetOptimizerScalesFromPhysicalShift()

# Initialize registration
registration_method.SetInitialTransform(initial_transform, inPlace=False)

## 5. Monitor the registration process

In [None]:
# Callback invoked when the StartEvent happens, sets up our new data.
def clear_values():
    global metric_values
    
    metric_values = []


# Connect observers so that we can perform plotting 
# during registration.
registration_method.AddCommand(sitk.sitkStartEvent, 
                               clear_values)
    
# Callback invoked when the IterationEvent happens, update our data 
# and display new figure.    
def save_values(registration_method):
    global metric_values
    value = registration_method.GetMetricValue()
    metric_values.append(value)
    print('Iteration {}: metric value {:.4f}'.format(
        len(metric_values), value))
    
    
# Connect observers so that we can perform plotting 
# during registration.
registration_method.AddCommand(sitk.sitkIterationEvent, 
                               lambda: save_values(registration_method))


## 6. Get the final registration result

In [None]:
final_transform = registration_method.Execute(fixed_image, 
                                              moving_image)

print('Final metric value: {0}'.format(
    registration_method.GetMetricValue()))
print('Optimizer\'s stopping condition, {0}'.format(
    registration_method.GetOptimizerStopConditionDescription()))

moving_resampled = sitk.Resample(moving_image,
                                 fixed_image,
                                 final_transform,
                                 sitk.sitkLinear,
                                 0.0,
                                 moving_image.GetPixelID())

# Save image and transformation to local drive
#
#moving_resampled = sitk.Cast(moving_resampled, sitk.sitkUInt8)
#sitk.WriteImage(moving_resampled, 'moving_resampled.png')
#sitk.WriteTransform(final_transform, 'final_transform.txt')

## Exercise II

3. Use `matplotlib.pyplot.plot` to plot the metric values over iterations.
4. Show the initial difference image and final difference image after registration.
5. Use correlation as image registration metric and repeat steps 3 and 4.