In [None]:
import numpy as np
import time, os, sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"


from skimage import io
import tqdm
import napari
import pandas as pd

import zarr
from dask import array as da

import mFISHwarp.morphology

from skimage.filters import threshold_otsu

In [None]:
# image path
sources = ['/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5',
          '/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fused.n5']

voxel_size = (2.0,1.3,1.3)
reference_chan = 3
upper_percentile = 99
lower_percentile = 30
lookup = {}

for n5_path in sources:
    # create Zarr file object
    img_zarr = zarr.open(store=zarr.N5Store(n5_path), mode='r')
    n5_setups = list(img_zarr.keys())

    img_down_ref = img_zarr[n5_setups[reference_chan]]['timepoint0']['s4'][:]
    global_thresh = threshold_otsu(img_down_ref)
    img_mask = mFISHwarp.morphology.mask_maker(img_down_ref,global_thresh)
    sub_lookup = {}

    for i, setup in enumerate(n5_setups):
        img_down = img_zarr[setup]['timepoint0']['s4'][:]
        # get values at the certain percentile within the mask
        lower = np.percentile(img_down[np.where(img_mask)],lower_percentile)
        upper = np.percentile(img_down[np.where(img_mask)],upper_percentile)
        sub_lookup[i] = {'lower':lower,'upper':upper}


    lookup[n5_path] = sub_lookup

In [None]:
# import pickle
# # open a file, where you ant to store the data
# file = open('/mnt/ampa02_data01/tmurakami/model_training/norm_values.pkl', 'wb')
# # dump information to that file
# pickle.dump(lookup, file)
# # close the file
# file.close()

# # open a file, where you stored the pickled data
# file = open('/mnt/ampa02_data01/tmurakami/model_training/norm_values.pkl', 'rb')
# # dump information to that file
# data = pickle.load(file)
# # close the file
# file.close()

In [None]:
# load information
info_path = '/mnt/ampa02_data01/tmurakami/model_training/info.pkl'
obj = pd.read_pickle(info_path)

# image path
save_path = '/mnt/ampa02_data01/tmurakami/model_training/crops'

for i in range(obj.shape[0]):
    print(i)
    n5_path = obj.loc[i,'source']

    # create Zarr file object
    img_zarr = zarr.open(store=zarr.N5Store(n5_path), mode='r')

    corner_positions = obj.loc[i,'corner']
    crop_size = obj.loc[i,'crop_size']
    segment_chan = obj.loc[i,'channel']
    reference_chan = obj.loc[i,'ref_channel']
    plane_position = int(obj.loc[i,'plane_position'])

    # load images according to the input parameters.
    n5_setups = list(img_zarr.keys())
    img_ref = img_zarr[n5_setups[reference_chan]]['timepoint0']['s0']
    img_ref_ = img_ref[tuple(slice(i,i+j) for i,j in zip(corner_positions, crop_size))][plane_position,...]
    img_ref_norm = (img_ref_.astype(float) - lookup[n5_path][reference_chan]['lower']) / (lookup[n5_path][reference_chan]['upper'] - lookup[n5_path][reference_chan]['lower'])

    img = img_zarr[n5_setups[segment_chan]]['timepoint0']['s0']
    img_ = img[tuple(slice(i,i+j) for i,j in zip(corner_positions, crop_size))][plane_position,...]
    img_norm = (img_.astype(float) - lookup[n5_path][segment_chan]['lower']) / (lookup[n5_path][segment_chan]['upper'] - lookup[n5_path][segment_chan]['lower'])


    imgs = np.stack([img_ref_norm,img_norm])
    prefix = str(i)
    while len(prefix) < 4:
        prefix = '0' + prefix
        
        
    img_path = os.path.join(save_path, prefix+'_img_norm.tif')
    io.imsave(img_path, imgs, plugin='tifffile', metadata={'axes': 'CYX'})