# Running cellpose with GPUs

In [None]:
import numpy as np
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

from skimage import io
from skimage.filters import threshold_otsu

import zarr
from cellpose import models, core

use_GPU = core.use_gpu(gpu_number=0)
print('>>> GPU activated? %d'%use_GPU)

from cellpose import utils
from cellpose import models
from cellpose.io import logger_setup
logger_setup();

import ray

from dask import array as da

import mFISHwarp.morphology
import mFISHwarp.utils
import napari
import pandas as pd

## Load model

In [None]:
# path to dataset and model
model_dir = "/mnt/ampa02_data01/gabacoll/shared/Yuchen/model_training/crops/augment/training/models"

models_file = os.listdir(model_dir); models_file.sort()
model_path = os.path.join(model_dir, models_file[-1])

model = models.CellposeModel(gpu=use_GPU, pretrained_model=model_path)

## Load Zarr

In [None]:
# set paths
n5_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5'
normalization_metadata = '/mnt/ampa02_data01/tmurakami/model_training/norm_values.pkl'
segment_chan = 1
reference_chan = 3

# load images according to the input parameters.
img_zarr = zarr.open(store=zarr.N5Store(n5_path), mode='r')
n5_setups = list(img_zarr.keys())

## Create mask

In [None]:
# downsampling
img_down_ref = img_zarr[n5_setups[reference_chan]]['timepoint0']['s4'][:]
global_thresh = threshold_otsu(img_down_ref)
img_mask = mFISHwarp.morphology.mask_maker(img_down_ref,global_thresh)

viewer = napari.Viewer()
viewer.add_image(img_mask)
viewer.add_image(img_down_ref)

## Make overlapped images

In [None]:
### Parameters
auto_diam = False # Cellpose automatic diameter estimation.
# theoretically, anisotropy parameter affects the accuracy. However in practice, changing this values to be the exact voxel ratio does not significantly add accuracy. 
# this may be because of the non-isotropic PSF of light-sheet.
voxel_size = (2.0,1.3,1.3)
anisotropy = voxel_size[1]/voxel_size[0]
min_size = 40

# Channel parameters which were used during the training.
Training_channel = 2 # I do not know but the cellpose see the images as KRGB. If the color is green, set it to 2.
Second_training_channel = 1

# lazyly read image and convert to dask array
chunk_size = (256,512,512)
depth = (32,64,64) 
boundary = "reflect"

### Make overlapping images
# make overlapped images for both refernce and target
overlap_imgs = []
n5_setups = list(img_zarr.keys())

# reference
img_ref = da.from_zarr(img_zarr[n5_setups[reference_chan]]['timepoint0']['s0'])
img_ref = da.rechunk(img_ref,chunks=chunk_size)
overlap_imgs.append(da.overlap.overlap(img_ref, depth, boundary))
# target
img = da.from_zarr(img_zarr[n5_setups[segment_chan]]['timepoint0']['s0'])
img = da.rechunk(img,chunks=chunk_size)
overlap_imgs.append(da.overlap.overlap(img, depth, boundary))

# If mask is used, calculate which chunks will be segmented
flag_array = mFISHwarp.utils.flag_array_generator(chunk_size, img_ref.shape, img_mask)
print(f'{flag_array.sum()} blocks of {flag_array.shape[0]}*{flag_array.shape[1]}*{flag_array.shape[2]}={flag_array.size} blocks will be calculated')

# load normalization information
norm_values = {}
if normalization_metadata is not None:
    norm_info = pd.read_pickle(normalization_metadata)
    norm_values['ref_lower'] = norm_info[n5_path][reference_chan]['lower']
    norm_values['ref_upper'] = norm_info[n5_path][reference_chan]['upper']
    norm_values['tar_lower'] = norm_info[n5_path][segment_chan]['lower']
    norm_values['tar_upper'] = norm_info[n5_path][segment_chan]['upper']

## Prepare zarr container to save segmentation

In [None]:
labeled_overlap_zarr_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/segmentation/segmented_overlap.zarr'
labeled_overlap_zarr = zarr.open(
    labeled_overlap_zarr_path,
    mode='a', 
    shape=overlap_imgs[0].shape, 
    chunks=mFISHwarp.utils.chunks_from_dask(overlap_imgs[0]), 
    dtype=np.int32)

# labeled_overlap_zarr = zarr.open(labeled_overlap_zarr_path,mode='a')

In [None]:
@ray.remote(num_gpus=0.5,max_calls=1)
def segmentor(
    chunks, # list of images. reference and target
    norm_values_ref, # list. [lower, upper]
    norm_values_tar,
    channels,
    model,
    anisotropy,
    index,
    min_size,
    chunk_info,
    zarr_file,
):
    # convert dask array to numpy array.
    chunks = [i.compute() for i in chunks]
    chunk = np.stack([
        mFISHwarp.utils.normalization_two_values(chunks[0], norm_values_ref[0], norm_values_ref[1]),
        mFISHwarp.utils.normalization_two_values(chunks[1], norm_values_tar[0], norm_values_tar[1])
    ])

    # precomputation to coarsly estimate cell positions
    segments, _, _  = model.eval(chunk, channels=channels, normalize=False, z_axis=1, diameter=model.diam_mean, do_3D=True, min_size=min_size, progress=False, anisotropy=anisotropy, tile=False)
    segments = segments.astype(np.int32)
    
    zarr_file[mFISHwarp.utils.obtain_chunk_slicer(chunk_info, index)] = segments
    # return segments

In [None]:
### in case from the middle of the computation
stored_chunks = os.listdir(labeled_overlap_zarr_path)
stored_chunks.sort()

idxs = mFISHwarp.utils.get_dask_index(overlap_imgs[0])
chunk_info = overlap_imgs[0].chunks

diameter_yx = model.diam_mean
anisotropy = anisotropy
# fast_mode = True
min_size = 40
model_type = model
channels = [Training_channel, Second_training_channel]
zarr_file = labeled_overlap_zarr

for index in idxs:
    if flag_array[index[0],index[1],index[2]]:
        if '.'.join([str(i) for i in index]) not in stored_chunks:# flag == 1:
            input_blocks = [mFISHwarp.utils.slicing_with_chunkidx(img, index) for img in overlap_imgs]
            segmentor.remote(
                input_blocks,
                [norm_values['ref_lower'],norm_values['ref_upper']],
                [norm_values['tar_lower'],norm_values['tar_upper']],
                [Training_channel, Second_training_channel],
                model,
                anisotropy,
                index,
                min_size,
                chunk_info,
                zarr_file
            )

In [None]:
'''Those are great resource for the scaling up the segmentation'''
# https://github.com/MouseLand/cellpose/issues/244
# https://github.com/MouseLand/cellpose/pull/356
# https://github.com/MouseLand/cellpose/blob/master/cellpose/contrib/distributed_segmentation.py
# https://github.com/MouseLand/cellpose/pull/408/commits/359e480335d68caa04c7caa6ff66a089c767b63f#diff-a3c93a2a4ec84f6c33dc52001df50cd7e5afbc78482858fe54a67987772464da