In [1]:
import numpy as np
import time, os, sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

from cellpose import models, core

use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

# call logger_setup to have output of cellpose written
from cellpose.io import logger_setup
from cellpose import utils

import random
from skimage import io
import tqdm
import napari
import pandas as pd
import mFISH3D
import mFISH3D.segment

import zarr
from dask import array as da

>>> GPU activated? 1


In [2]:
# path to dataset and model
dataset_folder = "/mnt/ampa_data01/tmurakami/conf_proc/human_ish_training_dataset/slc17a7_double"
train_folder = os.path.join(dataset_folder,'training')

models_path = os.path.join(train_folder,'models')
models_file = os.listdir(models_path); models_file.sort()
model_path = os.path.join(train_folder,'models',models_file[-1])

model = models.CellposeModel(gpu=use_GPU, pretrained_model=model_path, net_avg=False) 
#torch=True, diam_mean=10, do_3D=True, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)

In [3]:
# load validation data
zarr_paths = ['/mnt/ampa_data01/tmurakami/220806_visual_02_R01/ch640.zarr', '/mnt/ampa_data01/tmurakami/220806_visual_02_R01/ch488.zarr']
pyramids = []

viewer = napari.Viewer()
for zarr_path in zarr_paths:
    img_zarr = zarr.open(zarr_path, mode='r')
    pyramid = [da.from_zarr(img_zarr[i]) for i in range(0,5)]
    viewer.add_image(pyramid, contrast_limits=[0,10000], rgb=False, name='img', colormap='gray', blending='additive', multiscale=True)
    pyramids.append(pyramid)

In [12]:
# set data range for analysis
#img_3D = io.imread('/mnt/ampa_data01/tmurakami/conf_proc/human_ish_training_dataset/validation/slc17a7_02.tif')
depth = 256
width = 256
(z, y, x) = (623,4346,3911)#(636,3546,4376)#(636,3558,6352)   #(629,2752,2068)#(629,3860,5662)#(629,3832,3030)#  #(636,5177,5345)#
# img_3D = img_da[z:z+depth,y:y+width,x:x+width].compute()
img_3D = []
for pyramid in pyramids:
    img_3D.append(pyramid[0][z:z+depth,y:y+width,x:x+width].compute())

img_3D = np.asarray(img_3D)
img_3D = np.moveaxis(img_3D, 0, -1)
img_3D = np.pad(img_3D, ((0, 0), (0, 0), (0, 0), (0, 3-len(pyramids))), 'constant')

my_img_3D_norm = np.zeros(img_3D.shape, dtype=float)
for i in range(len(zarr_paths)):
    my_img_3D_norm[...,i] = mFISH3D.segment.gpu_percentile_normalization(img_3D[...,i], footprint=np.ones((1,50,50)))
viewer = napari.Viewer()
viewer.add_image(img_3D[...,0], rgb=False, name='image01' , colormap='gray' ,blending='additive')
viewer.add_image(my_img_3D_norm, rgb=True, name='image02' , colormap='gray' ,blending='additive')
viewer.add_image(my_img_3D_norm[...,0], rgb=False, name='image03' , colormap='gray' ,blending='additive')

<Image layer 'image03' at 0x7f4d0e3aa850>

In [13]:
%%time
# the initial segmentation for local normalization. This is performed section by section.
# for the faster segmentation, adjust diameter. The larger the faster.
print('Start the first segmentation.')
pre_masks, _, _ = model.eval(img_3D, channels=[1,2], z_axis=0, do_3D=False, min_size=100, cellprob_threshold=0.0, stitch_threshold=0.3, tile=False)


print('Start the local normalization.')
img_3D_norm = np.zeros(img_3D.shape,dtype=float)
for i in range(len(zarr_paths)):
    img_single = img_3D[...,i]
    # get local max
    interpolator = mFISH3D.segment.get_cellular_intensity_interpolator(pre_masks,img_single)
    local_max = mFISH3D.segment.local_max_with_interpolator(interpolator,img_single.shape)

    # normalize using local max before second segmentation
    img_3D_norm[...,i] = mFISH3D.segment.gpu_percentile_normalization(img_single, footprint=np.ones((1,5,100)), img_high=local_max)


print('Start the second segmentation.')
# the second segmentation. do_3D = True.
masks, flows, styles = model.eval(img_3D_norm, channels=[1,2], z_axis=0, diameter=10, do_3D=True, min_size=100, cellprob_threshold=0.0, tile=False)


Start the first segmentation.
Start the local normalization.
Start the second segmentation.
CPU times: user 3min 12s, sys: 52.3 s, total: 4min 5s
Wall time: 2min 14s


In [14]:
viewer = napari.Viewer()
viewer.add_image(img_3D[...,0], rgb=False, name='image01' , colormap='magenta' ,blending='additive')
viewer.add_image(img_3D[...,1], rgb=False, name='image02' , colormap='green' ,blending='additive')
viewer.add_labels(pre_masks)
viewer.add_labels(masks)

<Labels layer 'masks' at 0x7f4cdba63fa0>