In [None]:
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from shubow_tools import imreadseq_multithread, imsaveseq
from ipywidgets import interact, fixed, interact_manual
%matplotlib inline

In [None]:
class Registration2D:
    def __init__(self, ref_img_file, tar_img_file):
        
        self.ref_img = sitk.GetArrayFromImage(sitk.ReadImage(ref_img_file))
        self.tar_img = sitk.GetArrayFromImage(sitk.ReadImage(tar_img_file))
        self._ref_img_grey = sitk.Cast(sitk.GetImageFromArray(self.ref_img.mean(axis = -1)), sitk.sitkFloat32)
        self._tar_img_grey = sitk.Cast(sitk.GetImageFromArray(self.tar_img.mean(axis = -1)), sitk.sitkFloat32)
        
        self.ref_shape = self._ref_img_grey.GetSize()
        self.tar_shape = self._tar_img_grey.GetSize()
        
        self.reg_img = np.zeros(shape = self.ref_img.shape, dtype=np.uint8)
        self.transformation = self.__center_initialization__()
        self.reg_transform = None
        
    def __center_initialization__(self):
        '''initial_transform = sitk.Euler2DTransform(sitk.CenteredTransformInitializer(self._ref_img_grey, 
                                                      self._tar_img_grey, 
                                                      sitk.Euler2DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY))'''
        initial_transform = sitk.Euler2DTransform()
        
        return initial_transform
    
    def manual_initialization(self):
        
        def display(translation_x, translation_y, angle):
            translation=(-translation_x, -translation_y)
            center=[0.5*i for i in self.ref_shape]
            
            self.transformation.SetCenter(center)
            self.transformation.SetAngle(angle*np.pi/180)
            self.transformation.SetTranslation(translation)
            
            for i in range(3):
                self.reg_img[:,:,i] = sitk.GetArrayFromImage(
                    sitk.Resample(
                        sitk.Cast(
                            sitk.GetImageFromArray(self.tar_img[:,:,i]), sitk.sitkFloat32), 
                        self.transformation, 
                        sitk.sitkLinear,
                        0.0, 
                        sitk.sitkUInt8
                    )
                )      
            
            plt.subplots(1,1,figsize = (8, 6), dpi = 100)
            plt.subplot(1,1,1)
            plt.imshow(self.ref_img, alpha = 1.0)
            plt.imshow(self.reg_img, alpha = 0.7)
            plt.axis('off')
            
        interact(display,  
                translation_x=(-self.ref_shape[0]-self.transformation.GetTranslation()[0], self.ref_shape[0]-self.transformation.GetTranslation()[0], 1), 
                translation_y=(-self.ref_shape[1]-self.transformation.GetTranslation()[1], self.ref_shape[1]-self.transformation.GetTranslation()[1], 1), 
                angle=(-180, 180, 0.25), 
                )
        
    def save(self, *args, **kwds):
        sitk.WriteImage(sitk.Cast(self.GetImageFromArray(self.reg_img), sitk.sitkUInt8), *args, **kwds)
        
    def save_transform(self, *args, **kwds):
        sitk.WriteTransform(self.transformation, *args, **kwds)
        
    def load_transform(self, *args, **kwds):
        self.transformation = sitk.ReadTransform(*args, **kwds)
        
    def display_results(self):
            plt.subplots(1,1,figsize = (8, 6), dpi = 100)
            plt.subplot(1,1,1)
            plt.imshow(self.ref_img, alpha = 1.0)
            plt.imshow(self.reg_img, alpha = 0.7)
            plt.axis('off')

In [None]:
reg = Registration2D(ref_silver, tar_tunel)

In [None]:
reg.manual_initialization()

In [None]:
tar_tunel = r"E:\DATA\01.14.21_mouse Sample_RCC_Texas Group\1T TUNEL\Image_20514-tip 2 (endo).tif"
ref_silver = r"E:\DATA\01.14.21_mouse Sample_RCC_Texas Group\1T\1T_Image_20288.tif"
tar_img = sitk.GetArrayFromImage(sitk.GetImageFromArray(sitk.GetArrayFromImage(sitk.ReadImage(tar_tunel))))
ref_img = sitk.GetArrayFromImage(sitk.ReadImage(ref_silver))

In [None]:
reg.save(r"E:\DATA\01.14.21_mouse Sample_RCC_Texas Group\1T TUNEL\Image_20514-tip 2 (endo)__transformed.tif")

In [None]:
reg.save_transform(r"E:\DATA\01.14.21_mouse Sample_RCC_Texas Group\1T TUNEL\Image_20514-tip 2 (endo).tfm")

In [None]:
reg.load_transform(r"E:\DATA\01.14.21_mouse Sample_RCC_Texas Group\1T TUNEL\Image_20514-tip 2 (endo).tfm")