In [1]:
import os
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.metrics import jaccard_score

# PlantSeg imports
import plantseg
from plantseg.core.zoo import ModelZoo
from plantseg.tasks import import_image_task, unet_prediction_task

# CellPose imports
from cellpose import models, io


INFO: P [MainThread] 2025-01-24 11:56:05,039 plantseg - Logger configured at initialisation. PlantSeg logger name: plantseg


In [2]:


# Helper function for loading images
def load_image(file_path):
    return cv2.imread(file_path, cv2.IMREAD_UNCHANGED)


# Function to perform segmentation using PlantSeg
def segment_with_plantseg(image_path, model_name="lightsheet_2D_unet_root_ds1x", patch=(1, 64, 64), device="cuda"):
    plantseg_image = import_image_task(input_path=Path(image_path), semantic_type="raw", stack_layout="ZYX")
    predicted_images = unet_prediction_task(image=plantseg_image, model_name=model_name, patch=patch, device=device)
    # Return the first predicted image
    return predicted_images[0].get_data()


# Function to perform segmentation using CellPose
def segment_with_cellpose(image_path, model_type="cyto2", flow_threshold=0.4, cellprob_threshold=0):
    image = io.imread(image_path)
    model = models.Cellpose(gpu=True, model_type=model_type)
    masks, flows, styles, diams = model.eval(image, diameter=None, channels=[0, 0], flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)
    return masks


# Function for calculating IoU with ground truth
def calculate_iou(pred, gt):
    pred_flat = pred.flatten()
    gt_flat = gt.flatten()
    return jaccard_score(gt_flat, pred_flat, average="binary")


In [None]:


# Main function for processing datasets
def process_datasets(base_path, model="plantseg", plantseg_model_name="lightsheet_2D_unet_root_ds1x"):
    results = []
    dataset_paths = glob.glob(os.path.join(base_path, "*"))

    for dataset_path in dataset_paths:
        print(f"Processing dataset: {os.path.basename(dataset_path)}")
        images_dir = os.path.join(dataset_path, "0N")  # Change if images are stored in a different folder
        gt_dir = os.path.join(dataset_path, "0N_ST/SEG")  # Using Silver Segmentation Truth

        image_files = glob.glob(os.path.join(images_dir, "*.tif"))
        gt_files = glob.glob(os.path.join(gt_dir, "*.tif"))

        for image_file, gt_file in zip(image_files, gt_files):
            image_name = os.path.basename(image_file)
            print(f"Processing image: {image_name}")

            # Perform segmentation
            if model.lower() == "plantseg":
                segmented_image = segment_with_plantseg(image_file, model_name=plantseg_model_name)
            elif model.lower() == "cellpose":
                segmented_image = segment_with_cellpose(image_file)
            else:
                raise ValueError("Invalid model name. Use 'plantseg' or 'cellpose'.")

            # Load ground truth
            ground_truth = load_image(gt_file)

            # Calculate IoU
            iou = calculate_iou(segmented_image, ground_truth)
            results.append({"dataset": os.path.basename(dataset_path), "image": image_name, "iou": iou})

            # Visualize results
            plt.figure(figsize=(15, 5))
            plt.subplot(1, 3, 1)
            plt.title("Original Image")
            plt.imshow(cv2.imread(image_file, cv2.IMREAD_GRAYSCALE), cmap="gray")
            plt.subplot(1, 3, 2)
            plt.title("Ground Truth")
            plt.imshow(ground_truth, cmap="gray")
            plt.subplot(1, 3, 3)
            plt.title("Segmented Image")
            plt.imshow(segmented_image, cmap="gray")
            plt.show()

    return results


# Define paths and execute
base_path = "/path/to/datasets"  # Update this with the path to your datasets
results = process_datasets(base_path, model="plantseg", plantseg_model_name="lightsheet_2D_unet_root_ds1x")

# Save results
import pandas as pd

results_df = pd.DataFrame(results)
results_df.to_csv("segmentation_results.csv", index=False)
print("Results saved to segmentation_results.csv")