In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import numpy as np
import napari

import dask_image
import dask.array as da
import dask
import zarr

import dask_stitch.stitch

import mFISHwarp.transform
import mFISHwarp.utils

import ray

In [None]:
# set path
io_path = 'path_to_directory'
fix_zarr_path = '/mnt/ampa_data01/tmurakami/brain01/ch488.zarr' # zarr with pyramid resolution
mov_zarr_path = '/mnt/ampa_data01/tmurakami/brain02/ch488.zarr'
displacement_overlap_path = os.path.join(io_path,'displacements_overlap.zarr') 

# set saving path
moved_path = os.path.join(io_path, 'R02ch488_to_R01ch488.zarr')

# create Zarr file object
fix_zarr = zarr.open(fix_zarr_path, mode='r')
mov_zarr = zarr.open(mov_zarr_path, mode='r')
displacement_overlap = zarr.open(displacement_overlap_path, mode='r')

# zarr to dask
displacement_overlap = da.from_zarr(displacement_overlap, chunks=displacement_overlap.chunks[:-1]+(3,))

# load images as zarr
fix_l = fix_zarr['1']
fix = fix_zarr['0']
mov = mov_zarr['0']
fix_da = da.from_zarr(fix)

# load the overlapped displacement image as zarr
blocksize = fix_l.chunks
overlap = tuple((x-y)//2 for x,y in zip(mFISHwarp.utils.chunks_from_dask(displacement_overlap)[:-1] ,fix_l.chunks))

In [None]:
# trim the outside of the overlapping regions to remove the erroneous outer edge.
trimming_factor = 0.75
trimming_range = tuple(round(i*trimming_factor) for i in overlap)
trimmed_displacement_overlap = da.overlap.trim_overlap(displacement_overlap, trimming_range+(0,), boundary=None)
suboverlap = tuple((x-y)//2 for x,y in zip(mFISHwarp.utils.chunks_from_dask(trimmed_displacement_overlap)[:-1] ,fix_l.chunks))

# stitch displacement
displacement = dask_stitch.stitch.stitch_blocks(trimmed_displacement_overlap, blocksize, suboverlap)

In [None]:
# set rescale factors
rescale_constant = tuple((np.array(fix.shape) / np.array(displacement.shape[:-1])).round().astype(int)) # upsampling ratio in zyx

# prepare dask array for upsampling
out_chunk_size=fix.chunks
upsampled_displacement = mFISHwarp.transform.upscale_displacement_gpu(
    displacement,
    rescale_constant,
    out_chunk_size=out_chunk_size,
)

In [None]:
# make zarr to save
root = zarr.open_group(moved_path, mode='a')
root.create_dataset(
    '0', 
    shape=fix.shape, 
    chunks=fix.chunks, 
    dtype=fix.dtype
)

In [None]:
# get chunk info to save in zarr
chunk_info = fix_da.chunks

# get ray id
upsampled_displacement_id = ray.put(upsampled_displacement)

In [None]:
@ray.remote(num_gpus=0.5)
def warp_block(index, upsampled_displacement):
    disp = upsampled_displacement.blocks[index]

    slicer = tuple(slice(sum(i[:j]),sum(i[:j])+i[j]) for i, j in zip(chunk_info,index))

    root['0'][slicer] = mFISHwarp.transform.transform_block_gpu(disp, mov, size_limit=1024*1024*1024)

In [None]:
# loop over blocks. Upsampling should be fit to gpu
for index in list(np.ndindex(*fix_da.numblocks)):
    warp_block.remote(index, upsampled_displacement_id)

## Optional. Save pyramid resolution

In [None]:
# optional. Save pyramid resolution
from skimage.transform import downscale_local_mean
def rescale_chunk(chunks, rescale_constant):
    rescaled_chunks = []
    for i in range(3):
        rescaled_chunks.append(tuple(np.ceil(np.array(chunks[i]) / rescale_constant[i]).astype(int)))
    rescaled_chunks = tuple(rescaled_chunks)

    return rescaled_chunks

def save_pyramid(file_name_base, downscale_constants, iteration=5):
    # save downsampled resolution
    for i in range(iteration):
        img = da.from_zarr(os.path.join(file_name_base,str(i)))
        down_img = da.map_blocks(
            downscale_local_mean,
            img,
            downscale_constants,
            dtype=img.dtype,
            chunks=rescale_chunk(img.chunks,downscale_constants)
        )
        da.to_zarr(down_img, os.path.join(file_name_base,str(i+1)))
        print('done:' + str(i+1))

In [None]:
downscale_constants = (2,2,2)
save_pyramid(moved_path, downscale_constants)