# Running cellpose with GPUs

In [1]:
import numpy as np
import time, os, sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

import skimage.io
import zarr

from cellpose import models, core

use_GPU = core.use_gpu(gpu_number=0)
print('>>> GPU activated? %d'%use_GPU)

from cellpose import utils
from cellpose import models
# from cellpose.contrib import distributed_segmentation
# import my_distributed_segmentation

# call logger_setup to have output of cellpose written
from cellpose.io import logger_setup
logger_setup();

import ray

import sys
SCRIPT_DIR = '/home/tmurakami/src/pylsfm/pylsfm'
sys.path.append(os.path.dirname(SCRIPT_DIR))

import pylsfm.morphology
import pylsfm.utils

from dask import array as da

import functools
import operator

import mFISH3D.segment_utils
import mFISH3D.segment

>>> GPU activated? 1
2022-09-03 16:11:15,816 [INFO] WRITING LOG OUTPUT TO /home/tmurakami/.cellpose/run.log


## Create mask

In [2]:
nuc_img = zarr.open('/mnt/ampa_data01/tmurakami/220715_prefrontal_q2_R01/ch488.zarr',mode='r')[4][:]

mask = pylsfm.morphology.mask_maker(nuc_img,300)
# skimage.io.imsave('/mnt/ampa_data01/tmurakami/220715_prefrontal_q2_R01/mask.tif',mask)
# skimage.io.imsave('/mnt/ampa_data01/tmurakami/220715_prefrontal_q2_R01/nuc_img.tif',nuc_img)

In [3]:
# import napari
# viewer = napari.Viewer()
# viewer.add_image(mask)
# viewer.add_image(nuc_img)


## make overlapped images


In [3]:
# read image and convert to dask array
zarr_paths = [
    '/mnt/ampa_data01/tmurakami/220715_prefrontal_q2_R01/ch640.zarr', 
    '/mnt/ampa_data01/tmurakami/220715_prefrontal_q2_R01/ch488.zarr'
]

chunk_size = (256,512,512)
depth = (32,64,64) 
boundary = "reflect"

overlap_imgs = []
for zarr_path in zarr_paths:
    img_zarr = zarr.open(zarr_path, mode='r')[0]

    img_da = da.rechunk(da.from_zarr(img_zarr),chunks=chunk_size)

    # No chunking in channel direction
    overlap_imgs.append(da.overlap.overlap(img_da, depth, boundary))

flag_array = pylsfm.utils.flag_array_generator(chunk_size, img_zarr.shape, mask)
print(f'{flag_array.sum()} blocks of {flag_array.shape[0]}*{flag_array.shape[1]}*{flag_array.shape[2]}={flag_array.size} blocks will be calculated')

1006 blocks of 5*26*17=2210 blocks will be calculated


## load training model

In [4]:
# load training model
train_folder = '/mnt/ampa_data01/tmurakami/conf_proc/human_ish_training_dataset/slc17a7_double/training'

models_path = os.path.join(train_folder,'models')
models_file = os.listdir(models_path); models_file.sort()
model_path = os.path.join(train_folder,'models',models_file[-1])

my_model = models.CellposeModel(gpu=True, pretrained_model=model_path, net_avg=False)


2022-09-03 16:12:47,429 [INFO] >>>> loading model /mnt/ampa_data01/tmurakami/conf_proc/human_ish_training_dataset/slc17a7_double/training/models/cellpose_residual_on_style_on_concatenation_off_training_2022_08_27_16_18_25.738022
2022-09-03 16:12:47,494 [INFO] ** TORCH CUDA version installed and working. **
2022-09-03 16:12:47,495 [INFO] >>>> using GPU
2022-09-03 16:12:48,383 [INFO] >>>> model diam_mean =  30.000 (ROIs rescaled to this size during training)
2022-09-03 16:12:48,383 [INFO] >>>> model diam_labels =  12.851 (mean diameter of training ROIs)


## prepare zarr container to save segmentation

In [5]:
block_labeled_zarr_path = '/mnt/ampa_data01/tmurakami/220715_prefrontal_q2_R01/R01_R01/R01ch640_to_R01_segmentation/segmented_overlap.zarr'
block_labeled_zarr = zarr.open(
    block_labeled_zarr_path,
    mode='a', 
    shape=overlap_imgs[0].shape, 
    chunks=pylsfm.utils.chunks_from_dask(overlap_imgs[0]), 
    dtype=np.int32)
# block_labeled_zarr = zarr.open(block_labeled_zarr_path,mode='a')

In [6]:
@ray.remote(num_gpus=0.5,max_calls=1)
def segmentor(
    chunks,
    channels,
    model_type,
    diameter_yx,
    anisotropy,
    fast_mode,
    index,
    min_size,
    chunk_info,
    zarr_file,
):
    # convert dask array to numpy array.
    chunks = [i.compute() for i in chunks]
    chunk = np.asarray(chunks)
    chunk = np.moveaxis(chunk, 0, -1)
    chunk = np.pad(chunk, ((0, 0), (0, 0), (0, 0), (0, 3-len(chunks))), 'constant')

    model = model_type
    
    # precomputation to coarsly estimate cell positions
    pre_masks, _, _ = model.eval(chunk, channels=channels, z_axis=0, do_3D=False, min_size=min_size, stitch_threshold=0.3, tile=False)
    
    if np.any(pre_masks!=0):

        chunk_norm = np.zeros(chunk.shape,dtype=float)
        for i in range(len(chunks)):
            img_single = chunk[...,i]
            # get local max
            interpolator = mFISH3D.segment.get_cellular_intensity_interpolator(pre_masks,img_single)
            local_max = mFISH3D.segment.local_max_with_interpolator(interpolator,img_single.shape)

            # normalize using local max before second segmentation
            chunk_norm[...,i] = mFISH3D.segment.gpu_percentile_normalization(img_single, footprint=np.ones((1,5,100)), img_high=local_max)

        segments, _, _ = model.eval(chunk_norm, channels=channels, z_axis=0, diameter=diameter_yx, do_3D=True, min_size=min_size, tile=False)
        segments = segments.astype(np.int32)

        # for zarr indexing
        p = slice(sum(chunk_info[0][:index[0]]),sum(chunk_info[0][:index[0]])+chunk_info[0][index[0]]) 
        q = slice(sum(chunk_info[1][:index[1]]),sum(chunk_info[1][:index[1]])+chunk_info[1][index[1]])
        r = slice(sum(chunk_info[2][:index[2]]),sum(chunk_info[2][:index[2]])+chunk_info[2][index[2]])


        zarr_file[p,q,r] = segments

In [None]:
# idxs = mFISH3D.segment_utils.get_dask_index(overlap_imgs[0])
# chunk_info = overlap_imgs[0].chunks

# diameter_yx = 10
# anisotropy = None
# fast_mode = True
# min_size = 100
# model_type=my_model
# channels=[1,2]
# zarr_file = block_labeled_zarr

# for index in idxs:
#     if flag_array[index[0],index[1],index[2]]:
#         input_blocks = [mFISH3D.segment_utils.slicing_with_chunkidx(i, index) for i in overlap_imgs]
#         segmentor.remote(
#             input_blocks,
#             channels,
#             model_type,
#             diameter_yx,
#             anisotropy,
#             fast_mode,
#             index,
#             min_size,
#             chunk_info,
#             zarr_file,
#         )

In [None]:
#

In [None]:
### in case from the middle of the computation
stored_chunks = os.listdir(block_labeled_zarr_path)
stored_chunks.sort()

# last_chunk = (1,22,3)

idxs = mFISH3D.segment_utils.get_dask_index(overlap_imgs[0])
chunk_info = overlap_imgs[0].chunks

diameter_yx = 10
anisotropy = None
fast_mode = True
min_size = 100
model_type=my_model
channels=[1,2]
zarr_file = block_labeled_zarr

# flag = 0
for index in idxs:
    if flag_array[index[0],index[1],index[2]]:
        # if index == last_chunk:
        #     flag = 1
        if '.'.join([str(i) for i in index]) not in stored_chunks:# flag == 1:
            input_blocks = [mFISH3D.segment_utils.slicing_with_chunkidx(i, index) for i in overlap_imgs]
            segmentor.remote(
                input_blocks,
                channels,
                model_type,
                diameter_yx,
                anisotropy,
                fast_mode,
                index,
                min_size,
                chunk_info,
                zarr_file,
            )

In [None]:
# 

In [None]:
'''Those are great resource for the scaling up the segmentation'''
# https://github.com/MouseLand/cellpose/issues/244
# https://github.com/MouseLand/cellpose/pull/356
# https://github.com/MouseLand/cellpose/blob/master/cellpose/contrib/distributed_segmentation.py
# https://github.com/MouseLand/cellpose/pull/408/commits/359e480335d68caa04c7caa6ff66a089c767b63f#diff-a3c93a2a4ec84f6c33dc52001df50cd7e5afbc78482858fe54a67987772464da