In [1]:
import numpy as np
import time, os, sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from cellpose import models, core

use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

# call logger_setup to have output of cellpose written
from cellpose.io import logger_setup
from cellpose import utils

import random
from skimage import io
import pandas as pd

import zarr
import napari

>>> GPU activated? 1


In [2]:
def get_file_prefix(directory,postfix):
    files = os.listdir(directory)
    files.sort()
    files = pd.Series(files)
    selected_files = files[files.str.contains(postfix)]
    prefixes = selected_files.str.split(postfix,expand=True)[0].to_list()
    
    return prefixes

def augmenter(x):
    """
    Augmentation of a single input/label image pair. 90rot, 180rot, 270rot, flip, flip90rot, flip180rot, flip270rot will be generated
    x is an input image
    """
    # Note that we only use fliprots along axis=(1,2), i.e. the yx axis 
    # as 3D microscopy acquisitions are usually not axially symmetric
    x90rot = np.rot90(x)
    x180rot = np.rot90(x90rot)
    x270rot = np.rot90(x180rot)
    xflip = np.flip(x)
    xflip90rot = np.rot90(xflip)
    xflip180rot = np.rot90(xflip90rot)
    xflip270rot = np.rot90(xflip180rot)
    
    return (x90rot, x180rot, x270rot, xflip, xflip90rot, xflip180rot, xflip270rot)

def add_last_size_to3(array, size=3):
    """
    add the size of the last dimension up to 3. Otherwise Cellpose looks get confused.
    """
    if array.shape[-1] <3:
        array = np.pad(array, ((0, 0), (0, 0), (0, size-array.shape[-1])), 'constant' ,constant_values=0)
        return array
    else:
        return array

In [3]:
# path to dataset
dataset_folder = "/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/model_training/crops"

# other parameters for training.
number_of_epochs =  500

Training_channel = 2 # I do not know but the cellpose see the images as KRGB. If the color is green, set it to 2.
Second_training_channel = 1

batch_size =  8
initial_learning_rate = 0.0002

img_postfix = 'img'
mask_postfix = 'mask'
extension = '.tif'

In [4]:
# make folder to save augmented data
aug_folder = os.path.join(dataset_folder, 'augment')
if not os.path.isdir(aug_folder):
    os.makedirs(aug_folder)

In [5]:
### data augmentation # skip here for prediction
prefixes = get_file_prefix(dataset_folder, img_postfix)
for prefix in prefixes:
    x = np.moveaxis(io.imread(os.path.join(dataset_folder,prefix+img_postfix+extension)),0,-1)
    x = add_last_size_to3(x,3)
    y = io.imread(os.path.join(dataset_folder,prefix+mask_postfix+extension))
    x_args = augmenter(x)
    y_args = augmenter(y)
    counter = 0
    io.imsave(os.path.join(aug_folder,prefix+'_'+'arg_'+img_postfix+extension),x)
    io.imsave(os.path.join(aug_folder,prefix+'_'+'arg_'+mask_postfix+extension),y)
    for x_arg, y_arg in zip(x_args, y_args):
        io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+img_postfix+extension),x_arg)
        io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+mask_postfix+extension),y_arg)
        counter += 1

  io.imsave(os.path.join(aug_folder,prefix+'_'+'arg_'+img_postfix+extension),x)
  io.imsave(os.path.join(aug_folder,prefix+'_'+'arg_'+mask_postfix+extension),y)
  io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+img_postfix+extension),x_arg)
  io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+mask_postfix+extension),y_arg)
  io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+img_postfix+extension),x_arg)
  io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+mask_postfix+extension),y_arg)
  io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+img_postfix+extension),x_arg)
  io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+mask_postfix+extension),y_arg)
  io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+img_postfix+extension),x_arg)
  io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+mask_postfix+extension),y_arg)
  io.imsave(os.path.join(aug_folder,prefix+str(counter)+'arg_'+img_postfix+extension),x_arg

In [5]:
# make folders to save traning and test dataset
train_folder = os.path.join(aug_folder,'training')
test_folder = os.path.join(aug_folder,'test')

if not os.path.isdir(train_folder):
    os.makedirs(train_folder)
if not os.path.isdir(test_folder):
    os.makedirs(test_folder)

In [9]:
# divide dataset to traning and test dataset # skip here for prediction
prefixes = get_file_prefix(aug_folder, img_postfix)

rng = np.random.RandomState(42)
ind = rng.permutation(len(prefixes))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]

for i in ind_val:
    os.rename(os.path.join(aug_folder,prefixes[i]+img_postfix+extension),os.path.join(test_folder,prefixes[i]+img_postfix+extension))
    os.rename(os.path.join(aug_folder,prefixes[i]+mask_postfix+extension),os.path.join(test_folder,prefixes[i]+mask_postfix+extension))
for i in ind_train:
    os.rename(os.path.join(aug_folder,prefixes[i]+img_postfix+extension),os.path.join(train_folder,prefixes[i]+img_postfix+extension))
    os.rename(os.path.join(aug_folder,prefixes[i]+mask_postfix+extension),os.path.join(train_folder,prefixes[i]+mask_postfix+extension))

In [6]:
### set pretrained model if there is any # skip here for prediction
model_to_load = None 
diameter = 10 # diameter is a very important hyperparameter, which could affect both the speed and accuracy 
min_train_masks = 0 # Cellpose seems to ignore the images with no labels even if this parameter sets to be zero.

# run the training
# no_norm will make the training more challenging. Skip this option.
# with diameter parameter
!/home/tmurakami/app/miniconda3/envs/cellpose/bin/python -m cellpose --train --use_gpu --dir $train_folder --test_dir $test_folder --pretrained_model $model_to_load --diam_mean $diameter --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter $img_postfix --mask_filter $mask_postfix  --verbose --min_train_masks $min_train_masks
# without diameter
# !/home/tmurakami/app/miniconda3/envs/cellpose/bin/python -m cellpose --train --use_gpu --dir $train_folder --test_dir $test_folder --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter $img_postfix --mask_filter $mask_postfix  --verbose --min_train_masks 0

2024-05-31 14:24:18,871 [INFO] WRITING LOG OUTPUT TO /home/tmurakami/.cellpose/run.log
2024-05-31 14:24:18,871 [INFO] 
cellpose version: 	3.0.8 
platform:       	linux 
python version: 	3.8.19 
torch version:  	1.12.0
2024-05-31 14:24:19,599 [INFO] ** TORCH CUDA version installed and working. **
2024-05-31 14:24:19,599 [INFO] >>>> using GPU
2024-05-31 14:24:19,840 [INFO] 102 / 102 images in /mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/model_training/crops/augment/training folder have labels
2024-05-31 14:24:19,883 [INFO] 18 / 18 images in /mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/model_training/crops/augment/test folder have labels
2024-05-31 14:24:19,883 [INFO] >>>> during training rescaling images to fixed diameter of 10.0 pixels
2024-05-31 14:24:19,952 [INFO] >>>> no model weights loaded
2024-05-31 14:24:19,953 [INFO] flows precomputed
2024-05-31 14:24:19,981 [INFO] flows precomputed
2024-05-31 14:24:19,987 [INFO] >>> computing diameters
  return

In [7]:
# load trained model for predicition
models_path = os.path.join(train_folder,'models')
models_file = os.listdir(models_path); models_file.sort()
model_path = os.path.join(train_folder,'models',models_file[-1])

model = models.CellposeModel(gpu=use_GPU, pretrained_model=model_path) #torch=True, diam_mean=10, do_3D=True, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)

In [8]:
# image path
fix_n5_path = '/mnt/ampa02_data01/tmurakami/240417_whole_4color_1st_M037-3pb/fused/fused.n5' # zarr with pyramid resolution

# create Zarr file object
fix_zarr = zarr.open(store=zarr.N5Store(fix_n5_path), mode='r')
voxel_size = (2.0,1.3,1.3)

corner_positions = [1235,2510,776]
crop_size = [128,256,256]
segment_chan = 1
reference_chan = 3

# load images according to the input parameters.
n5_setups = list(fix_zarr.keys())
img_ref = fix_zarr[n5_setups[reference_chan]]['timepoint0']['s0']
img_ref_ = img_ref[tuple(slice(i,i+j) for i,j in zip(corner_positions, crop_size))]

img = fix_zarr[n5_setups[segment_chan]]['timepoint0']['s0']
img_ = img[tuple(slice(i,i+j) for i,j in zip(corner_positions, crop_size))]

imgs = np.stack([img_ref_,img_])

In [9]:
# theoretically, anisotropy parameter affects the accuracy. However in practice, changing this values to be the exact voxel ratio does not significantly add accuracy. 
# this may be because of the non-isotropic PSF of light-sheet. 
anisotropy = voxel_size[1]/voxel_size[0] 

# with diameter
%time masks, flows, styles  = model.eval(imgs, channels=[Training_channel,Second_training_channel], z_axis=1, diameter=diameter, do_3D=True, min_size=40, progress=True, anisotropy=anisotropy)
# without diameter
# %time masks, flows, styles = model.eval(imgs, channels=[Training_channel,Second_training_channel], z_axis=1, diameter=None, do_3D=True, min_size=40, progress=True, anisotropy=anisotropy)

CPU times: user 9.29 s, sys: 2.19 s, total: 11.5 s
Wall time: 10.4 s


In [10]:
viewer = napari.Viewer()
viewer.add_image(imgs, channel_axis=0, name='image01', blending='additive')
viewer.add_labels(masks)

<Labels layer 'masks' at 0x7ff55c315a60>

In [11]:
#