In [1]:
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from shubow_tools import imreadseq_multithread, imsaveseq
from ipywidgets import interact, fixed, interact_manual
%matplotlib inline

In [2]:
class Registration:
    def __init__(self, ref_dir, tar_dir, ref_z_range=None, tar_z_range = None, ref_flip=False, tar_flip=False):
        
        if not ref_flip:
            self.ref_img = sitk.Cast(self.imread(ref_dir, z_range = ref_z_range), sitk.sitkFloat32)
        else:
            self.ref_img = sitk.Cast(sitk.GetImageFromArray(
                np.flip(self.imread(ref_dir, z_range = ref_z_range, sitkimg = False), axis = 2)
            ), sitk.sitkFloat32)
            
        if not tar_flip:    
            self.tar_img = sitk.Cast(self.imread(tar_dir, z_range = tar_z_range), sitk.sitkFloat32)
        else:
            self.tar_img = sitk.Cast(sitk.GetImageFromArray(
                np.flip(self.imread(tar_dir, z_range = tar_z_range, sitkimg = False), axis = 2)
            ), sitk.sitkFloat32)
        
        self.ref_shape = self.ref_img.GetSize()
        self.tar_shape = self.tar_img.GetSize()
        self.reg_img = None
        self.transformation = self.__center_initialization__()
        self.reg_transform = None
        
    def imread(self, *args, **kwds):
        return imreadseq_multithread(*args,**kwds)
    
    def __center_initialization__(self):
        initial_transform = sitk.Euler3DTransform(sitk.CenteredTransformInitializer(self.ref_img, 
                                                      self.tar_img, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY))
        
        
        return initial_transform
    
    def manual_initialization(self):
                
        def display( 
                    img_x, img_y, img_z, 
                    translation_x, translation_y, translation_z, 
                    rotation_x, rotation_y, rotation_z
                   ):
            
            translation=(-translation_x, -translation_y, -translation_z)
            center=[0.5*i for i in self.ref_shape]
            rotation=(rotation_x, rotation_y, rotation_z)
            
            self.transformation.SetCenter(center)
            self.transformation.SetRotation(*rotation)
            self.transformation.SetTranslation(translation)
            
            self.reg_img = sitk.Resample(self.tar_img, self.ref_img, self.transformation, sitk.sitkLinear, 0.0, sitk.sitkFloat32)  
            #yz_img = sitk.GetArrayFromImage(self.reg_img[img_x, :, :]) * 0.5 + sitk.GetArrayFromImage(self.ref_img[img_x, :, :])*0.5 
            #xz_img = sitk.GetArrayFromImage(self.reg_img[:, img_y, :]) * 0.5 + sitk.GetArrayFromImage(self.ref_img[:, img_y, :])*0.5 
            #xy_img = sitk.GetArrayFromImage(self.reg_img[:, :, img_z]) * 0.5 + sitk.GetArrayFromImage(self.ref_img[:, :, img_z])*0.5 
            
            plt.subplots(2,2,figsize=(10,8))
            plt.subplot(2,2,1)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, :, img_x])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, :, img_x], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("Y-Z plane")
            plt.axis('off')
            
            plt.subplot(2,2,2)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[img_z, :, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[img_z, :, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Y plane")
            plt.axis('off') 
            
            plt.subplot(2,2,3)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, img_y, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, img_y, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Z plane")
            plt.axis('off')   
            
            plt.subplot(2,2,4)
            plt.axis('off') 
            
            
        interact_manual(display, img_x=(0, self.ref_shape[0]-1), img_y=(0, self.ref_shape[1]-1), img_z=(0, self.ref_shape[2]-1), 
                 translation_x=(-self.ref_shape[0]-self.transformation.GetTranslation()[0], self.ref_shape[0]-self.transformation.GetTranslation()[0], 2.0), 
                 translation_y=(-self.ref_shape[1]-self.transformation.GetTranslation()[1], self.ref_shape[1]-self.transformation.GetTranslation()[1], 2.0), 
                 translation_z=(-self.ref_shape[2]-self.transformation.GetTranslation()[2], self.ref_shape[2]-self.transformation.GetTranslation()[2], 2.0),
                 rotation_x=(-np.pi+self.transformation.GetAngleX(), np.pi+self.transformation.GetAngleX(), 0.05), 
                 rotation_y=(-np.pi+self.transformation.GetAngleY(), np.pi+self.transformation.GetAngleY(), 0.05), 
                 rotation_z=(-np.pi+self.transformation.GetAngleZ(), np.pi+self.transformation.GetAngleZ(), 0.05)
                )
        
    def registration(self):
        registration_method = sitk.ImageRegistrationMethod()
        registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
        registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
        registration_method.SetMetricSamplingPercentage(0.20)
        registration_method.SetInterpolator(sitk.sitkLinear)
        registration_method.SetOptimizerAsGradientDescentLineSearch(learningRate=1.4,
                                                                numberOfIterations=200,
                                                                convergenceMinimumValue=1e-5,
                                                                convergenceWindowSize=5)
        registration_method.SetOptimizerScalesFromPhysicalShift()
        registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
        registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,2,1])
        registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
        registration_method.SetInitialTransform(self.transformation, inPlace=False)
        self.reg_transform = registration_method.Execute(self.ref_img, self.tar_img)
        self.reg_img = sitk.Resample(self.tar_img, self.ref_img, self.reg_transform, sitk.sitkLinear, 0.0, sitk.sitkFloat32)
        
    
    def display_result(self):
        
        def display(img_x, img_y, img_z):
            
            plt.subplots(2,2,figsize=(10,8))
            plt.subplot(2,2,1)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, :, img_x])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, :, img_x], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("Y-Z plane")
            plt.axis('off')
            
            plt.subplot(2,2,2)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[img_z, :, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[img_z, :, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Y plane")
            plt.axis('off') 
            
            plt.subplot(2,2,3)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, img_y, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, img_y, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Z plane")
            plt.axis('off')   
            
            plt.subplot(2,2,4)
            plt.axis('off') 
        
        
        interact(display, img_x=(0, self.ref_shape[0]-1), img_y=(0, self.ref_shape[1]-1), img_z=(0, self.ref_shape[2]-1))
        
    def save(self, *args, **kwds):
        imsaveseq(sitk.Cast(self.reg_img, sitk.sitkUInt8), *args, **kwds)

In [37]:
import os
import re
import shutil
refdir = r"/media/spl/D/MicroCT_data/Machine learning/SITK_reg_7um"
tardir = r"/media/spl/D/MicroCT_data/Machine learning/Treadmill running 35n tibia"
img_title = r"372 week 5 left tibia"
fd_path = os.path.join(refdir, img_title+" registered")

ref = os.path.join(refdir, re.sub("week \d", "week 1", img_title+" registered"))
tar = os.path.join(tardir, img_title)

if os.path.exists(fd_path):
            shutil.rmtree(fd_path)
        
os.mkdir(fd_path)

##############################
reg = Registration(ref, tar, tar_z_range=[-550, -50], tar_flip = False)
#reg = Registration(ref, tar)
reg.manual_initialization()

interactive(children=(IntSlider(value=221, description='img_x', max=443), IntSlider(value=269, description='im…

In [38]:
reg.registration()

In [39]:
reg.display_result()

interactive(children=(IntSlider(value=221, description='img_x', max=443), IntSlider(value=269, description='im…

In [40]:
reg.save(fd_path, img_title+"_Reg")

In [34]:
print(reg.transformation.GetCenter())
print(reg.transformation.GetTranslation())
print(reg.transformation.GetAngleX())
print(reg.transformation.GetAngleY())
print(reg.transformation.GetAngleZ())

(390.0, 303.0, 381.0)
(-0.0, -0.0, -0.0)
-0.04159265358979303
-0.04159265358979303
-0.04159265358979303


In [38]:
print(reg.reg_transform)

itk::simple::Transform
 CompositeTransform (00000204B250B420)
   RTTI typeinfo:   class itk::CompositeTransform<double,3>
   Reference Count: 1
   Modified Time: 451657532
   Debug: Off
   Object Name: 
   Observers: 
     none
   Transforms in queue, from begin to end:
   >>>>>>>>>
   Euler3DTransform (00000204E256EDE0)
     RTTI typeinfo:   class itk::Euler3DTransform<double>
     Reference Count: 1
     Modified Time: 451657524
     Debug: Off
     Object Name: 
     Observers: 
       none
     Matrix: 
       0.999466 -0.0310407 0.0102505 
       0.0311343 0.999474 -0.00909449 
       -0.00996285 0.00940877 0.999906 
     Offset: [2.91215, -13.6611, -20.3215]
     Center: [390, 303, 381]
     Translation: [-2.79617, -5.14315, -21.392]
     Inverse: 
       0.999466 0.0311343 -0.00996285 
       -0.0310407 0.999474 0.00940877 
       0.0102505 -0.00909449 0.999906 
     Singular: 0
     Euler's angles: AngleX=0.00940891 AngleY=0.00996346 AngleZ=0.0310471
     m_ComputeZYX = 0
   En