In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

import numpy as np
import napari

import dask_image
import dask.array as da
import dask
import zarr

import mFISHwarp.transform
import mFISHwarp.utils
import mFISHwarp.dask_stitch

import ray

from ome_zarr.io import parse_url
from ome_zarr.writer import write_image

In [None]:
# set path
io_path = '/mnt/ampa02_data01/tmurakami/kidney/registration'
fix_n5_path = '/mnt/ampa02_data01/tmurakami/kidney/240614_kidney_finescan/fused/fused.n5' # zarr with pyramid resolution
mov_n5_path = '/mnt/ampa02_data01/tmurakami/kidney/240709_kidney-2nd-round-finescan/fused/fused.n5'

displacement_overlap_path = os.path.join(io_path,'displacements_overlap.zarr') 
mov_transformed_path = os.path.join(io_path,'transformed_midres.zarr') # this is used only to recall overlap

# set saving path
moved_path = os.path.join(io_path, 'R02ch561_to_R01ch488.zarr')

# create Zarr file object
fix_zarr = zarr.open(store=zarr.N5Store(fix_n5_path), mode='r')
mov_zarr = zarr.open(store=zarr.N5Store(mov_n5_path), mode='r')
displacement_overlap_zarr = zarr.open(displacement_overlap_path, mode='r')
mov_transformed_zarr = zarr.open(mov_transformed_path, mode='r')

# zarr to dask
displacement_overlap = da.from_zarr(displacement_overlap_zarr, chunks=displacement_overlap_zarr.chunks[:-1]+(3,))
mov_transformed = da.from_zarr(mov_transformed_zarr)

# load images as zarr
fix_l = fix_zarr['setup3']['timepoint0']['s1']
fix = fix_zarr['setup3']['timepoint0']['s0']
mov = mov_zarr['setup2']['timepoint0']['s0']
fix_da = da.from_zarr(fix)

In [None]:
# load the overlapped displacement image as zarr
blocksize = mFISHwarp.utils.chunks_from_dask(mov_transformed)
overlap = tuple((x-y)//2 for x,y in zip(mFISHwarp.utils.chunks_from_dask(displacement_overlap)[:-1] ,blocksize)) # I should use attributes from zarr.

# trim the outside of the overlapping regions to remove the erroneous outer edge.
trimming_factor = 0.75
trimming_range = tuple(round(i*trimming_factor) for i in overlap)

# only suboverlap area is used to fuse the displacement
trimmed_displacement_overlap = da.overlap.trim_overlap(displacement_overlap, trimming_range+(0,), boundary='reflect')
suboverlap = tuple((x-y)//2 for x,y in zip(mFISHwarp.utils.chunks_from_dask(trimmed_displacement_overlap)[:-1] ,blocksize))

# stitch displacement. Note the shape of the displacement is interger fold of the chunk size.
displacement = mFISHwarp.dask_stitch.stitch_blocks(
    trimmed_displacement_overlap, 
    blocksize, 
    suboverlap, 
    mov_transformed.chunks+(3,)# need full chunk information
)

In [None]:
# set rescale factors
rescale_constant = tuple((np.array(fix.shape) / np.array(displacement.shape[:-1])).round().astype(int)) # upsampling ratio in zyx
out_chunk_size = (256,256,256) # too small makes the calculation very slow
# prepare dask array for upsampling
upsampled_displacement = mFISHwarp.transform.upscale_displacement_gpu(
    displacement,
    rescale_constant,
    out_chunk_size=out_chunk_size
)
fix_da = da.rechunk(fix_da, out_chunk_size)

In [None]:
# make zarr to save
root = zarr.open_group(moved_path, mode='a')
# root.create_dataset(
#     '0', 
#     shape=fix.shape, 
#     chunks=out_chunk_size, 
#     dtype=fix.dtype
# )

In [None]:
# get chunk info to save in zarr
chunk_info = fix_da.chunks

# get ray id
upsampled_displacement_id = ray.put(upsampled_displacement)

In [None]:
@ray.remote(num_gpus=0.2)
def warp_block(index, upsampled_displacement):
    disp = upsampled_displacement.blocks[index]

    slicer = tuple(slice(sum(i[:j]),sum(i[:j])+i[j]) for i, j in zip(chunk_info,index))
    chunk_shape = tuple(i[j] for i, j in zip(chunk_info,index))
    
    disp = mFISHwarp.transform.pad_trim_array_to_size(disp, chunk_shape+(3,), mode='edge') # This is because at the edge chunks do not always have same shape as slice
    
    root['0'][slicer] = mFISHwarp.transform.transform_block_gpu(disp, mov, size_limit=1024*1024*1024)

In [None]:
# upsampled image should fit to gpu RAM.
index_list = list(np.ndindex(*fix_da.numblocks))
# index_list.reverse()
for index in index_list:
    sep = '.'
    file_name = sep.join([str(i) for i in index])
    existing_files = os.listdir(os.path.join(moved_path,'0'))
    if not file_name in existing_files:
        warp_block.remote(index, upsampled_displacement_id)

## Save images in OME-Zarr format

In [None]:
### For transformed Zarr
zarr_paths = [
    '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/R02ch785_to_R01ch488.zarr',
    '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/R02ch640_to_R01ch488.zarr',
    '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/R02ch561_to_R01ch488.zarr',
    '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/R02ch488_to_R01ch488.zarr'
]

imgs = []
resolution = '0'
voxel_ratio = (1.0, 2.0, 1.3, 1.3) # czyx

for zarr_path in zarr_paths:
    img_zarr = zarr.open(zarr_path, mode='r')
    imgs.append(da.from_zarr(img_zarr[resolution]))
imgs = da.stack(imgs)

In [None]:
# documentation 
# https://ome-zarr.readthedocs.io/en/stable/python.html#reading-ome-ngff-images

path = "/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/round02_2.zarr"
os.mkdir(path)

store = parse_url(path, mode="w").store
root = zarr.group(store=store)

# write pyramid
# Note that the scaling is limited in xy. No downsampling in z. See https://github.com/ome/ome-zarr-py/issues/262 .
write_image(image=imgs, group=root, axes="czyx", storage_options=dict(chunks=(1, 256, 256, 256)))

In [None]:
# re-write attribution to fix the voxel size. This is optional but can fix the in-isotropic voxel size.

my_attr = root.attrs["multiscales"].copy()

for i, attr in enumerate(my_attr[0]['datasets']):
    scale = attr['coordinateTransformations'][0]['scale']
    
    my_attr[0]['datasets'][i]['coordinateTransformations'][0]['scale'] = [x * y for x,y in zip(scale,voxel_ratio)]
    
del root.attrs["multiscales"]

# recreate the attributes
root.attrs["multiscales"] = my_attr

In [None]:
viewer = napari.Viewer()
viewer.open(path, plugin="napari-ome-zarr")

# napari.run()