# Description
*author:* Vina My Pham<br>
*supervisor:* Robin van der Weide<br>
*project:* MSc internship project<br>
<br>
*date:* January 15 - July 19, 2024<br>
*host:* Kind group, Hubrecht Institute<br>
*university:* Bioinformatics, Wageningen University & Research<br>

---
Notebook to run Cellpose3D on a stack.

Input:
- (requirements.txt for installing Cellpose and dependencies using `pip`.)
- A TIFF file of the stack.
- Path to the output directory

Output:
- A TIFF file of the Cellpose predictions



# Notebook initialisation
**Description:** This block contains the code for the set-up of the notebook.
1. installing cellpose
2. mounting the notebook to the Drive and imports
3. loading custom functions
4. checking GPU status




In [None]:
#@markdown [mounting the notebook to the drive]
from google.colab import drive
drive.mount('/content/gdrive')#, force_remount=True)


In [None]:
#@markdown [Installing Cellpose and dependencies using pip and a `requirements.txt` file]<br>
#@markdown [runtime: ~3min]

pip_requirements_path = "/content/gdrive/MyDrive/msc-internship_HI_2024_vmp/01_notebooks/colab_requirements.txt" #@param {type:"string"}

!pip install -r "$pip_requirements_path"
print("Succesfully installed requirements with pip")




In [None]:
# #@markdown [imports]
from datetime import datetime
print(f"{datetime.now()}\tImporting packages")

import os, sys, json, copy
from pprint import pprint
import skimage.io
import tifffile as tif

import numpy as np
import matplotlib.pyplot as plt

from cellpose import models, plot, utils
from cellpose.io import masks_flows_to_seg

print(f"{datetime.now()}\tFinished importing packages")

# #@markdown [custom functions]
print(f"{datetime.now()}\tLoading functions")
#_output
def write_json(parameters: dict, save_dir: str,
               output_name: str = ".model_params.JSON",
               overwrite=False, verbose=True) -> str:
    """Write settings to a JSON file

    Args:
        parameters (dict): Settings to be written to the JSON file
        save_dir (str): The directory path where the JSON file will be saved
        output_name (str): name of output file. Default: ".model_params.JSON"
        overwrite (bool, optional): Overwrite if file exists. Default: False
        verbose (bool, optional): Print verbose. Default: True

    Returns:
        str: The path where the JSON file was saved

    Raises:
        FileExistsError: If a file with `output_name` in `save_dir` already
                         exists, and `overwrite` is set to False
    """
    if os.path.exists(save_dir) == False:
      os.makedirs(save_dir)

    json_path = os.path.join(save_dir, output_name)

    if os.path.exists(json_path) and not overwrite:
        raise FileExistsError(f"File '{json_path}' exists and `overwrite` has" +
                              f" been set to {overwrite}")

    with open(json_path, 'w') as outfile_obj:
        json.dump(parameters, outfile_obj, indent=4)

    if verbose:
        print(f"All settings written to {json_path}")

    return json_path

#_ visualisation
def show_all_planes(stack_array: np.array, plot_channel: int, zstep: int = 1,
                    ncols: int = 7, masks: np.array = None,
                    outline_color: str = 'r', outline_width: float = 1.2,
                    figsize: tuple[int, int] = (15, 30)) -> None:
    """wrapper for show_planes() to plot all planes
    """
    nplanes=stack_array.shape[0]
    min = 0
    while True:
        max = min + ncols*zstep
        if max > nplanes+zstep*ncols: break
        plotrange = np.arange(min,max,zstep,int)
        show_planes(stack_array, plot_channel, plotrange, zstep, ncols, masks,
                    outline_color, outline_width, figsize)
        min = max

    return None

def show_planes(stack_array: np.array, plot_channel: int, planerange: list[int],
                zstep: int = 1, ncols: int = 7, masks: np.array = None,
                outline_color: str = 'r', outline_width: float = 1.2,
                figsize: tuple[int, int] = (15, 30)) -> plt.figure:
    """Visualise the planes of a stack and their masks (optional)

    Args:
        stack_array (np.array): 3D stack array
        plot_channel (int): Channel to plot
        planerange (list of ints): Range of planes to plot
        zstep (int): Step size along the z-axis (default: 1)
        ncols (int): Number of columns for subplots (default: 7)
        masks (np.array, optional): 3D mask array. If provided, outlines of the
                                    masks will be plotted.
        outline_color (str): Color of the outlines (default: 'r'; red)
        linewidth (float): Width of the outline lines
        figsize (tuple of 2 ints): Figure size (width, height)

    Returns:
        fig (matplotlib.figure.Figure): Matplotlib figure object

    Notes:
        uses modules: numpy, matplotlib, cellpose.plot, cellpose.utils, copy
    """
    if len(planerange) == 1:
        raise ValueError(f"use `show_single()` for one plane.")
        return None

    #img = np.transpose(stack_array, (plane_id, channel_id, y_id, x_id))
    img = stack_array #assuming its in order plane-channel-y-x

    fig, axes = plt.subplots(1, ncols, figsize=figsize)

    for i, iplane in enumerate(planerange):
        try:
            img0 = img[iplane, :, :, plot_channel]
            axes[i].imshow(img0, cmap=plt.cm.gray)

            if isinstance(masks, np.ndarray):
                outlines = utils.outlines_list(masks[iplane, :, :])
                for o in outlines:
                    axes[i].plot(o[:, 0], o[:, 1], color=outline_color,
                                 linewidth=outline_width)
            axes[i].axis("off")
            axes[i].set_title(f"plane {iplane}", size=10)

        except IndexError: #plot black when all planes have been plotted
            imgout = copy.deepcopy(img0)
            imgout[imgout != 0] = 0
            axes[i].imshow(imgout)

            axes[i].set_title(f"plane {iplane}", size=10)
            axes[i].axis("off")

    plt.tight_layout()
    plt.show()

    return fig

def show_single(stack_array: np.array, planerange: list[int],
                plot_channel: int,  figsize: tuple[int,int] = (5,5),
                masks: np.array = None, outline_color: str = 'r',
                outline_width: float = 1.2, save: bool or str = False) -> None:
    """Plot individual plane(s) with mask outlines

    Args:
        stack_array (np.array): 3D stack array
        plot_channel (int): Channel to plot (0-based)
        masks (np.array, optional): 3D mask array. If provided, mask outlines
                                    will be plotted.
        outline_color (str): Color of the outlines
        outline_width (float): Width of the outline lines
        save (bool or str): If True, save the image(s) in working directory.
                            If string, directory path to save images to.
                            (default: False)

    Returns:
        None
    """
    #img = np.transpose(stack_array, (plane_id, y_id, x_id, channel_id))
    img = stack_array #assuming its in order plane-y-x-channel

    for iplane in planerange:
        img0 = img[iplane, :, :, plot_channel]
        plt.figure(figsize=figsize)
        plt.imshow(img0, cmap=plt.cm.gray)

        if isinstance(masks, np.ndarray):
            outlines = utils.outlines_list(masks[iplane, :, :])
            for o in outlines:
                plt.plot(o[:, 0], o[:, 1], color=outline_color,
                         linewidth=outline_width)

        plt.title(f"plane {iplane}", size=10)
        plt.axis("off")

        if save:
            if isinstance(save, str):
                plt.savefig(f"{save}/plane{iplane}.png")
            else:
                plt.savefig(f"plane{iplane}.png")

        plt.show();

print(f"{datetime.now()}\tFinished loading functions")

In [None]:
#@markdown ###GPU status
#@markdown *Note: To change hardware type: `Runtime` >> `Change runtime type` >> `Hardware accelerator`*
def check_gpu_connection(use_gpu: bool) -> None:
    """Reports the details on GPU connection

    Args:
        use_gpu (bool): Whether to use the GPU for the script

    Returns:
        None

    Raises:
        GPUConnectionError: If the runtime type does not match the connection
        settings
    """
    from cellpose import core
    class GPUConnectionError(Exception):
        def __init__(self, message):
            self.message = message

    if core.use_gpu() != use_gpu:
      raise GPUConnectionError(f"Connection type (`{core.use_gpu()}`) does " +
                               f"not match connection settings (`{use_gpu}`)."+
                               "\nPlease check the hardware type in the Colab"+
                               " Notebook settings.")

    if core.use_gpu():
      !nvidia-smi

    return None

use_gpu = False #@param {type:"boolean"}
print(f"GPU usage enabled? {use_gpu}")
check_gpu_connection(use_gpu)

## Custom classes/functions

## (Re)start session

**Provide output directory**.

If the directory does not exist, the stored settings from the `.model_params.JSON` file will be loaded in.<br>
(If the JSON file does not exist, a new session will start.)

If the directory does not exist, it will be made and a new session will start.


In [None]:
session_dir = "/content/gdrive/MyDrive/msc-internship_HI_2024_vmp/02_results/04_cellpose3D/test" #@param {type:"string"}
session_dir = os.path.join(session_dir, "")

new_session = True
if os.path.exists(session_dir):
    print(f"{datetime.now()}\tDirectory exists.")
    if os.path.exists(f"{session_dir}/.model_params.JSON"):
        print(f"{datetime.now()}\t.model_params.JSON file found. Previous settings will be loaded in.")
        new_session = False
    else: print(f"{datetime.now()}\tNo .model_params.JSON found in the directory.\nA new session will start.")
else:
    os.makedirs(session_dir)
    print(f"{datetime.now()}\tDirectory does not exist. New session will start.")

# Image loading


In [None]:
#@markdown [code: loading in stack file]
stack_img_path = "/content/gdrive/MyDrive/msc-internship_HI_2024_vmp/00_raw_data/full_zstacks/sub/SPE_20230327_D25_x600-850_y550-800.tif" #@param {type:"string"}
stack_array = skimage.io.imread(stack_img_path)
file_name = stack_img_path.split("/")[-1]
print("file name: ", file_name)
print("Stack has shape: ", stack_array.shape)
print("\tid\tvalue")
for idx, val in enumerate(stack_array.shape):
    print(f"\t{idx}\t{val}")

**Indicate the planes (Z), channels, and X- and Y- data**

In [None]:
##@markdown Reshape the array if shape does not correspond to `nplanes x channels x nY x nX`<br>
##@markdown *note: if stack created using singletiffs2stack():*<br>
##@markdown &emsp;&emsp; `plane_id` = 0<br>
##@markdown &emsp;&emsp; `channel_id` = 3<br>
##@markdown &emsp;&emsp; `y_id` = 1<br>
##@markdown &emsp;&emsp; `x_id` = 2<br>

plane_id = 0 #@param {type:"integer"}
channel_id = 3 #@param {type:"integer"}
y_id = 1 #@param {type:"integer"}
x_id = 2 #@param {type:"integer"}

## Showing the middle plane

In [None]:
print("Showing middle plane...")
iplane_middle = round(stack_array.shape[plane_id]/2)
show_single(stack_array, plot_channel = 1, planerange = [iplane_middle])

### Initialise a session with a model
This block can be skipped if a previous session is resumed.<br>
An error will be raised if a `.model_params.JSON` file from a previous session had been detected.

In [None]:
if not new_session:
    raise IOError("Settings from a previous session are loaded in. "+
                  "Any input provided here will be ignored.")

#@markdown **Model path input**
model_type = "" #@param {type:"string"}
pretrained_model = "/content/gdrive/MyDrive/msc-internship_HI_2024_vmp/02_results/03_finetuning/run_week07/run03_20240306-094709/models/cyto2_trained-on-d25z30-d30z43-d40z58-d35z15" #@param {type:"string"}

#Check model settings
if len(model_type) > 0 and len(pretrained_model) > 0:
    raise IOError("Input for both `model_type` and `pretrained_model` are provided. Please choose one.")
elif len(model_type) > 0:
    pretrained_model = None
    print(f"Loading in the pre-trained model `{model_type}`")
elif len(pretrained_model) > 0:
    model_type = None
    print(f"Loading in the custom model `{pretrained_model.split('/')[-1]}`")
else:
    raise IOError("No model input found. Please provide a name or path in `model_type` or `pretrained_model`.")

#@markdown **Model settings:**
gpu = True #@param {type:"boolean"}
net_avg = True #@param {type:"boolean"}
diam_mean = 30.0 #@param {type:"number"}
device = None #@param {type:"raw"}
residual_on = True #@param {type:"boolean"}
style_on = True #@param {type:"boolean"}
concatenation = False #@param {type:"boolean"}
nchan = 2 #@param {type:"integer"}

#writing to JSON file
model_args = {
        "model_type": model_type,
        "pretrained_model": pretrained_model,
        "gpu": gpu,
        "net_avg": net_avg,
        "diam_mean": diam_mean,
        "device": device,
        "residual_on": residual_on,
        "style_on": style_on,
        "concatenation": concatenation,
        "nchan": nchan
    }

_ = write_json(model_args, session_dir, verbose=True, overwrite=True)


### Loading model

In [None]:
#@markdown [code: displaying model settings]
print("The following model will loaded in.")

if not new_session:
    print("[from previous session]")
    with open(f'{session_dir}.model_params.JSON') as json_file:
        model_args = json.load(json_file)

pprint(model_args)

In [None]:
#@markdown [code: models.CellposeModel stored as `model`]
model = models.CellposeModel(
    gpu=model_args.get('gpu', False),
    pretrained_model=model_args.get('pretrained_model', False),
    model_type=model_args.get('model_type', None),
    net_avg=model_args.get('net_avg', True),
    diam_mean=model_args.get('diam_mean', 30.0),
    device=model_args.get('device', None),
    residual_on=model_args.get('residual_on', True),
    style_on=model_args.get('style_on', True),
    concatenation=model_args.get('concatenation', False),
    nchan=model_args.get('nchan', 2)
    )

model.gpu = model_args.get('gpu', False)
model.diam_mean = model_args.get('diam_mean', 30.0)
print("loaded model", model.pretrained_model)

# Running Cellpose

In [None]:
#@markdown [**Showing the channels of the centre plane**]
single_plane_id = round(stack_array.shape[0]/2)
single_plane = stack_array[single_plane_id,:,:,:]

fig, axes = plt.subplots(1,single_plane.shape[channel_id-1])
for idx in range(single_plane.shape[channel_id-1]):
    if channel_id == 1: axes[idx].imshow(single_plane[idx,:,:], cmap = plt.cm.gray)
    if channel_id == 2: axes[idx].imshow(single_plane[:,idx,:], cmap = plt.cm.gray)
    if channel_id == 3: axes[idx].imshow(single_plane[:,:,idx], cmap = plt.cm.gray)
    axes[idx].set_title(f"Channel {idx+1}")
    axes[idx].axis("off")
plt.suptitle(f"Plane {single_plane_id}", y=0.79, weight='bold')
plt.tight_layout()
plt.show();

# Running `model.eval` - 3D flows from 3D masks
- stack stored in `stack_array`

if do_3D == True:
    following settings are not used:
    - interp
    - flow_threshold
    - stitch_threshold

In [None]:
#@markdown **model settings**
channels = [2,0] #@param {type:"raw"}
diameter = 30.0 #@param {type:"number"}
batch_size = 16 #@param {type:"number"}

#@markdown **3D segmentation settings**
do_3D = False #@param {type:"boolean"}
anisotropy = 0 #@param {type:"number"}
stitch_threshold = 0.5 #@param {type:"number"}

In [None]:
#@markdown [code: custom run_args]
save_as_segnpy=False #@param {type:"boolean"}

print("stack shape: ", stack_array.shape)
print()
run_args = {
    "batch_size": batch_size,
    "channels": channels,
    "channel_axis": channel_id,
    "z_axis": plane_id,
    "diameter" : diameter,
    "do_3D" : False,
    "anisotropy": 0.59,
    "stitch_threshold" : 0.5
}
mode ={True : "3dflows", False : "stitched2dmasks"}[run_args.get("do_3D", False)]
print("mode: ", mode)

pprint(run_args)
print()

_ = write_json(run_args, session_dir, output_name=f".model.eval_{mode}.json", verbose=True, overwrite=True)

In [None]:
#@markdown [code: model.eval]
print(datetime.now(),"\trunning model.eval")
masks, flows, _ = model.eval(
    stack_array,
    batch_size=run_args.get("batch_size", 8),
    channels=run_args.get("channels", None),
    channel_axis=run_args.get("channel_axis", None),
    z_axis=run_args.get("z_axis", None),
    normalize=run_args.get("normalize", True),
    invert=run_args.get("invert", False),
    rescale=run_args.get("rescale", None),
    diameter=run_args.get("diameter", None),
    do_3D=run_args.get("do_3D", False),
    anisotropy=run_args.get("anisotropy", None),
    stitch_threshold=run_args.get("stitch_threshold", None),
    net_avg=run_args.get("net_avg", True),
    augment=run_args.get("augment", False),
    tile=run_args.get("tile", True),
    tile_overlap=run_args.get("tile_overlap", 0.1),
    resample=run_args.get("resample", True),
    compute_masks=run_args.get("compute_masks", True),
    min_size=run_args.get("min_size", 15),
    progress=run_args.get("progress", True),
    loop_run=run_args.get("loop_run", False),
    model_loaded=run_args.get("model_loaded", True)
)

print(datetime.now(),"\tsaving segmentations...")
file_name = f"{session_dir}/3D_masks_mode-{mode}"
if save_as_segnpy:
    masks_flows_to_seg(
        stack_array, masks, flows, run_args.get('diameter', None),
        file_name, run_args.get('channels', [0, 0])
        )
else:
    outfile_name = file_name+".tiff"
    tif.imwrite(outfile_name, masks, bigtiff=True)

print(datetime.now(),f"\t3D masks saved as {file_name}.")
print()

# Visualisation of segmentations

In [None]:
#@markdown **plot parameters**
plot_channel = 1 #@param {type:"number"}
figsize = (15, 30) #@param {type:"raw"}

stack_array = stack_array
masks = masks

show_all_planes(stack_array = stack_array,
                plot_channel = plot_channel,
                masks = masks,
                figsize = figsize)