In [None]:
import numpy as np
import time, os, sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from cellpose import models, core

use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

# call logger_setup to have output of cellpose written
from cellpose.io import logger_setup
from cellpose import utils

from skimage import io
import pandas as pd

import zarr
import napari

import json
import shutil

In [None]:
def get_file_prefix(directory,postfix):
    files = os.listdir(directory)
    files.sort()
    files = pd.Series(files)
    selected_files = files[files.str.contains(postfix)]
    prefixes = selected_files.str.split(postfix,expand=True)[0].to_list()
    
    return prefixes

def augmenter(x):
    """
    Augmentation of a single input/label image pair. 90rot, 180rot, 270rot, flip, flip90rot, flip180rot, flip270rot will be generated
    x is an input image
    """
    # Note that we only use fliprots along axis=(1,2), i.e. the yx axis 
    # as 3D microscopy acquisitions are usually not axially symmetric
    x90rot = np.rot90(x)
    x180rot = np.rot90(x90rot)
    x270rot = np.rot90(x180rot)
    xflip = np.flip(x)
    xflip90rot = np.rot90(xflip)
    xflip180rot = np.rot90(xflip90rot)
    xflip270rot = np.rot90(xflip180rot)
    
    return (x90rot, x180rot, x270rot, xflip, xflip90rot, xflip180rot, xflip270rot)

def add_last_size_to3(array, size=3):
    """
    add the size of the last dimension up to 3. Otherwise Cellpose looks get confused.
    """
    if array.shape[-1] <3:
        array = np.pad(array, ((0, 0), (0, 0), (0, size-array.shape[-1])), 'constant' ,constant_values=0)
        return array
    else:
        return array

In [None]:
# path to dataset and datacards
dataset_folder = "/mnt/ampa02_data01/tmurakami/model_training/trial_only_cells"
datacard_path = "/mnt/ampa02_data01/tmurakami/model_training/tatz_datacard.json"

# other parameters for training.
number_of_epochs =  5000

Training_channel = 2 # I do not know but the cellpose see the images as KRGB. If the color is green, set it to 2.
Second_training_channel = 1

batch_size =  8
initial_learning_rate = 0.0002

pre_norm = True

# set pretrained model if there is any 
model_to_load = None 
diameter = 10 # diameter is a very important hyperparameter, which could affect both the speed and accuracy. You can leave here None, but highly recommended to add value here. 
min_train_masks = 0 # set here to zero to teach what is not the cells


if pre_norm:
    img_postfix = 'img_norm'
else:
    img_postfix = 'img'
mask_postfix = 'mask'
extension = '.tif'

# open datacard
with open(datacard_path) as f:
    datacard = json.load(f)

In [None]:
### seperate training and test datasets according to the datacard
# make folders to save traning and test dataset
train_dir_name = 'training'
test_dir_name = 'test'
train_folder = os.path.join(dataset_folder, train_dir_name)
test_folder = os.path.join(dataset_folder, test_dir_name)

if not os.path.isdir(train_folder):
    os.makedirs(train_folder)
if not os.path.isdir(test_folder):
    os.makedirs(test_folder)


In [None]:
for train_data in datacard['datasets']['train']:
    for key in train_data.keys():
        source = os.path.join(dataset_folder, train_data[key])
        destination = os.path.join(train_folder, train_data[key])
        if os.path.isfile(source):
            shutil.copyfile(source, destination)
            # os.rename(source,destination)
    
for test_data in datacard['datasets']['test']:
    for key in test_data.keys():
        source = os.path.join(dataset_folder, test_data[key])
        destination = os.path.join(test_folder, test_data[key])
        if os.path.isfile(source):
            shutil.copyfile(source, destination)
            # os.rename(source,destination)

In [None]:
### data augmentation
# make augmented in training dir
for train_data in datacard['datasets']['train']:
    for key in train_data.keys():
        source = os.path.join(train_folder, train_data[key])
        if os.path.isfile(source):
            filename, file_extension = os.path.splitext(source)
            base = os.path.basename(source)
            pardir = os.path.dirname(source)
            if key == "img" or key == "img_norm":
                x = np.moveaxis(io.imread(source),0,-1)
                x = add_last_size_to3(x,3)
                x_args = augmenter(x)
                for i, x_arg in enumerate(x_args):
                    name = os.path.join(train_folder,base[:4]+"_"+str(i)+"_"+ key + file_extension)
                    io.imsave(name, x_arg, check_contrast=False)
            if key == "label":
                y = io.imread(source)
                y_args = augmenter(y)
                for i, y_arg in enumerate(y_args):
                    name = os.path.join(train_folder,base[:4]+"_"+str(i)+"_"+ "mask" + file_extension)
                    io.imsave(name, y_arg, check_contrast=False)

In [None]:
### run the training
if pre_norm:
    # no_norm with diameter parameter
    !/home/tmurakami/app/miniconda3/envs/cellpose/bin/python -m cellpose --train --use_gpu --dir $train_folder --test_dir $test_folder --pretrained_model $model_to_load --diam_mean $diameter --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter $img_postfix --mask_filter $mask_postfix  --verbose --min_train_masks $min_train_masks --no_norm
else:
    # with diameter parameter
    !/home/tmurakami/app/miniconda3/envs/cellpose/bin/python -m cellpose --train --use_gpu --dir $train_folder --test_dir $test_folder --pretrained_model $model_to_load --diam_mean $diameter --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter $img_postfix --mask_filter $mask_postfix  --verbose --min_train_masks $min_train_masks
    # without diameter
    # !/home/tmurakami/app/miniconda3/envs/cellpose/bin/python -m cellpose --train --use_gpu --dir $train_folder --test_dir $test_folder --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter $img_postfix --mask_filter $mask_postfix  --verbose --min_train_masks $min_train_masks