# Train a cellpose to segment A549 cells  
Author: Ke  
Data source: Dr. Weikang Wang

In [None]:
# !pip install omnipose

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
from pathlib import Path


In [None]:
# path for saving re-fitted cellpose model
model_save_path = Path("./results/cellpose/cellpose_A549_cyto2_cellbody_bg_corrected")
model_save_path.mkdir(parents=True, exist_ok=True)


## Loading data for training models from CellPose

In [None]:
data_dirs = [
    r"D:\LiveCellTracker-dev\datasets\nidhi_data_8-7-2023\nidhi-training-9-28"
]
raw_img_dir = [Path(path) / "Img" for path in data_dirs]
dist_img_dir = [Path(path) / "Bwdist" for path in data_dirs]
mask_img_dir = [Path(path) / "Interior" for path in data_dirs]

# check if paths exist
for i in range(len(raw_img_dir)):
    assert raw_img_dir[i].exists(), f"{raw_img_dir[i]} does not exist"
    assert dist_img_dir[i].exists(), f"{dist_img_dir[i]} does not exist"
    assert mask_img_dir[i].exists(), f"{mask_img_dir[i]} does not exist"

In [None]:
raw_img_dir

In [None]:
raw_img_paths = [sorted(list(path.glob("*.tif"))) for path in raw_img_dir]
dist_img_paths = [sorted(list(path.glob("*.tif"))) for path in dist_img_dir]
mask_img_paths = [sorted(list(path.glob("*.png"))) for path in mask_img_dir]

# check existence of all images
for i in range(len(raw_img_dir)):
    assert len(raw_img_paths[i]) == len(
        mask_img_paths[i]
    ), f"Number of images in {raw_img_dir[i]}, {dist_img_dir[i]}, {mask_img_dir[i]} do not match, number of images: {len(raw_img_paths[i])}, {len(dist_img_paths[i])}, {len(mask_img_paths[i])}"

# flatten all lists
raw_img_paths = [item for sublist in raw_img_paths for item in sublist]
dist_img_paths = [item for sublist in dist_img_paths for item in sublist]
mask_img_paths = [item for sublist in mask_img_paths for item in sublist]
    

In [None]:
# read images
raw_imgs = [imread(str(path)) for path in raw_img_paths]
dist_imgs = [imread(str(path)) for path in dist_img_paths]
mask_imgs = [imread(str(path)) for path in mask_img_paths]

In [None]:
len(raw_imgs), len(dist_imgs), len(mask_imgs)

In [None]:
# squeeze images
raw_imgs = [img.squeeze() for img in raw_imgs]
dist_imgs = [img.squeeze() for img in dist_imgs]
mask_imgs = [img.squeeze() for img in mask_imgs]


In [None]:
raw_imgs[0].shape, type(raw_imgs[0])

In [None]:
import cv2

# convert raw rgb images to grayscale via opencv
raw_imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) for img in raw_imgs]

### Check image shape match

In [None]:
for i in range(len(raw_imgs)):
    assert (
        raw_imgs[i].shape ==  mask_imgs[i].shape
    ), f"Image shapes do not match for image {i}, {raw_imgs[i].shape}, {mask_imgs[i].shape}"

In [None]:
len(raw_imgs), len(dist_imgs), len(mask_imgs)

Note the following assumptions  
    when Dr. WWK annotated datasets, he intentionally avoid overlapping masks, so we can obtain label masks simply by label()

In [None]:
from livecellx.preprocess.utils import normalize_img_to_uint8, standard_preprocess
# normalize images
raw_imgs = [standard_preprocess(img) for img in raw_imgs]

In [None]:
import skimage
label_mask_imgs = [skimage.measure.label(mask_img) for mask_img in mask_imgs]

# counter how many mask labels are empty
empty_mask_label_count = 0
for i in range(len(label_mask_imgs)):
    if len(np.unique(label_mask_imgs[i])) <= 1:
        empty_mask_label_count += 1
        # show image and label mask
        # fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        # ax[0].imshow(raw_imgs[i])
        # ax[1].imshow(label_mask_imgs[i])
print(f"Number of empty mask labels: {empty_mask_label_count}", "total number of images:", len(label_mask_imgs))

## Omnipose prediction before training

In [None]:
# !pip install --upgrade mahotas==1.4.13

In [None]:
from cellpose_omni import models
from cellpose_omni.models import MODEL_NAMES

MODEL_NAMES

In [None]:
model_name = 'bact_phase_omni'
use_GPU=True
model = models.CellposeModel(gpu=use_GPU, model_type=model_name)

In [None]:
import time
chans = [0,0] #this means segment based on first channel, no second channel 

n = [-1] # make a list of integers to select which images you want to segment
# n = range(nimg) # or just segment them all 

# define parameters
params = {'channels':chans, # always define this with the model
          'rescale': None, # upscale or downscale your images, None = no rescaling 
          'mask_threshold': -1, # erode or dilate masks with higher or lower values 
          'flow_threshold': 0, # default is .4, but only needed if there are spurious masks to clean up; slows down output
          'transparency': True, # transparency in flow output
          'omni': True, # we can turn off Omnipose mask reconstruction, not advised 
          'cluster': True, # use DBSCAN clustering
          'resample': True, # whether or not to run dynamics on rescaled grid or original grid 
          # 'verbose': False, # turn on if you want to see more output 
          'tile': False, # average the outputs from flipped (augmented) images; slower, usually not needed 
          'niter': None, # None lets Omnipose calculate # of Euler iterations (usually <20) but you can tune it for over/under segmentation 
          'augment': False, # Can optionally rotate the image and average outputs, usually not needed 
          'affinity_seg': False, # new feature, stay tuned...
         }

tic = time.time() 
masks, flows, styles = model.eval([raw_imgs[i] for i in n],**params)

net_time = time.time() - tic
print('total segmentation time: {}s'.format(net_time))

In [None]:
from cellpose_omni import plot
import omnipose
import matplotlib as mpl


for idx,i in enumerate(n):
    maski = masks[idx] # get masks
    bdi = flows[idx][-1] # get boundaries
    flowi = flows[idx][0] # get RGB flows 

    # set up the output figure to better match the resolution of the images 
    f = 10
    szX = maski.shape[-1]/mpl.rcParams['figure.dpi']*f
    szY = maski.shape[-2]/mpl.rcParams['figure.dpi']*f
    fig = plt.figure(figsize=(szY,szX*4))
    fig.patch.set_facecolor([0]*4)
    
    plot.show_segmentation(fig, omnipose.utils.normalize99(raw_imgs[i]), 
                           maski, flowi, bdi, channels=chans, omni=True,
                           interpolation=None)

    plt.tight_layout()
    plt.show()

### Fine-tune on a cellpose model

In [None]:
model_path = r"D:\LiveCellTracker-dev\notebooks\application_nidhi_JC\results\cellpose\cellpose_A549_cyto2_cellbody_bg_corrected\models\cellpose_residual_on_style_on_concatenation_off_cellpose_A549_cyto2_cellbody_bg_corrected_2023_09_28_05_00_13.696883"
model = models.CellposeModel(gpu=True, pretrained_model=model_path)
# model.sz.cp.train(train_data=raw_imgs, train_labels=label_mask_imgs, batch_size=5, channels=[0,0], n_epochs=10000, save_path=model_save_path)
model.train(raw_imgs, label_mask_imgs, channels=[0,0], n_epochs=10000, save_path=model_save_path)

Randomly show 10 prediction samples

In [None]:
raw_imgs[0]

In [None]:
from livecellx.segment.cellpose_utils import segment_single_image_by_cellpose

for _ in range(5):
    index = np.random.randint(0, len(raw_imgs))
    # masks = segment_single_image_by_cellpose(raw_imgs[index][0], model, channels=[[0, 0]], diameter=40)
    result_tuple = model.eval([raw_imgs[index][0]], diameter=55, channels=[[0, 0]])
    # masks, flows, styles, diams = result_tuple
    masks, flows, styles = result_tuple
    assert len(masks) == 1
    masks = masks[0]
    flows = flows[0]
    print("masks shape: ", masks.shape)
    print("flows length: ", len(flows))
    print("flows shape: ", flows[0].shape)
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(raw_imgs[index][0])
    axes[0].set_title("raw image")
    axes[1].imshow(masks)
    axes[1].set_title("cellpose mask")
    axes[2].imshow(label_mask_imgs[index])
    axes[2].set_title("label mask")
    plt.show()

    flow_fig, flow_axes = plt.subplots(1, 5, figsize=(20, 5))
    flow_axes[0].imshow(flows[0])
    flow_axes[0].set_title("hsv")
    flow_axes[1].imshow(flows[1][0])
    flow_axes[1].set_title("flows ch0")
    flow_axes[2].imshow(flows[1][1])
    flow_axes[2].set_title("flows ch1")
    flow_axes[3].imshow(flows[2])
    flow_axes[3].set_title("flows cell prob ch0")
    flow_axes[4].imshow(flows[2] > 0.7)
    # flow_axes[4].imshow(flows[2][1])
    # flow_axes[4].set_title("flows cell prob ch1")

    plt.show()


## Predict with the model trained on your own data

In [None]:
test_img_dir = [Path(r"D:\LiveCellTracker-dev\datasets\nidhi_data_8-7-2023\data\images")]
test_img_paths = [sorted(list(path.glob("*.tif"))) for path in test_img_dir]

test_img_paths = test_img_paths[0]
test_imgs = [imread(str(path)) for path in test_img_paths]

# Nidhi's images are RGB with all channels the same, so we can just take the first channel
test_imgs = [img[:, :, 0] for img in test_imgs]

masks, flows, styles = model.eval(test_imgs,**params)



In [None]:
test_imgs[i].shape

In [None]:
pred_out_dir = Path("./pred_outs")
pred_out_dir.mkdir(parents=True, exist_ok=True)

# Predict on test_img_paths
for idx in range(len(masks)):
    maski = masks[idx] # get masks
    bdi = flows[idx][-1] # get boundaries
    flowi = flows[idx][0] # get RGB flows 

    # set up the output figure to better match the resolution of the images 
    f = 10
    szX = maski.shape[-1]/mpl.rcParams['figure.dpi']*f
    szY = maski.shape[-2]/mpl.rcParams['figure.dpi']*f
    # fig = plt.figure(figsize=(szY,szX*4))
    fig = plt.figure(figsize=(100, 25))
    fig.patch.set_facecolor([0]*4)
    # print(test_imgs[idx].shape, maski.shape, flowi.shape, bdi.shape)
    plot.show_segmentation(fig, omnipose.utils.normalize99(test_imgs[idx]), 
                           maski, flowi, bdi, channels=chans, omni=True,
                           interpolation=None)

    plt.tight_layout()
    

    plt_file_path = pred_out_dir / f"pred_{idx}.png"
    mask_file_path = pred_out_dir / f"mask_{idx}.png"
    plt.savefig(plt_file_path)
    plt.imsave(mask_file_path, maski)
    plt.show()
    plt.close()

    # save the images