In [None]:
import numpy as np
import napari
import zarr
from skimage.exposure import match_histograms
import dask.array as da
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

from ome_zarr.writer import write_multiscales_metadata
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from tqdm import tqdm
import mFISHwarp.utils
import mFISHwarp.zarr

import ray

from cucim.skimage.morphology import white_tophat
from cucim.skimage.morphology import ball
import cupy as cp

In [None]:
# read the source
# image path
# '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fused.n5'
# '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'
data_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'
resolution = 0
chunk_size = (1, 256,256,256)
physical_scale = (1.0,2.0,1.3,1.3)
# lazily load the data of the targeted resolution using dask
_, ext = os.path.splitext(data_path)


imgs = []
if ext == '.n5': # n5 assume bigstitcher (bigdataviewer) format
    # create Zarr file object
    # load images according to the input parameters.
    img_zarr = zarr.open(store=zarr.N5Store(data_path), mode='r')
    n5_setups = list(img_zarr.keys())
    res_list = list(img_zarr[n5_setups[0]]['timepoint0'].keys())
    
    for n5_setup in n5_setups:
        imgs.append(da.from_zarr(img_zarr[n5_setup]['timepoint0'][res_list[resolution]]))
    imgs = da.stack(imgs)
        

elif ext == '.zarr': # zarr assumes ome-zarr
    # read the image data
    store = parse_url(data_path, mode="r").store
    reader = Reader(parse_url(data_path))
    # nodes may include images, labels etc
    nodes = list(reader())
    # first node will be the image pixel data
    image_node = nodes[0]

    dask_data = image_node.data
    imgs = dask_data[resolution]
    
    
# rechunk for analysis
imgs = imgs.rechunk(chunk_size)

In [None]:
save_zarr_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/top_hat_10.zarr'
downscale_factor_pyramid = (1,2,2,2)
pyramid_level = 5
axes_info = ['c','z','y','x']

## create zarr to save the histogram matched image.
store = zarr.DirectoryStore(save_zarr_path, dimension_separator='/')
root = zarr.group(store=store)

# np.int16 may not be appropriate, but it works for now to reduce the datasize without clipping.
data_zarr = root.create_dataset('0',shape=imgs.shape,chunks=chunk_size,dtype=np.uint16)

# prepare metadata to zarr
datasets = mFISHwarp.zarr.datasets_metadata_generator(physical_scale, downscale_factor=downscale_factor_pyramid, pyramid_level=pyramid_level)

### write metadata for ome-zarr
write_multiscales_metadata(root, datasets=datasets, axes=axes_info)

In [None]:
index_list = list(np.ndindex(*imgs.numblocks))
chunk_info = imgs.chunks

In [None]:
# img_id = ray.put(imgs)

@ray.remote(num_gpus=0.1)
def top_hat_gpu(index, ballsize=10):
    
    footprint = ball(ballsize)
    img = imgs.blocks[index].compute().squeeze()
    img_cp = cp.asarray(img)
    
    res = cp.asnumpy(white_tophat(img_cp,footprint))

    del footprint
    del img_cp
    cp._default_memory_pool.free_all_blocks()
    
    slicer = tuple(slice(sum(i[:j]),sum(i[:j])+i[j]) for i, j in zip(chunk_info,index))

    root['0'][slicer] = res[np.newaxis, ...]

In [None]:
tasks = []
for index in index_list:
    tasks.append(top_hat_gpu.remote(index))
# using ray.get can make the things wait until it done the all process.
results = ray.get(tasks)

In [None]:
# make pyramid images
data = da.from_zarr(root['0'])
mFISHwarp.zarr.pyramid_from_dask_to_zarr(data, root, downscale_factor=downscale_factor_pyramid, resolution_start=1, pyramid_level=pyramid_level, chunk=mFISHwarp.utils.chunks_from_dask(data))
# pyramid = mFISHwarp.zarr.pyramid_generator_from_dask(data, downscale_factor=downscale_factor_pyramid, pyramid_level=5, chunk=mFISHwarp.utils.chunks_from_dask(data))

In [None]:
# for resolution in range(pyramid_level):
#     if resolution == 0:
#         pass
#     else:
#         arr = pyramid[resolution]
#         p = root.create_dataset(str(resolution),shape=arr.shape,chunks=mFISHwarp.utils.chunks_from_dask(data),dtype=arr.dtype)
#         arr.to_zarr(p,dimension_separator='/')