In [None]:
import numpy as np
import napari
import zarr
import mFISHwarp.morphology
import mFISHwarp.fft
import mFISHwarp.zarr
from skimage.filters import threshold_otsu
from scipy import fft
import dask.array as da
import os
from ome_zarr.writer import write_multiscales_metadata
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from tqdm import tqdm

In [None]:
from dask.distributed import Client
client = Client(n_workers=8, threads_per_worker=1, dashboard_address='localhost:8787') # https://docs.dask.org/en/latest/how-to/deploy-dask/single-distributed.html
client

In [None]:
## set parameters for fft
# image path
# '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fused.n5'
# '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'
data_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'
original_chunk_size = (1,256,256,256)
original_physical_scale = (1.0,2.0,1.3,1.3)
reference_chan = 3
upper_percentile = 99.8
resolution = 2 # 2 is the maximum considering the RAM size.
hann_window_shrink_factor = 1

## set parameters for saving zarr of fft filtered image
# '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fft_norm_2_99p8.zarr'
# '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/registration/round02_3.zarr'
save_zarr_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fft_norm_2.zarr'
downscale_factor = (1,2,2,2)
pyramid_level = 1
axes_info = ['c','z','y','x']

# set physical_scale and chunk size considering the downsampling rate
factors = tuple(i**resolution for i in downscale_factor) 
physical_scale = tuple(i*j for i,j in zip(original_physical_scale, factors))
chunk_size = tuple(i//j for i,j in zip(original_chunk_size,factors))

# lazily load the data of the targeted resolution using dask
_, ext = os.path.splitext(data_path)

imgs = []
if ext == '.n5': # n5 assume bigstitcher (bigdataviewer) format
    # create Zarr file object
    # load images according to the input parameters.
    img_zarr = zarr.open(store=zarr.N5Store(data_path), mode='r')
    n5_setups = list(img_zarr.keys())
    res_list = list(img_zarr[n5_setups[reference_chan]]['timepoint0'].keys())
    
    for n5_setup in n5_setups:
        imgs.append(da.from_zarr(img_zarr[n5_setup]['timepoint0'][res_list[resolution]]))
    imgs = da.stack(imgs)
        

elif ext == '.zarr': # zarr assumes ome-zarr
    # read the image data
    store = parse_url(data_path, mode="r").store
    reader = Reader(parse_url(data_path))
    # nodes may include images, labels etc
    nodes = list(reader())
    # first node will be the image pixel data
    image_node = nodes[0]

    dask_data = image_node.data
    imgs = dask_data[resolution]

else:
    raise ValueError("the extension should be .n5 or .zarr")

if original_chunk_size is None:
    chunk_size = [i[0] for i in imgs.chunks] # set the chunk size for saving. this can be arbitrary.

## Create tissue mask

In [None]:
# make mask that covers only tissue to exclude the area outside of the tissue.
# the mask will be used to get
img_down_ref = imgs[reference_chan].compute()
global_thresh = threshold_otsu(img_down_ref)
img_mask = mFISHwarp.morphology.mask_maker(img_down_ref,global_thresh)

In [None]:
# confirm the tissue masking is correct
viewer = napari.Viewer()

viewer.add_image(img_down_ref, contrast_limits=[0,20000],blending='additive',name='img',colormap='magenta')
viewer.add_image(img_mask, contrast_limits=[0,10],blending='additive',name='mask',colormap='green')

## FFT fileter and percentile normalization within the tissue 
The upper and lower range for the normalization is determined by the values in the tissue mask refering the percentile parameters.

In [None]:
## create zarr to save the FFT image.
store = zarr.DirectoryStore(save_zarr_path, dimension_separator='/')
root = zarr.group(store=store)

data_zarr = root.create_dataset('0',shape=imgs.shape,chunks=chunk_size,dtype=np.float32)

# prepare metadata to zarr
datasets = mFISHwarp.zarr.datasets_metadata_generator(physical_scale, downscale_factor=downscale_factor, pyramid_level=pyramid_level)
# create custom-made attrubute to save normalization parameter
datasets[0]['norm_upper_values'] = {}
datasets[0]['norm_upper_percentile'] = upper_percentile

In [None]:
# create Hann filter window for FFT filetering
shape = imgs.shape[1:]
mask = 1 - mFISHwarp.fft.create_3d_hann_window(tuple(i//hann_window_shrink_factor for i in shape), shape)

for chan in tqdm(range(imgs.shape[0])):
    # get downsampled image
    img_down = imgs[chan].compute()
    
    # get absolute values after FFT fileter, and call it as filtered image
    filtered_image = mFISHwarp.fft.fft_filter(img_down, mask)
    masked_vals = filtered_image[np.where(img_mask)]

    upper_f = np.percentile(masked_vals,upper_percentile)

    # normalization
    # norm_img_fft = filtered_image.astype(float)/upper_f

    # data_zarr[chan,...] = norm_img_fft
    data_zarr[chan,...] = filtered_image
    datasets[0]['norm_upper_values'][str(chan)] =  upper_f

In [None]:
### write metadata for ome-zarr
write_multiscales_metadata(root, datasets=datasets, axes=axes_info)

In [None]:
viewer = napari.Viewer()
viewer.open(save_zarr_path, plugin="napari-ome-zarr")

In [None]:
###

In [None]:
### for EDA

# import seaborn as sns

# i = 1
# test_percentile = 99.8

# dat = masked_vals_list[i][::10000]
# # sns.histplot(data=img_down[np.where(img_mask)][::10000])
# sns.histplot(data=np.abs(dat), log_scale=True)
# print(np.percentile(dat,test_percentile))
# print(np.log10(np.percentile(dat,test_percentile)))