In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

import numpy as np
import napari

import dask_image
import dask.array as da
import dask
import zarr

# import dask_stitch.stitch

import mFISHwarp.transform
import mFISHwarp.utils
import mFISHwarp.dask_stitck

import ray

from ome_zarr.io import parse_url
from ome_zarr.writer import write_image
from ome_zarr.reader import Reader

In [2]:
# set path
io_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration'
fix_n5_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5' # zarr with pyramid resolution
mov_n5_path = '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fused.n5'
displacement_overlap_path = os.path.join(io_path,'displacements_overlap.zarr') 
mov_transformed_path = os.path.join(io_path,'transformed_midres.zarr') # this is used only to recall overlap

# set saving path
moved_path = os.path.join(io_path, 'R02ch561_to_R01ch488.zarr')

# create Zarr file object
fix_zarr = zarr.open(store=zarr.N5Store(fix_n5_path), mode='r')
mov_zarr = zarr.open(store=zarr.N5Store(mov_n5_path), mode='r')
displacement_overlap_zarr = zarr.open(displacement_overlap_path, mode='r')
mov_transformed_zarr = zarr.open(mov_transformed_path, mode='r')

# zarr to dask
displacement_overlap = da.from_zarr(displacement_overlap_zarr, chunks=displacement_overlap_zarr.chunks[:-1]+(3,))
mov_transformed = da.from_zarr(mov_transformed_zarr)

# load images as zarr
fix_l = fix_zarr['setup3']['timepoint0']['s1']
fix = fix_zarr['setup3']['timepoint0']['s0']
mov = mov_zarr['setup2']['timepoint0']['s0']
fix_da = da.from_zarr(fix)

In [6]:
# load the overlapped displacement image as zarr
blocksize = mFISHwarp.utils.chunks_from_dask(mov_transformed)
overlap = tuple((x-y)//2 for x,y in zip(mFISHwarp.utils.chunks_from_dask(displacement_overlap)[:-1] ,blocksize)) # I should use attributes from zarr.

# trim the outside of the overlapping regions to remove the erroneous outer edge.
trimming_factor = 0.75
trimming_range = tuple(round(i*trimming_factor) for i in overlap)

# only suboverlap area is used to fuse the displacement
trimmed_displacement_overlap = da.overlap.trim_overlap(displacement_overlap, trimming_range+(0,), boundary='reflect')
suboverlap = tuple((x-y)//2 for x,y in zip(mFISHwarp.utils.chunks_from_dask(trimmed_displacement_overlap)[:-1] ,blocksize))

# stitch displacement. Note the shape of the displacement is interger fold of the chunk size.
displacement = mFISHwarp.dask_stitck.stitch_blocks(
    trimmed_displacement_overlap, 
    blocksize, 
    suboverlap, 
    mov_transformed.chunks+(3,)# need full chunk information
)

In [4]:
# set rescale factors
rescale_constant = tuple((np.array(fix.shape) / np.array(displacement.shape[:-1])).round().astype(int)) # upsampling ratio in zyx
out_chunk_size = (256,256,256) # too small makes the calculation very slow
# prepare dask array for upsampling
upsampled_displacement = mFISHwarp.transform.upscale_displacement_gpu(
    displacement,
    rescale_constant,
    out_chunk_size=out_chunk_size
)
fix_da = da.rechunk(fix_da, out_chunk_size)

In [5]:
# make zarr to save
root = zarr.open_group(moved_path, mode='a')
# root.create_dataset(
#     '0', 
#     shape=fix.shape, 
#     chunks=out_chunk_size, 
#     dtype=fix.dtype
# )

In [6]:
# get chunk info to save in zarr
chunk_info = fix_da.chunks

# get ray id
upsampled_displacement_id = ray.put(upsampled_displacement)

2024-05-03 16:43:44,658	INFO services.py:1374 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


In [7]:
@ray.remote(num_gpus=0.2)
def warp_block(index, upsampled_displacement):
    disp = upsampled_displacement.blocks[index]

    slicer = tuple(slice(sum(i[:j]),sum(i[:j])+i[j]) for i, j in zip(chunk_info,index))
    chunk_shape = tuple(i[j] for i, j in zip(chunk_info,index))
    
    disp = mFISHwarp.transform.pad_trim_array_to_size(disp, chunk_shape+(3,), mode='edge') # This is because at the edge chunks do not always have same shape as slice
    
    root['0'][slicer] = mFISHwarp.transform.transform_block_gpu(disp, mov, size_limit=1024*1024*1024)

In [8]:
# loop over blocks. Upsampling should be fit to gpu
index_list = list(np.ndindex(*fix_da.numblocks))
# index_list.reverse()
for index in index_list:
    sep = '.'
    file_name = sep.join([str(i) for i in index])
    existing_files = os.listdir(os.path.join(moved_path,'0'))
    if not file_name in existing_files:
        warp_block.remote(index, upsampled_displacement_id)

In [None]:
### test direct saving to ome-zarr

In [None]:
ome_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/round02_02.zarr'


store = parse_url(ome_path, mode="w").store
root = zarr.group(store=store)
# root = zarr.open_group(moved_path, mode='a')

In [10]:

# # Define the shape and chunk size of the array
# shape = (10, 512, 512)
# chunks = (1, 512, 512)

# # Create an empty Zarr array
# store = zarr.DirectoryStore('/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/round02_02.zarr')
# root = zarr.group(store=store)
# zarr_array = root.create_dataset('image', shape=shape, chunks=chunks, dtype='float32')

# # Function to generate chunks of data (replace with your actual data)
# def generate_chunk(chunk_index):
#     # Simulate data generation for a chunk
#     return np.random.random(chunks)

# # Write data chunk by chunk
# for i in range(shape[0]):
#     chunk_index = (i, 0, 0)
#     data_chunk = generate_chunk(chunk_index)
#     zarr_array[chunk_index[0]:chunk_index[0]+1, :, :] = data_chunk

# # Write OME-Zarr metadata
# write_image(root, zarr_array)

AttributeError: 

In [3]:
url = parse_url("/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/round02.zarr")
reader = Reader(url)
nodes = list(reader())
# first node will be the image pixel data
image_node = nodes[0]

dask_data = image_node.data

In [9]:
dask_data[1]

Unnamed: 0,Array,Chunk
Bytes,252.75 GiB,32.00 MiB
Shape,"(4, 2475, 4274, 3207)","(1, 256, 256, 256)"
Dask graph,8840 chunks in 2 graph layers,8840 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 252.75 GiB 32.00 MiB Shape (4, 2475, 4274, 3207) (1, 256, 256, 256) Dask graph 8840 chunks in 2 graph layers Data type uint16 numpy.ndarray",4  1  3207  4274  2475,

Unnamed: 0,Array,Chunk
Bytes,252.75 GiB,32.00 MiB
Shape,"(4, 2475, 4274, 3207)","(1, 256, 256, 256)"
Dask graph,8840 chunks in 2 graph layers,8840 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray


## Optional. Save pyramid resolution

In [None]:
# optional. Save pyramid resolution
from skimage.transform import downscale_local_mean
def rescale_chunk(chunks, rescale_constant):
    rescaled_chunks = []
    for i in range(3):
        rescaled_chunks.append(tuple(np.ceil(np.array(chunks[i]) / rescale_constant[i]).astype(int)))
    rescaled_chunks = tuple(rescaled_chunks)

    return rescaled_chunks

def save_pyramid(file_name_base, downscale_constants, iteration=5):
    # save downsampled resolution
    for i in range(iteration):
        img = da.from_zarr(os.path.join(file_name_base,str(i)))
        down_img = da.map_blocks(
            downscale_local_mean,
            img,
            downscale_constants,
            dtype=img.dtype,
            chunks=rescale_chunk(img.chunks,downscale_constants)
        )
        da.to_zarr(down_img, os.path.join(file_name_base,str(i+1)))
        print('done:' + str(i+1))

In [None]:
downscale_constants = (2,2,2)
save_pyramid(moved_path, downscale_constants)