In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

import numpy as np
import napari

import dask_image
import dask.array as da
import dask
import zarr

import mFISHwarp.morphology
import mFISHwarp.utils
import mFISHwarp.transform
import mFISHwarp.register

import ray

In [None]:
# set path
io_path = '/mnt/ampa02_data01/tmurakami/kidney/registration'
fix_n5_path = '/mnt/ampa02_data01/tmurakami/kidney/240614_kidney_finescan/fused/fused.n5' # zarr with pyramid resolution
mov_n5_path = '/mnt/ampa02_data01/tmurakami/kidney/240709_kidney-2nd-round-finescan/fused/fused.n5'

# create IO folder
if not os.path.isdir(io_path):
    os.makedirs(io_path)

# create Zarr file object
fix_zarr = zarr.open(store=zarr.N5Store(fix_n5_path), mode='r')
mov_zarr = zarr.open(store=zarr.N5Store(mov_n5_path), mode='r')

# load the displacement field
relative_displacement = np.load(os.path.join(io_path, 'global_displacement.npy'))
# convert the relative displacement array to scaled positional displacement array
positional_displacement = mFISHwarp.transform.relative2positional(relative_displacement)

# get high resolution images as zarr object
# think twice if you are loading the highest resolution.
fix_l = fix_zarr['setup3']['timepoint0']['s1'] # two fold downsampling
mov_l = mov_zarr['setup2']['timepoint0']['s1']

## Make masks to avoid registration of the blank space

In [None]:
fix_s = fix_zarr['setup3']['timepoint0']['s4'][:]
fix_mask = mFISHwarp.morphology.mask_maker(fix_s,1000)

# visualize the mask
viewer = napari.Viewer()
viewer.add_image(fix_s, contrast_limits=[0,10000], rgb=False, colormap='green', blending='additive')
viewer.add_image(fix_mask, contrast_limits=[0,5], rgb=False, colormap='magenta', blending='additive')

In [None]:
from skimage import io
io.imsave(os.path.join(io_path,'fix_mask.tif'),fix_mask)

## Upsampling of positional displacement

In [None]:
# specify the chunk size and overlaps
out_chunk_size=(128, 128, 128)
out_overlap=(32, 32, 32)

# set rescale factors
rescale_constant = tuple((np.array(fix_l.shape) / np.array(relative_displacement.shape[:-1])).round().astype(int)) # upsampling ratio in zyx

# upscale the displacement field and make overlaps
fix_da = da.from_zarr(fix_l, chunks=out_chunk_size)
fix_overlap = da.overlap.overlap(fix_da,depth=out_overlap,boundary=0)

# get chunk size of fixed with overlaps
overlap_chunk_size = mFISHwarp.utils.chunks_from_dask(fix_overlap)

# upscale the displacement field and make overlaps
displacement_overlap = mFISHwarp.transform.upscale_displacement_overlap(
    positional_displacement,
    rescale_constant,
    out_chunk_size=out_chunk_size,
    out_overlap=out_overlap
)

# Trim or pad the upscaled displacement array to fit the size of fix_overlap
# displacement_overlap = mFISHwarp.transform.pad_trim_array_to_size(displacement_overlap, fix_overlap.shape + (fix_da.ndim,))
# Clean the redundant chunks caused by padding
# displacement_overlap = da.rechunk(displacement_overlap, overlap_chunk_size + (fix_da.ndim,))

## Preparation for registration

In [None]:
# get shape of the target.
target_shape = fix_l.shape

# make flag array from masks
flag_array = mFISHwarp.utils.flag_array_generator(out_chunk_size, target_shape, fix_mask)
print(f'{flag_array.sum()} blocks of {flag_array.shape[0]}*{flag_array.shape[1]}*{flag_array.shape[2]}={flag_array.size} blocks will be calculated')

In [None]:
# Set zarr path for registered moving image.
registered_mov_zarr = zarr.open(
    os.path.join(io_path, 'transformed_midres.zarr'), 
    mode='w-', 
    shape=target_shape, 
    chunks=out_chunk_size, 
    dtype=np.uint16
)
# Or, registered_mov_zarr = None

# Set zarr path for displacement map.
displacement_shape = fix_overlap.shape + (fix_da.ndim,)
displacement_zarr = zarr.open(
    os.path.join(io_path, 'displacements_overlap.zarr'), 
    mode='w-', 
    shape=displacement_shape, 
    chunks=overlap_chunk_size+(fix_da.ndim,), 
    dtype=displacement_overlap.dtype
)

# add attribute so that I can know the overlap size later.
displacement_zarr.attrs.update({"overlap_size": out_overlap+(0,)})

In [None]:
# set parameters for local registration
settings = {
    'pyramid_levels':2, 
    'pyramid_stop_level': 1, 
    'step_size': [1.0, 1.0, 1.0], # [1.0,1.0,1.0] or [0.5,0.5,0.5] is recommended.
    'block_size': [16, 16, 16],# [32,32,32] or [16,16,16] is recommended.
    'block_energy_epsilon':1e-6,
    'max_iteration_count':100, # -1 is the best but takes long time to converge. 100 is enough most cases.
    'constraints_weight':1000.0,
    'regularization_weight': 0.15, # reduce here for more flexibility at high resolution. default 0.25
    'regularization_scale': 1.0, # default 1.0
    'regularization_exponent': 2.0, # default 2.0
    'image_slots':[ 
        {
            'resampler': 'gaussian',
            'normalize': True,
            'cost_function':[
                {
                    'function':'ncc',
                    'weight':1.0,
                    'radius':3
                }]
        }]
}

In [None]:
# make ray function for parallel
@ray.remote(num_gpus=0.5)
def chunk_wise_registration(chunk_position, displacement_overlap, fix_overlap, mov_l, settings, displacement_zarr, registered_mov_zarr, num_threads, use_gpu):
    mFISHwarp.register.chunk_wise_registration(chunk_position, displacement_overlap, fix_overlap, mov_l, settings, displacement_zarr, registered_mov_zarr, num_threads, use_gpu)
    
@ray.remote(num_gpus=0.5)
def chunk_wise_no_registration(chunk_position, displacement_overlap, fix_overlap, displacement_zarr, registered_mov_zarr):
    mFISHwarp.register.chunk_wise_no_registration(chunk_position, displacement_overlap, fix_overlap, displacement_zarr, registered_mov_zarr)
    
# put large object for ray.
displacement_overlap_id = ray.put(displacement_overlap)

In [None]:
# run chunk-wise registration
num_threads=24 # more is faster
use_gpu=True

for index in list(np.ndindex(*fix_da.numblocks)):
    if flag_array[index]:
        chunk_wise_registration.remote(index, displacement_overlap_id, fix_overlap, mov_l, settings, displacement_zarr, registered_mov_zarr, num_threads,use_gpu)
    else:
        chunk_wise_no_registration.remote(index,displacement_overlap_id, fix_overlap, displacement_zarr, registered_mov_zarr)

In [None]:
#