In [1]:
import os
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.metrics import jaccard_score
import pandas as pd
import re
from tqdm import tqdm

# PlantSeg imports
import plantseg
from plantseg.core.zoo import ModelZoo
from plantseg.tasks import import_image_task, unet_prediction_task

# CellPose imports
from cellpose import models, io

import napari


from pathlib import Path
from plantseg.tasks import import_image_task, unet_prediction_task
from plantseg.core import PlantSegImage, ImageProperties, image
import cv2
from plantseg.core.zoo import ModelZoo

mz = ModelZoo(plantseg.PATH_MODEL_ZOO, plantseg.PATH_MODEL_ZOO_CUSTOM)
plantseg_model_names = mz.get_model_names()

from cellpose import models, utils, io
import matplotlib.pyplot as plt
import cellpose.models

cp_models = ["cyto3", "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto", "CPx", "transformer_cp3", "neurips_cellpose_default", "neurips_cellpose_transformer", "neurips_grayscale_cyto2", "CP", "CPx", "TN1", "TN2", "TN3", "LC1", "LC2", "LC3", "LC4"]

ps_models = [
    "generic_confocal_3D_unet",
    "generic_light_sheet_3D_unet",
    "confocal_3D_unet_ovules_ds1x",
    "confocal_3D_unet_ovules_ds2x",
    "confocal_3D_unet_ovules_ds3x",
    "confocal_2D_unet_ovules_ds2x",
    "lightsheet_3D_unet_root_ds1x",
    "lightsheet_3D_unet_root_ds2x",
    "lightsheet_3D_unet_root_ds3x",
    "lightsheet_2D_unet_root_ds1x",
    "lightsheet_3D_unet_root_nuclei_ds1x",
    "lightsheet_2D_unet_root_nuclei_ds1x",
    "confocal_2D_unet_sa_meristem_cells",
    "confocal_3D_unet_sa_meristem_cells",
    "lightsheet_3D_unet_mouse_embryo_cells",
    "confocal_3D_unet_mouse_embryo_nuclei",
    "PlantSeg_3Dnuc_platinum",
]

import warnings

# Suppress PyTorch warnings
warnings.filterwarnings("ignore", message="You are using `torch.load`", category=FutureWarning)

INFO: P [MainThread] 2025-01-28 13:05:58,276 plantseg - Logger configured at initialisation. PlantSeg logger name: plantseg


In [2]:
# viewer = napari.Viewer()

In [31]:
# Define the function to construct the dataset DataFrame
def construct_dataset_dataframe(base_dir):
    columns = ["dataset_name", "sequence_name", "image_path", "mask", "gold_mask"] + cp_models + ps_models

    data = []

    for dataset_name in os.listdir(base_dir):
        dataset_path = os.path.join(base_dir, dataset_name)
        if not os.path.isdir(dataset_path):
            continue

        for sequence_name in os.listdir(dataset_path):
            sequence_path = os.path.join(dataset_path, sequence_name)
            if not os.path.isdir(sequence_path) or "_" in sequence_name:
                continue

            # Paths for masks and gold masks
            err_seg_path = os.path.join(dataset_path, f"{sequence_name}_ERR_SEG")
            gt_seg_path = os.path.join(dataset_path, f"{sequence_name}_GT", "SEG")

            for image_file in os.listdir(sequence_path):
                if image_file.endswith(".tif"):
                    image_path = os.path.join(sequence_path, image_file)

                    # Corresponding mask and gold mask paths
                    mask_path = os.path.join(err_seg_path, f"mask{image_file[1:]}")
                    mask_path = mask_path if os.path.exists(mask_path) else None

                    gold_mask_path = os.path.join(gt_seg_path, f"man_seg{image_file[1:]}")
                    gold_mask_path = gold_mask_path if os.path.exists(gold_mask_path) else None

                    # Append a row with default None values for models
                    data.append([dataset_name, sequence_name, image_path, mask_path, gold_mask_path] + [None] * (len(columns) - 5))

    return pd.DataFrame(data, columns=columns)


# Define the base directory (update this to your actual dataset path)
base_dir = "./datasets"

# Create the DataFrame
dataset_df = construct_dataset_dataframe(base_dir)

# Save or display the DataFrame
print(dataset_df.head())
# Optionally save to a CSV file
dataset_df.to_csv("dataset_summary.csv", index=False)

  dataset_name sequence_name                           image_path  \
0  BF-C2DL-HSC            01  ./datasets\BF-C2DL-HSC\01\t0000.tif   
1  BF-C2DL-HSC            01  ./datasets\BF-C2DL-HSC\01\t0001.tif   
2  BF-C2DL-HSC            01  ./datasets\BF-C2DL-HSC\01\t0002.tif   
3  BF-C2DL-HSC            01  ./datasets\BF-C2DL-HSC\01\t0003.tif   
4  BF-C2DL-HSC            01  ./datasets\BF-C2DL-HSC\01\t0004.tif   

                                             mask gold_mask cyto3 nuclei  \
0  ./datasets\BF-C2DL-HSC\01_ERR_SEG\mask0000.tif      None  None   None   
1  ./datasets\BF-C2DL-HSC\01_ERR_SEG\mask0001.tif      None  None   None   
2  ./datasets\BF-C2DL-HSC\01_ERR_SEG\mask0002.tif      None  None   None   
3  ./datasets\BF-C2DL-HSC\01_ERR_SEG\mask0003.tif      None  None   None   
4  ./datasets\BF-C2DL-HSC\01_ERR_SEG\mask0004.tif      None  None   None   

  cyto2_cp3 tissuenet_cp3 livecell_cp3  ... lightsheet_3D_unet_root_ds2x  \
0      None          None         None  ...         

In [4]:
dataset_df.head(3)

Unnamed: 0,dataset_name,sequence_name,image_path,mask,gold_mask,cyto3,nuclei,cyto2_cp3,tissuenet_cp3,livecell_cp3,...,lightsheet_3D_unet_root_ds2x,lightsheet_3D_unet_root_ds3x,lightsheet_2D_unet_root_ds1x,lightsheet_3D_unet_root_nuclei_ds1x,lightsheet_2D_unet_root_nuclei_ds1x,confocal_2D_unet_sa_meristem_cells,confocal_3D_unet_sa_meristem_cells,lightsheet_3D_unet_mouse_embryo_cells,confocal_3D_unet_mouse_embryo_nuclei,PlantSeg_3Dnuc_platinum
0,BF-C2DL-HSC,1,./datasets\BF-C2DL-HSC\01\t0000.tif,./datasets\BF-C2DL-HSC\01_ERR_SEG\mask0000.tif,,,,,,,...,,,,,,,,,,
1,BF-C2DL-HSC,1,./datasets\BF-C2DL-HSC\01\t0001.tif,./datasets\BF-C2DL-HSC\01_ERR_SEG\mask0001.tif,,,,,,,...,,,,,,,,,,
2,BF-C2DL-HSC,1,./datasets\BF-C2DL-HSC\01\t0002.tif,./datasets\BF-C2DL-HSC\01_ERR_SEG\mask0002.tif,,,,,,,...,,,,,,,,,,


In [30]:
# Helper function for loading images
def load_image(file_path):
    return cv2.imread(file_path, cv2.IMREAD_UNCHANGED)


# # Function to perform segmentation using PlantSeg
# def segment_with_plantseg(image_path, model_name="lightsheet_2D_unet_root_ds1x", patch=(1, 64, 64), device="cuda"):
#     plantseg_image = import_image_task(input_path=Path(image_path), semantic_type="raw", stack_layout="YX")
#     predicted_images = unet_prediction_task(image=plantseg_image, model_name=model_name, patch=patch, device=device)
#     # Return the first predicted image
#     return predicted_images[0].get_data()


# # Function to perform segmentation using CellPose
# def segment_with_cellpose(image_path, model_type="cyto2", flow_threshold=0.4, cellprob_threshold=0):
#     image = io.imread(image_path)
#     model = models.Cellpose(gpu=True, model_type=model_type)
#     masks, flows, styles, diams = model.eval(image, diameter=None, channels=[0, 0], flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)
#     return masks


# Function for calculating IoU with ground truth
def calculate_iou(pred, gt):
    pred_flat = pred.flatten()
    gt_flat = gt.flatten()
    return jaccard_score(gt_flat, pred_flat, average="binary")


def get_images_for_masks(gt_gold_paths, image_paths):
    subset_paths = []
    for p in gt_gold_paths:
        first_gt_gold_file = os.path.basename(p)
        # Extract digits from the filename
        digits = int(re.findall(r"\d+", first_gt_gold_file)[0])
        subset_paths.append(image_paths[digits])
    return subset_paths


# def segment(image_path, model_name):
#     if model_name in cp_models:
#         return segment_with_cellpose(image_path, model_type=model_name)
#     elif model_name in ps_models:
#         return segment_with_plantseg(image_path, model_name=model_name)
#     else:
#         raise ValueError(f"Unknown model name: {model_name}")


# def seg_with_CellPose(image_path, model_type="cyto2", diameter=None):
#     image = io.imread(image_path)
#     model = models.CellposeModel(gpu=True, model_type=model_type)
#     masks, flows, styles = model.eval(image, diameter=diameter, channels=[0, 0], flow_threshold=0.4, cellprob_threshold=0)
#     return masks


# def seg_with_PlantSeg(image_path, model_name="lightsheet_2D_unet_root_ds1x", patch=(1, 64, 64), device="cuda"):
#     plantseg_image = import_image_task(input_path=Path(image_path), semantic_type="raw", stack_layout="YX")
#     predicted_images = unet_prediction_task(image=plantseg_image, model_name=model_name, patch=patch, device=device)
#     return predicted_images[0].get_data()


def ins_to_sem(mask):
    # Convert instance segmentation mask to semantic mask
    mask = mask.astype(np.uint8)
    mask = mask * 255
    mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    return mask


# def calc_jaccard_score(pred, gt):
#     pred_flat = pred.flatten()
#     gt_flat = gt.flatten()
#     return jaccard_score(gt_flat, pred_flat, average="micro")


# def calc_jaccard_score_per_object(pred, gt):
#     # Unique labels for reference objects (ground truth) and predicted objects
#     gt_labels = np.unique(gt)
#     pred_labels = np.unique(pred)

#     # Remove background label (assumed to be 0)
#     gt_labels = gt_labels[gt_labels > 0]
#     pred_labels = pred_labels[pred_labels > 0]

#     jaccard_scores = []

#     for gt_label in gt_labels:
#         # Extract pixels for the current reference object
#         gt_object = gt == gt_label

#         # Find the matching predicted object with maximum overlap
#         best_iou = 0
#         for pred_label in pred_labels:
#             pred_object = pred == pred_label
#             intersection = np.logical_and(gt_object, pred_object).sum()
#             union = np.logical_or(gt_object, pred_object).sum()

#             if union > 0:
#                 iou = intersection / union
#                 best_iou = max(best_iou, iou)

#         # Append the best IoU for this ground truth object
#         jaccard_scores.append(best_iou)

#     # Mean IoU across all ground truth objects
#     return np.mean(jaccard_scores) if jaccard_scores else 0

In [6]:
# for index, row in dataset_df.iterrows():

#     for model in cp_models + ps_models:
#         image_path = row["image_path"][0]  # Assuming you want to use the first image in the list
#         gt_path = row["gt_path"][0]  # Assuming you want to use the first ground truth in the list
#         gt_gold_path = row["gt_gold_path"][0]  # Assuming you want to use the first ground truth gold in the list

#         # Perform segmentation
#         segmented_image = segment(image_path, model_name=model)

#         # Convert instance segmentation to semantic segmentation
#         semantic_mask = ins_to_sem(segmented_image)

#         # Load ground truth
#         gt_image = load_image(gt_gold_path)
#         gt_semantic = ins_to_sem(gt_image)

#         # Calculate IoU score
#         iou_score = calc_jaccard_score_per_object(semantic_mask, gt_semantic)
#         print(f"Model: {model}, IoU Score: {iou_score}")

In [28]:
# Function to perform segmentation
def segment(image_path, model_name, dimension="2D"):
    twoD_ps_models = [model for model in ps_models if "2D" in model]
    threeD_ps_models = [ps for ps in ps_models if "3D" in ps]
    if model_name in cp_models:
        model = models.CellposeModel(gpu=True, model_type=model_name)
        image = io.imread(image_path)
        masks, _, _ = model.eval(image, diameter=None, channels=[0, 0], flow_threshold=0.4, cellprob_threshold=0)
        return masks
    elif model_name in twoD_ps_models:
        plantseg_image = import_image_task(input_path=Path(image_path), semantic_type="raw", stack_layout="YX")
        predicted_images = unet_prediction_task(image=plantseg_image, model_name=model_name, model_id=None, patch=(1, 64, 64), device="cuda")
        return predicted_images[0].get_data()
    elif model_name in threeD_ps_models:
        plantseg_image = import_image_task(input_path=Path(image_path), semantic_type="raw", stack_layout="ZYX")
        predicted_images = unet_prediction_task(image=plantseg_image, model_name=model_name, model_id=None, patch=(4, 64, 64), device="cuda")
        return predicted_images[0].get_data()
    else:
        raise ValueError(f"Unknown model name: {model_name}")


# # Function to calculate Jaccard score
# def calculate_jaccard_score(predicted_mask, ground_truth_mask):
#     # Ensure both masks are the same data type
#     pred_flat = predicted_mask.flatten().astype(np.int32)
#     gt_flat = ground_truth_mask.flatten().astype(np.int32)
#     return jaccard_score(gt_flat, pred_flat, average="micro")


def calculate_jaccard_score(predicted_mask, ground_truth_mask):
    """
    Calculate the Jaccard Index (IoU) for 2D or 3D masks.

    Parameters:
    - predicted_mask: np.array, predicted mask of shape (X, Y) or (Z, X, Y).
    - ground_truth_mask: np.array, ground truth mask of shape (X, Y) or (Z, X, Y).

    Returns:
    - jaccard_score: float, IoU for 2D data or averaged IoU for 3D data.
    """
    # Ensure both masks have the same shape
    assert predicted_mask.shape == ground_truth_mask.shape, "Masks must have the same shape."

    # Flatten and calculate IoU for 2D data
    if predicted_mask.ndim == 2:
        pred_flat = predicted_mask.flatten().astype(np.int32)
        gt_flat = ground_truth_mask.flatten().astype(np.int32)
        return jaccard_score(gt_flat, pred_flat, average="micro")

    # Process 3D data (compute IoU for each slice, then average)
    elif predicted_mask.ndim == 3:
        jaccard_scores = []
        for z in range(predicted_mask.shape[0]):  # Iterate through slices
            pred_flat = predicted_mask[z].flatten().astype(np.int32)
            gt_flat = ground_truth_mask[z].flatten().astype(np.int32)
            jaccard_scores.append(jaccard_score(gt_flat, pred_flat, average="micro"))
        return np.mean(jaccard_scores)  # Average over all slices

    else:
        raise ValueError("Input masks must be 2D or 3D arrays.")


# Main processing loop
def process_dataset(df, models):
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing dataset", position=0):
        try:
            image_path = row["image_path"]
            mask_path = row["mask"] if pd.notna(row["mask"]) else row["gold_mask"]

            assert image_path is not None
            assert mask_path is not None

            if not mask_path or not os.path.exists(mask_path):
                continue

            ground_truth_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)

            for model_name in tqdm(models, desc=f"Row {idx}", leave=False, position=1):
                if pd.notna(row.get(model_name)):
                    continue  # Skip if Jaccard is already calculated
                try:
                    predicted_mask = segment(image_path, model_name)
                    iou_score = calculate_jaccard_score(predicted_mask, ground_truth_mask)
                    df.at[idx, model_name] = iou_score
                except Exception as e:
                    print(f"Error processing model {model_name} for image {image_path}: {e}")
        except Exception as e:
            # print(f"Error processing row {idx} for model {model_name}: {e}")
            print(image_path)
            print(mask_path)

In [8]:
# dataset_df = pd.read_csv("dataset_summary.csv")

In [9]:
sampled_df= pd.read_csv("sample_summary.csv")

In [10]:
filter_df = dataset_df.dropna(subset=["mask", "gold_mask"], how="all")
# sampled_df = filter_df.groupby(["dataset_name", "sequence_name"], group_keys=False).apply(lambda x: x.sample(n=min(5, len(x)), random_state=42)).reset_index(drop=True)

In [11]:
the_models = [f for f in ps_models if "2D" in f]
the_models

['confocal_2D_unet_ovules_ds2x',
 'lightsheet_2D_unet_root_ds1x',
 'lightsheet_2D_unet_root_nuclei_ds1x',
 'confocal_2D_unet_sa_meristem_cells']

In [12]:
# process_dataset(sampled_df, the_models)

In [13]:
sampled_df

Unnamed: 0.1,Unnamed: 0,dataset_name,sequence_name,image_path,mask,gold_mask,cyto3,nuclei,cyto2_cp3,tissuenet_cp3,...,lightsheet_3D_unet_root_ds2x,lightsheet_3D_unet_root_ds3x,lightsheet_2D_unet_root_ds1x,lightsheet_3D_unet_root_nuclei_ds1x,lightsheet_2D_unet_root_nuclei_ds1x,confocal_2D_unet_sa_meristem_cells,confocal_3D_unet_sa_meristem_cells,lightsheet_3D_unet_mouse_embryo_cells,confocal_3D_unet_mouse_embryo_nuclei,PlantSeg_3Dnuc_platinum
0,0,BF-C2DL-HSC,1,./datasets/BF-C2DL-HSC/01/t0930.tif,./datasets/BF-C2DL-HSC/01_ERR_SEG/mask0930.tif,./datasets/BF-C2DL-HSC/01_GT/SEG/man_seg0930.tif,0.985805,0.994375,0.912181,0.997136,...,,,0.997134,,0.997134,0.997134,,,,
1,1,BF-C2DL-HSC,1,./datasets/BF-C2DL-HSC/01/t1378.tif,./datasets/BF-C2DL-HSC/01_ERR_SEG/mask1378.tif,,0.985861,0.992334,0.925176,0.994843,...,,,0.994949,,0.994949,0.994949,,,,
2,2,BF-C2DL-HSC,1,./datasets/BF-C2DL-HSC/01/t0589.tif,./datasets/BF-C2DL-HSC/01_ERR_SEG/mask0589.tif,,0.987407,0.995452,0.931505,0.998057,...,,,0.998333,,0.998333,0.998333,,,,
3,3,BF-C2DL-HSC,1,./datasets/BF-C2DL-HSC/01/t0097.tif,./datasets/BF-C2DL-HSC/01_ERR_SEG/mask0097.tif,,0.987503,0.996772,0.944727,0.998656,...,,,0.999383,,0.999383,0.999383,,,,
4,4,BF-C2DL-HSC,1,./datasets/BF-C2DL-HSC/01/t0367.tif,./datasets/BF-C2DL-HSC/01_ERR_SEG/mask0367.tif,,0.985604,0.994917,0.945221,0.998284,...,,,0.998441,,0.998441,0.998441,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,95,PhC-C2DL-PSC,2,./datasets/PhC-C2DL-PSC/02/t125.tif,./datasets/PhC-C2DL-PSC/02_ERR_SEG/mask125.tif,,0.874047,0.902686,0.871489,0.910589,...,,,0.910584,,0.910589,0.910589,,,,
96,96,PhC-C2DL-PSC,2,./datasets/PhC-C2DL-PSC/02/t109.tif,./datasets/PhC-C2DL-PSC/02_ERR_SEG/mask109.tif,,0.879738,0.910641,0.889712,0.920151,...,,,0.920151,,0.920151,0.920147,,,,
97,97,PhC-C2DL-PSC,2,./datasets/PhC-C2DL-PSC/02/t148.tif,./datasets/PhC-C2DL-PSC/02_ERR_SEG/mask148.tif,,0.845328,0.892968,0.860594,0.896322,...,,,0.896322,,0.896322,0.896322,,,,
98,98,PhC-C2DL-PSC,2,./datasets/PhC-C2DL-PSC/02/t134.tif,./datasets/PhC-C2DL-PSC/02_ERR_SEG/mask134.tif,,0.861710,0.895638,0.871214,0.902861,...,,,0.905182,,0.905182,0.905182,,,,


In [90]:
# sampled_df.to_csv("sample3D_summary.csv", index=True)

In [8]:
# dataset_df.to_csv("dataset_summary.csv", index=False)

In [None]:
image_path = sampled_df["image_path"][0]
image_path

In [None]:
plantseg_image = import_image_task(input_path=Path(image_path), semantic_type="raw", stack_layout="ZYX")

In [None]:
predicted_images = unet_prediction_task(image=plantseg_image, model_name=the_models[0], model_id=None, patch=(1, 64, 64), device="cuda")
# Return the first predicted image
plt.imshow(predicted_images[0].get_data())

In [None]:
# Columns representing the models
model_columns = cp_models + ps_models

# Group by dataset and sequence, then calculate the mean IoU for each model
average_iou = (
    sampled_df.groupby(["dataset_name", "sequence_name"])[model_columns]
    .mean()
    .reset_index()
)

# Calculate overall average IoU across datasets
overall_average_iou = average_iou[model_columns].mean().to_dict()

average_iou, overall_average_iou

In [None]:
overall_average_iou

# 3D

In [14]:
filter_df = dataset_df.dropna(subset=["mask", "gold_mask"], how="all")

In [15]:
filtered_3d_df = filter_df[filter_df['dataset_name'].str.contains("3D")]
filtered_3d_df.head(1)

Unnamed: 0,dataset_name,sequence_name,image_path,mask,gold_mask,cyto3,nuclei,cyto2_cp3,tissuenet_cp3,livecell_cp3,...,lightsheet_3D_unet_root_ds2x,lightsheet_3D_unet_root_ds3x,lightsheet_2D_unet_root_ds1x,lightsheet_3D_unet_root_nuclei_ds1x,lightsheet_2D_unet_root_nuclei_ds1x,confocal_2D_unet_sa_meristem_cells,confocal_3D_unet_sa_meristem_cells,lightsheet_3D_unet_mouse_embryo_cells,confocal_3D_unet_mouse_embryo_nuclei,PlantSeg_3Dnuc_platinum
6604,Fluo-C3DH-A549,1,./datasets\Fluo-C3DH-A549\01\t000.tif,./datasets\Fluo-C3DH-A549\01_ERR_SEG\mask000.tif,./datasets\Fluo-C3DH-A549\01_GT\SEG\man_seg000...,,,,,,...,,,,,,,,,,


In [16]:

sampled_3D_df = filtered_3d_df.groupby(["dataset_name", "sequence_name"], group_keys=False).apply(lambda x: x.sample(n=min(5, len(x)), random_state=42)).reset_index(drop=True)

  sampled_3D_df = filtered_3d_df.groupby(["dataset_name", "sequence_name"], group_keys=False).apply(lambda x: x.sample(n=min(5, len(x)), random_state=42)).reset_index(drop=True)


In [17]:
# Check if the values under dataset_name in sampled_3D_df and sampled_df are the same
same_values = sampled_3D_df['dataset_name'].equals(sampled_df['dataset_name'])
print(same_values)

False


In [33]:
image_path = sampled_3D_df["image_path"][0]
image_path
plantseg_image = import_image_task(input_path=Path(image_path), semantic_type="raw", stack_layout="ZYX")
predicted_images = unet_prediction_task(image=plantseg_image, model_name=ps_models[0], model_id=None, patch=(4, 64, 64), device="cuda")
# Return the first predicted image
# plt.imshow(predicted_images[0].get_data())

INFO: P [MainThread] 2025-01-28 13:27:22,966 plantseg.functionals.prediction.prediction - Zoo prediction: Running model from PlantSeg official zoo.
INFO: P [MainThread] 2025-01-28 13:27:23,027 plantseg.functionals.prediction.prediction - Computing theoretical minimum halo from model.
INFO: P [MainThread] 2025-01-28 13:27:23,028 plantseg.functionals.prediction.prediction - For raw in shape (29, 300, 350): set patch shape (4, 64, 64), set halo shape (44, 44, 44)




INFO: P [MainThread] 2025-01-28 13:27:23,496 plantseg.functionals.prediction.utils.array_predictor - Using batch size of 1 for prediction


 82%|████████▏ | 344/420 [01:09<00:15,  4.97it/s]


KeyboardInterrupt: 

In [21]:
pred_mask = predicted_images[0].get_data()


In [19]:
viewer = napari.Viewer()


In [20]:
viewer.add_image(predicted_images[0].get_data())

<Image layer 'Image' at 0x1c317090650>

In [22]:
mask_path = sampled_3D_df["mask"][0]
import tifffile
mask = tifffile.imread(mask_path)
# viewer.add_image(mask)


(dtype('float32'), dtype('uint16'))

In [None]:
# import numpy as np


# def calculate_jaccard_index(predicted_mask, ground_truth_mask):
#     """
#     Calculate the Jaccard Index (IoU) for 3D masks.

#     Parameters:
#     - predicted_mask: np.array, predicted mask of shape (Z, X, Y). Can be binary or multi-class.
#     - ground_truth_mask: np.array, ground truth mask of shape (Z, X, Y). Can be binary or multi-class.

#     Returns:
#     - jaccard_index: float or dict, IoU score(s) for binary or per class if multi-class.
#     """
#     # Ensure masks are the same shape
#     assert predicted_mask.shape == ground_truth_mask.shape, "Masks must have the same shape."

#     # Convert masks to a consistent binary format (np.bool_)
#     predicted_mask = np.asarray(predicted_mask, dtype=bool)
#     ground_truth_mask = np.asarray(ground_truth_mask, dtype=bool)

#     if np.array_equal(np.unique(predicted_mask), [False, True]) and np.array_equal(np.unique(ground_truth_mask), [False, True]):
#         # Binary case
#         intersection = np.logical_and(predicted_mask, ground_truth_mask)
#         union = np.logical_or(predicted_mask, ground_truth_mask)
#         return np.sum(intersection) / np.sum(union) if np.sum(union) > 0 else 1.0  # Handle zero union case

#     else:
#         # Multi-class case
#         classes = np.unique(ground_truth_mask)
#         iou_scores = {}
#         for cls in classes:
#             pred_binary = predicted_mask == cls
#             gt_binary = ground_truth_mask == cls
#             intersection = np.logical_and(pred_binary, gt_binary)
#             union = np.logical_or(pred_binary, gt_binary)
#             iou_scores[cls] = np.sum(intersection) / np.sum(union) if np.sum(union) > 0 else 1.0
#         return iou_scores



In [32]:

# Example usage

jaccard_index = calculate_jaccard_score(pred_mask, mask)
print("Jaccard Index:", jaccard_index)

Jaccard Index: 0.9537619265721315


In [None]:
sampled_3D_df.head(1)

In [28]:
ps_3d_models = [ps for ps in ps_models if "3D" in ps]

In [None]:
process_dataset(sampled_3D_df, [ps_3d_models[0]])