In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import numpy as np
import napari
import dask.array as da
import pydeform.sitk_api as pydeform
import SimpleITK as sitk

import zarr

import mFISHwarp.register
import mFISHwarp.transform

from skimage import io

In [None]:
# set path
io_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240620_03_MX002-1/registration'
fix_n5_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240620_03_MX002-1/fused/fused.n5' # zarr with pyramid resolution
mov_n5_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240710_03_MX002-2/fused/fused.n5'

# create IO directory
if not os.path.isdir(io_path):
    os.makedirs(io_path)

# create Zarr file object
fix_zarr = zarr.open(store=zarr.N5Store(fix_n5_path), mode='r')
mov_zarr = zarr.open(store=zarr.N5Store(mov_n5_path), mode='r')

# load low resolution images as numpy
fix_s = fix_zarr['setup3']['timepoint0']['s3'][:] # eight times downsampling
mov_s = mov_zarr['setup3']['timepoint0']['s3'][:]

# load higher resolution images as dask
fix_l = da.from_zarr(fix_zarr['setup3']['timepoint0']['s2'])
mov_l = mov_zarr['setup3']['timepoint0']['s2']

## Global affine registration

In [None]:
%%time
# affine registration
affine_transform = mFISHwarp.register.affine_registration(
    fix_s, mov_s,
    initial_rotation=None,
    initial_scaling=None,
    shrinking_factors=(32, 16, 8, 4), # shrinking factor determine the resolution of the registration.
    smoothing=(4, 4, 2, 1),
    model='affine'
)

# apply transformation
mov_affine = mFISHwarp.register.affine_warping(fix_s, mov_s, affine_transform)

In [None]:
viewer = napari.Viewer()
viewer.add_image(fix_s, contrast_limits=[0,5000], rgb=False, name='fix', colormap='green', blending='additive')
viewer.add_image(mov_s, contrast_limits=[0,5000], rgb=False, name='mov', colormap='yellow', blending='additive')
viewer.add_image(mov_affine, contrast_limits=[0,5000], rgb=False, name='mov_affine', colormap='magenta', blending='additive')

In [None]:
io.imsave(os.path.join(io_path,'fix_s.tif'),fix_s, check_contrast=False)
io.imsave(os.path.join(io_path,'mov_s.tif'),mov_s, check_contrast=False)
io.imsave(os.path.join(io_path,'mov_affine.tif'), mov_affine.astype(np.uint16), check_contrast=False)

In [None]:
# convert affine to displacement field 

size = fix_s.shape[::-1]
spacing = [1.0, 1.0, 1.0]
origin = [0.0, 0.0, 0.0]
direction = [1, 0, 0, 0, 1, 0, 0, 0, 1]

# Convert the affine transform to a displacement field
displacement_field = sitk.TransformToDisplacementField(affine_transform, sitk.sitkVectorFloat64, size, origin, spacing, direction)
# convert itk to numpy array.
relative_displacement = mFISHwarp.transform.displacement_itk2numpy(displacement_field)
# convert the relative displacement array to scaled positional displacement array
positional_displacement = mFISHwarp.transform.relative2positional(relative_displacement)

### Upsampling of lower-resolution affine to higher-resolution displacement

In [None]:
import math
# set rescale factors
target_shape = fix_l.shape
rescale_constant = tuple((np.array(target_shape) / np.array(relative_displacement.shape[:-1])).round().astype(int)) # upsampling ratio in zyx
# specify the chunk size and overlaps
out_chunk_size = [math.ceil(i/2) for i in target_shape]
out_chunk_size = [int(math.ceil(i/j)*j) for i,j in zip(out_chunk_size, rescale_constant)]
out_overlap=(64, 64, 64)


fix_da = da.rechunk(fix_l,chunks=out_chunk_size)
fix_overlap = da.overlap.overlap(fix_da,depth=out_overlap,boundary=0)

# get chunk size of fixed with overlaps
overlap_chunk_size = mFISHwarp.utils.chunks_from_dask(fix_overlap)

# upscale the displacement field and make overlaps
displacement_overlap = mFISHwarp.transform.upscale_displacement_overlap(
    positional_displacement,
    rescale_constant,
    out_chunk_size=out_chunk_size,
    out_overlap=out_overlap
)

In [None]:
# Set zarr path for registered moving image.
registered_mov_zarr = zarr.open(
    os.path.join(io_path, 'transformed_lowres.zarr'), 
    mode='w-', 
    shape=target_shape, 
    chunks=out_chunk_size, 
    dtype=np.uint16
)
# Or, registered_mov_zarr = None

# Set zarr path for displacement map.
displacement_shape = fix_overlap.shape + (fix_da.ndim,)
displacement_zarr = zarr.open(
    os.path.join(io_path, 'displacements_lowres_overlap.zarr'), 
    mode='w-', 
    shape=displacement_shape, 
    chunks=overlap_chunk_size+(fix_da.ndim,), 
    dtype=displacement_overlap.dtype
)

# add attribute so that I can know the overlap size later.
displacement_zarr.attrs.update({"overlaps": out_overlap+(0,)})

In [None]:
# set parameters for non linear registration
settings = {
    'pyramid_levels':1, 
    'pyramid_stop_level': 0, # if the computation takes too long time, reduce the resolution by increasing the number.
    'step_size': [1.0, 1.0, 1.0], # [1.0,1.0,1.0] seems enough. more than that will degrade the quality
    'block_size': [32,32,32],
    'block_energy_epsilon':1e-7,
    'max_iteration_count':-1,
    'constraints_weight':1000.0,
    'regularization_weight': 0.25, # default 0.25
    'regularization_scale': 1.0, # default 1.0
    'regularization_exponent': 2.0, # default 2.0
    'image_slots': 
    [
        {
            'resampler': 'gaussian',
            'normalize': True,
            'cost_function':
            [
                {
                    'function':'ncc',
                    'weight':1.0,
                    'radius':3 # 3 was better than 7. do not know why.
                }
            ]
        }
    ]
}

In [None]:
# run chunk-wise registration
num_threads=32 # more is faster
use_gpu=True

for index in list(np.ndindex(*fix_da.numblocks)):
    mFISHwarp.register.chunk_wise_affine_deform_registration(index, displacement_overlap, fix_overlap, mov_l, settings, displacement_zarr, registered_mov_zarr, num_threads, use_gpu)