In [None]:
import livecell_tracker.sample_data
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt


In [None]:
dic_dataset_path = Path("../datasets/test_data_STAV-A549/DIC_data")
mask_dataset_path = Path("../datasets/test_data_STAV-A549/mask_data")
dic_dataset, mask_dataset = livecell_tracker.sample_data.tutorial_three_image_sys(dic_dataset_path, mask_dataset_path)


In [None]:
from segment_anything import SamPredictor, sam_model_registry

device = "cuda"
sam = sam_model_registry["vit_h"](checkpoint="./segment-anything/sam_vit_h_4b8939.pth")
sam.to(device=device)


In [None]:
from livecell_tracker.preprocess.utils import (
    normalize_img_to_uint8,
    correct_background_polyfit,
    standard_preprocess,
    enhance_contrast,
)

img = dic_dataset.get_img_by_time(0)
img = standard_preprocess(img, correct_background_polyfit)
# make img into 3 channels, [img, img, img]
img = np.stack([img, img, img], axis=2)

# predictor = SamPredictor(sam)
# # masks, _, _ = predictor.predict()
# predictor.set_image(img)
# masks, _, _ = predictor.predict()


In [None]:
from segment_anything import SamAutomaticMaskGenerator

a whole DIC image prediction

In [None]:
img = dic_dataset.get_img_by_time(0)
img = standard_preprocess(img, correct_background_polyfit)
# make img into 3 channels, [img, img, img]
img = np.stack([img, img, img], axis=2)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(img)


all_seg_mask = np.zeros(img.shape[:2], dtype=np.uint8)
for idx, object in enumerate(masks):
    segmentation = object["segmentation"]
    all_seg_mask[segmentation] = idx + 1

fig, axes = plt.subplots(1, 3, figsize=(10, 5))
axes[0].imshow(enhance_contrast(img[..., 0]))
axes[1].imshow(all_seg_mask)
axes[2].imshow(mask_dataset.get_img_by_time(0))
plt.show()


In [None]:
mitosis_dataset = livecell_tracker.core.datasets.LiveCellImageDataset(
    dir_path="/home/ken67/LiveCellTracker-dev/datasets/wwk_train/A549_icnn_am_train/mitosis",
    ext="tif",
    index_by_time=False,
)


In [None]:
img = mitosis_dataset[0]
img = standard_preprocess(img, correct_background_polyfit)
# make img into 3 channels, [img, img, img]
img = np.stack([img, img, img], axis=2)
masks = mask_generator.generate(img)

all_seg_mask = np.zeros(img.shape[:2], dtype=np.uint8)
for idx, object in enumerate(masks):
    segmentation = object["segmentation"]
    all_seg_mask[segmentation] = idx + 1

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(enhance_contrast(img[..., 0]))
axes[1].imshow(all_seg_mask)
# axes[2].imshow(mask_dataset.get_img_by_time(0))
plt.show()


In [None]:
predictor = SamPredictor(sam)
# masks, _, _ = predictor.predict()
predictor.set_image(img)
masks, _, _ = predictor.predict()
plt.imshow(masks.transpose([1, 2, 0]).astype(np.uint8) * 255)


In [None]:
import os

out_dir = "./tmp/outputs"
os.makedirs(out_dir, exist_ok=True)

for index in range(len(mitosis_dataset)):
    img = mitosis_dataset[index]
    # img = normalize_img_to_uint8(img)
    img = standard_preprocess(img, bg_correct_func=correct_background_polyfit)
    img = np.stack([img, img, img], axis=2)
    predictor.set_image(img)
    masks, _, _ = predictor.predict()
    # show rgb respectively
    fig, axes = plt.subplots(1, 4, figsize=(10, 5))
    axes[0].imshow(img)
    axes[1].imshow(masks[0].astype(np.uint8) * 255)
    axes[2].imshow(masks[1].astype(np.uint8) * 255)
    axes[3].imshow(masks[2].astype(np.uint8) * 255)
    axes[0].set_title("original image")
    axes[1].set_title("SAM R channel")
    axes[2].set_title("SAM G channel")
    axes[3].set_title("SAM B channel")
    plt.savefig(os.path.join(out_dir, f"mitosis_{index}.png"))


In [None]:
masks.shape


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cellpose import models
from cellpose.io import imread
from pathlib import Path
from livecell_tracker.preprocess.utils import normalize_img_to_uint8
from livecell_tracker.segment.cellpose_utils import segment_single_images_by_cellpose, segment_single_image_by_cellpose

# pretrained_model_path = "/home/ken67/LiveCellTracker-dev/notebooks/notebook_results/cellpose/cellpose_A549_cyto2_cellbody/models/cellpose_residual_on_style_on_concatenation_off_cellpose_A549_cyto2_cellbody_2023_04_17_21_49_50.313712"
# model_type='cyto' or 'nuclei' or 'cyto2'
# model = models.Cellpose(gpu=True, model_type="cyto2", pretrained_model=pretrained_model_path)

pretrained_model_path = None
model = models.CellposeModel(pretrained_model=pretrained_model_path, gpu=True)  # , model_type="cyto2")


In [None]:
print(len(mitosis_dataset))


In [None]:
for index in range(len(mitosis_dataset)):
    diameter = 50
    img = mitosis_dataset[index]
    # img = normalize_img_to_uint8(img)
    img = standard_preprocess(img, bg_correct_func=correct_background_polyfit)
    mask = segment_single_image_by_cellpose(img, model, channels=[[0, 0]], diameter=diameter)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(img)
    axes[1].imshow(mask)
    axes[0].set_title("original image")
    axes[1].set_title("cellpose mask")
    plt.show()
