In [None]:
import numpy as np
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from cellpose import models, core

use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

from skimage import io
from skimage import exposure
import tqdm
import napari
import pandas as pd

import zarr
from dask import array as da

from ome_zarr.io import parse_url
from ome_zarr.reader import Reader

In [None]:
def normalization_two_values(arr, lower, upper):
    """
    Normalize array so that the lower values to be 0 and upper values to be 1.
    """
    return (arr - lower) / (upper - lower)

In [None]:
# path to dataset and model
dataset_folder = "/mnt/ampa02_data01/gabacoll/shared/Yuchen/model_training/crops"
aug_folder = os.path.join(dataset_folder, 'augment')
train_folder = os.path.join(aug_folder,'training')
models_path = os.path.join(train_folder,'models')

models_file = os.listdir(models_path); models_file.sort()
model_path = os.path.join(train_folder,'models',models_file[-1])

model = models.CellposeModel(gpu=use_GPU, pretrained_model=model_path)

In [None]:
### parameters
data_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5' # zarr with pyramid resolution
normalization_metadata = None #'/mnt/ampa02_data01/tmurakami/model_training/norm_values.pkl'
normalization_reference = "/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fft_norm_2_99p8.zarr"

voxel_size = (1.0,2.0,1.3,1.3) # CZYX

corner_positions = [1783,2093,2955]
crop_size = [128,512,512]
segment_chan = 1
reference_chan = 3
auto_diam = False # Cellpose automatic diameter estimation.
min_size = 40

# theoretically, anisotropy parameter affects the accuracy. However in practice, changing this values to be the exact voxel ratio does not significantly add accuracy. 
# this may be because of the non-isotropic PSF of light-sheet. 
anisotropy = voxel_size[-1]/voxel_size[1] 

# Channel parameters which were used during the training.
Training_channel = 2 # I do not know but the cellpose see the images as KRGB. If the color is green, set it to 2.
Second_training_channel = 1


### lazily load images using dask
_, ext = os.path.splitext(data_path)
imgs = []
if ext == '.n5': # n5 assume bigstitcher (bigdataviewer) format
    # create Zarr file object
    # load images according to the input parameters.
    img_zarr = zarr.open(store=zarr.N5Store(data_path), mode='r')
    n5_setups = list(img_zarr.keys())
    res_list = list(img_zarr[n5_setups[reference_chan]]['timepoint0'].keys())
    
    for n5_setup in n5_setups:
        imgs.append(da.from_zarr(img_zarr[n5_setup]['timepoint0'][res_list[0]]))
    imgs = da.stack(imgs)
        

elif ext == '.zarr': # zarr assumes ome-zarr
    # read the image data
    store = parse_url(data_path, mode="r").store
    reader = Reader(parse_url(data_path))
    # nodes may include images, labels etc
    nodes = list(reader())
    # first node will be the image pixel data
    image_node = nodes[0]

    dask_data = image_node.data
    imgs = dask_data[0]

else:
    raise ValueError("the extension should be .n5 or .zarr")


img_ref = imgs[reference_chan,...].squeeze()#img_zarr[n5_setups[reference_chan]]['timepoint0']['s0']
img_ref_ = img_ref[tuple(slice(i,i+j) for i,j in zip(corner_positions, crop_size))].compute()

img = imgs[segment_chan,...].squeeze()# img_zarr[n5_setups[segment_chan]]['timepoint0']['s0']
img_ = img[tuple(slice(i,i+j) for i,j in zip(corner_positions, crop_size))].compute()
img_stack_original = np.stack([img_ref_,img_])

if normalization_metadata is not None:
    norm_info = pd.read_pickle(normalization_metadata)
    img_ref_ = normalization_two_values(img_ref_.astype(float), norm_info[data_path][reference_chan]['lower'], norm_info[data_path][reference_chan]['upper'])
    img_ = normalization_two_values(img_.astype(float), norm_info[data_path][segment_chan]['lower'], norm_info[data_path][segment_chan]['upper'])

elif normalization_reference is not None:
    # read the image data
    img_zarr = zarr.open(normalization_reference, mode='r')
    scale = img_zarr.attrs['multiscales'][0]['datasets'][0]['coordinateTransformations'][0]['scale']
    factors = [i/j for i,j in zip(scale,voxel_size)]

    fft = da.from_zarr(img_zarr[0])
    fft_corner_positions = [pos//f for pos,f in zip(corner_positions,factors[1:])]
    fft_crop_size = [x//f for x,f in zip(crop_size,factors[1:])]

    fft_ref_img = fft[reference_chan][tuple(slice(i,i+j) for i,j in zip(fft_corner_positions, fft_crop_size))].compute()
    fft_img = fft[segment_chan][tuple(slice(i,i+j) for i,j in zip(fft_corner_positions, fft_crop_size))].compute()


    img_ref_ = exposure.match_histograms(img_ref_.astype(np.float32),fft_ref_img)
    img_ = exposure.match_histograms(img_.astype(np.float32),fft_img)
    
img_stack = np.stack([img_ref_,img_])

In [None]:
# Prediction
if (normalization_metadata is not None) or (normalization_reference is not None):
    if ~auto_diam:
        # with diameter parameter provided without Cellpose normalization
        %time masks, flows, styles  = model.eval(img_stack, channels=[Training_channel,Second_training_channel], normalize=False, z_axis=1, diameter=model.diam_mean, do_3D=True, min_size=min_size, progress=True, anisotropy=anisotropy)
    else:
        %time masks, flows, styles  = model.eval(img_stack, channels=[Training_channel,Second_training_channel], normalize=False, z_axis=1, diameter=None, do_3D=True, min_size=min_size, progress=True, anisotropy=anisotropy)
else:
    if ~auto_diam:
        # with diameter parameter provided 
        %time masks, flows, styles  = model.eval(img_stack, channels=[Training_channel,Second_training_channel], z_axis=1, diameter=model.diam_mean, do_3D=True, min_size=min_size, progress=True, anisotropy=anisotropy)
    else:
        # without diameter
        %time masks, flows, styles = model.eval(img_stack, channels=[Training_channel,Second_training_channel], z_axis=1, diameter=None, do_3D=True, min_size=min_size, progress=True, anisotropy=anisotropy)

In [None]:
viewer = napari.Viewer()
viewer.add_image(img_stack, channel_axis=0, name='image01', blending='additive')
viewer.add_labels(masks)