In [None]:
import numpy as np
import napari
import zarr
import mFISHwarp.morphology
import mFISHwarp.fft
import mFISHwarp.zarr
from skimage.filters import threshold_otsu
from skimage.exposure import match_histograms
from scipy import fft
import dask.array as da
import os
from ome_zarr.writer import write_multiscales_metadata
from tqdm import tqdm

In [None]:
data_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'
original_chunk_size = (1,256,256,256) # the chunk size used for analysis. This is optional and can be None.
original_physical_scale = (1.0,2.0,1.3,1.3)
reference_chan = 3
upper_percentile = 99
lower_percentile = 30
resolution = 3 # s2 is the maximum considering the RAM size.
hann_window_shrink_factor = 1

## set parameters for saving zarr of fft filtered image
save_zarr_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/fft_norm.zarr'
downscale_factor = (1,2,2,2)
pyramid_level = 1
axes_info = ['c','z','y','x']

# set physical_scale and chunk size considering the downsampling rate
factors = tuple(i**resolution for i in (1,2,2,2))
physical_scale = tuple(i*j for i,j in zip(original_physical_scale, factors))
if original_chunk_size is not None:
    chunk_size = tuple(i//j for i,j in zip(original_chunk_size, factors))

In [None]:
## set parameters for fft
# image path
# '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fused.n5'
# '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'
# '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/round02_3.zarr'
data_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'
original_chunk_size = (1,256,256,256)
original_physical_scale = (1.0,2.0,1.3,1.3)
reference_chan = 3
upper_percentile = 99
lower_percentile = 30
resolution = 3 # s2 is the maximum considering the RAM size.
hann_window_shrink_factor = 1

## set parameters for saving zarr of fft filtered image
save_zarr_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/fft_norm.zarr'
downscale_factor = (1,2,2,2)
pyramid_level = 1
axes_info = ['c','z','y','x']

# set physical_scale and chunk size considering the downsampling rate
factors = tuple(i**resolution for i in downscale_factor) 
physical_scale = tuple(i*j for i,j in zip(original_physical_scale, factors))
chunk_size = tuple(i//j for i,j in zip(original_chunk_size,factors))

# lazily load the data of the targeted resolution using dask
_, ext = os.path.splitext(data_path)

imgs = []
if ext == '.n5': # n5 assume bigstitcher (bigdataviewer) format
    # create Zarr file object
    # load images according to the input parameters.
    img_zarr = zarr.open(store=zarr.N5Store(data_path), mode='r')
    n5_setups = list(img_zarr.keys())
    res_list = list(img_zarr[n5_setups[reference_chan]]['timepoint0'].keys())
    
    for n5_setup in n5_setups:
        imgs.append(da.from_zarr(img_zarr[n5_setup]['timepoint0'][res_list[resolution]]))
    imgs = da.stack(imgs)
        

elif ext == '.zarr': # zarr assumes ome-zarr
    # read the image data
    store = parse_url(data_path, mode="r").store
    reader = Reader(parse_url(data_path))
    # nodes may include images, labels etc
    nodes = list(reader())
    # first node will be the image pixel data
    image_node = nodes[0]

    dask_data = image_node.data
    imgs = dask_data[resolution]
    # img = dask_data[resolution][segment_chan,...].compute()

else:
    raise ValueError("the extension should be .n5 or .zarr")

if original_chunk_size is None:
    chunk_size = [i[0] for i in imgs.chunks] # set the chunk size for saving. this can be arbitrary.

## Create tissue mask

In [None]:
# make mask that covers only tissue to exclude the area outside of the tissue.
# the mask will be used to get 
img_down_ref = imgs[reference_chan].compute()
global_thresh = threshold_otsu(img_down_ref)
img_mask = mFISHwarp.morphology.mask_maker(img_down_ref,global_thresh)

In [None]:
# confirm the tissue masking is correct
viewer = napari.Viewer()

viewer.add_image(img_down_ref, contrast_limits=[0,20000],blending='additive',name='img',color='magenta')
viewer.add_image(img_mask, contrast_limits=[0,50],blending='additive',name='mask',color='green')

## FFT fileter and percentile normalization within the tissue 
The upper and lower range for the normalization is determined by the values in the tissue mask refering the percentile parameters.

In [None]:
### create zarr to save the FFT image.
store = zarr.DirectoryStore(save_zarr_path, dimension_separator='/')
root = zarr.group(store=store)

dataset = root.create_dataset('0',shape=imgs.shape,chunks=chunk_size,dtype=np.float32)

In [None]:
for chan in tqdm(range(imgs.shape[0])):
    # get downsampled image
    img_down = imgs[chan].compute()
    shape = img_down.shape

    # make mask in Fourier space to remove low frequency component.
    hann_3d = mFISHwarp.fft.create_3d_hann_window(tuple(i//hann_window_shrink_factor for i in shape), shape)
    mask = 1 - hann_3d

    filtered_image = mFISHwarp.fft.fft_filter(img_down, mask)

    lower = np.percentile(img_down[np.where(img_mask)],lower_percentile)
    upper = np.percentile(img_down[np.where(img_mask)],upper_percentile)

    lower_f = np.percentile(filtered_image[np.where(img_mask)],lower_percentile)
    upper_f = np.percentile(filtered_image[np.where(img_mask)],upper_percentile)

    # normalization
    norm_img = (img_down.astype(float)-lower)/(upper-lower)
    norm_img_fft = (filtered_image.astype(float)-lower_f)/(upper_f-lower_f)

    dataset[chan,...] = norm_img_fft

In [None]:
### write metadata for ome-zarr
datasets = mFISHwarp.zarr.datasets_metadata_generator(physical_scale, downscale_factor=downscale_factor, pyramid_level=pyramid_level)
write_multiscales_metadata(root, datasets=datasets, axes=axes_info)

In [None]:
viewer = napari.Viewer()
viewer.open(save_zarr_path, plugin="napari-ome-zarr", contrast_limits=[0,1])

In [None]:
### left of the things are garbages

In [None]:
    # viewer = napari.Viewer()
    # viewer.add_image(norm_img, contrast_limits=[0,1],blending='translucent',name='img')
    # viewer.add_image(norm_img_fft, contrast_limits=[0,1],blending='translucent',name='fft')

In [None]:
box_size = 32
z = 85; y = 775; x = 575
# z = 160; y = 584; x = 478
# z = 160; y = 300; x = 233
# z = 100; y = 741; x = 671
# z = 100*2; y = 742*2; x = 638*2
# z = 164; y = 1442; x = 454
# factors = {'s0':1, 's1':2, 's2':4, 's3':8}
factor = 2 ** resolution# factors[resolution]
my_range = (slice(z,z+box_size), slice(y,y+box_size), slice(x,x+box_size))
my_range_in_high = (slice(z*factor,z*factor+box_size*factor),
                    slice(y*factor,y*factor+box_size*factor),
                    slice(x*factor,x*factor+box_size*factor))
matched = match_histograms(norm_img[my_range], norm_img_fft[my_range])

In [None]:
img_high = da.from_zarr(img_zarr[n5_setups[segment_chan]]['timepoint0']['s0'])
block = img_high[my_range_in_high].compute().astype(float)
matched_fft = match_histograms(block, norm_img_fft[my_range])
matched_reg = match_histograms(block, norm_img[my_range])

In [None]:
viewer = napari.Viewer()

viewer.add_image(matched_reg, contrast_limits=[0,1],blending='translucent',name='matched_reg')
viewer.add_image(matched_fft, contrast_limits=[0,1],blending='translucent',name='matched_fft')

In [None]:
viewer = napari.Viewer()

viewer.add_image(norm_img[my_range], contrast_limits=[0,1],blending='translucent',name='img')
viewer.add_image(norm_img_fft[my_range], contrast_limits=[0,1],blending='translucent',name='fft')
viewer.add_image(matched, contrast_limits=[0,1],blending='translucent',name='matched')

In [None]:
#