In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import numpy as np
import napari

import pydeform.sitk_api as pydeform
import SimpleITK as sitk

import zarr

import mFISHwarp.register
import mFISHwarp.transform

from skimage import io

In [None]:
# set path
io_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240620_03_MX002-1/registration'
fix_n5_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240620_03_MX002-1/fused/fused.n5' # zarr with pyramid resolution
mov_n5_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240710_03_MX002-2/fused/fused.n5'

# create IO directory
if not os.path.isdir(io_path):
    os.makedirs(io_path)

# create Zarr file object
fix_zarr = zarr.open(store=zarr.N5Store(fix_n5_path), mode='r')
mov_zarr = zarr.open(store=zarr.N5Store(mov_n5_path), mode='r')

# load low resolution images as numpy
fix_l = fix_zarr['setup3']['timepoint0']['s3'][:] # eight times downsampling
mov_l = mov_zarr['setup3']['timepoint0']['s3'][:]

## Affine registration

In [None]:
%%time
# affine registration
affine_transform = mFISHwarp.register.affine_registration(
    fix_l, mov_l,
    initial_rotation=None,
    initial_scaling=None,
    shrinking_factors=(32, 16, 8, 4), # shrinking factor determine the resolution of the registration.
    smoothing=(4, 4, 2, 1),
    model='affine'
)

# apply transformation
mov_affine = mFISHwarp.register.affine_warping(fix_l, mov_l, affine_transform)

In [None]:
viewer = napari.Viewer()
viewer.add_image(fix_l, contrast_limits=[0,5000], rgb=False, name='fix', colormap='green', blending='additive')
viewer.add_image(mov_l, contrast_limits=[0,5000], rgb=False, name='mov', colormap='yellow', blending='additive')
viewer.add_image(mov_affine, contrast_limits=[0,5000], rgb=False, name='mov_affine', colormap='magenta', blending='additive')

In [None]:
io.imsave(os.path.join(io_path,'fix_l.tif'),fix_l, check_contrast=False)
io.imsave(os.path.join(io_path,'mov_l.tif'),mov_l, check_contrast=False)
io.imsave(os.path.join(io_path,'mov_affine.tif'), mov_affine.astype(np.uint16), check_contrast=False)

## Non-linear registration

In [None]:
# set path for log file
log_file = os.path.join(io_path,'deform.log')

# set parameters for non linear registration
settings = {
    'pyramid_levels':1, 
    'pyramid_stop_level': 0, # if the computation takes too long time, reduce the resolution by increasing the number.
    'step_size': [1.0, 1.0, 1.0], # [1.0,1.0,1.0] seems enough. more than that will degrade the quality
    'block_size': [32,32,32],
    'block_energy_epsilon':1e-7,
    'max_iteration_count':-1,
    'constraints_weight':1000.0,
    'regularization_weight': 0.25, # default 0.25
    'regularization_scale': 1.0, # default 1.0
    'regularization_exponent': 2.0, # default 2.0
    'image_slots': 
    [
        {
            'resampler': 'gaussian',
            'normalize': True,
            'cost_function':
            [
                {
                    'function':'ncc',
                    'weight':1.0,
                    'radius':3 # 3 was better than 7. do not know why.
                }
            ]
        }
    ]
}

In [None]:
%%time
# non-linear registration
displacement = mFISHwarp.register.deform_registration(
    fix_l, mov_l, 
    settings, 
    affine_transform=affine_transform, 
    log_path=log_file, 
    num_threads=-1, 
    use_gpu=True
)

# apply transformation
mov_deformed = mFISHwarp.register.deform_warping(mov_l, displacement)

In [None]:
viewer = napari.Viewer()
viewer.add_image(fix_l, contrast_limits=[0,5000], rgb=False, name='fix', colormap='green', blending='additive')
viewer.add_image(mov_affine, contrast_limits=[0,5000], rgb=False, name='mov_affine', colormap='blue', blending='additive')
viewer.add_image(mov_deformed, contrast_limits=[0,5000], rgb=False, name='mov_deform', colormap='magenta', blending='additive')

In [None]:
# save displacement array
disp_arr = mFISHwarp.transform.displacement_itk2numpy(displacement)
np.save(os.path.join(io_path, 'global_displacement.npy'), disp_arr)
io.imsave(os.path.join(io_path,'mov_deformed.tif'), mov_deformed.astype(np.uint16), check_contrast=False)

In [None]:
#