In [1]:
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from shubow_tools import imreadseq_multithread, imsaveseq
from ipywidgets import interact, fixed, interact_manual
%matplotlib inline

In [2]:
class Registration:
    def __init__(self, ref_dir, tar_dir, ref_z_range=None, tar_z_range = None, ref_flip=False, tar_flip=False):
        
        if not ref_flip:
            self.ref_img = sitk.Cast(self.imread(ref_dir, z_range = ref_z_range), sitk.sitkFloat32)
        else:
            self.ref_img = sitk.Cast(sitk.GetImageFromArray(
                np.flip(self.imread(ref_dir, z_range = ref_z_range, sitkimg = False), axis = 2)
            ), sitk.sitkFloat32)
            
        if not tar_flip:    
            self.tar_img = sitk.Cast(self.imread(tar_dir, z_range = tar_z_range), sitk.sitkFloat32)
        else:
            self.tar_img = sitk.Cast(sitk.GetImageFromArray(
                np.flip(self.imread(tar_dir, z_range = tar_z_range, sitkimg = False), axis = 2)
            ), sitk.sitkFloat32)
        
        self.ref_shape = self.ref_img.GetSize()
        self.tar_shape = self.tar_img.GetSize()
        self.reg_img = None
        self.transformation = self.__center_initialization__()
        self.reg_transform = None
        
    def imread(self, *args, **kwds):
        return imreadseq_multithread(*args,**kwds)
    
    def __center_initialization__(self):
        initial_transform = sitk.Euler3DTransform(sitk.CenteredTransformInitializer(self.ref_img, 
                                                      self.tar_img, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY))
        
        
        return initial_transform
    
    def manual_initialization(self):
                
        def display( 
                    img_x, img_y, img_z, 
                    translation_x, translation_y, translation_z, 
                    rotation_x, rotation_y, rotation_z
                   ):
            
            translation=(-translation_x, -translation_y, -translation_z)
            center=[0.5*i for i in self.ref_shape]
            rotation=(rotation_x, -rotation_y, -rotation_z)
            
            self.transformation.SetCenter(center)
            self.transformation.SetRotation(*rotation)
            self.transformation.SetTranslation(translation)
            
            self.reg_img = sitk.Resample(self.tar_img, self.ref_img, self.transformation, sitk.sitkLinear, 0.0, sitk.sitkFloat32)  
            #yz_img = sitk.GetArrayFromImage(self.reg_img[img_x, :, :]) * 0.5 + sitk.GetArrayFromImage(self.ref_img[img_x, :, :])*0.5 
            #xz_img = sitk.GetArrayFromImage(self.reg_img[:, img_y, :]) * 0.5 + sitk.GetArrayFromImage(self.ref_img[:, img_y, :])*0.5 
            #xy_img = sitk.GetArrayFromImage(self.reg_img[:, :, img_z]) * 0.5 + sitk.GetArrayFromImage(self.ref_img[:, :, img_z])*0.5 
            
            plt.subplots(2,2,figsize=(10,8))
            plt.subplot(2,2,1)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, :, img_x])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, :, img_x], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("Y-Z plane")
            plt.axis('off')
            
            plt.subplot(2,2,2)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[img_z, :, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[img_z, :, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Y plane")
            plt.axis('off') 
            
            plt.subplot(2,2,3)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, img_y, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, img_y, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Z plane")
            plt.axis('off')   
            
            plt.subplot(2,2,4)
            plt.axis('off') 
            
            
        interact_manual(display, img_x=(0, self.ref_shape[0]-1), img_y=(0, self.ref_shape[1]-1), img_z=(0, self.ref_shape[2]-1), 
                 translation_x=(-self.ref_shape[0]-self.transformation.GetTranslation()[0], self.ref_shape[0]-self.transformation.GetTranslation()[0], 1), 
                 translation_y=(-self.ref_shape[1]-self.transformation.GetTranslation()[1], self.ref_shape[1]-self.transformation.GetTranslation()[1], 1), 
                 translation_z=(-self.ref_shape[2]-self.transformation.GetTranslation()[2], self.ref_shape[2]-self.transformation.GetTranslation()[2], 1),
                 rotation_x=(-np.pi+self.transformation.GetAngleX(), np.pi+self.transformation.GetAngleX(), 0.05), 
                 rotation_y=(-np.pi-self.transformation.GetAngleY(), np.pi-self.transformation.GetAngleY(), 0.05), 
                 rotation_z=(-np.pi-self.transformation.GetAngleZ(), np.pi-self.transformation.GetAngleZ(), 0.05)
                )
        
    def registration(self):
        registration_method = sitk.ImageRegistrationMethod()
        registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
        registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
        registration_method.SetMetricSamplingPercentage(0.50)
        registration_method.SetInterpolator(sitk.sitkLinear)
        registration_method.SetOptimizerAsGradientDescentLineSearch(learningRate=1.4,
                                                                numberOfIterations=200,
                                                                convergenceMinimumValue=1e-5,
                                                                convergenceWindowSize=5)
        registration_method.SetOptimizerScalesFromPhysicalShift()
        registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
        registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,2,1])
        registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
        registration_method.SetInitialTransform(self.transformation, inPlace=False)
        self.reg_transform = registration_method.Execute(self.ref_img, self.tar_img)
        self.reg_img = sitk.Resample(self.tar_img, self.ref_img, self.reg_transform, sitk.sitkLinear, 0.0, sitk.sitkFloat32)
        
        self.transformation.SetCenter(self.reg_transform.GetFixedParameters()[:3])
        self.transformation.SetRotation(*self.reg_transform.GetParameters()[:3])
        self.transformation.SetTranslation(self.reg_transform.GetParameters()[3:])
    
    def display_result(self):
        
        def display(img_x, img_y, img_z):
            
            plt.subplots(2,2,figsize=(10,8))
            plt.subplot(2,2,1)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, :, img_x])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, :, img_x], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("Y-Z plane")
            plt.axis('off')
            
            plt.subplot(2,2,2)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[img_z, :, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[img_z, :, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Y plane")
            plt.axis('off') 
            
            plt.subplot(2,2,3)
            plt.imshow(sitk.GetArrayFromImage(self.ref_img)[::-1, img_y, :])
            plt.imshow(sitk.GetArrayFromImage(self.reg_img)[::-1, img_y, :], cmap=plt.cm.Greys_r, alpha=0.5)
            plt.title("X-Z plane")
            plt.axis('off')   
            
            plt.subplot(2,2,4)
            plt.axis('off') 
        
        
        interact(display, img_x=(0, self.ref_shape[0]-1), img_y=(0, self.ref_shape[1]-1), img_z=(0, self.ref_shape[2]-1))
        
    def save(self, *args, **kwds):
        imsaveseq(sitk.Cast(self.reg_img, sitk.sitkUInt8), *args, **kwds)
        
    def save_transform(self, *args, **kwds):
        sitk.WriteTransform(self.transformation, *args, **kwds)
        
    def load_transform(self, *args, **kwds):
        self.transformation = sitk.ReadTransform(*args, **kwds)

In [4]:
import os
import re
import shutil
refdir = r"E:\Yoda1-tumor-loading 2.26.2021\Registration week 0"
tardir = r"E:\Yoda1-tumor-loading 2.26.2021\Tibia & femur week 2"
img_title = r"537 week 2 left tibia"
fd_path = os.path.join(refdir, img_title+" registered")

ref = os.path.join(refdir, re.sub("week \d", "week 0", img_title+" registered"))

tar = os.path.join(tardir, img_title)


##############################
reg = Registration(ref, tar, tar_z_range=[-550, -20], tar_flip = False)
#reg = Registration(ref, tar)

In [5]:
reg.manual_initialization()

interactive(children=(IntSlider(value=275, description='img_x', max=551), IntSlider(value=274, description='im…

In [26]:
reg.save_transform(os.path.join(tar, img_title+".tfm"))

In [6]:
reg.registration()

In [7]:
reg.display_result()

interactive(children=(IntSlider(value=275, description='img_x', max=551), IntSlider(value=274, description='im…

In [8]:
'''if os.path.exists(fd_path):
    shutil.rmtree(fd_path)

os.mkdir(fd_path)'''

reg.save(r"E:\Yoda1-tumor-loading 2.26.2021\Registration week 2\537 week 2 left tibia registered", img_title+"_Reg")

In [34]:
print(reg.transformation.GetCenter())
print(reg.transformation.GetTranslation())
print(reg.transformation.GetAngleX())
print(reg.transformation.GetAngleY())
print(reg.transformation.GetAngleZ())

(390.0, 303.0, 381.0)
(-0.0, -0.0, -0.0)
-0.04159265358979303
-0.04159265358979303
-0.04159265358979303


In [12]:
print(type(reg.reg_transform))

<class 'SimpleITK.SimpleITK.Transform'>


In [15]:
print(type(reg.transformation))

<class 'SimpleITK.SimpleITK.Euler3DTransform'>


In [28]:
reg.transformation.GetCenter()

(222.0, 269.5, 175.0)

In [22]:
dir(reg.reg_transform)

['AddTransform',
 'FlattenTransform',
 'GetDimension',
 'GetFixedParameters',
 'GetITKBase',
 'GetInverse',
 'GetName',
 'GetNumberOfFixedParameters',
 'GetNumberOfParameters',
 'GetParameters',
 'IsLinear',
 'MakeUnique',
 'SetFixedParameters',
 'SetIdentity',
 'SetInverse',
 'SetParameters',
 'TransformPoint',
 'TransformVector',
 'WriteTransform',
 '__class__',
 '__del__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__swig_destroy__',
 '__swig_getmethods__',
 '__swig_setmethods__',
 '__weakref__',
 'this']

In [27]:
print(reg.reg_transform.GetFixedParameters())
print(reg.reg_transform.GetParameters())

(222.0, 269.5, 175.0, 0.0)
(0.03241928462907902, 0.014786274483489823, -0.004661059801034991, 68.62378224287195, 56.36650938888136, 138.81392240269452)


In [24]:
print(reg.reg_transform)

itk::simple::Transform
 CompositeTransform (0x3a7d730)
   RTTI typeinfo:   itk::CompositeTransform<double, 3u>
   Reference Count: 1
   Modified Time: 381919458
   Debug: Off
   Object Name: 
   Observers: 
     none
   Transforms in queue, from begin to end:
   >>>>>>>>>
   Euler3DTransform (0x3968aa0)
     RTTI typeinfo:   itk::Euler3DTransform<double>
     Reference Count: 2
     Modified Time: 381919450
     Debug: Off
     Object Name: 
     Observers: 
       none
     Matrix: 
       0.999882 0.00465859 0.0146345 
       -0.00418128 0.999464 -0.0324786 
       -0.014778 0.0324136 0.999365 
     Offset: [64.8334, 63.1231, 133.47]
     Center: [222, 269.5, 175]
     Translation: [68.6238, 56.3665, 138.814]
     Inverse: 
       0.999882 -0.00418128 -0.014778 
       0.00465859 0.999464 0.0324136 
       0.0146345 -0.0324786 0.999365 
     Singular: 0
     Euler's angles: AngleX=0.0324193 AngleY=0.0147863 AngleZ=-0.00466106
     m_ComputeZYX = 0
   End of MultiTransform.
<<<<<<<<<<