# Train a cellpose to segment A549 cells  
Author: Ke  
Data source: Dr. Weikang Wang

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
from pathlib import Path
from livecellx.preprocess.utils import normalize_img_to_uint8
from livecellx.segment.cellpose_utils import segment_single_images_by_cellpose, segment_single_image_by_cellpose

In [None]:
# path for saving re-fitted cellpose model
# pretrained_model_path = "./notebook_results/cellpose/cellpose_A549_cyto2/models/cellpose_residual_on_style_on_concatenation_off_cellpose_A549_cyto2_2023_03_07_01_07_22.191293"
# pretrained_model_path = "cellpose_residual_on_style_on_concatenation_off_cellpose_A549_cyto2_cellbody_2023_04_17_21_49_50.313712"
# pretrained_model_path = "/home/ken67/LiveCellTracker-dev/notebooks/notebook_results/cellpose/cellpose_A549_cyto2_cellbody_bg_corrected/models/cellpose_residual_on_style_on_concatenation_off_cellpose_A549_cyto2_cellbody_bg_corrected_2023_04_19_12_03_03.872596"
pretrained_model_path = "/home/ken67/LiveCellTracker-dev/notebooks/notebook_results/cellpose/cellpose_A549_cyto2_cellbody/models/cellpose_residual_on_style_on_concatenation_off_cellpose_A549_cyto2_cellbody_2023_04_17_21_49_50.313712"
# model_type='cyto' or 'nuclei' or 'cyto2'
# model = models.Cellpose(gpu=True, model_type="cyto2", pretrained_model=pretrained_model_path)
model = models.CellposeModel(pretrained_model=pretrained_model_path, gpu=True) #, model_type="cyto2")

In [None]:
from livecellx.core.datasets import LiveCellImageDataset, SingleImageDataset
dataset_dir_path = Path(
    "../datasets/test_data_STAV-A549/DIC_data"
)

mask_dataset_path = Path("../datasets/test_data_STAV-A549/mask_data")
dic_dataset = LiveCellImageDataset(dataset_dir_path, ext="tif")

diameter = 80

In [None]:
from livecellx.preprocess.utils import enhance_contrast, standard_preprocess
import random

num_img_to_viz = 2
times = dic_dataset.times
diameter = 80
for i in range(0, num_img_to_viz):
    # randomly sample a time from the dataset
    img = dic_dataset[times[random.randint(0, len(times) - 1)]]
    # img = normalize_img_to_uint8(img)
    img = standard_preprocess(img)
    mask = segment_single_image_by_cellpose(img, model, channels=[[0, 0]], diameter=diameter)

    # visualize
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(enhance_contrast(img))
    axes[0].set_title("raw image")
    axes[1].imshow(mask)
    axes[1].set_title("cellpose mask")
    plt.show()

In [None]:
num_img_to_viz = 2
times = dic_dataset.times
diameter = 50
for i in range(0, num_img_to_viz):
    # randomly sample a time from the dataset
    img = dic_dataset[times[random.randint(0, len(times) - 1)]]
    # img = normalize_img_to_uint8(img)
    img = standard_preprocess(img)
    mask = segment_single_image_by_cellpose(img, model, channels=[[0, 0]], diameter=diameter)

    # visualize
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(enhance_contrast(img))
    axes[0].set_title("raw image")
    axes[1].imshow(mask)
    axes[1].set_title("cellpose mask")
    plt.show()

In [None]:
from livecellx.core.io_utils import save_png
diameter = 80
out_dir = Path("./notebook_results/cellpose/test_outputs")
out_dir.mkdir(parents=True, exist_ok=True)

for time in dic_dataset.times:
    img = dic_dataset[time]
    img = normalize_img_to_uint8(img)
    mask = segment_single_image_by_cellpose(img, model, channels=[[0, 0]], diameter=diameter)
    
    # save the mask
    mask_path = out_dir / f"mask_{time}.png"
    save_png(mask_path, mask, mode="I")