In [None]:
import numpy as np
import os

import zarr

import dask
from dask import array as da

import mFISHwarp.utils
import mFISHwarp.distributed_segmentation

In [None]:
from dask.distributed import Client
client = Client(n_workers=1, threads_per_worker=1, dashboard_address='localhost:8787') # https://docs.dask.org/en/latest/how-to/deploy-dask/single-distributed.html
client

In [None]:
# make small dataset for test
outprefix = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/segmentation'
segment_overlap = da.from_zarr(os.path.join(outprefix,'segmented_overlap.zarr'))
label_zarr = os.path.join(outprefix,'segmented_fused.zarr')

In [None]:
# if the overlap is (32,64,64), trim overlap (16,32,32) and fuse with diameter (16,32,32).
segment_overlap = da.overlap.trim_overlap(segment_overlap,depth=(16,32,32),boundary='reflect')
segment_overlap

In [None]:
diameter = (16,32,32)
#image = da.asarray(image)
depth = tuple(np.ceil(diameter).astype(np.int64))
boundary = "reflect"
iou_depth=2
iou_threshold=0.7
block_iter = mFISHwarp.utils.get_block_iter(segment_overlap)

labeled_blocks = np.empty(segment_overlap.numblocks, dtype=object)

total = None
for index, input_block in block_iter:
    labeled_block = input_block #.astype(np.int64)
    n = input_block.max()

    n = dask.delayed(np.int32)(n)
    n = da.from_delayed(n, shape=(), dtype=np.int32)

    total = n if total is None else total + n

    block_label_offset = da.where(labeled_block > 0, total, np.int32(0))
    labeled_block += block_label_offset

    labeled_blocks[index] = labeled_block
    total += n

# Put all the blocks together
block_labeled = da.block(labeled_blocks.tolist())

depth = da.overlap.coerce_depth(len(depth), depth)

if np.prod(block_labeled.numblocks) > 1:
    iou_depth = da.overlap.coerce_depth(len(depth), iou_depth)

    if any(iou_depth[ax] > depth[ax] for ax in depth.keys()):
        raise DistSegError("iou_depth (%s) > depth (%s)" % (iou_depth, depth))

    trim_depth = {k: depth[k] - iou_depth[k] for k in depth.keys()}
    block_labeled = da.overlap.trim_internal(
        block_labeled, trim_depth, boundary=boundary
    )
    block_labeled = mFISHwarp.distributed_segmentation.link_labels(
        block_labeled,
        total,
        iou_depth,
        iou_threshold=iou_threshold,
    )

    block_labeled = da.overlap.trim_internal(
        block_labeled, iou_depth, boundary=boundary
    )

else:
    block_labeled = da.overlap.trim_internal(
        block_labeled, depth, boundary=boundary
    )

In [None]:
### Warning. This takes a lot of ram space and disk space for spilling. 
# There is not way to mitigate this because the re-labeling requires all information in one hand. 
# Unless we have better format other than labeling image (such as YOLO json in 3D), we have to accept this wierd process.
block_labeled.to_zarr(label_zarr)

### Optional pyramid formatting

In [None]:
# I found ome-ngff format is not well suited for labeled ID images (to be tested)? Just simply save in zarr.
label =  da.from_zarr(label_zarr)
chunk_size = (256,256,256)
label = da.rechunk(label,chunk_size)

In [None]:
def save_pyramid(img, file_name_base, iteration=5, chunk_size=(256,256,256), down_scale=2):
    """
    img: dask array
    """
    # save initial resolution
    img = da.rechunk(img,chunk_size)
    da.to_zarr(img, os.path.join(file_name_base,str(0)))

    # save downsampled resolution
    for i in range(iteration):
        img = da.from_zarr(os.path.join(file_name_base,str(i)))
        down_img = img[::down_scale,::down_scale,::down_scale]
        down_img = da.rechunk(down_img,chunk_size)
        da.to_zarr(down_img, os.path.join(file_name_base,str(i+1)))
        print('done:' + str(i+1))

In [None]:
dest_name = os.path.join(outprefix,'segmented_pyramid.zarr')
root = zarr.open_group(dest_name, mode='w')
save_pyramid(label,dest_name)

In [None]:
#