In [2]:
import pytorch_lightning as pl
import cellseg_models_pytorch as csmp
from pathlib import Path
from cellseg_models_pytorch.datamodules import LizardDataModule
import matplotlib.pyplot as plt
from skimage.color import label2rgb
from cellseg_models_pytorch.training.lit import SegmentationExperiment
import warnings
import torch
import numpy as np

In [2]:
# Settings
warnings.filterwarnings('ignore')
DATA_PATH = './data/benchmarks/'
DATA_PATH = Path(DATA_PATH)
LIZARD_PATH = DATA_PATH / 'lizard'

In [3]:
def prepare_lizard_data():
    """Prepare lizard data for training, validation, and testing."""
    fold_split = {"train": 1, "valid": 2, "test": 3}
    lizard_module = LizardDataModule(
        save_dir=LIZARD_PATH,
        fold_split=fold_split,
        inst_transforms=["cellpose"],
        img_transforms=["blur", "hue_sat"],
        patch_size=(320, 320),
        stride=220,
        normalization="minmax",
    )
    lizard_module.prepare_data()
    return lizard_module

In [2]:
def plot_lizard_data_samples(img_dir, mask_dir, indices):
    """Plot Lizard data sample"""
    imgs = sorted(img_dir.glob("*"))
    masks = sorted(mask_dir.glob("*"))
    fig, axes = plt.subplots(3, len(indices)) 
    for idx, indice in enumerate(indices):
        img = csmp.utils.FileHandler.read_img(imgs[indice])
        mask = csmp.utils.FileHandler.read_mat(masks[indice], return_all=True)
        axes[idx, 0].imshow(img)
        axes[idx, 1].imshow(label2rgb(mask["inst_map"], bg_label=0))
        axes[idx, 2].imshow(label2rgb(mask["type_map"], bg_label=0))
    for ax in axes.flatten():
        ax.set_axis_off()
    return fig

def test_plot_lizard_data_samples():
    _ = plot_lizard_data_samples(
        LIZARD_PATH / "train" / "train_im_patches",
        LIZARD_PATH / "train" / "train_mask_patches",
        [0, 50, 300]
    )

# test_plot_lizard_data_samples()

In [5]:
def get_cellpose_model(enc_name, num_classes):
    """Get cellpose model
    enc_name (str): name of encoder. e.g. -> "tf_efficientnetv2_s"
    num_classes (int): number of classes. e.g. -> len(lizard_module.type_classes)
    """
    model = csmp.models.cellpose_base(
        enc_name=enc_name,
        type_classes=num_classes,
    )
    return model

def get_seg_experiment(model):
    experiment = SegmentationExperiment(
        model=model,
        branch_losses={"cellpose": "ssim_mse", "type": "tversky_focal"},
        branch_metrics={"cellpose": [None], "type": ["miou"]},
        optimizer="adamw",
    )
    return experiment

def get_trainer(max_epochs=10):
    trainer = pl.Trainer(
        accelerator="gpu",
        max_epochs=max_epochs,
        move_metrics_to_cpu=True,
    )
    return trainer

def train_model(experiment, trainer, datamodule, ckpt_path=None):
    if ckpt_path is None:
        trainer.fit(experiment, datamodule=datamodule)
        return None
    trainer.fit(experiment, 
                datamodule=datamodule, 
                ckpt_path=ckpt_path)

In [6]:
def lizard_cellpose_pipline(ckpt_path=None):
    # prepare data
    lizard_module = prepare_lizard_data()
    model = get_cellpose_model("tf_efficientnetv2_s", 
                               len(lizard_module.type_classes))
    experiment = get_seg_experiment(model)
    trainer = get_trainer()
    train_model(experiment, 
                trainer, 
                lizard_module, 
                ckpt_path)

# lizard_cellpose_pipline()

In [16]:
def load_lizard_cellpose(ckpt_path):
    model = get_cellpose_model("tf_efficientnetv2_s", 7)
    experiment = get_seg_experiment(model)
    ckpt = torch.load(ckpt_path)
    experiment.load_state_dict(ckpt["state_dict"])
    return experiment

In [28]:
def infer(experiment, input_path):
    inferer = csmp.inference.SlidingWindowInferer(
        model=experiment,
        input_path=input_path,
        out_activations={"cellpose": None, "type": "softmax"},
        out_boundary_weights={"cellpose": True, "type": False},
        padding=16,
        stride=48,
        patch_size=(64, 64),
        instance_postproc="cellpose",
        normalization="minmax",  # same normalization as during training
        batch_size=1,  # Set to 1 since input images have different shapes
        n_images=3,  # Use only the 3 first images of the folder
    )
    inferer.infer()
    return inferer

In [34]:
def plot_contour(inferer, input_path, idx):
    samples = list(inferer.out_masks.keys())
    masks = inferer.out_masks[samples[idx]]
    print(len(np.unique(masks["inst"])))
    print(np.unique(masks["type"]))
    img_path = str(input_path) + f"/{samples[idx]}.png"
    img = csmp.utils.FileHandler.read_img(img_path)
    cont = csmp.utils.draw_thing_contours(masks["inst"], img, masks["type"])
    fig, ax = plt.subplots()
    ax.imshow(cont)
    return fig

In [1]:
def test_cellpose():
    cellpose_experiment = load_lizard_cellpose(
        './lightning_logs/version_0/checkpoints/epoch=9-step=1540.ckpt'
    )
    print("Infering on lizard data")
    target_path = LIZARD_PATH / "train" / "images"
    inferer = infer(cellpose_experiment, target_path)
    _ = plot_contour(inferer, target_path, 0)
    
    print("Infering on TCGA data")
    tcga_path = "./data/benchmarks/QC/rgb"
    tcga_interer = infer(cellpose_experiment, tcga_path)
    _ = plot_contour(tcga_interer, tcga_path, 0)

# test_cellpose()

In [3]:
# import scipy.io

# masks = scipy.io.loadmat('./data/benchmarks/lizard/train/labels/crag_2.mat')
# print(masks.keys())
# masks["inst_map"], masks["id"], masks['class'], masks['bbox']

In [4]:
# masks['centroid']