# Leaf segmentation inference

## Importing

In [None]:
!pip install loguru
!pip install gdown
!pip install ftfy
!pip install ultralytics scikit-learn opencv-python
!pip install filterpy
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

In [None]:
import os
import sys
import cv2
import numpy as np
from tqdm import tqdm
import imageio.v3 as iio
from typing import List
import torch, detectron2
from datetime import datetime
from detectron2.data.datasets import register_coco_instances
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.engine import DefaultTrainer
import matplotlib.pyplot as plt
import pandas as pd
from typing import Dict, Any

### Inference parameters

In [None]:
# inference parameters
PARAMS = {
    "device": "cuda:1",  # which device to use for inference
    "out_dir": "../results_stage4/",  # directory to save results
    "save_leafs": True,  # save the separate leafs images. Disable to make processing slightly faster
    "method": "final",  # which method to use. Options: "yolo", "final" (yolo+sam)
}

# Data parameters 

# option 1.  
DATA_PARAMS = {
    "images_root": "",  # path to images
    "masks_root": None,  # path to labels. Can be None, but without labels no metrics can be calculated
}

# option 2. 
DATA_PARAMS = {
    "images_root": "",
    "masks_root": "",
    "ds.csv": "../data_meta/ds.csv",
    "nn_roles": ["test"],
    "share": 1.0,  # what share of sequences to process, use 1. to process all sequences
}

## 2. Util part: imports, functions

### imports

In [None]:
%cd /home/rsaric/Desktop/leaf_cv/notebooks
sys.path.insert(0, "../src")

In [None]:
import os
import random
import traceback

import imageio
import matplotlib.pyplot as plt
import pandas as pd
import tqdm as orig_tqdm
from tqdm.auto import tqdm

from dataset import load_or_build_dataset
from metrics import (
    AbstractMetric,
    FrameBasedIOU,
    MultiObjectTrackingAccuracy,
    MultiObjectTrackingPrecision,
)
from saveload import EmptySaver, Saver, read_image, read_masks
from masks import draw_joined_masks_on_image, mask_joined_to_masks_dict
from model.tracking import ensure_same_image_sizes, change_mask_resolution


class NoTqdm(orig_tqdm.tqdm):
    def __init__(self, *args, **kwargs):
        kwargs["disable"] = True
        super().__init__(*args, **kwargs)

orig_tqdm.tqdm = NoTqdm
from model.yolo_models import (  # noqa E402,E501 # pylint: disable=wrong-import-position
    YoloTrackerModel,
    AbstractModel,
)
from model.final_model import (  # noqa E402,E501 # pylint: disable=wrong-import-position
    VideoSAMFinal,
)

In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
!wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg
image = cv2.imread("./input.jpg")
predictor = DefaultPredictor(cfg)
outputs = predictor(image)
print(outputs["instances"].pred_classes)
print(outputs["instances"].pred_boxes)

In [None]:
def subselect_sequences(
    ds: pd.DataFrame,
    nn_roles: list[str],
    share: float = 1.0,
    seed: int = 1,
):
    """Select from dataset only set of sequences.

    Args:
        ds (pd.DataFrame): dataset
        nn_roles (list[str]): roles to select
        share (float): share of sequences to select
        seed (int): random seed
    """
    not_in_roles = set(nn_roles) - set(ds["nn_role"].unique())
    assert not not_in_roles, f"Roles {not_in_roles} are not in the dataset"

    ds = ds[ds["nn_role"].isin(nn_roles)].copy()
    sequences_ids = ds[["plant", "rep"]].drop_duplicates()
    selected_sequences = sequences_ids.sample(frac=share, random_state=seed)
    if share < 1:
        print(selected_sequences.copy().sort_values(by=["plant"]))
        # print("selected_ids = ", selected_sequences.copy().set_index(["plant", "rep"]).index)

    ds = ds[
        ds.set_index(["plant", "rep"]).index.isin(
            selected_sequences.set_index(["plant", "rep"]).index
        )
    ]
    return ds

def _model_inference_subdataset(
    ds: pd.DataFrame,
    model: AbstractModel,
    metrics: list[AbstractMetric],
    saver: Saver,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    for metric in metrics:
        metric.reset()

    nn_roles = ds["nn_role"].unique()
    assert len(nn_roles) == 1, f"Multiple nn_roles in the dataset: {nn_roles}"
    description = f"Inference for {nn_roles[0]} sequences ({len(ds)} frames)"

    for group, rep in tqdm(ds.groupby(["plant", "rep"]), desc=description):
        try:
            rep = rep.sort_values("image_num")
            images = []
            masks_labels = []
            for _, row in rep.iterrows():
                images.append(read_image(row))

                if len(metrics) > 0:
                    # no parsing masks if metrics are disabled.
                    masks_labels.append(read_masks(row))

            # fix sizes:
            images = ensure_same_image_sizes(images, f"{group[0]}/{group[1]}")
            for img, masks in zip(images, masks_labels):
                for v in masks.values():
                    if v["segmentation"].shape[:2] != img.shape[:2]:
                        v["segmentation"] = change_mask_resolution(
                            v["segmentation"], img.shape[:2]
                        )

            masks_predicted = model.predict_masks(images)
            for i, (_, row) in enumerate(rep.iterrows()):
                saver.save_masks(images[i], row, masks_predicted[i])

            for metric in metrics:
                # TODO: make it frame-based to avoid storing a lot in memory
                metric.add_sequence(masks_labels, masks_predicted, name="/".join(group))

            saver.finalize_sequence(row)

        except KeyboardInterrupt:
            print("Inference interrupted")
            raise

        except Exception:
            traceback.print_exc()
            print(f"Error during inference for sequence {group}")

    seq_metrics = []
    total_metrics = []
    for metric in metrics:
        seq_metrics.append(metric.get_aggregate_metrics(per_seq=True))
        total_metrics.append(metric.get_aggregate_metrics(per_seq=False))
        metric.reset()

    if len(seq_metrics) > 0:
        df_seq_metrics = pd.concat(seq_metrics, axis=1)
        df_total_metrics = pd.concat(total_metrics, axis=1)
    else:
        df_seq_metrics = pd.DataFrame()
        df_total_metrics = pd.DataFrame()
    return df_seq_metrics, df_total_metrics

def model_inference(
    ds: pd.DataFrame,
    model: AbstractModel,
    metrics: list[AbstractMetric],
    saver: Saver,
):
    """Inference function.

    Args:
        ds (pd.DataFrame): dataset
        out_dir (str): output directory
        model (AbstractModel): model
        metrics (list[AbstractMetric]): list of metrics to calculate
        save_results (bool): save results to the disk

    Returns:
        pd.DataFrame: dataset with results
    """
    saver.save_configs(model.config)
    seq_results = []
    total_results = []
    for nn_role in ["train", "val", "test"]:
        nn_role_ds = ds[ds["nn_role"] == nn_role]
        if len(nn_role_ds) == 0:
            continue

        seq_results_df, total_results_df = _model_inference_subdataset(
            nn_role_ds, model, metrics, saver
        )
        seq_results_df["nn_role"] = nn_role
        seq_results.append(seq_results_df)
        total_results_df["nn_role"] = nn_role
        total_results.append(total_results_df)

    seq_dfs_joined = pd.concat(seq_results, axis=0)
    total_dfs_joined = pd.concat(total_results, axis=0)
    saver.save_metrics(seq_dfs_joined, per_seq=True)
    saver.save_metrics(total_dfs_joined, per_seq=False)
    return total_dfs_joined
    
def show_image_masks(image, masks, descriptions):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(descriptions[0])

    if masks is not None:
        plt.subplot(1, 2, 2)
        plt.imshow(draw_joined_masks_on_image(image, masks, not_on_image=False))
        plt.title(descriptions[1])

    plt.show()

## 3. Prepare dataset

In [None]:
# apply default values if not specified:
DATA_PARAMS["ds.csv"] = DATA_PARAMS.get("ds.csv", None)
DATA_PARAMS["nn_roles"] = DATA_PARAMS.get(
    "nn_roles", ["test"]
)  # "test" is used in case of not specified in ds.csv
DATA_PARAMS["share"] = DATA_PARAMS.get("share", 1.0)
PARAMS["save_results"] = PARAMS.get(
    "save_results", True
)   
PARAMS["seed"] = PARAMS.get("seed", 1)  # fix seed for image

ds = load_or_build_dataset(
    DATA_PARAMS["ds.csv"], DATA_PARAMS["images_root"], DATA_PARAMS["masks_root"]
)
subselected_ds = subselect_sequences(
    ds, DATA_PARAMS["nn_roles"], DATA_PARAMS["share"], PARAMS["seed"]
)

def show_ds(name: str, ds: pd.DataFrame) -> None:
    """Show dataset statistics."""
    print(f"{name} has")
    for nn_role in ["test"]:
        role_ds = ds[ds["nn_role"] == nn_role]
        print(
            f"    {nn_role}  images: {len(role_ds)}, sequences: {len(role_ds.groupby(['plant', 'rep']))}"
        )
    print()


show_ds("Dataset", ds)
print(
    f"Subselected {DATA_PARAMS['share']} of {DATA_PARAMS['nn_roles']} sequences for inference"
)
print()
show_ds("Subselected dataset", subselected_ds)
random.seed(PARAMS["seed"])

In [None]:
import imageio.v3 as iio
from ultralytics import YOLO
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import jaccard_score
from boxmot.appearance.reid_auto_backend import ReidAutoBackend	
from typing import Any, Tuple
from collections import namedtuple

# Load YOLOv8 model
model_path = ""
model = YOLO(model_path)

In [None]:
def load_yolo_model(model_path: str) -> YOLO:
    """
    Load a YOLOv8 segmentation model from the specified path.

    Args:
        model_path (str): Path to the YOLOv8 model file.

    Returns:
        YOLO: Loaded YOLOv8 model.
    """
    try:
        model = YOLO(model_path)  # Load the YOLOv8 segmentation model
        print(f"Model loaded successfully from: {model_path}")
        return model
    except Exception as e:
        print(f"Error loading the model from {model_path}: {e}")
        raise
        
# Define model paths
model_paths = {
    "yolo_v8": "",
    "yolo_v11": ""
}

# Load models and store them in a dictionary
models = {name: load_yolo_model(path) for name, path in model_paths.items()}

In [None]:
sys.path.insert(0, "../src")
SAM_PATH = "../thirdparty/segment-anything-2/"
sys.path.insert(1, SAM_PATH)
os.environ["HYDRA_FULL_ERROR"] = "1"

In [None]:
import pathlib
import tempfile
import torch
from PIL import Image
from ultralytics import YOLO
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from pathlib import Path
from trackers.strongsort.strongsort import StrongSort
from trackers.deepocsort.deepocsort import DeepOcSort
from trackers.bytetrack.bytetrack import ByteTrack

class SAM2Inference:
    def __init__(self, config: dict, device: str):
        """
        Initialize the SAM2 model with the provided configuration.

        Args:
            config (dict): Configuration for SAM2 and YOLO.
            device (str): Device to run inference (e.g., 'cuda' or 'cpu').
        """
        self.config = config
        self.device = torch.device(device)
        self.yolo_model = YOLO(config["yolo_model"]).to(self.device)
        self.sam2_model = build_sam2(
            config["sam2_cfg"], config["sam2_checkpoint"], device=self.device
        )
        self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)

    def infer_masks(self, image: np.ndarray) -> dict[int, dict]:
        """
        Perform inference using YOLO for object detection and SAM2 for segmentation.

        Args:
            image (np.ndarray): Input image (H, W, C).

        Returns:
            dict[int, dict]: Predicted masks in SAM2 format.
        """
        # Run YOLO detection
        yolo_res = self.yolo_model(image, verbose=False)
        over_threshold = yolo_res[0].boxes.xywh[
            yolo_res[0].boxes.conf > self.config["yolo_threshold"]
        ]

        if over_threshold.shape[0] == 0:
            return {}  # No detections

        points = over_threshold.cpu().detach().numpy()[:, None, :2]
        labels = np.ones([len(points), 1])

        # Run SAM2 prediction
        self.sam2_predictor.set_image(image)
        masks, scores, _ = self.sam2_predictor.predict(
            point_coords=points,
            point_labels=labels,
            box=None,
            multimask_output=False,
        )
        return {
            idx: {"segmentation": masks[idx, 0, :, :] > 0.5}
            for idx in range(masks.shape[0])
        }

def sam2_model(config: dict, device: str):
    """
    Infer masks on a list of images using SAM2.

    Args:
        config (dict): Configuration for SAM2 and YOLO.
        device (str): Device to run inference (e.g., 'cuda' or 'cpu').

    Returns:
        list[dict[int, dict]]: List of predicted masks for each image.
    """
    sam2_inference = SAM2Inference(config, device)
    return sam2_inference

In [None]:
# Define the list of trackers and their parameters
trackers_config = [
    {
        "name": "ByteTrack",
        "params": {
            "track_thresh": 0.1,
            "match_thresh": 0.8,
            "track_buffer": 30,
            "frame_rate": 30
        }
    },
    {
        "name": "DeepOcSort",
        "params": {
            "reid_weights": Path("weights/lmbn_n_cuhk03_d.pt"),
            "device": "cuda:1",
            "half": False,
            "per_class": False,
            "det_thresh": 0.2,
            "max_age": 50,
            "min_hits": 1,
            "iou_threshold": 0.2,
            "delta_t": 3,
            "asso_func": "iou",
            "inertia": 0.4,
            "w_association_emb": 0.5,
            "alpha_fixed_emb": 0.99
        }
    },
    {
        "name": "StrongSort",
        "params": {
            "reid_weights": Path("weights/lmbn_n_cuhk03_d.pt"),
            "device": "cuda:1",
            "half": False,
            "max_cos_dist": 0.3,
            "max_iou_dist": 0.9,
            "max_age": 50,
            "n_init": 1,
            "nn_budget": 70,
            "mc_lambda": 0.9,
            "ema_alpha": 0.95
        }
    }
]

In [None]:
def initialize_tracker(tracker_name, **kwargs):
    """
    Initialize a tracker based on the specified name and parameters.

    Args:
        tracker_name (str): Name of the tracker. Options: 'ByteTrack', 'DeepOcSort', 'StrongSort'.
        **kwargs: Tracker-specific parameters for customization.

    Returns:
        Tracker instance: The initialized tracker.

    Raises:
        ValueError: If the tracker name is not recognized.
    """
    if tracker_name == "ByteTrack":
        return ByteTrack(**kwargs)
    elif tracker_name == "DeepOcSort":
        return DeepOcSort(**kwargs)
    elif tracker_name == "StrongSort":
        return StrongSort(**kwargs)
    else:
        raise ValueError(f"Unknown tracker name: {tracker_name}")
        
def process_tracking_seg_masks(
    image_files: List[str], model, tracker, generate_unique_color, mode: str = "yolo"
) -> (np.ndarray, np.ndarray):
    """
    Process a list of image files, generate tracking segmentation masks, stack them, and stack the resized images.

    Args:
        image_files (List[str]): List of paths to image files.
        model: Model instance for object detection and segmentation (YOLO or SAM).
        tracker: Tracker instance for tracking objects across frames.
        generate_unique_color: Function to generate unique colors for track IDs.
        mode (str): Mode of operation, either "yolo" or "sam".

    Returns:
        tuple: 
            - np.ndarray: Stacked tracking segmentation masks with uniform dimensions.
            - np.ndarray: Stacked resized images with uniform dimensions.
    """
    tracking_seg_masks_list = []
    resized_images_list = []
    track_colors: Dict[int, tuple] = {}  # To store track_id and corresponding colors
    frame_count = 0  # Counter for frames

    for image_file in image_files:
        # Read the image
        image = iio.imread(image_file)
        # Resize the image to 640x640
        image = cv2.resize(image, (640, 640))
        resized_image = cv2.resize(image, (533, 517))
        resized_images_list.append(resized_image)

        if mode == "yolo":
            # YOLO mode
            model.conf = 0.15
            results = model(image, verbose=False)
            segmentation_mask = results[0].masks.data.cpu().numpy().astype(np.int32)  # Shape: (n, 640, 640)
        elif mode == "sam":
            # SAM mode
            mask_dict = model.infer_masks(image)  # Assuming SAM model has `infer_masks`
            segmentation_mask = np.stack(
                [mask_data["segmentation"] for mask_data in mask_dict.values()], axis=0
            ).astype(np.int32)  # Shape: (n, 640, 640)

            # Generate bounding boxes for SAM2 masks
            dets = []
            for mask_idx, mask in enumerate(segmentation_mask):
                if mask.sum() > 0:  # Ensure the mask is not empty
                    y_indices, x_indices = np.where(mask > 0)
                    x1, y1, x2, y2 = x_indices.min(), y_indices.min(), x_indices.max(), y_indices.max()
                    dets.append([x1, y1, x2, y2, 0.9, 1])  # Add bounding box with class=1 and confidence=0.9
                print("dets", dets)
        elif mode == "detectron2":
            # DETECTRON2 mode
            outputs = model(image)  # Get predictions from Detectron2
        
            instances = outputs["instances"].to("cpu")
            pred_masks = instances.pred_masks.numpy().astype(np.int32)  # Convert masks
            pred_boxes = instances.pred_boxes.tensor.numpy().astype(np.int32)  # Convert boxes
            pred_scores = instances.scores.numpy()  # Get confidence scores
            pred_classes = instances.pred_classes.numpy()  # Get class labels
            # print("Image shape:", image.shape)
            # print("pred_boxes shape:", pred_boxes.shape)

            # Convert to format similar to YOLO and SAM
            segmentation_mask = pred_masks  # Shape: (n, height, width)
            # print(segmentation_mask.shape)
            # Generate bounding boxes
            dets = []
            # print("pred_masks", pred_masks, "pred_boxes", pred_boxes)
            if pred_masks.sum() > 0:
                for i in range(len(pred_boxes)):
                    x1, y1, x2, y2 = pred_boxes[i]
                    score = pred_scores[i]
                    class_id = pred_classes[i]
                    dets.append([x1, y1, x2, y2, 0.9, 1])
                # print("dets", dets)
        
        else:
            raise ValueError("Unsupported mode. Use 'yolo', 'sam', or 'detectron2'.")
        # Create tracking_seg with the same shape as segmentation_mask, initialized to 0
        tracking_seg = np.zeros_like(segmentation_mask, dtype=np.int32)

        if mode == "yolo":
            # Prepare YOLO detections
            dets = []
            if len(results[0].boxes) > 0:  # Ensure there are detections
                for det in results[0].boxes:
                    x1, y1, x2, y2 = det.xyxy[0].tolist()
                    conf = det.conf.item()  # Confidence score
                    cls = int(det.cls.item())  # Class ID
                    dets.append([x1, y1, x2, y2, conf, cls])

        dets = np.array(dets) if dets else np.empty((0, 6))  # Ensure dets has the correct shape
        tracks = tracker.update(dets, image)  # Update tracker

        # print(tracks, tracks.shape)
        # Assign class_id to the segmentation mask layer
        for track in tracks:
            x1, y1, x2, y2, track_id, score, cls, det_ind = track[:8]
            print("x1, y1, x2, y2, track_id, score, cls, det_ind--", x1, y1, x2, y2, track_id, score, cls, det_ind)
            # Check if track_id already has a color
            if track_id not in track_colors:
                track_colors[track_id] = generate_unique_color(track_id)

            # Start with the center point
            center_x = int((x1 + x2) / 2)
            center_y = int((y1 + y2) / 2)

            # Check if the center point belongs to any layer
            valid_mask_layer = None
            for mask_idx, mask in enumerate(segmentation_mask):
                if mask[center_y, center_x] > 0:
                    valid_mask_layer = mask_idx
                    break

            if valid_mask_layer is not None:
                # Assign track_id to the entire valid mask layer
                tracking_seg[valid_mask_layer][segmentation_mask[valid_mask_layer] > 0] = track_id

        # Resize tracking_seg to (517, 533)
        resized_tracking_seg = np.array(
            [cv2.resize(mask, (533, 517), interpolation=cv2.INTER_NEAREST) for mask in tracking_seg]
        )
        # print("resized_tracking_seg", resized_tracking_seg)
        # Append the resized tracking_seg to the list
        tracking_seg_masks_list.append(resized_tracking_seg)
        # print("tracking_seg_masks_list", tracking_seg_masks_list)
        frame_count += 1

    # Determine the maximum number of layers across all masks
    max_layers = max(mask.shape[0] for mask in tracking_seg_masks_list)

    # Pad all masks to have the same number of layers
    padded_tracking_seg_masks_list = []
    for mask in tracking_seg_masks_list:
        # print("ssss", mask.shape)
        num_layers = mask.shape[0]
        if num_layers < max_layers:
            # Pad with zeros
            padding = np.zeros((max_layers - num_layers, mask.shape[1], mask.shape[2]), dtype=mask.dtype)
            padded_mask = np.concatenate([mask, padding], axis=0)
        else:
            padded_mask = mask
        padded_tracking_seg_masks_list.append(padded_mask)

    # Stack all padded masks and resized images into single arrays
    tracking_seg_masks = np.stack(padded_tracking_seg_masks_list, axis=0)
    resized_images = np.stack(resized_images_list, axis=0)

    return tracking_seg_masks, resized_images

In [None]:
def process_ground_truth_masks(
    mask_files: List[str],
    tracking_seg_masks: np.ndarray,
    target_size: tuple = (533, 517)
) -> np.ndarray:
    """
    Process and align ground truth masks to match the number of tracking segmentation mask layers.

    Args:
        mask_files (List[str]): List of file paths for ground truth masks.
        tracking_seg_masks (np.ndarray): Tracking segmentation masks, shape (num_frames, num_layers, H, W).
        target_size (tuple): Target size for resizing the ground truth masks, (width, height).

    Returns:
        np.ndarray: Aligned ground truth masks with shape (num_frames, num_tracking_layers, H, W).
    """
    # Initialize an empty list to store resized ground truth masks
    ground_truth_list = []

    # Step 1: Read, resize, and stack ground truth masks
    for mask_file in tqdm(mask_files, desc="Reading and Resizing Ground Truth Masks"):
        mask = iio.imread(mask_file)  # Read the mask image
        mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)  # Resize to target size
        ground_truth_list.append(mask)

    # Combine all masks into a single 4D array
    ground_truth = np.stack(ground_truth_list, axis=0)
    # print(f"Original ground_truth shape: {ground_truth.shape}")  # Shape: (num_frames, H, W, C)

    # Step 2: Extract the first layer (R channel) from the RGB ground truth
    ground_truth_first_layer = ground_truth[:, :, :, 0]  # Shape: (num_frames, H, W)
    # print(f"Ground truth first layer shape: {ground_truth_first_layer.shape}")

    def create_unique_value_layers(gt_frame, num_layers):
        """
        Create layers for each unique value in the ground truth frame, saving the unique value itself.
        Args:
            gt_frame (np.ndarray): Ground truth frame, shape (H, W).
            num_layers (int): Number of layers to create (match with tracking_seg_masks).
        Returns:
            np.ndarray: Unique value layers, shape (num_layers, H, W).
        """
        unique_values = np.unique(gt_frame)
        unique_values = unique_values[unique_values > 0]  # Exclude background (0)
        
        layers = []
        for val in unique_values:
            value_layer = (gt_frame == val).astype(np.int32) * val  # Save the unique value in the layer
            layers.append(value_layer)
        
        # Stack layers and pad with zeros to match num_layers
        stacked_layers = np.stack(layers, axis=0)
        if stacked_layers.shape[0] < num_layers:
            padding = np.zeros((num_layers - stacked_layers.shape[0], *stacked_layers.shape[1:]), dtype=np.int32)
            stacked_layers = np.vstack((stacked_layers, padding))
        elif stacked_layers.shape[0] > num_layers:
            stacked_layers = stacked_layers[:num_layers]
        
        return stacked_layers

    # Step 4: Align layers for each frame
    num_tracking_layers = tracking_seg_masks.shape[1]  # Number of layers in tracking_seg_masks
    ground_truth_aligned = []

    for frame_idx in tqdm(range(ground_truth_first_layer.shape[0]), desc="Aligning Ground Truth Layers"):
        aligned_layers = create_unique_value_layers(ground_truth_first_layer[frame_idx], num_tracking_layers)
        ground_truth_aligned.append(aligned_layers)

    # Ensure all frames have the same shape
    max_num_layers = num_tracking_layers
    final_ground_truth_aligned = []

    for idx, aligned_layers in enumerate(ground_truth_aligned):
        if aligned_layers.shape[0] < max_num_layers:
            # Pad with zeros to ensure all frames have the same number of layers
            padding = np.zeros((max_num_layers - aligned_layers.shape[0], *aligned_layers.shape[1:]), dtype=np.int32)
            aligned_layers = np.vstack((aligned_layers, padding))
        elif aligned_layers.shape[0] > max_num_layers:
            # Trim excess layers if any
            aligned_layers = aligned_layers[:max_num_layers]
        
        final_ground_truth_aligned.append(aligned_layers)

    # Stack into a single array
    ground_truth_aligned = np.stack(final_ground_truth_aligned, axis=0)
    # print(f"Final aligned ground_truth shape: {ground_truth_aligned.shape}")  # Shape: (num_frames, num_tracking_layers, H, W)

    return ground_truth_aligned

In [None]:
def evaluate_tracking_performance_with_id_tracking(
    ground_truth_aligned: np.ndarray, tracking_seg_masks: np.ndarray, overlap_threshold: float = 0.1
) -> Dict[str, Any]:
    
    def calculate_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
        """Calculate Intersection over Union (IoU) between two masks."""
        intersection = np.logical_and(mask1, mask2).sum()
        union = np.logical_or(mask1, mask2).sum()
        return intersection / union if union > 0 else 0.0

    def match_ids(gt_mask: np.ndarray, pred_mask: np.ndarray, threshold: float) -> Tuple[Dict[int, int], list[float]]:
        """Match ground truth IDs to predicted IDs based on IoU."""
        import matplotlib.pyplot as plt
    
        gt_ids = np.unique(gt_mask[gt_mask > 0])
        pred_ids = np.unique(pred_mask[pred_mask > 0])
        id_mapping = {}
        ious = []
        for gt_id in gt_ids:
            gt_instance = gt_mask == gt_id
    
            best_iou = 0
            best_pred_id = None
            for pred_id in pred_ids:
                pred_instance = pred_mask == pred_id
    
                iou = calculate_iou(gt_instance, pred_instance)
    
                if iou > best_iou and iou >= threshold:
                    best_iou = iou
                    best_pred_id = pred_id

            if best_pred_id is not None:
                id_mapping[gt_id] = best_pred_id
                ious.append(best_iou)
        return id_mapping, ious


    def calculate_id_switches(prev_mapping: Dict[int, int], current_mapping: Dict[int, int]) -> int:
        """Calculate the number of ID switches between frames."""
        id_switches = 0
        for gt_id, pred_id in current_mapping.items():
            if gt_id in prev_mapping and prev_mapping[gt_id] != pred_id:
                id_switches += 1
        return id_switches

    # Flatten the masks by taking the max ID across the instance dimension
    gt_masks = np.max(ground_truth_aligned, axis=1)
    pred_masks = np.max(tracking_seg_masks, axis=1)

    total_tp, total_fp, total_fn, total_id_switches = 0, 0, 0, 0
    total_ious = []
    frame_ious = []
    ground_truth_count = 0
    prev_id_mapping = {}

    for frame_idx in range(gt_masks.shape[0]):
        gt_mask = gt_masks[frame_idx]
        pred_mask = pred_masks[frame_idx]

        # Match IDs between ground truth and predicted masks
        current_id_mapping, ious = match_ids(gt_mask, pred_mask, overlap_threshold)

        # Calculate ID switches
        if frame_idx > 0:
            total_id_switches += calculate_id_switches(prev_id_mapping, current_id_mapping)
        prev_id_mapping = current_id_mapping

        # Evaluate frame metrics
        gt_ids   = np.unique(gt_mask[gt_mask > 0])
        pred_ids = np.unique(pred_mask[pred_mask > 0])
        tp = sum(1 for gt_id in gt_ids if gt_id in current_id_mapping)
        # print(len(pred_ids), tp, len(gt_ids), "=================")

        # fp = len(pred_ids) - tp
        fp = max(len(pred_ids) - tp, 0)
        fn = len(gt_ids) - tp
    
        total_tp += tp
        total_fp += fp
        total_fn += fn
        total_ious.extend(ious)
        ground_truth_count += len(gt_ids)

        # Compute IoU for the entire frame
        frame_iou = calculate_iou(gt_mask > 0, pred_mask > 0)  # Non-zero areas only
        frame_ious.append(frame_iou)

    # Calculate metrics
    precision = total_tp / (total_tp + total_fp) if total_tp + total_fp > 0 else 0.0
    recall = total_tp / (total_tp + total_fn) if total_tp + total_fn > 0 else 0.0
    mota = 1 - (total_fn + total_fp + total_id_switches) / ground_truth_count if ground_truth_count > 0 else 0.0
    motp = sum(total_ious) / len(total_ious) if total_ious else 0.0
    frame_based_iou = sum(frame_ious) / len(frame_ious) if frame_ious else 0.0

    metrics = {
        "GroundTruthMasksCount": ground_truth_count,
        "MultiObjectTrackingAccuracy": mota,
        "IDSwitches": total_id_switches,
        "MultiObjectTrackingPrecision": motp,
        "FalseNegatives": total_fn,
        "FalsePositives": total_fp,
        "FrameBasedIOU": frame_based_iou,
    }

    return metrics

In [None]:
def get_image_and_mask_files_from_dataset(image_paths, mask_paths):
    """
    Retrieves and sorts image and mask files directly from dataset paths.
    Args:
        image_paths (list): List of image paths from the dataset.
        mask_paths (list): List of mask paths from the dataset.
    Returns:
        tuple: Sorted lists of image and mask paths.
    """
    image_files = sorted(image_paths)
    mask_files = sorted(mask_paths)
    return image_files, mask_files

# Iterate over the processed data
def iterate_and_return_files(processed_data):
    """
    Iterates over processed data to retrieve sorted image and mask files.
    Args:
        processed_data (dict): Processed dataset organized by plant and rep.
    Yields:
        dict: Contains plant name, rep name, image files, and mask files.
    """
    for plant, reps in processed_data.items():
        for rep, paths in reps.items():
            # Retrieve sorted image and mask files from dataset
            image_files, mask_files = get_image_and_mask_files_from_dataset(
                paths["image_paths"], paths["mask_paths"]
            )
            yield {
                "plant_name": plant,
                "rep_name": rep,
                "image_files": image_files,
                "mask_files": mask_files
            }

def iterate_and_return_files_with_validation(processed_data):
    """
    Iterates over processed data, retrieves sorted and validated image and mask files.
    Args:
        processed_data (dict): Processed dataset organized by plant and rep.
    Yields:
        dict: Contains plant name, rep name, validated image files, and mask files.
    """
    for plant, reps in processed_data.items():
        for rep, paths in reps.items():
            print(images_dir,masks_dir)
            # Retrieve and sort image and mask files from dataset
            image_files = sorted([os.path.join(images_dir, f) for f in os.listdir(images_dir) if f.endswith('.png')])
            mask_files = sorted([os.path.join(masks_dir, f) for f in os.listdir(masks_dir) if f.endswith('.png')])
            print(len(image_files))
            # Ensure image and mask files are aligned
            assert len(image_files) == len(mask_files), f"Number of images and masks do not match for plant {plant}, rep {rep}!"
            for img, mask in zip(image_files, mask_files):
                assert os.path.basename(img) == os.path.basename(mask), f"Image and mask filenames do not align for plant {plant}, rep {rep}!"
            
            yield {
                "plant_name": plant,
                "rep_name": rep,
                "image_files": image_files,
                "mask_files": mask_files
            }

# Iterate over the plants
def process_plants(dataframe):
    result = {}
    for plant, plant_group in dataframe.groupby('plant'):
        result[plant] = {}
        # Iterate over the reps within each plant
        for rep, rep_group in plant_group.groupby('rep'):
            # Sort the data within each rep based on 'image_num'
            sorted_rep_group = rep_group.sort_values(by='image_num')
            # Collect image and mask paths
            image_paths = sorted_rep_group['image_path'].tolist()
            mask_paths = sorted_rep_group['mask_path'].tolist()
            # Store the result for the current rep
            result[plant][rep] = {
                "image_paths": image_paths,
                "mask_paths": mask_paths
            }
    return result

def iterate_and_return_files_with_validation(processed_data):
    """
    Iterates over processed data, retrieves sorted and validated image and mask files.
    Args:
        processed_data (dict): Processed dataset organized by plant and rep.
    Yields:
        dict: Contains plant name, rep name, validated and sorted image and mask files.
    """
    for plant, reps in processed_data.items():
        for rep, data in reps.items():
            # Extract image_paths and mask_paths
            image_paths = data.get('image_paths', [])
            mask_paths = data.get('mask_paths', [])
            
            # Sort image_paths and mask_paths
            image_paths = sorted(image_paths)
            mask_paths = sorted(mask_paths)
            
            # Ensure the number of image and mask files match
            assert len(image_paths) == len(mask_paths), (
                f"Mismatch in the number of images and masks for plant '{plant}', rep '{rep}'!"
            )
            
            # Ensure image and mask filenames match
            for img, mask in zip(image_paths, mask_paths):
                assert os.path.basename(img) == os.path.basename(mask), (
                    f"Image and mask filenames do not match for plant '{plant}', rep '{rep}': "
                    f"{os.path.basename(img)} vs {os.path.basename(mask)}"
                )
            
            # Yield the sorted and validated data
            yield {
                "plant_name": plant,
                "rep_name": rep,
                "image_files": image_paths,
                "mask_files": mask_paths
            }

In [None]:
import pathlib
import tempfile
import numpy as np
import torch
from PIL import Image
from ultralytics import YOLO
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from detectron2.engine import DefaultPredictor
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg

import os
import sys
import pathlib
import random
import tempfile
import numpy as np
import torch
from PIL import Image
from ultralytics import YOLO
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


sys.path.insert(0, "../src")
SAM_PATH = "../thirdparty/segment-anything-2/"
sys.path.insert(1, SAM_PATH)
os.environ["HYDRA_FULL_ERROR"] = "1"

In [None]:
def load_yolo_model(model_path: str) -> YOLO:
    """
    Load a YOLOv8 segmentation model from the specified path.

    Args:
        model_path (str): Path to the YOLOv8 model file.

    Returns:
        YOLO: Loaded YOLOv8 model.
    """
    try:
        model = YOLO(model_path)  # Load the YOLOv8 segmentation model
        print(f"Model loaded successfully from: {model_path}")
        return model
    except Exception as e:
        print(f"Error loading the model from {model_path}: {e}")
        raise

In [None]:
def load_detectron2_model(model_path: str, num_classes: int = 3, score_thresh: float = 0.7):
    """
    Load a Detectron2 model from a given checkpoint path.

    Args:
        model_path (str): Path to the model checkpoint (.pth file).
        config_path (str): Path to the model's config.yaml file.
        num_classes (int): Number of classes in the dataset.
        score_thresh (float): Score threshold for predictions.

    Returns:
        DefaultPredictor: A Detectron2 predictor object.
    """
    cfg = get_cfg()
    ARCHITECTURE = "mask_rcnn_R_101_FPN_3x"
    # COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml
    CONFIG_FILE_PATH = cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
    cfg.MODEL.WEIGHTS = "./leaf/mask_rcnn_R_101_FPN_3x/2025-02-01-10-12-17/model_final.pth"  # Path to trained model
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # Adjust threshold as needed
    cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"  # Use GPU if available
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1

    # Initialize predictor
    predictor = DefaultPredictor(cfg)
    return predictor

In [None]:
sys.path.insert(0, "../src")
SAM_PATH = "../thirdparty/segment-anything-2/"
sys.path.insert(1, SAM_PATH)
os.environ["HYDRA_FULL_ERROR"] = "1"

# Define model paths
model_paths = {
    "yolo_v8": "",
    "yolo_v11": "",
    "detectron2": ""
}

# Define SAM2 configuration
# Configuration for YOLO and SAM2
config = {
    "yolo_model": "../models/yolo_11mseg_finetuned_stage4.pt",
    "sam2_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
    "sam2_checkpoint": "../models/sam2.1l_finetuned.pt",
    "yolo_threshold": 0.5,
}

device = "cuda:1" if torch.cuda.is_available() else "cpu"
# Load models and store them in a dictionary
models = {
    "yolo_v8": load_yolo_model(model_paths["yolo_v8"]),
    "yolo_v11": load_yolo_model(model_paths["yolo_v11"]),
    "sam2": sam2_model(config, device),
    "detectron2": load_detectron2_model(model_path=model_paths["detectron2"])
}

# Print loaded models
print("Loaded models:")
for model_name in models:
    print(f"- {model_name}")

In [None]:
# Predefined color mapping
PREDEFINED_COLORS = {
    1:  (244, 64, 14),
    2:  (48, 57, 249),
    3:  (234, 250, 37),
    4:  (24, 193, 65),
    5:  (245, 130, 49),
    6:  (231, 80, 219),
    7:  (0, 182, 173),
    8:  (115, 0, 218),
    9:  (191, 239, 69),
    10: (255, 250, 200),
    11: (250, 190, 212),
    12: (66, 212, 244),
    13: (155, 99, 36),
    14: (220, 190, 255),
    15: (69, 158, 220),
    16: (255, 216, 177),
    17: (98, 2, 37),
    18: (227, 213, 12),
    19: (79, 159, 83),
    20: (170, 23, 101),
    21: (170, 255, 195),
    22: (169, 169, 169),
    23: (181, 111, 119),
    24: (144, 121, 171),
    25: (9, 125, 244),
    26: (184, 70, 30),
    27: (154, 35, 246),
    28: (229, 225, 238),
    29: (141, 254, 82),
    30: (31, 200, 209),
    31: (194, 217, 105),
    32: (91, 20, 124),
    33: (181, 220, 171),
    34: (37, 3, 193),
}

# Cache for dynamically generated colors
unique_color_mapping = {}

def generate_unique_color(track_id):
    """
    Generate a unique color for a track_id or retrieve it from predefined or dynamic mapping.

    Args:
        track_id (int): ID for which to generate or retrieve a color.

    Returns:
        tuple: BGR color as a tuple.
    """
    # Check if color is predefined
    if track_id in PREDEFINED_COLORS:
        return PREDEFINED_COLORS[track_id]
    
    # Check if color is already generated dynamically
    if track_id in unique_color_mapping:
        return unique_color_mapping[track_id]
    
    # Generate a new random color
    random.seed(track_id)  # Ensure consistent color generation for the same ID
    color_rgb = np.random.randint(100, 255, 3)  # Generate a random RGB color
    color_bgr = tuple(map(int, color_rgb[::-1]))  # Convert to BGR format for OpenCV compatibility
    unique_color_mapping[track_id] = color_bgr
    
    return color_bgr

def save_visualization(image_path, mask_path, frame_index, max_mask, unique_colors, original_image):
    """
    Save the visualization of the mask with unique colors for each instance overlaid on the original image.

    Parameters:
        image_path (Path): Path to save the visualization image.
        frame_index (int): Frame index for reference.
        max_mask (ndarray): The reduced mask for the current frame (2D array).
        unique_colors (dict): Dictionary mapping instance IDs to BGR colors.
        original_image (ndarray): The original BGR image (3D array).
        output_folder (str): Path to the folder where the overlaid images will be saved.
    """
    # Create a blank RGB image for the mask visualization
    height, width = max_mask.shape
    visualization = np.zeros((height, width, 3), dtype=np.uint8)

    # Apply colors to the mask
    for instance_id, color in unique_colors.items():
        visualization[max_mask == instance_id] = color

    # Save the mask visualization
    cv2.imwrite(str(mask_path), visualization)
    visualization = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB)

    # Overlay the mask on the original image
    overlayed_image = original_image.copy()
    non_black_pixels = max_mask > 0  # Mask non-zero pixels
    overlayed_image[non_black_pixels] = visualization[non_black_pixels]

    # Convert the overlayed image to RGB format
    overlayed_image_rgb = cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB)

    # Save the overlayed image in RGB
    cv2.imwrite(str(image_path), overlayed_image_rgb)


def save_instance_masks(tracking_seg_masks, original_images, output_dir):
    """
    Save instance masks into folders named leaf_{id}, where IDs are remapped to start from 1 and sorted.

    Args:
        tracking_seg_masks (numpy.ndarray): Instance segmentation masks with shape (frames, layers, H, W).
        original_images (numpy.ndarray): Original images with shape (frames, H, W, 3).
        output_dir (str): Parent directory to save instance mask folders.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Combine layers for all frames to get unique IDs and remap them
    unique_ids = sorted(set(np.unique(tracking_seg_masks)) - {0})  # Remove background (ID=0) and sort
    id_mapping = {old_id: new_id for new_id, old_id in enumerate(unique_ids, start=1)}  # Map old IDs to new IDs

    # Ensure each ID folder exists
    for new_id in id_mapping.values():
        (output_dir / "leaf_binary_masks"/ f"leaf_{new_id}").mkdir(parents=True, exist_ok=True)
        (output_dir / "leaf_extracted_images"/ f"leaf_{new_id}").mkdir(parents=True, exist_ok=True)

    # Iterate over frames
    for frame_idx, (frame_masks, original_image) in enumerate(zip(tracking_seg_masks, original_images)):
        # Sum layers to get a single mask for the frame (max or sum per pixel to combine layers)
        combined_frame_mask = frame_masks.max(axis=0)  # Shape: (H, W)

        # Process each instance in the frame
        for old_id, new_id in id_mapping.items():
            # Extract binary mask for the current instance in this frame
            instance_mask = (combined_frame_mask == old_id).astype(np.uint8)  # Binary mask (H, W)
            if instance_mask.sum() == 0:
                continue  # Skip if no instance in this frame

            # Save black-and-white mask
            bw_path = output_dir /  "leaf_binary_masks" / f"leaf_{new_id}" / f"frame_{frame_idx:04d}.png"
            cv2.imwrite(str(bw_path), instance_mask * 255)  # Convert binary mask to 255 for saving

            # Create an extracted RGB image with the instance region
            extracted_image = np.zeros_like(original_image)  # Black background
            extracted_image[instance_mask > 0] = original_image[instance_mask > 0]  # Copy region of interest

            # Save RGB overlay
            rgb_path = output_dir / "leaf_extracted_images" / f"leaf_{new_id}" / f"frame_{frame_idx:04d}.png"
            cv2.imwrite(str(rgb_path), extracted_image)

def process_tracking_custom_seg_masks(
    image_files: List[str], model, tracker, generate_unique_color, mode: str = "yolo"
) -> (np.ndarray, np.ndarray):
    """
    Process a list of image files, generate tracking segmentation masks, stack them, and stack the resized images.

    Args:
        image_files (List[str]): List of paths to image files.
        model: Model instance for object detection and segmentation (YOLO or SAM).
        tracker: Tracker instance for tracking objects across frames.
        generate_unique_color: Function to generate unique colors for track IDs.
        mode (str): Mode of operation, either "yolo" or "sam".

    Returns:
        tuple: 
            - np.ndarray: Stacked tracking segmentation masks with uniform dimensions.
            - np.ndarray: Stacked resized images with uniform dimensions.
    """
    tracking_seg_masks_list = []
    resized_images_list = []
    track_colors: Dict[int, tuple] = {}  # To store track_id and corresponding colors
    frame_count = 0  # Counter for frames

    for image_file in image_files:
        # Read the image
        image = iio.imread(image_file)
        # Resize the image to 640x640
        image = cv2.resize(image, (640, 640))
        resized_image = cv2.resize(image, (533, 517))
        resized_images_list.append(resized_image)

        if mode == "detectron2":
            # DETECTRON2 mode
            outputs = model(image)  # Get predictions from Detectron2
            instances = outputs["instances"].to("cpu")
            pred_masks = instances.pred_masks.numpy().astype(np.int32)  # Convert masks
            pred_boxes = instances.pred_boxes.tensor.numpy().astype(np.int32)  # Convert boxes
            pred_scores = instances.scores.numpy()  # Get confidence scores
            pred_classes = instances.pred_classes.numpy()  # Get class labels
            # print("Image shape:", image.shape)
            # print("pred_boxes shape:", pred_boxes.shape)

            # Convert to format similar to YOLO and SAM
            segmentation_mask = pred_masks  # Shape: (n, height, width)
            # print(segmentation_mask.shape)
            # Generate bounding boxes
            dets = []
            # print("pred_masks", pred_masks, "pred_boxes", pred_boxes)
            if pred_masks.sum() > 0:
                for i in range(len(pred_boxes)):
                    x1, y1, x2, y2 = pred_boxes[i]
                    score = pred_scores[i]
                    class_id = pred_classes[i]
                    dets.append([x1, y1, x2, y2, 0.9, 1])
                # print("dets", dets)
        
        else:
            raise ValueError("Unsupported mode. Use 'yolo', 'sam', or 'detectron2'.")
        # Create tracking_seg with the same shape as segmentation_mask, initialized to 0
        tracking_seg = np.zeros_like(segmentation_mask, dtype=np.int32)

            # Get segmentation masks from the model
        outputs = model_detectron2(image)
        instances = outputs["instances"].to("cpu")
        pred_masks = instances.pred_masks.numpy().astype(np.int32)  # Convert masks
        pred_boxes = instances.pred_boxes.tensor.numpy().astype(np.int32)  # Convert boxes
        pred_scores = instances.scores.numpy()  # Get confidence scores
        pred_classes = instances.pred_classes.numpy()  # Get class labels
        
        # Convert to format similar to YOLO and SAM
        segmentation_mask = pred_masks  # Shape: (n, height, width)
        
        dets = []
        if pred_masks.sum() > 0:
            for i in range(len(pred_boxes)):
                x1, y1, x2, y2 = pred_boxes[i]
                score = pred_scores[i]
                class_id = pred_classes[i]
                dets.append([x1, y1, x2, y2, 0.9, 1])
                
        # Convert masks into the tracker's format
        detections = [{"mask": mask} for mask in segmentation_mask]
        detections = np.array(detections)
  
        # Update the tracker
        tracked_objects = tracker.update(detections)

        tracks = []
        
        for key, obj in tracked_objects.items():
            binary_image = obj['mask']  # Extract the binary mask
            track_id = obj['id']  # Extract the track ID
        
            # Find non-zero pixels
            non_zero_pixels = cv2.findNonZero(binary_image.astype(np.uint8))
        
            if non_zero_pixels is not None:
                # Compute bounding box (x, y, width, height)
                x, y, w, h = cv2.boundingRect(non_zero_pixels)
        
                # Convert to x1, y1, x2, y2
                x1, y1 = x, y
                x2, y2 = x + w, y + h
        
                # Append to tracks list
                tracks.append([x1, y1, x2, y2, track_id, 1, 1, 0])
        
        for track in tracks:
            x1, y1, x2, y2, track_id, score, cls, det_ind = track[:8]
            # print("x1, y1, x2, y2, track_id, score, cls, det_ind--", x1, y1, x2, y2, track_id, score, cls, det_ind)
            # Check if track_id already has a color
            if track_id not in track_colors:
                track_colors[track_id] = generate_unique_color(track_id)

            # Start with the center point
            center_x = int((x1 + x2) / 2)
            center_y = int((y1 + y2) / 2)

            # Check if the center point belongs to any layer
            valid_mask_layer = None
            for mask_idx, mask in enumerate(segmentation_mask):
                if mask[center_y, center_x] > 0:
                    valid_mask_layer = mask_idx
                    break

            if valid_mask_layer is not None:
                # Assign track_id to the entire valid mask layer
                tracking_seg[valid_mask_layer][segmentation_mask[valid_mask_layer] > 0] = track_id

        # Resize tracking_seg to (517, 533)
        resized_tracking_seg = np.array(
            [cv2.resize(mask, (533, 517), interpolation=cv2.INTER_NEAREST) for mask in tracking_seg]
        )
        # print("resized_tracking_seg", resized_tracking_seg)
        # Append the resized tracking_seg to the list
        tracking_seg_masks_list.append(resized_tracking_seg)
        # print("tracking_seg_masks_list", tracking_seg_masks_list)
        frame_count += 1

    # Determine the maximum number of layers across all masks
    max_layers = max(mask.shape[0] for mask in tracking_seg_masks_list)

    # Pad all masks to have the same number of layers
    padded_tracking_seg_masks_list = []
    for mask in tracking_seg_masks_list:
        # print("ssss", mask.shape)
        num_layers = mask.shape[0]
        if num_layers < max_layers:
            # Pad with zeros
            padding = np.zeros((max_layers - num_layers, mask.shape[1], mask.shape[2]), dtype=mask.dtype)
            padded_mask = np.concatenate([mask, padding], axis=0)
        else:
            padded_mask = mask
        padded_tracking_seg_masks_list.append(padded_mask)

    # Stack all padded masks and resized images into single arrays
    tracking_seg_masks = np.stack(padded_tracking_seg_masks_list, axis=0)
    resized_images = np.stack(resized_images_list, axis=0)

    return tracking_seg_masks, resized_images

In [None]:
import warnings
from collections import Counter
import numpy as np
from PIL import Image

def resize_mask(mask: np.ndarray, new_size: tuple[int, int]) -> np.ndarray:
    """Resize a mask to a new resolution."""
    assert mask.ndim == 2, "Mask should be 2D"
    return np.array(
        Image.fromarray(mask).resize(new_size[::-1], resample=Image.Resampling.NEAREST)
    )

def resize_image(image: np.ndarray, new_size: tuple[int, int]) -> np.ndarray:
    """Resize an image to a new resolution."""
    assert image.ndim == 3, "Image should be 3D"
    return np.array(
        Image.fromarray(image).resize(new_size[::-1], resample=Image.Resampling.BICUBIC)
    )

def ensure_same_image_sizes(images: list[np.ndarray]) -> list[np.ndarray]:
    """Ensure all images have the same size."""
    sizes = Counter([img.shape[:2] for img in images])
    if len(sizes) <= 1:
        return images
    
    common_size = sizes.most_common(1)[0][0]
    return [resize_image(img, common_size) for img in images]

class ObjectTracker:
    """Generic object tracker that can be used with any detection model."""

    def __init__(self, iou_threshold: float = 0.3):
        self.prev_shape = None
        self.prev_detections = []  # List of detected objects (dicts)
        self.id_counter = 0
        self.iou_threshold = iou_threshold

    def _compute_iou(self, mask1: np.ndarray, mask2: np.ndarray) -> float:
        """Compute IoU between two masks."""
        intersection = (mask1 & mask2).sum()
        union = (mask1 | mask2).sum()
        return intersection / union if union > 0 else 0

    def _match_detections(self, new_detections: list[dict]) -> dict[int, dict]:
        """Match new detections with previous ones using IoU."""
        matches = {}
        used_indices = set()

        for prev in self.prev_detections:
            best_iou, best_idx = 0, None
            for i, new_det in enumerate(new_detections):
                if i in used_indices:
                    continue
                iou = self._compute_iou(prev["mask"], new_det["mask"])
                if iou > best_iou:
                    best_iou, best_idx = iou, i

            if best_idx is not None and best_iou >= self.iou_threshold:
                matches[best_idx] = prev["id"]
                used_indices.add(best_idx)

        return matches

    def update(self, detections: list[dict]) -> dict[int, dict]:
        """Update tracker with new detections."""
        if len(detections) == 0:
            return {}

        new_size = detections[0]["mask"].shape
        if self.prev_shape and self.prev_shape != new_size:
            warnings.warn("Resolution changed, resizing previous detections.")
            self.prev_detections = [
                {**d, "mask": resize_mask(d["mask"], new_size)} for d in self.prev_detections
            ]
        self.prev_shape = new_size

        matches = self._match_detections(detections)
        for i, det in enumerate(detections):
            det["id"] = matches.get(i, self.id_counter)
            if i not in matches:
                self.id_counter += 1
        
        self.prev_detections = detections
        return {d["id"]: d for d in self.prev_detections}

    def reset(self):
        """Reset the tracker."""
        self.prev_shape = None
        self.prev_detections = []
        self.id_counter = 0

## Apply Tracking

In [None]:
from tqdm import tqdm
import csv
from collections import defaultdict
import cv2
import numpy as np

# Configuration for YOLO and SAM2
config = {
    "yolo_model": "../models/yolo_11mseg_finetuned_stage4.pt",
    "sam2_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
    "sam2_checkpoint": "../models/sam2.1l_finetuned.pt",
    "yolo_threshold": 0.5,
}
device = "cuda:1" if torch.cuda.is_available() else "cpu"

# Run inference
model_sam2 = sam2_model(config, device)

csv_headers = ["Plant", "Rep", "Model", "Tracker"]

# Create results_tracking folder
results_tracking_dir = Path("results_tracking")
results_tracking_dir.mkdir(exist_ok=True)

# Ensure the 'results_csv' directory exists
csv_dir = Path("results_tracking/results_csv")
csv_dir.mkdir(parents=True, exist_ok=True)

# Define the CSV file paths
csv_file_path = csv_dir / "tracking_results.csv"
total_csv_file_path = csv_dir / "total_tracking.csv"

# Write the header once at the beginning
with open(csv_file_path, mode='w', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=csv_headers)
    writer.writeheader()
    
# Process data
processed_data = process_plants(subselected_ds)

# Write the header once at the beginning
with open(csv_file_path, mode='w', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=csv_headers)
    writer.writeheader()

# Initialize dictionary for aggregating metrics
aggregated_results = defaultdict(lambda: defaultdict(list))

unique_color_mapping = {}

In [None]:
# Main processing loop
for data in tqdm(iterate_and_return_files_with_validation(processed_data), desc="Processing plants"):
    plant_name = data['plant_name']
    rep_name = data['rep_name']
    image_files = data['image_files']
    mask_files = data['mask_files']
    # Collect results for this rep
    rep_results = []

    # Iterate over all models (YOLO and SAM2)
    for model_name, model in tqdm(models.items(), desc=f"Models for {rep_name}", leave=False):
        # Set mode based on model type
        mode = "sam" if model_name == "sam2" else "detectron2" if model_name == "detectron2" else "yolo"
        # Iterate over all tracker configurations
        for config in tqdm(trackers_config, desc=f"Trackers for {model_name}", leave=False):
            try:
                tracker_name = config["name"]
                if tracker_name != "botsort":
                    tracker_params = config["params"]
                            
                tracker_dir = results_tracking_dir / model_name / tracker_name / plant_name / rep_name
                sub_mask_black = tracker_dir / "all_masks_on_black"
                sub_mask_image = tracker_dir / "all_masks_on_image"
                
                tracker_dir.mkdir(parents=True, exist_ok=True)
                tracker_dir = Path(tracker_dir)
                
                sub_mask_black.mkdir(parents=True, exist_ok=True)
                sub_mask_black = Path(sub_mask_black)
                
                sub_mask_image.mkdir(parents=True, exist_ok=True)
                sub_mask_image = Path(sub_mask_image)

                # # Initialize the tracker
                if tracker_name != "botsort":
                    tracker = initialize_tracker(tracker_name, **tracker_params)
                else:
                    tracker = ObjectTracker(iou_threshold=0.3)
    
                # Process tracking segmentation masks
                if mode == "detectron2"
                    tracking_seg_masks, original_images = process_tracking_custom_seg_masks(image_files, model, tracker, generate_unique_color, mode=mode)
                else:
                    tracking_seg_masks, original_images = process_tracking_seg_masks(image_files, model, tracker, generate_unique_color, mode=mode)

                # print(tracking_seg_masks.shape, np.unique(tracking_seg_masks))
                aligned_ground_truth = process_ground_truth_masks(mask_files, tracking_seg_masks)
    
                # Evaluate tracking performance
                metrics = evaluate_tracking_performance_with_id_tracking(aligned_ground_truth, tracking_seg_masks, overlap_threshold=0.1)
    
                # Log and store results
                print(f"Plant: {plant_name} | Rep: {rep_name} | Model: {model_name} | Tracker: {tracker_name}")
                result_row = {
                    "Plant": plant_name,
                    "Rep": rep_name,
                    "Model": model_name,
                    "Tracker": tracker_name
                }
        
                rep_results.append(result_row)
    
                for metric_name, value in metrics.items():
                    result_row[metric_name] = value
                    aggregated_results[(model_name, tracker_name)][metric_name].append(value)
                        
            except Exception as e:
                print(f"Error processing tracker {tracker_name} for plant {plant_name}, rep {rep_name}: {e}")

    
    # Dynamically update CSV headers to include new metrics
    new_headers = set(result_row.keys()) - set(csv_headers)
    if new_headers:
        csv_headers.extend(new_headers)
        with open(csv_file_path, mode='w', newline='') as file:
            writer = csv.DictWriter(file, fieldnames=csv_headers)
            writer.writeheader()
            writer.writerows(rep_results)
    else:
        with open(csv_file_path, mode='a', newline='') as file:
            writer = csv.DictWriter(file, fieldnames=csv_headers)
            writer.writerows(rep_results)

    print(f"Results for rep {rep_name} saved to {csv_file_path}")

In [None]:
import csv
# Step 1: Read the CSV file
df = pd.read_csv("results_tracking/results_csv/tracking_results.csv")
df_cleaned = df.drop_duplicates()
df_cleaned.to_csv("results_tracking/results_csv/tracking_results_cleaned.csv", index=False)
print("Duplicate rows removed and saved as 'cleaned_output.csv'")

# Define headers for the CSV file
average_headers = ["Model", "Tracker"] + list(next(iter(aggregated_results.values())).keys()) 
average_rows = []

# Iterate over the aggregated results
for (model_name, tracker_name), metrics in aggregated_results.items():
    avg_row = {
        "Model": model_name,
        "Tracker": tracker_name
    }
    
    total_id_switches = 0
    total_fn = 0
    total_fp = 0
    ground_truth_count = 0
    framebasedious = []
    multi_object_tracking_precisions = []

    for metric_name, values in metrics.items():
        # Compute the average for metrics
        avg_value = sum(values) / len(values)
        
        if metric_name.lower() not in ["id", "groundtruthmaskscount"]:
            avg_row[metric_name] = round(avg_value, 3)
        else:
            avg_row[metric_name] = int(round(avg_value))
        
        # Collect specific metrics for additional calculations
        if metric_name.lower() == "idswitches":
            total_id_switches = sum(values)
        elif metric_name.lower() == "falsenegatives":
            total_fn = sum(values)
        elif metric_name.lower() == "falsepositives":
            total_fp = sum(values)
        elif metric_name.lower() == "groundtruthmaskscount":
            ground_truth_count = sum(values)
        elif metric_name.lower() == "framebasediou":
            framebasedious.extend(values)  # Fix: use extend
        elif metric_name.lower() == "multiobjecttrackingprecision":
            multi_object_tracking_precisions.extend(values)  # Fix: use extend
    
    # Compute additional metrics
    multi_object_tracking_accuracy = (
        1 - (total_fn + total_fp + total_id_switches) / ground_truth_count
        if ground_truth_count > 0 else 0.0
    )
    multi_object_tracking_precision = (
        sum(multi_object_tracking_precisions) / len(multi_object_tracking_precisions) if multi_object_tracking_precisions else 0.0
    )
    frame_based_iou = (
        sum(framebasedious) / len(framebasedious) if framebasedious else 0.0
    )
    
    # Add computed values to the row
    avg_row["GroundTruthMasksCount"] = round(ground_truth_count, 3)
    avg_row["IDSwitches"] = round(total_id_switches, 3)
    avg_row["MultiObjectTrackingAccuracy"] = round(multi_object_tracking_accuracy, 3)
    avg_row["MultiObjectTrackingPrecision"] = round(multi_object_tracking_precision, 3)
    avg_row["FalseNegatives"] = round(total_fn, 3)
    avg_row["FalsePositives"] = round(total_fp, 3)
    avg_row["FrameBasedIOU"]  = round(frame_based_iou, 3)
    
    average_rows.append(avg_row)

# Write the results to the CSV file
with open(total_csv_file_path, mode='w', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=average_headers)
    writer.writeheader()
    writer.writerows(average_rows)

print(f"Averages saved to {total_csv_file_path}")

In [None]:
import csv
from collections import defaultdict

total_csv_file_path = "results_tracking/results_csv/tracking_results_cleaned.csv"

# Read the CSV file and aggregate results
aggregated_results = defaultdict(lambda: defaultdict(list))
with open(total_csv_file_path, mode='r', newline='') as file:
    reader = csv.DictReader(file)
    for row in reader:
        model_name = row["Model"]
        tracker_name = row["Tracker"]
        
        for metric_name, value in row.items():
            if metric_name in ["Model", "Tracker"]:
                continue
            try:
                aggregated_results[(model_name, tracker_name)][metric_name].append(float(value))
            except ValueError:
                pass  # Ignore non-numeric values

# Define headers for the output CSV file in the exact order
average_headers = [
    "Model",
    "Tracker",
    "GroundTruthMasksCount",
    "MultiObjectTrackingAccuracy",
    "IDSwitches",
    "MultiObjectTrackingPrecision",
    "FalseNegatives",
    "FalsePositives",
    "FrameBasedIOU"
]
average_rows = []

# Process aggregated results
for (model_name, tracker_name), metrics in aggregated_results.items():
    avg_row = {
        "Model": model_name,
        "Tracker": tracker_name
    }
    
    total_id_switches = 0
    total_fn = 0
    total_fp = 0
    ground_truth_count = 0
    framebasedious = []
    multi_object_tracking_precisions = []

    for metric_name, values in metrics.items():
        avg_value = sum(values) / len(values) if values else 0
        
        if metric_name.lower() == "idswitches":
            total_id_switches = sum(values)
        elif metric_name.lower() == "falsenegatives":
            total_fn = sum(values)
        elif metric_name.lower() == "falsepositives":
            total_fp = sum(values)
        elif metric_name.lower() == "groundtruthmaskscount":
            ground_truth_count = sum(values)
        elif metric_name.lower() == "framebasediou":
            framebasedious.extend(values)  
        elif metric_name.lower() == "multiobjecttrackingprecision":
            multi_object_tracking_precisions.extend(values)  
    
    # Compute additional metrics
    multi_object_tracking_accuracy = (
        1 - (total_fn + total_fp + total_id_switches) / ground_truth_count
        if ground_truth_count > 0 else 0.0
    )
    multi_object_tracking_precision = (
        sum(multi_object_tracking_precisions) / len(multi_object_tracking_precisions) if multi_object_tracking_precisions else 0.0
    )
    frame_based_iou = (
        sum(framebasedious) / len(framebasedious) if framebasedious else 0.0
    )
    
    # Add computed values to the row
    avg_row["GroundTruthMasksCount"] = round(ground_truth_count, 3)
    avg_row["MultiObjectTrackingAccuracy"] = round(multi_object_tracking_accuracy, 3)
    avg_row["IDSwitches"] = round(total_id_switches, 3)
    avg_row["MultiObjectTrackingPrecision"] = round(multi_object_tracking_precision, 3)
    avg_row["FalseNegatives"] = round(total_fn, 3)
    avg_row["FalsePositives"] = round(total_fp, 3)
    avg_row["FrameBasedIOU"]  = round(frame_based_iou, 3)
    
    # Append only the specified fields in the desired order
    average_rows.append({key: avg_row.get(key, "") for key in average_headers})

# Write the results to the output CSV file
output_csv_file_path = "results_tracking/results_csv/total_tracking_results.csv"

with open(output_csv_file_path, mode='w', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=average_headers)
    writer.writeheader()
    writer.writerows(average_rows)

print(f"Averages saved to {output_csv_file_path}")

In [None]:
import pandas as pd
# Load the CSV file
total_csv_file_path = "results_tracking/results_csv/total_tracking.csv"
data = pd.read_csv(total_csv_file_path)
# Find the best tracker for each model based on MultiObjectTrackingAccuracy
best_trackers = data.loc[data.groupby("Model")["MultiObjectTrackingAccuracy"].idxmax()]
# Save the best trackers to a new CSV file or print them
best_trackers_csv_path = "results_tracking/results_csv/best_trackers.csv"
best_trackers.to_csv(best_trackers_csv_path, index=False)
# Load the best models CSV
best_models_csv_path = "results_tracking/results_csv/best_trackers.csv"
best_models_df = pd.read_csv(best_models_csv_path)
# Create a dictionary mapping each model to its best tracker
best_tracker_per_model = {row['Model']: row['Tracker'] for _, row in best_models_df.iterrows()}
print("Best trackers for each model:")
print(best_trackers)

## Select Best Combination

In [None]:
# Configuration for YOLO and SAM2
config = {
    "yolo_model": "../models/yolo_11mseg_finetuned_stage4.pt",
    "sam2_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
    "sam2_checkpoint": "../models/sam2.1l_finetuned.pt",
    "yolo_threshold": 0.5,
}
device = "cuda:1" if torch.cuda.is_available() else "cpu"

# Run inference
model_sam2 = sam2_model(config, device)

csv_headers = ["Plant", "Rep", "Model", "Tracker"]

# Create results_tracking folder
results_tracking_dir = Path("results_tracking")
results_tracking_dir.mkdir(exist_ok=True)

# Ensure the 'results_csv' directory exists
csv_dir = Path("results_tracking/results_csv")
csv_dir.mkdir(parents=True, exist_ok=True)

# Define the CSV file paths
csv_file_path = csv_dir / "tracking_results.csv"
total_csv_file_path = csv_dir / "total_tracking.csv"

# Write the header once at the beginning
with open(csv_file_path, mode='w', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=csv_headers)
    writer.writeheader()
    
# Process data
processed_data = process_plants(subselected_ds)

# Write the header once at the beginning
with open(csv_file_path, mode='w', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=csv_headers)
    writer.writeheader()

# Initialize dictionary for aggregating metrics
aggregated_results = defaultdict(lambda: defaultdict(list))

unique_color_mapping = {}

# Main processing loop
for data in tqdm(iterate_and_return_files_with_validation(processed_data), desc="Processing plants"):
    plant_name = data['plant_name']
    rep_name = data['rep_name']
    image_files = data['image_files']
    mask_files = data['mask_files']
    # Collect results for this rep
    rep_results = []

    # Iterate over all models (YOLO and SAM2)
    for model_name, model in tqdm(models.items(), desc=f"Models for {rep_name}", leave=False):
        print("model_name",model_name)
        # Set mode based on model type
        mode = "sam" if model_name == "sam2" else "detectron2" if model_name == "detectron2" else "yolo"
        # Iterate over all tracker configurations
        for config in tqdm(trackers_config, desc=f"Trackers for {model_name}", leave=False):

            tracker_name = config["name"]
            if tracker_name != "botsort":
                tracker_params = config["params"]
                  
            if best_tracker_per_model[model_name] != tracker_name:
                continue
            tracker_dir = results_tracking_dir / model_name / tracker_name / plant_name / rep_name
            sub_mask_black = tracker_dir / "all_masks_on_black"
            sub_mask_image = tracker_dir / "all_masks_on_image"
            
            tracker_dir.mkdir(parents=True, exist_ok=True)
            tracker_dir = Path(tracker_dir)
            
            sub_mask_black.mkdir(parents=True, exist_ok=True)
            sub_mask_black = Path(sub_mask_black)
            
            sub_mask_image.mkdir(parents=True, exist_ok=True)
            sub_mask_image = Path(sub_mask_image)
            
            # # Initialize the tracker
            if tracker_name != "botsort":
                tracker = initialize_tracker(tracker_name, **tracker_params)
            else:
                tracker = ObjectTracker(iou_threshold=0.3)

            # Process tracking segmentation masks
            if mode == "detectron2"
                tracking_seg_masks, original_images = process_tracking_custom_seg_masks(image_files, model, tracker, generate_unique_color, mode=mode)
            else:
                tracking_seg_masks, original_images = process_tracking_seg_masks(image_files, model, tracker, generate_unique_color, mode=mode)

            # # print(tracking_seg_masks.shape, np.unique(tracking_seg_masks))
            aligned_ground_truth = process_ground_truth_masks(mask_files, tracking_seg_masks)
            
            # Evaluate tracking performance
            metrics = evaluate_tracking_performance_with_id_tracking(aligned_ground_truth, tracking_seg_masks, overlap_threshold=0.1)
            
            # Log and store results
            print(f"Plant: {plant_name} | Rep: {rep_name} | Model: {model_name} | Tracker: {tracker_name}")
            result_row = {
                "Plant": plant_name,
                "Rep": rep_name,
                "Model": model_name,
                "Tracker": tracker_name
            }
            
            rep_results.append(result_row)
            
            for metric_name, value in metrics.items():
                result_row[metric_name] = value
                aggregated_results[(model_name, tracker_name)][metric_name].append(value)
                    
            save_instance_masks(tracking_seg_masks, original_images, tracker_dir)
            # Process each frame and save visualizations
            for frame_index in range(tracking_seg_masks.shape[0]):
                instance_layers = tracking_seg_masks[frame_index]
                original_image = original_images[frame_index]
                max_mask = instance_layers.max(axis=0)

                # Save visualization for the current frame
                image_mask_path = sub_mask_image / f"frame_{frame_index:04d}.png"
                black_mask_path = sub_mask_black / f"frame_{frame_index:04d}.png"

                # Generate unique colors for each instance ID and remap IDs to start from 1, 2, 3, ...
                unique_id_mapping = {old_id: new_id for new_id, old_id in enumerate(sorted(np.unique(max_mask)), start=1) if old_id != 0}
                remapped_mask = np.zeros_like(max_mask)
                
                for old_id, new_id in unique_id_mapping.items():
                    remapped_mask[max_mask == old_id] = new_id
                
                # Generate unique colors for remapped IDs
                unique_colors = {instance_id: generate_unique_color(instance_id) for instance_id in unique_id_mapping.values()}

                save_visualization(image_mask_path, black_mask_path, frame_index, remapped_mask, unique_colors, original_image)

                rep_results.append(result_row)
    

    # Dynamically update CSV headers to include new metrics
    new_headers = set(result_row.keys()) - set(csv_headers)
    if new_headers:
        csv_headers.extend(new_headers)
        with open(csv_file_path, mode='w', newline='') as file:
            writer = csv.DictWriter(file, fieldnames=csv_headers)
            writer.writeheader()
            writer.writerows(rep_results)
    else:
        with open(csv_file_path, mode='a', newline='') as file:
            writer = csv.DictWriter(file, fieldnames=csv_headers)
            writer.writerows(rep_results)

    print(f"Results for rep {rep_name} saved to {csv_file_path}")