In [1]:
import numpy as np
import os

from skimage import io
from tqdm import tqdm
import napari
import pandas as pd
from skimage import exposure

import zarr
import dask.array as da

import pywt

In [2]:
# load information
info_path = '/mnt/ampa02_data01/tmurakami/model_training/info.pkl'
obj = pd.read_pickle(info_path)

# normalization reference path
normalization_references = {
    "/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5":"/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fft_norm_2_99p8.zarr",
    "/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fused.n5":"/mnt/ampa02_data01/tmurakami/240425_whole_4color_2nd_M037-3pb/fused/fft_norm_2_99p8.zarr"
}

# image save path
save_path = '/mnt/ampa02_data01/tmurakami/model_training/crops'

In [3]:
resolution = 0 # because the annotation was made using the highest resolution
factors = (1.0,4.0,4.0,4.0) # how much time the norm_ref was downsampled.

for i in tqdm(range(obj.shape[0])):
    data_path = obj.loc[i,'source']
    _, ext = os.path.splitext(data_path)
    

    imgs = []
    if ext == '.n5': # n5 assume bigstitcher (bigdataviewer) format
        # create Zarr file object
        # load images according to the input parameters.
        img_zarr = zarr.open(store=zarr.N5Store(data_path), mode='r')
        n5_setups = list(img_zarr.keys())
        res_list = list(img_zarr[n5_setups[0]]['timepoint0'].keys())

        for n5_setup in n5_setups:
            imgs.append(da.from_zarr(img_zarr[n5_setup]['timepoint0'][res_list[resolution]]))
        imgs = da.stack(imgs)


    elif ext == '.zarr': # zarr assumes ome-zarr
        # read the image data
        store = parse_url(data_path, mode="r").store
        reader = Reader(parse_url(data_path))
        # nodes may include images, labels etc
        nodes = list(reader())
        # first node will be the image pixel data
        image_node = nodes[0]

        dask_data = image_node.data
        imgs = dask_data[resolution]

    else:
        raise ValueError("the extension should be .n5 or .zarr")


    corner_positions = obj.loc[i,'corner']
    crop_size = obj.loc[i,'crop_size']
    segment_chan = obj.loc[i,'channel']
    reference_chan = obj.loc[i,'ref_channel']
    plane_position = int(obj.loc[i,'plane_position'])
     
    
    # load normalization reference
    img_norm_zarr = zarr.open(normalization_references[data_path], mode='r')
    fft = da.from_zarr(img_norm_zarr[0])

    fft_corner_positions = [int(pos//f) for pos,f in zip(corner_positions,factors[1:])]
    fft_crop_size = [int(x//f) for x,f in zip(crop_size,factors[1:])]

    fft_ref_img = fft[reference_chan][tuple(slice(i,i+j) for i,j in zip(fft_corner_positions, fft_crop_size))].compute()
    fft_img = fft[segment_chan][tuple(slice(i,i+j) for i,j in zip(fft_corner_positions, fft_crop_size))].compute()

    # load images according to the input parameters.
    img_ref = imgs[reference_chan].squeeze()
    img_ref_ = img_ref[tuple(slice(i,i+j) for i,j in zip(corner_positions, crop_size))].compute()
    img_ref_norm = exposure.match_histograms(img_ref_.astype(np.float32),fft_ref_img)[plane_position,...]

    img = imgs[segment_chan].squeeze()
    img_ = img[tuple(slice(i,i+j) for i,j in zip(corner_positions, crop_size))].compute()
    img_norm = exposure.match_histograms(img_.astype(np.float32),fft_img)[plane_position,...]
        
    img_stack = np.stack([img_ref_norm,img_norm])
    prefix = str(i)
    while len(prefix) < 4:
        prefix = '0' + prefix
        
    img_path = os.path.join(save_path, prefix+'_img_norm.tif')
    io.imsave(img_path, img_stack, plugin='tifffile', metadata={'axes': 'CYX'}, check_contrast=False)
    
    # # for debug
    # img_path_fft = os.path.join(save_path, prefix+'_img_fft.tif')
    # io.imsave(img_path_fft, np.stack([fft_ref_img,fft_img]), plugin='tifffile', metadata={'axes': 'CYX'}, check_contrast=False)

100%|████████████████████████████████████████████████████████████████████| 321/321 [30:49<00:00,  5.76s/it]


In [None]:
#