In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

import numpy as np
import napari

import dask_image
import dask.array as da
import dask
import zarr

import mFISHwarp.transform
import mFISHwarp.utils
import mFISHwarp.dask_stitch
import mFISHwarp.zarr

import ray

from ome_zarr.writer import write_multiscales_metadata
from itertools import compress

### Lazily load images as dask

In [2]:
# set path
io_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240620_03_MX002-1/registration_R03'
fix_n5_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240620_03_MX002-1/fused/fused.n5' # zarr with pyramid resolution
mov_n5_path = '/mnt/ampa02_data01/tmurakami/MK_administration/240726_03_MX002-3/fused/fused.n5'

displacement_overlap_path = os.path.join(io_path,'displacements_overlap.zarr') 

# set saving path
moved_path = os.path.join(io_path, 'R03_to_R01ch488.zarr')

# create Zarr file object
fix_zarr = zarr.open(store=zarr.N5Store(fix_n5_path), mode='r')
mov_zarr = zarr.open(store=zarr.N5Store(mov_n5_path), mode='r')
displacement_overlap_zarr = zarr.open(displacement_overlap_path, mode='r')

# zarr to dask
displacement_overlap = da.from_zarr(displacement_overlap_zarr, chunks=displacement_overlap_zarr.chunks[:-1]+(3,))

# load images as zarr
fix = fix_zarr['setup3']['timepoint0']['s0']

mov_n5_setups = list(mov_zarr.keys())
fix_da = da.from_zarr(fix)

### Make fused displacement field

In [3]:
# get overlap information and the original chunk size of displacement field
overlap = displacement_overlap_zarr.attrs['overlap_size'][:-1]# tuple((x-y)//2 for x,y in zip(mFISHwarp.utils.chunks_from_dask(displacement_overlap)[:-1] ,blocksize)) # I should use attributes from zarr.
# displacement_nooverlap is used only to get chunk size information.
displacement_nooverlap = da.overlap.trim_overlap(displacement_overlap, tuple(overlap)+(0,), boundary='reflect')
blocksize = mFISHwarp.utils.chunks_from_dask(displacement_nooverlap)[:-1]

# trim the outside of the overlapping regions to remove the erroneous outer edge.
trimming_factor = 0.75
trimming_range = tuple(round(i*trimming_factor) for i in overlap)

# only suboverlap area is used to fuse the displacement
trimmed_displacement_overlap = da.overlap.trim_overlap(displacement_overlap, trimming_range+(0,), boundary='reflect')
suboverlap = tuple((x-y)//2 for x,y in zip(mFISHwarp.utils.chunks_from_dask(trimmed_displacement_overlap)[:-1] ,blocksize))

# stitch displacement. Note the shape of the displacement is interger fold of the chunk size.
displacement = mFISHwarp.dask_stitch.stitch_blocks(
    trimmed_displacement_overlap, 
    blocksize, 
    suboverlap, 
    displacement_nooverlap.chunks# need full chunk information
)

### upsample the displacement field because the displacement might be calculated using downsampled image
# set rescale factors
rescale_constant = tuple((np.array(fix.shape) / np.array(displacement.shape[:-1])).round().astype(int)) # upsampling ratio in zyx
out_chunk_size = (256,256,256) # too small makes the calculation very slow

# prepare dask array for upsampling
upsampled_displacement = mFISHwarp.transform.upscale_displacement_gpu(
    displacement,
    rescale_constant,
    out_chunk_size=out_chunk_size
)

# get chunk info to save in zarr
chunk_info = (da.rechunk(fix_da, out_chunk_size)).chunks
# get ray id
upsampled_displacement_id = ray.put(upsampled_displacement)

2024-08-12 09:03:30,334	INFO services.py:1374 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


### Create zarr to save the warped image

In [4]:
physical_scale = (1.0,2.0,1.3,1.3)
downscale_factor = (1,2,2,2)
pyramid_level = 5
axes_info = ['c','z','y','x']

# Create the zarr group
store = zarr.DirectoryStore(moved_path, dimension_separator='/')
root = zarr.group(store=store)
# create dataset to save the highest resolution. the name should be '0' according to ome-zarr spec.
root.create_dataset('0', shape=(len(mov_n5_setups),)+fix.shape, chunks=(1,)+out_chunk_size, dtype=fix.dtype)

# write metadata for the zarr using ome-zarr library
datasets = mFISHwarp.zarr.datasets_metadata_generator(physical_scale, downscale_factor, pyramid_level)
write_multiscales_metadata(root, datasets=datasets, axes=axes_info)

# or if you have to resume the process
# root = zarr.open(moved_path,mode='a')

### Warp images

In [5]:
# find out the missing chunks.
# this is to resume the process once the warping was stalled in the middle.
def is_chunk_missing(chunk_idx):
    # Construct the filename based on chunk index
    file_p = ''.join([str(c) + '/' for c in chunk_idx])[:-1]
    chunk_path = os.path.join(moved_path,'0',file_p)
    
    return not os.path.exists(chunk_path)

In [6]:
# 
index_list = list(np.ndindex(*upsampled_displacement[...,0].numblocks))
for i, setup in enumerate(mov_n5_setups):
    mov = mov_zarr[setup]['timepoint0']['s0']
    
    @ray.remote(num_gpus=0.25)
    def warp_block(index, upsampled_displacement):
        disp = upsampled_displacement.blocks[index]

        slicer = tuple(slice(sum(i[:j]),sum(i[:j])+i[j]) for i, j in zip(chunk_info,index))
        chunk_shape = tuple(i[j] for i, j in zip(chunk_info,index))

        disp = mFISHwarp.transform.pad_trim_array_to_size(disp, chunk_shape+(3,), mode='edge') # This is because at the edge chunks do not always have same shape as slice

        root['0'][(i,)+slicer] = mFISHwarp.transform.transform_block_gpu(disp, mov, size_limit=1024*1024*1024)
        
    tasks = []
    sub_index_list = list(compress(index_list,[is_chunk_missing((i,)+idx) for idx in index_list]))
    for index in sub_index_list:
        tasks.append(warp_block.remote(index, upsampled_displacement_id))
    results = ray.get(tasks)

In [8]:
# make pyramid images
data = da.from_zarr(root['0'])
mFISHwarp.zarr.pyramid_from_dask_to_zarr(data, root, downscale_factor=downscale_factor, resolution_start=1, pyramid_level=pyramid_level, chunk=(1,)+out_chunk_size)

In [None]:
viewer = napari.Viewer()
viewer.open(path, plugin="napari-ome-zarr")

# napari.run()

In [7]:
#