In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import numpy as np
import napari

import pydeform.sitk_api as pydeform
import SimpleITK as sitk

import zarr

import mFISHwarp.register
import mFISHwarp.transform

from skimage import io

In [2]:
# set path
io_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240620_03_MX002-1/registration'
fix_n5_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240620_03_MX002-1/fused/fused.n5' # zarr with pyramid resolution
mov_n5_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240710_03_MX002-2/fused/fused.n5'

# create IO directory
if not os.path.isdir(io_path):
    os.makedirs(io_path)

# create Zarr file object
fix_zarr = zarr.open(store=zarr.N5Store(fix_n5_path), mode='r')
mov_zarr = zarr.open(store=zarr.N5Store(mov_n5_path), mode='r')

# load low resolution images as numpy
fix_s = fix_zarr['setup3']['timepoint0']['s3'][:] # eight times downsampling
mov_s = mov_zarr['setup3']['timepoint0']['s3'][:]

# load higher resolution images as dask
fix_l = da.from_zarr(fix_zarr['setup3']['timepoint0']['s2'])
mov_l = da.from_zarr(mov_zarr['setup3']['timepoint0']['s2'])

## Global affine registratoin

In [3]:
%%time
# affine registration
affine_transform = mFISHwarp.register.affine_registration(
    fix_s, mov_s,
    initial_rotation=None,
    initial_scaling=None,
    shrinking_factors=(32, 16, 8, 4), # shrinking factor determine the resolution of the registration.
    smoothing=(4, 4, 2, 1),
    model='affine'
)

# apply transformation
mov_affine = mFISHwarp.register.affine_warping(fix_l, mov_l, affine_transform)

CPU times: user 1h 2min 6s, sys: 1min 10s, total: 1h 3min 16s
Wall time: 2min 16s


In [4]:
viewer = napari.Viewer()
viewer.add_image(fix_l, contrast_limits=[0,5000], rgb=False, name='fix', colormap='green', blending='additive')
viewer.add_image(mov_l, contrast_limits=[0,5000], rgb=False, name='mov', colormap='yellow', blending='additive')
viewer.add_image(mov_affine, contrast_limits=[0,5000], rgb=False, name='mov_affine', colormap='magenta', blending='additive')

<Image layer 'mov_affine' at 0x7fa7b3f5efa0>

In [None]:
io.imsave(os.path.join(io_path,'fix_l.tif'),fix_l, check_contrast=False)
io.imsave(os.path.join(io_path,'mov_l.tif'),mov_l, check_contrast=False)
io.imsave(os.path.join(io_path,'mov_affine.tif'), mov_affine.astype(np.uint16), check_contrast=False)

In [62]:
size = fix_l.shape[::-1]
spacing = [1.0, 1.0, 1.0]
origin = [0.0, 0.0, 0.0]
direction = [1, 0, 0, 0, 1, 0, 0, 0, 1]

# Convert the affine transform to a displacement field
displacement_field = sitk.TransformToDisplacementField(affine_transform, sitk.sitkVectorFloat64, size, origin, spacing, direction)


In [63]:
# convert itk to numpy array.
relative_displacement = mFISHwarp.transform.displacement_itk2numpy(displacement_field)
# convert the relative displacement array to scaled positional displacement array
positional_displacement = mFISHwarp.transform.relative2positional(relative_displacement)

In [68]:
viewer = napari.Viewer()
viewer.add_image(np.moveaxis(disp,-1,0), channel_axis=0)

[<Image layer 'Image' at 0x7fa7bdca9c70>,
 <Image layer 'Image [1]' at 0x7fa9c676daf0>,
 <Image layer 'Image [2]' at 0x7fa9c64b6b50>]

In [67]:
#

In [5]:
def affine_to_displacement_field(affine_matrix, shape):
    # Generate a grid of coordinates
    coords = np.indices(shape)
    coords = coords.reshape(3, -1)
    
    # Add a row of ones for homogeneous coordinates
    coords_homogeneous = np.vstack([coords, np.ones((1, coords.shape[1]))])
    
    # Apply the affine transformation
    transformed_coords = affine_matrix @ coords_homogeneous
    transformed_coords = transformed_coords[:3]  # Remove the homogeneous row
    
    # Calculate the displacement field
    displacement_field = transformed_coords - coords
    
    # Reshape the displacement field to the original shape
    displacement_field = displacement_field.reshape(3, *shape)
    
    return displacement_field

In [24]:
# def sitk_to_numpy_affine(sitk_transform):
#     # Get the 3x3 matrix and the translation vector from the SimpleITK affine transform
#     matrix = np.array(sitk_transform.GetMatrix()).reshape((3, 3))
#     translation = np.array(sitk_transform.GetTranslation())

#     # Construct the 4x4 affine transformation matrix
#     affine_matrix = np.eye(4)
#     affine_matrix[:3, :3] = matrix
#     affine_matrix[:3, 3] = translation

#     return affine_matrix

def sitk_to_numpy_affine(sitk_transform):
    # Get the 3x3 matrix and the translation vector from the SimpleITK affine transform
    matrix = np.array(sitk_transform.GetMatrix()).reshape((3, 3))
    translation = np.array(sitk_transform.GetTranslation())

    # Permute the matrix and translation to convert from XYZ to ZYX
    permuted_matrix = matrix[:, [2, 1, 0]][[2, 1, 0], :]
    permuted_translation = translation[[2, 1, 0]]

    # Construct the 4x4 affine transformation matrix
    affine_matrix = np.eye(4)
    affine_matrix[:3, :3] = permuted_matrix
    affine_matrix[:3, 3] = permuted_translation

    return affine_matrix

In [27]:
disp_global = affine_to_displacement_field(sitk_to_numpy_affine(affine_transform), fix_l.shape)

In [30]:
displacement_field = disp_global

In [28]:
# disp_global.shape
viewer = napari.Viewer()
viewer.add_image(disp_global, channel_axis=0)

[<Image layer 'Image' at 0x7fa9dfd81f70>,
 <Image layer 'Image [1]' at 0x7fa9d77e4490>,
 <Image layer 'Image [2]' at 0x7fa9d6d23f40>]

In [15]:
def numpy_to_sitk_displacement_field(displacement_field):
    # Displacement field shape is expected to be (3, depth, height, width)
    displacement_field_sitk = sitk.GetImageFromArray(displacement_field.transpose(1, 2, 3, 0), isVector=True)
    return displacement_field_sitk


In [31]:
sitk_displacement_field = numpy_to_sitk_displacement_field(displacement_field)
displacement_transform = sitk.DisplacementFieldTransform(sitk_displacement_field)

In [18]:
def apply_displacement_field(image, displacement_transform):
    # Resample the image using the displacement field transform
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(image)
    resampler.SetTransform(displacement_transform)
    resampler.SetInterpolator(sitk.sitkLinear)
    
    # Optionally set output pixel type and default pixel value
    resampler.SetOutputPixelType(image.GetPixelID())
    resampler.SetDefaultPixelValue(0)
    
    transformed_image = resampler.Execute(image)
    return transformed_image


In [33]:
mov_itk = sitk.Cast(sitk.GetImageFromArray(mov_l), sitk.sitkFloat32)
transformed_image = apply_displacement_field(mov_itk, displacement_transform)
manual_mov = sitk.GetArrayFromImage(transformed_image).astype(np.uint16)

In [34]:
viewer = napari.Viewer()
viewer.add_image(fix_l, contrast_limits=[0,5000], rgb=False, name='fix', colormap='green', blending='additive')
viewer.add_image(manual_mov, contrast_limits=[0,5000], rgb=False, name='mov', colormap='yellow', blending='additive')
viewer.add_image(mov_affine, contrast_limits=[0,5000], rgb=False, name='mov_affine', colormap='magenta', blending='additive')

<Image layer 'mov_affine' at 0x7fa9df917070>