In [1]:
import numpy as np
import time, os, sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import skimage.io
import zarr

import my_distributed_segmentation

import dask
from dask import array as da
import dask_image.imread

import mFISH3D.segment_utils

In [2]:
from dask.distributed import Client
client = Client(n_workers=1, threads_per_worker=1, dashboard_address='localhost:8787') # https://docs.dask.org/en/latest/how-to/deploy-dask/single-distributed.html
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 1
Total threads: 1,Total memory: 503.52 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:38883,Workers: 1
Dashboard: http://127.0.0.1:8787/status,Total threads: 1
Started: Just now,Total memory: 503.52 GiB

0,1
Comm: tcp://127.0.0.1:45943,Total threads: 1
Dashboard: http://127.0.0.1:41763/status,Memory: 503.52 GiB
Nanny: tcp://127.0.0.1:45161,
Local directory: /tmp/dask-worker-space/worker-5npgmm0u,Local directory: /tmp/dask-worker-space/worker-5npgmm0u


In [3]:
# make small dataset for test
outprefix = '/mnt/ampa_data01/tmurakami/220715_prefrontal_q2_R01/R02_R01/R02ch640_to_R01_segmentation'
segment_overlap = da.from_zarr(os.path.join(outprefix,'segmented_overlap.zarr'))

In [4]:
segment_overlap = da.overlap.trim_overlap(segment_overlap,depth=(16,32,32),boundary='reflect')
segment_overlap

Unnamed: 0,Array,Chunk
Bytes,773.22 GiB,364.50 MiB
Shape,"(1440, 14720, 9792)","(288, 576, 576)"
Count,6631 Tasks,2210 Chunks
Type,int32,numpy.ndarray
"Array Chunk Bytes 773.22 GiB 364.50 MiB Shape (1440, 14720, 9792) (288, 576, 576) Count 6631 Tasks 2210 Chunks Type int32 numpy.ndarray",9792  14720  1440,

Unnamed: 0,Array,Chunk
Bytes,773.22 GiB,364.50 MiB
Shape,"(1440, 14720, 9792)","(288, 576, 576)"
Count,6631 Tasks,2210 Chunks
Type,int32,numpy.ndarray


In [5]:
diameter = (16,32,32)
#image = da.asarray(image)
depth = tuple(np.ceil(diameter).astype(np.int64))
boundary = "reflect"
iou_depth=2
iou_threshold=0.7
block_iter = mFISH3D.segment_utils.get_block_iter(segment_overlap)

labeled_blocks = np.empty(segment_overlap.numblocks, dtype=object)

total = None
for index, input_block in block_iter:
    labeled_block = input_block #.astype(np.int64)
    n = input_block.max()

    n = dask.delayed(np.int32)(n)
    n = da.from_delayed(n, shape=(), dtype=np.int32)

    total = n if total is None else total + n

    block_label_offset = da.where(labeled_block > 0, total, np.int32(0))
    labeled_block += block_label_offset

    labeled_blocks[index] = labeled_block
    total += n

# Put all the blocks together
block_labeled = da.block(labeled_blocks.tolist())

depth = da.overlap.coerce_depth(len(depth), depth)

if np.prod(block_labeled.numblocks) > 1:
    iou_depth = da.overlap.coerce_depth(len(depth), iou_depth)

    if any(iou_depth[ax] > depth[ax] for ax in depth.keys()):
        raise DistSegError("iou_depth (%s) > depth (%s)" % (iou_depth, depth))

    trim_depth = {k: depth[k] - iou_depth[k] for k in depth.keys()}
    block_labeled = da.overlap.trim_internal(
        block_labeled, trim_depth, boundary=boundary
    )
    block_labeled = my_distributed_segmentation.link_labels(
        block_labeled,
        total,
        iou_depth,
        iou_threshold=iou_threshold,
    )

    block_labeled = da.overlap.trim_internal(
        block_labeled, iou_depth, boundary=boundary
    )

else:
    block_labeled = da.overlap.trim_internal(
        block_labeled, depth, boundary=boundary
    )


In [6]:
label_zarr = os.path.join(outprefix,'segmented_temp.zarr')
block_labeled.to_zarr(label_zarr)

In [7]:
# for the quality check, resave the label in pyramid format
# label_zarr = '/mnt/ampa_data01/tmurakami/220310_0004_R01/R02_R01/R02ch561_to_R01_segmentation/segmented_fin.zarr'
label =  da.from_zarr(label_zarr)
label = da.rechunk(label,(256,256,256))

In [8]:
def save_pyramid(img, file_name_base, iteration=5):
    """
    img: dask array
    """
    # save initial resolution
    da.to_zarr(img, os.path.join(file_name_base,str(0)))

    # save downsampled resolution
    for i in range(iteration):
        img = da.from_zarr(os.path.join(file_name_base,str(i)))
        down_img = img[::2,::2,::2]
        da.to_zarr(down_img, os.path.join(file_name_base,str(i+1)))
        print('done:' + str(i+1))

In [9]:
dest_name = os.path.join(outprefix,'segmented_pyramid.zarr')
root = zarr.open_group(dest_name, mode='w')
save_pyramid(label,dest_name)

done:1
done:2
done:3
done:4
done:5


In [10]:
#