# Description
*author:* Vina My Pham<br>
*supervisor:* Robin van der Weide<br>
*project:* MSc internship project<br>
<br>
*date:* January 15 - July 19, 2024<br>
*host:* Kind group, Hubrecht Institute<br>
*university:* Bioinformatics, Wageningen University & Research<br>

---

Notebook to finetune a pre-trained *cellpose* model with a human-in-the-loop approach (HITL) [1].

The method constitutes of the following steps:
  1. Running the model on an image
  2. Refining the predictions to match the required segmentation style.
  3. Adding the image and refined masks to the training data.
  4. Retraining the model on the training data
  5. Repeat steps 1-4 until retrained model has required performance.

![](https://drive.google.com/uc?export=view&id=1dUD8rGO2QQnSqrhVBRCREPPpfHZ1xBAy)

**References**
1. Pachitariu, M., Stringer, C. Cellpose 2.0: how to train your own model. Nat Methods 19, 1634–1641 (2022). https://doi.org/10.1038/s41592-022-01663-4

# Notebook initialisation
**Description:** This block contains the code for the set-up of the notebook.

0. mount the notebook to the drive
1. install required dependencies with a `requirements.txt` file
2. import python modules
3. define custom classes and functions
4. check GPU connection

**Instructions:**
1. Select the wanted initialisation settings.
2. Execute the code at the beginning of the run.




In [3]:
#@markdown ## Initialisation settings
mount_drive = True #@param {type:"boolean"}
#@markdown <br>
pip_install = True #@param {type:"boolean"}
pip_requirements_path = "/content/gdrive/MyDrive/msc-internship_HI_2024_vmp/01_notebooks/colab_requirements.txt" #@param {type:"string"}
#@markdown <br>
import_pkg = True #@param {type:"boolean"}
#@markdown <br>
use_gpu = False #@param {type:"boolean"}

In [4]:
#@markdown [code: imports - runtime: ~3m23s]
if mount_drive:
  from google.colab import drive
  drive.mount('/content/gdrive', force_remount=True)

if pip_install:
  import subprocess
  subprocess.run(['apt-get', 'install', '-y', 'libcairo2-dev'], check=True)
  subprocess.run(['pip', 'install', '-r', pip_requirements_path], check=True)
  print("Succesfully installed requirements with pip")

if import_pkg:
  import os
  import json
  import copy
  from datetime import datetime
  from collections import defaultdict, Counter
  from tifffile import imwrite
  from cellpose.io import imread, masks_flows_to_seg, save_masks
  from cellpose import core, utils, plot, models
  import matplotlib.pyplot as plt
  import numpy as np

  print("Succesfully imported packages")

Mounted at /content/gdrive
Succesfully installed requirements with pip
Succesfully imported packages


In [7]:
#@markdown [code: classes]
class Slice:
  """Class for slices

  Attributes:
    name (str): Name of the slice
    img_path (str): Absolute path to the image file
    mask_path (str): Absolute path to the mask file

  Methods:
    matrix(matrix_type: str) -> np.ndarray:
        Retrieve the image or mask of the slice as a matrix

    show(mode: str = "",
         seg_channel: int = 2,
         subplot_size: tuple = (5,5)) -> None:
        Display the image and/or mask of the slice using matplotlib
  """
  def __init__(self, name: str, img_path: str, mask_path: str):
    """Initialise object

    Args:
      name (str): Name of the slice
      img_path (str): Absolute path to the image file
      mask_path (str): Absolute path to the mask file
    """
    self.name = name
    self.img_path = img_path
    self.mask_path = mask_path

  def matrix(self, matrix_type: str):
    """Retrieve the image or masks of the slice as a matrix

    Args:
        matrix_type (str): "img" or "mask"

    Returns:
        np.ndarray: Matrix representing the image or mask
                    if image: shape = (n channels x nX x nY)
    """
    if matrix_type not in ["img", "mask"]:
      raise ValueError(f"matrix type `{matrix_type}` not recognised. ")

    path = {"img":self.img_path, "mask":self.mask_path}.get(matrix_type)
    matrix = imread(path)

    if matrix_type == "img":
      channel_idx = matrix.shape.index(min(matrix.shape))

      if channel_idx == 1:
        matrix = matrix.transpose(1,0,2) #channel as first element
      if channel_idx == 2:
        matrix = matrix.transpose(2,0,1)

    if matrix_type == "mask":
      matrix = reassign_ids(matrix)

    return matrix

  def show(self,
           mode: str = "",
           seg_channel: int = 2,
           subplot_size: tuple = (5,5)) -> None:
    """Display the image and mask using matplotlib.

    Args:
        mode (str): Visualisation mode. Valid options:
                    "img": the original image (separate channels + merged)
                    "mask". the original image and mask outlines.
                    "": both "img" and "mask"
                    Default: ""
        seg_channel (int): channel used for segmentation - R:1, G:2, B:3.
                           Default: 2
        subplot_size (tuple): Size of the subplot_size (width, height).
                              Default: (5, 5)
    """
    if mode not in ("", "mask", "img"):
      error_msg = f"Visualisation mode `{mode}` not recognised. Valid " +\
                  "options are ['', 'img', 'mask']"
      raise ValueError(error_msg)

    img_matrix = self.matrix("img")
    merged_matrix = np.dstack((img_matrix[0,:,:], img_matrix[1,:,:], img_matrix[2,:,:]))

    #plot image (separate channels, merged)
    if mode in ("", "img"):
      n_subplots = img_matrix.shape[0]+1

      figsize = (subplot_size[0]*n_subplots, subplot_size[1])
      fig, axes = plt.subplots(1, n_subplots, figsize=figsize)

      #separate channels
      for i in range(0, n_subplots-1):
        axes[i].imshow(img_matrix[i,:,:], cmap=plt.cm.gray)
        axes[i].axis("off")
        axes[i].set_title(f"channel {i+1}")

      #merged
      axes[n_subplots-1].imshow(merged_matrix)
      axes[n_subplots-1].axis("off")
      axes[n_subplots-1].set_title("Merged")

      fig.suptitle(f"{self.name} - channels")
      fig.tight_layout()
      plt.show();

    #plot mask (original image, outlines)
    if mode in ("", "mask"):
      figsize = (subplot_size[0]*3, subplot_size[1])
      fig, axes = plt.subplots(1, 3, figsize=figsize)

      #original image
      axes[0].imshow(merged_matrix)
      axes[0].set_title("Merged")
      axes[0].axis("off")

      #channel used to segment
      axes[1].imshow(img_matrix[seg_channel-1,:,:], cmap=plt.cm.gray)
      axes[1].set_title(f"Channel {seg_channel}")
      axes[1].axis("off")

      #outlines
      outlines = utils.outlines_list(self.matrix("mask"))
      axes[2].imshow(img_matrix[seg_channel-1,:,:], cmap=plt.cm.gray)
      for o in outlines:
          axes[2].plot(o[:,0], o[:,1], color='r', linewidth=0.7)
      axes[2].axis("off")
      axes[2].set_title("Masks")

      fig.suptitle(f"{self.name} - annotation")
      fig.tight_layout()
      plt.show();

    return None

In [16]:
#@markdown [code: functions]
#_output
def write_json(parameters: dict, save_dir: str,
               output_name: str = ".model_params.JSON",
               overwrite=False, verbose=True) -> str:
    """Write settings to a JSON file

    Args:
        parameters (dict): Settings to be written to the JSON file
        save_dir (str): The directory path where the JSON file will be saved
        output_name (str): name of output file. Default: ".model_params.JSON"
        overwrite (bool, optional): Overwrite if file exists. Default: False
        verbose (bool, optional): Print verbose. Default: True

    Returns:
        str: The path where the JSON file was saved

    Raises:
        FileExistsError: If a file with `output_name` in `save_dir` already
                         exists, and `overwrite` is set to False
    """
    if os.path.exists(save_dir) == False:
      os.makedirs(save_dir)

    json_path = os.path.join(save_dir, output_name)

    if os.path.exists(json_path) and not overwrite:
        raise FileExistsError(f"File '{json_path}' exists and `overwrite` has" +
                              f" been set to {overwrite}")

    with open(json_path, 'w') as outfile_obj:
        json.dump(parameters, outfile_obj, indent=4)

    if verbose:
        print(f"All settings written to {json_path}")

    return json_path

#_utilities
def check_gpu_connection(use_gpu: bool) -> None:
    """Reports the details on GPU connection

    Args:
        use_gpu (bool): Whether to use the GPU for the script

    Returns:
        None

    Raises:
        GPUConnectionError: If the runtime type does not match the connection
        settings
    """
    class GPUConnectionError(Exception):
        def __init__(self, message):
            self.message = message

    if core.use_gpu() != use_gpu:
      raise GPUConnectionError(f"Connection type (`{core.use_gpu()}`) does " +
                               f"not match connection settings (`{use_gpu}`)."+
                               "\nPlease check the hardware type in the Colab"+
                               " Notebook settings.")

    if core.use_gpu():
      !nvidia-smi

    return None

#_segmentation
def run_cellpose_model(slice_obj: Slice,
                       save_dir: str,
                       model: models.CellposeModel,
                       run_args: dict = {},
                       verbose: bool = True) -> None:
    """Wrapper function for CellposeModel.eval()

    Args:
      img (np.ndarray): matrix representation of an image to segment. shape:
                        (n channels x nX x nY)
      model(cellpose.models.CellposeModel): the cellpose model
      run_args (dict): arguments to give to model.eval(). {param (str) : arg}
                      if not specified, default will be used
      save_dir (None or str): if None: Results are not saved (default)
                              if str: path to save the _seg.npy file

    Returns:
      str - path to masks .tif
    """
    #init
    img = slice_obj.matrix("img")
    img_name = slice_obj.name

    #run model
    masks, flows, styles = model.eval(
        img,
        diameter=run_args.get('diameter', 30.0),
        flow_threshold=run_args.get('flow_threshold', 0.4),
        cellprob_threshold=run_args.get('cellprob_threshold', 0.0),
        channels=run_args.get('channels', [0, 0])
        )

    #save model
    file_name = f"{save_dir}/{img_name}"
    masks_flows_to_seg(
        img, masks, flows, run_args.get('diameter', 30.0),
        file_name, run_args.get('channels', [0, 0])
      )
    if verbose: print(f"_seg.npy file saved as {file_name}_predicted_seg.npy")

    #save masks
    mask_path = os.path.join(save_dir, f"{file_name}_predicted_masks.tif")
    imwrite(mask_path, masks)
    if verbose: print(f"Predicted masks have been saved in {mask_path}")

    return mask_path


In [None]:
#@markdown ###GPU status
#@markdown *Note: To change hardware type: `Runtime` >> `Change runtime type` >> `Hardware accelerator`*
print(f"GPU usage enabled: {use_gpu}")
check_gpu_connection(use_gpu)

## Input/output path specification

In [19]:
#@markdown **Test data - settings**
test_img_path = "" #@param {type:"string"}
test_mask_path = "" #@param {type:"string"}
test_img_name = "" #@param {type:"string"}

#@markdown **Path to the main output directory**
main_save_dir = "" #@param {type:"string"}
main_save_dir = os.path.join(main_save_dir, "")

# Main script

## Inspection of test data

In [None]:
#@markdown [code block]
#@markdown 1. store test data in `test_obj` (Slice object)
#@markdown 2. display the test data (original image and masks)

test_obj = Slice(test_img_name, test_img_path, test_mask_path)
print(f"Test data has been stored with metadata:\n"+
      f"\tname: {test_obj.name}\n"+
      f"\timg_path: {test_obj.img_path}\n"+
      f"\tmask_path: {test_obj.mask_path}")

test_obj.show(subplot_size=(5,5), seg_channel=2)

## Model finetuning

### 0. init

In [35]:
#@markdown
#@markdown - (if first time run) initialise `train_objs`, `save_dirs` and `model_names`
#@markdown - show pre-trained model names
try:
    train_objs
except NameError:
    train_objs, save_dirs, model_names = [[],[],[]]
    cellpose_models =  models.MODEL_NAMES
    print("Initialised train_objs, save_dirs, and model_names")

print("Available pre-trained models:\n\t", "\n\t ".join(cellpose_models))

Available pre-trained models:
	 cyto
	 nuclei
	 tissuenet
	 livecell
	 cyto2
	 general
	 CP
	 CPx
	 TN1
	 TN2
	 TN3
	 LC1
	 LC2
	 LC3
	 LC4


In [33]:
#@markdown **I/O set-up**
train_img_path = "" #@param {type:"string"}
train_img_name = "" #@param {type:"string"}
run_id = "" #@param {type:"string"}
overwrite = False #@param {type:"boolean"}

#@markdown **Model settings: load**
#@markdown - `model_type`: pre-trained model name (str). Set to `None` if a custom model is used.
#@markdown - `pretrained_model_path`: path to custom model (str). Set to `False` if a pre-trained model is used.
model_type = None #@param {type:"raw"}
pretrained_model = False #@param {type:"raw"}

gpu = False #@param {type:"boolean"}
net_avg = True #@param {type:"boolean"}
diam_mean = 30.0 #@param {type:"number"}
device = None #@param {type:"raw"}
residual_on = True #@param {type:"boolean"}
style_on = True #@param {type:"boolean"}
concatenation = False #@param {type:"boolean"}
nchan = 2 #@param {type:"integer"}

#@markdown **Model settings: run**
channels = [2,0] #@param {type:"raw"}
flow_threshold = 0.4 #@param  {type:"number"}
mask_threshold = 0.0 #@param {type:"number"}

#initialising run directory name
curr_datetime = datetime.now()
run_name = f"run{run_id}_{curr_datetime.strftime('%Y%m%d-%H%M%S')}"
save_dir = os.path.join(main_save_dir, run_name)
print(f"Output will be saved in {save_dir}")

#writing to settings to a JSON file
model_args = {
        "model_type": model_type,
        "pretrained_model": pretrained_model,
        "gpu": gpu,
        "net_avg": net_avg,
        "diam_mean": diam_mean,
        "device": device,
        "residual_on": residual_on,
        "style_on": style_on,
        "concatenation": concatenation,
        "nchan": nchan  # Fill in the value for `nchan` here
    }

run_args = {"channels":channels,
            "flow_threshold":flow_threshold,
            "mask_threshold":mask_threshold}
parameters = {
    "Slice": {
        "name": train_img_name,
        "img_path": train_img_path,
        "mask_path": ""
    },
    "Model parameters": model_args,
    "Run parameters": run_args
}

write_json(parameters, save_dir, overwrite=overwrite, verbose=True)

Output will be saved in test_dir/run_20240226-125259
{'Slice': {'name': '', 'img_path': '', 'mask_path': ''}, 'Model parameters': {'model_type': None, 'pretrained_model': False, 'gpu': False, 'net_avg': True, 'diam_mean': 30.0, 'device': None, 'residual_on': True, 'style_on': True, 'concatenation': False, 'nchan': 2}, 'Run parameters': {'channels': [2, 0], 'flow_threshold': 0.4, 'mask_threshold': 0.0}}


###0. Load the model

In [None]:
#@markdown [code: models.CellposeModel]
model = models.CellposeModel(
    gpu=model_args.get('gpu', False),
    pretrained_model=model_args.get('pretrained_model', False),
    model_type=model_args.get('model_type', None),
    net_avg=model_args.get('net_avg', True),
    diam_mean=model_args.get('diam_mean', 30.0),
    device=model_args.get('device', None),
    residual_on=model_args.get('residual_on', True),
    style_on=model_args.get('style_on', True),
    concatenation=model_args.get('concatenation', False),
    nchan=model_args.get('nchan', 2)
    )

### 1. Run the model on an input slice

In [None]:
next_train_obj = Slice(name=train_img_name, img_path=train_img_path, mask_path="")

mask_path = run_cellpose_model(slice_obj = next_train_obj,
                   save_dir = save_dir,
                   model = model,
                   run_args = run_args)

#@markdown **5. Store results (SliceObject, model)**
slice_predicted = Slice(name=f"{next_train_obj.name}_predicted",
                        img_path=next_train_obj.img_path,
                        mask_path=mask_path)

### 2. Manually refining the masks
1. load _seg.npy in the GUI
2. manually refine
3. save masks
4. convert to tif

### 3. Compare predictions with manually refined masks.

In [None]:
#@markdown **Input**
img_mask_refined = "" #@param {type:"string"}

#@markdown <hr>

#@markdown 1. Add manually refined masks to the Slice and `train_objs`
next_train_obj.mask_path = img_mask_refined
train_objs.append(next_train_obj)

#@markdown 2. Compare predicted with manual refined masks
slice_predicted.show('mask')
next_train_obj.show('mask')

### 4. Run the model on the golden standard
- compare with golden standard segmentation

In [None]:
#@markdown run on golden standard
#run models on the golden standard
mask_path = run_cellpose_model(slice_obj = test_obj,
                  save_dir = save_dir,
                  model = model,
                  run_args = run_args)

#store as a Slice object
test_slice_predicted = Slice(name=f"{test_obj.name}_predicted",
                        img_path=test_obj.img_path,
                        mask_path=mask_path)

test_results.append(test_slice_predicted)
test_slice_predicted.show('mask')
test_obj.show('mask')

### Train the model
- https://cellpose.readthedocs.io/en/latest/api.html#cellpose.models.CellposeModel.train

- retrain the previously loaded model

In [None]:
#@markdown Available training images:
print("id\tslice_name\tmask_path")
for idx, train_obj in enumerate(train_objs):
  print(f"{idx}\t{train_obj.name}\t{train_obj.mask_path.split('/')[-1]}")


In [None]:
#current_train_obj = train_objs[0]
#current_test_obj = test_objs[0]
#@markdown <br>**Data settings**<hr>
selected_train_obj_ids = [0] #@param {type:"raw"}

#@markdown <br>**Training settings**<hr>
#pretrained_model_path = "/content/gdrive/MyDrive/msc-internship_HI_2024_vmp/02_results/03.1_finetuning/iteration_01/models/cyto2_retrained" #@param {type:"raw"}
n_epochs = 100  #@param {type:"integer"}
learning_rate = 0.1  #@param {type:"number"}
momentum = 0.9  #@param {type:"number"}
sgd = True  #@param {type:"boolean"}
weight_decay = 0.0001  #@param {type:"number"}
batch_size = 8  #@param {type:"integer"}
nimg_per_epoch = None  #@param {type:"raw"}
rescale = True  #@param {type:"boolean"}
min_train_masks = 1  #@param {type:"integer"}
normalize = True  #@param {type:"boolean"}

#@markdown <br>**Output settings**<hr>
save_every = 100  #@param {type:"integer"}
save_each = False #@param {type:"boolean"}
model_name = "retrained" #@param {type:"string"}

#Data loading
selected_train_objs = [train_objs[idx] for idx in selected_train_obj_ids]
train_data = [sliceobj.matrix("img") for sliceobj in selected_train_objs]
train_labels = [sliceobj.matrix("mask") for sliceobj in selected_train_objs]
train_files = [sliceobj.name for sliceobj in selected_train_objs]
test_data = None
test_labels = None
test_files = None

#Write parameters to a file
parameters = {
    "n_epochs": n_epochs, "train_files": train_files,
    "test_files": test_files, "channels": channels, "normalize": normalize,
    "save_path": save_dir, "save_every": save_every, "save_each": save_each,
    "learning_rate": learning_rate, "momentum": momentum,
    "sgd": sgd, "weight_decay": weight_decay, "batch_size": batch_size,
    "nimg_per_epoch": nimg_per_epoch, "rescale": rescale,
    "min_train_masks": min_train_masks, "model_name": model_name
    }

write_json(parameters, save_dir, output_name=".retrained_model_params.JSON")

save_dirs.append(save_dir)
model_names.append(model_name)


# Visualise all segmentations on golden standard (compilation)

In [None]:
#loop through runs
for run_dir in os.listdir(main_save_dir):
  abs_run_dir_path = os.path.join(main_save_dir, run_dir)
  test_pred_mask_path = os.path.join(abs_run_dir_path, f"{test_obj.name}_predicted_masks.tif")
  plt.imshow(imread(test_obj.img_path))

  outlines_pred = plot.outlines_list(imread(test_pred_mask_path))
  for o in outlines_pred:
    plt.plot(o[:,0], o[:,1], color='r')