In [1]:
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from shubow_tools import imreadseq_multithread, imsaveseq
from ipywidgets import interact, fixed
%matplotlib inline

In [13]:
class Registration:
    def __init__(self, ref_dir, tar_dir, ref_z_range=None, tar_z_range = None, ref_flip=False, tar_flip=False):
        
        if not ref_flip:
            self.ref_img = sitk.Cast(self.imread(ref_dir, z_range = ref_z_range), sitk.sitkFloat32)
        else:
            self.ref_img = sitk.Cast(sitk.GetImageFromArray(
                np.flip(self.imread(ref_dir, z_range = ref_z_range, sitkimg = False), axis = 2)
            ), sitk.sitkFloat32)
            
        if not tar_flip:    
            self.tar_img = sitk.Cast(self.imread(tar_dir, z_range = tar_z_range), sitk.sitkFloat32)
        else:
            self.tar_img = sitk.Cast(sitk.GetImageFromArray(
                np.flip(self.imread(tar_dir, z_range = tar_z_range, sitkimg = False), axis = 2)
            ), sitk.sitkFloat32)
        
        self.ref_shape = self.ref_img.GetSize()
        self.tar_shape = self.tar_img.GetSize()
        self.reg_img = None
        self.transformation = sitk.Euler3DTransform()
        
    def imread(self, *args, **kwds):
        return imreadseq_multithread(*args,**kwds)
    
    def center_initialization(self):
        initial_transform = sitk.CenteredTransformInitializer(self.ref_img, 
                                                      self.tar_img, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
        
        
        return initial_transform
    
    def manual_initialization(self):
                
        def display( 
                    img_x, img_y, img_z, 
                    translation_x, translation_y, translation_z, 
                    rotation_x, rotation_y, rotation_z
                   ):
            
            translation=(-translation_x, -translation_y, -translation_z)
            center=[0.5*i for i in self.ref_shape]
            rotation=(rotation_x, rotation_y, rotation_z)
            
            self.transformation.SetCenter(center)
            self.transformation.SetRotation(*rotation)
            self.transformation.SetTranslation(translation)
            
            self.reg_img = sitk.Resample(self.tar_img, self.ref_img, self.transformation, sitk.sitkLinear, 0.0, sitk.sitkFloat32)  
            #yz_img = sitk.GetArrayFromImage(self.reg_img[img_x, :, :]) * 0.5 + sitk.GetArrayFromImage(self.ref_img[img_x, :, :])*0.5 
            #xz_img = sitk.GetArrayFromImage(self.reg_img[:, img_y, :]) * 0.5 + sitk.GetArrayFromImage(self.ref_img[:, img_y, :])*0.5 
            #xy_img = sitk.GetArrayFromImage(self.reg_img[:, :, img_z]) * 0.5 + sitk.GetArrayFromImage(self.ref_img[:, :, img_z])*0.5 
            
            plt.subplots(2,2,figsize=(10,8))
            plt.subplot(2,2,1)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, :, img_x])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, :, img_x], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("Y-Z plane")
            plt.axis('off')
            
            plt.subplot(2,2,2)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[img_z, :, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[img_z, :, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Y plane")
            plt.axis('off') 
            
            plt.subplot(2,2,3)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, img_y, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, img_y, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Z plane")
            plt.axis('off')   
            
            plt.subplot(2,2,4)
            plt.axis('off') 
            
            
        interact(display, img_x=(0, self.ref_shape[0]-1), img_y=(0, self.ref_shape[1]-1), img_z=(0, self.ref_shape[2]-1), 
                 translation_x=(-self.ref_shape[0], self.ref_shape[0]), 
                 translation_y=(-self.ref_shape[1], self.ref_shape[1]), 
                 translation_z=(-self.ref_shape[2], self.ref_shape[2]),
                 rotation_x=(-np.pi, np.pi, 0.1), rotation_y=(-np.pi, np.pi, 0.1), rotation_z=(-np.pi, np.pi, 0.1)
                )
        
    def registration(self):
        registration_method = sitk.ImageRegistrationMethod()
        registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
        registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
        registration_method.SetMetricSamplingPercentage(0.40)
        registration_method.SetInterpolator(sitk.sitkLinear)
        registration_method.SetOptimizerAsGradientDescentLineSearch(learningRate=1.3,
                                                                numberOfIterations=200,
                                                                convergenceMinimumValue=1e-5,
                                                                convergenceWindowSize=5)
        registration_method.SetOptimizerScalesFromPhysicalShift()
        registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
        registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,2,1])
        registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
        registration_method.SetInitialTransform(ini_transform, inPlace=False)
        self.transformation = registration_method.Execute(self.ref_img, self.tar_img)
        self.reg_img = sitk.Resample(self.tar_img, self.ref_img, self.transformation, sitk.sitkLinear, 0.0, sitk.sitkUInt8)
        
        return
    
    def display(self, sitkimage):
        
        return
        
        

In [14]:
ref = r"/media/spl/D/MicroCT_data/Machine learning/SITK_reg_7um/339 week 2 right tibia registered"
tar = r"/media/spl/D/MicroCT_data/Machine learning/Treadmill running 35n tibia/339 week 3 right tibia"

reg = Registration(ref, tar, tar_z_range=[-550, -100], tar_flip = True)
reg.manual_initialization()

interactive(children=(IntSlider(value=221, description='img_x', max=443), IntSlider(value=269, description='im…

In [12]:
print(reg.transformation.GetCenter())
print(reg.transformation.GetTranslation())
print(reg.transformation.GetAngleX())
print(reg.transformation.GetAngleY())
print(reg.transformation.GetAngleZ())

(222.0, 269.5, 150.0)
(176.0, 0.0, 126.0)
-0.24159
-0.04159265358979303
-0.64159
