In [1]:
import numpy as np
import napari
import zarr
from skimage.exposure import match_histograms
import dask.array as da
import os
from ome_zarr.writer import write_multiscales_metadata
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from tqdm import tqdm
import mFISHwarp.utils
import mFISHwarp.zarr

import ray

2024-07-23 14:30:14,062	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.6.5 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
# read the source
# image path
# '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fused.n5'
# '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'
data_path = '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fused.n5'
resolution = 0
chunk_size = (256,256,256)
physical_scale = (1.0,2.0,1.3,1.3)
# lazily load the data of the targeted resolution using dask
_, ext = os.path.splitext(data_path)


imgs = []
if ext == '.n5': # n5 assume bigstitcher (bigdataviewer) format
    # create Zarr file object
    # load images according to the input parameters.
    img_zarr = zarr.open(store=zarr.N5Store(data_path), mode='r')
    n5_setups = list(img_zarr.keys())
    res_list = list(img_zarr[n5_setups[0]]['timepoint0'].keys())
    
    for n5_setup in n5_setups:
        imgs.append(da.from_zarr(img_zarr[n5_setup]['timepoint0'][res_list[resolution]]))
    imgs = da.stack(imgs)
        

elif ext == '.zarr': # zarr assumes ome-zarr
    # read the image data
    store = parse_url(data_path, mode="r").store
    reader = Reader(parse_url(data_path))
    # nodes may include images, labels etc
    nodes = list(reader())
    # first node will be the image pixel data
    image_node = nodes[0]

    dask_data = image_node.data
    imgs = dask_data[resolution]

In [3]:
# read the template
target_path = '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fft_norm_3_99p8.zarr'

# read the image data
store = parse_url(target_path, mode="r").store
root_target = zarr.group(store=store)
reader = Reader(parse_url(target_path))
# nodes may include images, labels etc
nodes = list(reader())
# first node will be the image pixel data
image_node = nodes[0]

dask_data = image_node.data
imgs_target = dask_data[0]



In [4]:
# calculate chunk size for analysis
down_factors = [round(i//j) for i,j in zip(imgs.shape,imgs_target.shape)]
target_chunk_size = mFISHwarp.utils.chunks_from_dask(imgs_target)
chunk_size = [i*j for i, j in zip(down_factors,target_chunk_size)]
# rechunk for analysis
imgs = imgs.rechunk(chunk_size)
imgs = imgs.astype(np.float32)

In [5]:
save_zarr_path = '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/hist_matched.zarr'
downscale_factor_pyramid = (1,2,2,2)
pyramid_level = 1
axes_info = ['c','z','y','x']

## create zarr to save the histogram matched image.
store = zarr.DirectoryStore(save_zarr_path, dimension_separator='/')
root = zarr.group(store=store)

# np.int16 may not be appropriate, but it works for now to reduce the datasize without clipping.
data_zarr = root.create_dataset('0',shape=imgs.shape,chunks=chunk_size,dtype=np.int16)

# prepare metadata to zarr
datasets = mFISHwarp.zarr.datasets_metadata_generator(physical_scale, downscale_factor=downscale_factor_pyramid, pyramid_level=pyramid_level)
# create custom-made attrubute to save normalization parameter
datasets[0]['target'] = target_path
datasets[0]['norm_upper_percentile'] = root_target.attrs['multiscales'][0]['datasets'][0]['norm_upper_percentile']
datasets[0]['norm_upper_values'] = root_target.attrs['multiscales'][0]['datasets'][0]['norm_upper_values']

### write metadata for ome-zarr
write_multiscales_metadata(root, datasets=datasets, axes=axes_info)

In [6]:
index_list = list(np.ndindex(*imgs.numblocks))
chunk_info = imgs.chunks

In [7]:
@ray.remote(num_cpus=1)
def histmatch_to_target(index):
    
    target = imgs_target.blocks[index].compute()
    source = imgs.blocks[index].compute()
    
    matched = match_histograms(source, target)
    # to save data size, convert from float to int16. WARNING: this may cause the clipping.
    matched = np.round(matched).astype(np.int16)    
    slicer = tuple(slice(sum(i[:j]),sum(i[:j])+i[j]) for i, j in zip(chunk_info,index))

    root['0'][slicer] = matched

In [8]:
for index in index_list:
    histmatch_to_target.remote(index)

2024-07-23 14:30:22,915	INFO worker.py:1743 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


