In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
import glob
from pathlib import Path
from PIL import Image, ImageSequence
from tqdm import tqdm
import os
import os.path
from livecell_tracker import segment
from livecell_tracker.segment import datasets
from skimage import measure
from livecell_tracker.core import SingleCellTrajectory, SingleCellStatic

In [None]:
# model_type='cyto' or 'nuclei' or 'cyto2'
# model = models.Cellpose(gpu=True, model_type="cyto")
pretrained_model_path = "../livecell_tracker/models/yuzhong_cyto2.pth"
model = models.CellposeModel(pretrained_model=pretrained_model_path, gpu=True)
# list of files
# PUT PATH TO YOUR FILES HERE!
# dir_path = Path(
#     r"D:/xing-vimentin-dic-pipeline/src/livecell_dev/cxa-data/june_2022_data_8bit_png/day0_Notreat_Group1_wellA1_RI_MIP_stitched"
# )
dir_path = Path(
    r"../cxa-data/june_2022_data_8bit_png/restart_day0_Group 1_wellA1_RI_MIP_stitched"
)
# dir_path = Path(
#     r"../cxa-data/june_2022_data_8bit_png/restart_day2_Group 1_wellA1_RI_MIP_stitched"
# )
# imgs = segment.datasets.LiveCellImageDataset(dir_path, ext="tif")
# imgs = segment.datasets.LiveCellImageDataset(dir_path, ext="png")
imgs = segment.datasets.LiveCellImageDataset(dir_path, ext="png", num_imgs=3)
img_list = sorted(glob.glob(str(dir_path / "*tif")))

In [None]:
img_list[:4]


In [None]:
def segment_single_image_by_cellpose(image, model, channels=[[0, 0]], diameter=150):
    result_tuple = model.eval([image], diameter=diameter, channels=channels)
    masks = result_tuple[0]
    return masks[0]


def segment_single_images_by_cellpose(images, model, channels=[[0, 0]], diameter=150):
    masks, flows, styles, diams = model.eval(
        images, diameter=diameter, channels=channels
    )
    return masks


Define detectron segmentation

In [None]:
import detectron2
from detectron2.utils.logger import setup_logger

setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
import cv2

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from livecell_tracker.segment.detectron_utils import gen_cfg


DETECTRON_CFG = gen_cfg(
    model_path=r"./notebook_results/train_log/detectron_train_output__ver0.0.2/model_final.pth",
    output_dir=r"""./notebook_results/train_log/detectron_train_output__ver0.0.2/""",
)
DETECTRON_PREDICTOR = DefaultPredictor(DETECTRON_CFG)
def segment_by_detectron(img):
    outputs = DETECTRON_PREDICTOR(img)
    return outputs

In [None]:
test_img = imgs[0]
test_mask = segment_single_image_by_cellpose(imgs[0], model)


## Create overlay images

In [None]:
def normalize_img_by_zscore(img: np.array):
    """calculate z score of img and normalize to range [0, 255]

    Parameters
    ----------
    img : np.array
        _description_

    Returns
    -------
    _type_
        _description_
    """
    img = (img - np.mean(img.flatten())) / np.std(img.flatten())
    img = img + abs(np.min(img.flatten()))
    img = img / np.max(img) * 255
    return img


def overlay(image, mask, mask_channel_rgb_val=100, img_channel_rgb_val_factor=1):
    mask = mask.astype(np.uint8)
    mask[mask > 0] = mask_channel_rgb_val
    image = normalize_img_by_zscore(image).astype(np.uint8)
    image = image * img_channel_rgb_val_factor
    res = np.zeros(list(mask.shape) + [3])
    res[:, :, 2] = image
    res[:, :, 1] = mask
    return res


overlayed_img = overlay(
    test_img, test_mask, mask_channel_rgb_val=100, img_channel_rgb_val_factor=2
)

figure = plt.figure(figsize=(8, 6), dpi=80)

plt.imshow(overlayed_img)


## Segment all the cells

In [None]:
from livecell_tracker.segment.detectron_utils import convert_detectron_instances_to_binary_masks, convert_detectron_instances_to_label_masks
# def convert_detectron_instances_to_label_masks(instance_pred_masks):
#     res_mask = np.zeros(instance_pred_masks.shape[1:])
#     for idx in range(instance_pred_masks.shape[0]):
#         res_mask[instance_pred_masks[idx, :, :]] = idx + 1
#     return res_mask


# def convert_detectron_instances_to_binary_masks(instance_pred_masks):
#     label_mask = convert_detectron_instances_to_label_masks(instance_pred_masks)
#     label_mask[label_mask > 0] = 1
#     return label_mask

# outputs = segment_by_detectron(normalize_img_by_zscore(imgs[0][:, :, np.newaxis]))["instances"].to("cpu")
# mask = outputs.pred_masks.numpy()
# combine_detectron_instances_masks(outputs.pred_masks.numpy())


## Visualize segmentation by overlay movies

In [None]:
from livecell_tracker.segment.datasets import LiveCellImageDataset
from livecell_tracker.segment.detectron_utils import detectron_visualize_img
def gen_T_filename(frame, ndigits):
    return "T%s.tif" % (str(frame).zfill(ndigits))
ndigits = len(str(len(imgs)))
print("ndigits:", ndigits)

out_dir = Path("./seg_test_tmp")
os.makedirs(out_dir, exist_ok=True)

def segment_cellpose_wrapper(img):
    return segment_single_image_by_cellpose(img, model, diameter=173)

def segment_detectron_wrapper(img):
    if img.ndim == 2:
        img = img[:, :, np.newaxis]
    results = segment_by_detectron(normalize_img_by_zscore(img))
    instances = results["instances"].to("cpu").pred_masks.numpy()
    mask = convert_detectron_instances_to_binary_masks(instances)
    return mask

def segment_raw_img_by_detectron_wrapper(img, return_detectron_results=False):
    if img.ndim == 2:
        img = img[:, :, np.newaxis]
    results = segment_by_detectron(img)
    instance_pred_masks = results["instances"].to("cpu").pred_masks.numpy()
    mask = convert_detectron_instances_to_binary_masks(instance_pred_masks)
    if return_detectron_results:
        return mask, results
    return mask

segment_func_wrapper = segment_raw_img_by_detectron_wrapper
# segment_func = segment_cellpose_wrapper
# segment_func = segment_detectron_wrapper

# overlay_img_generator = ...
# overlay_img_generator = overlay
# overlay_img_generator = detectron_visualize_img

# def cellpose_segment_imgs(img, out_dir, frame, ndigits):
#     for idx in tqdm(range(len(imgs))):
#         img_path = imgs.get_img_path(idx)
#         img = imgs[idx]
#         file_name = os.path.basename(img_path)
#         output_filename = file_name.split(".")[0] + ".png" # change extension to PNG
#         mask = segment_func_wrapper(img)

#         # convert mask to 8-bit binary mask
#         assert mask.max() < 2**8, "more than 256 instances predicted?"
#         mask = mask.astype(np.uint8)
#         temp_img = Image.fromarray(mask)
#         temp_img.save(out_dir / output_filename)

#         overlay_output_filename = "overlay_" + file_name.split(".")[0] + ".png" # change extension to PNG
#         # overlayed_img = overlay(img, mask, mask_channel_rgb_val=100, img_channel_rgb_val_factor=2)
#         overlayed_img = detectron_visualize_img(img, DETECTRON_CFG, )
#         overlayed_img.save(out_dir / overlay_output_filename)


def detectron_segment_imgs(imgs: LiveCellImageDataset, out_dir: Path):
    segmentation_results = {
    }
    for idx in tqdm(range(len(imgs))):
        img_path = imgs.get_img_path(idx)
        img = imgs[idx]
        original_img_filename = os.path.basename(img_path).split(".")[0]
        output_filename = original_img_filename + ".png" # change extension to PNG

        # save binary mask
        mask, predictor_results = segment_raw_img_by_detectron_wrapper(img, return_detectron_results=True)
        # convert mask to 8-bit binary mask
        assert mask.max() < 2**8, "more than 256 instances predicted?"
        mask = mask.astype(np.uint8)
        binary_mask_img = Image.fromarray(mask)
        binary_mask_img.save(out_dir / output_filename)

        # save overlayed image
        overlay_output_filename = "overlay_" + original_img_filename + ".png" # change extension to PNG
        # overlayed_img = overlay(img, mask, mask_channel_rgb_val=100, img_channel_rgb_val_factor=2)
        overlayed_arr = detectron_visualize_img(img[:, :, np.newaxis], DETECTRON_CFG, predictor_results)
        overlayed_img = Image.fromarray(overlayed_arr)
        overlayed_img.save(out_dir / overlay_output_filename)
        del overlayed_img, overlayed_arr, mask, binary_mask_img

        def _save_instance_masks():
            # save predicted instance masks
            pred_binary_masks = predictor_results["instances"].to("cpu").pred_masks.numpy()
            for idx in range(pred_binary_masks.shape[0]):
                pred_binary_mask = pred_binary_masks[idx, :, :]
                pred_binary_mask_img = Image.fromarray(pred_binary_mask)
                pred_binary_mask_img.save(out_dir / f"{original_img_filename}_instance_{idx}.png")
                del pred_binary_mask, pred_binary_mask_img
            del predictor_results, pred_binary_masks

        # _save_instance_masks()

        # generate contours and save to json
        contours = []
        for instance_mask in predictor_results["instances"].to("cpu").pred_masks.numpy():
            tmp_contours = measure.find_contours(
                        instance_mask, level=0.5, fully_connected="low", positive_orientation="low")
            
            if len(tmp_contours) == 0:
                print("no contours found for image:", original_img_filename)
            elif len(tmp_contours) != 1:
                print("[WARN] more than 1 contour found in instance mask, num_contours:", len(tmp_contours))
            # convert to list for saving into json
            contours.extend([[list(coords) for coords in coord_arr]for coord_arr in tmp_contours])
        assert original_img_filename not in segmentation_results, "duplicate image filename?"
        segmentation_results[img_path] = {}
        segmentation_results[img_path]["contours"] = contours
    return segmentation_results


# import torch.utils.data as data_utils
# import torch
# indices = torch.arange(1, 10)
segmentation_results = detectron_segment_imgs(imgs, out_dir)
with open(out_dir / "segmentation_results.json", "w") as f:
    json.dump(segmentation_results, f)

In [None]:
segmentation_results = detectron_segment_imgs(imgs, out_dir)
with open(out_dir / "segmentation_results.json", "w") as f:
    json.dump(segmentation_results, f)

In [None]:
mask, results = segment_raw_img_by_detectron_wrapper(imgs[0], return_detectron_results=True)

## Generate contour points

In [None]:
# for instance in results["instances"].to("cpu").pred_masks.numpy():
#     print(instance)
label_mask = convert_detectron_instances_to_binary_masks(results["instances"].to("cpu").pred_masks.numpy())
plt.imshow(label_mask)

Find contours in label mask combined

In [None]:
contours = measure.find_contours(
                label_mask, level=0.5, fully_connected="low", positive_orientation="low"
)


fig = plt.figure(figsize=(5, 5))
ax = plt.subplot()
ax.invert_yaxis()
for contour in contours:
    ax.plot(contour[:, 1], contour[:, 0], linewidth=2)


Find contours by instance prediction

In [None]:
contours = []
for instance_mask in results["instances"].to("cpu").pred_masks.numpy():
    tmp_contours = measure.find_contours(
                instance_mask, level=0.5, fully_connected="low", positive_orientation="low")
    contours.append(tmp_contours[0])

# # Following code not correct
fig = plt.figure(figsize=(5, 5))
ax = plt.subplot()
ax.invert_yaxis()
for contour in contours:
    ax.plot(contour[:, 1], contour[:, 0], linewidth=2)
