In [None]:
# import required libraries
import os
import numpy as np
from tqdm import tqdm
from glob import glob
from numpy import zeros
from numpy.random import randint
import torch
import os
import cv2
from statistics import mean
from torch.nn.functional import threshold, normalize

# Data Viz
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [None]:
! pip install torch torchvision &> /dev/null
! pip install opencv-python pycocotools matplotlib onnxruntime onnx &> /dev/null

# ==== Download Pretrained SAM Model Weights ====
# Download the ViT versions of the SAM model weights from Facebook's public storage.
! pip install git+https://github.com/facebookresearch/segment-anything.git &> /dev/null
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth &> /dev/null
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth &> /dev/null
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth &> /dev/null

In [None]:
# === CONFIGURATION ===
# Set the path to your training images and labels
image_path = "/content/drive/MyDrive/canal_project/27_11_2025/train/images"   # <-- Update this path
label_path = "/content/drive/MyDrive/canal_project/27_11_2025/train/masks"    # <-- Update this path
# === Load Image Paths ===
# Count total number of image files (e.g., .jpg format)
all_image_paths = sorted(glob(os.path.join(image_path, "*.jpg")))  # Use .png if needed
total_images = len(all_image_paths)
print(f"Total Number of Images: {total_images}")

# === Load Label Paths ===
# Count total number of label files (e.g., .png format for segmentation masks)
all_label_paths = sorted(glob(os.path.join(label_path, "*.png")))
total_labels = len(all_label_paths)
print(f"Total Number of Labels: {total_labels}")

# === Match Images and Labels ===
# Assuming both are in matching order and of equal length
train_image_paths = all_image_paths[:total_images]
train_label_paths = all_label_paths[:total_labels]

# Preview label paths (for verification)
print("Sample label paths:")
for path in train_label_paths[:5]:
    print(path)

Total Number of Images: 186
Total Number of Labels: 186
Sample label paths:
/content/drive/MyDrive/canal_project/27_11_2025/train/masks/image_capture_2025-04-05_11-45-56_jpg.rf.3d31effc5cd44f2feb5ca1c2c219346b_mask.png
/content/drive/MyDrive/canal_project/27_11_2025/train/masks/image_capture_2025-04-06_10-10-41_jpg.rf.2d25833c6629824eb13665eb4b7a794c_mask.png
/content/drive/MyDrive/canal_project/27_11_2025/train/masks/image_capture_2025-04-06_11-10-42_jpg.rf.ddc1a4d0ad9d75799e5114aa78022e19_mask.png
/content/drive/MyDrive/canal_project/27_11_2025/train/masks/image_capture_2025-04-06_11-11-48_jpg.rf.6b1a6629a791c2c0c47e5ed49cca16ee_mask.png
/content/drive/MyDrive/canal_project/27_11_2025/train/masks/image_capture_2025-04-06_12-12-56_jpg.rf.a087e1490f708cdb45e28555f3786aaf_mask.png


In [None]:
# === CONFIGURATION ===
# Set the path to your validation images and labels
val_image_path = "/content/drive/MyDrive/canal_project/27_11_2025/valid/images"    # <-- Update this path
val_label_path = "/content/drive/MyDrive/canal_project/27_11_2025/valid/masks"      # <-- Update this path

# === Load Validation Image Paths ===
# Collect and sort all .jpg image files in the validation folder
val_all_image_paths = sorted(glob(os.path.join(val_image_path, "*.jpg")))
val_total_images = len(val_all_image_paths)
print(f"Total Number of Validation Images: {val_total_images}")

# === Load Validation Label Paths ===
# Collect and sort all .png label files in the validation folder
val_all_label_paths = sorted(glob(os.path.join(val_label_path, "*.png")))
val_total_labels = len(val_all_label_paths)
print(f"Total Number of Validation Labels: {val_total_labels}")

# === Match Images and Labels (by order) ===
# This assumes one-to-one correspondence between image and label files
Val1_image_paths = val_all_image_paths[:val_total_images]
Val1_label_paths = val_all_label_paths[:val_total_labels]

# Preview a few label paths to confirm loading
print("Sample validation label paths:")
for path in Val1_label_paths[:5]:
    print(path)

Total Number of Validation Images: 39
Total Number of Validation Labels: 39
Sample validation label paths:
/content/drive/MyDrive/canal_project/27_11_2025/valid/masks/image_capture_2025-04-05_11-44-49_jpg.rf.b321be12d6ab3078ae5064300cfb22ac_mask.png
/content/drive/MyDrive/canal_project/27_11_2025/valid/masks/image_capture_2025-04-07_11-37-39_jpg.rf.2fabed3881bc90ecc487ee025e0509cb_mask.png
/content/drive/MyDrive/canal_project/27_11_2025/valid/masks/image_capture_2025-04-08_11-04-35_jpg.rf.da8fbcd31f962ae95b3d5301f544de7e_mask.png
/content/drive/MyDrive/canal_project/27_11_2025/valid/masks/image_capture_2025-04-09_11-31-33_jpg.rf.f831b52a22d28ffa7e8e351d3e28381b_mask.png
/content/drive/MyDrive/canal_project/27_11_2025/valid/masks/image_capture_2025-05-19_12-47-22_jpg.rf.7deac398fbf7675b3cb73df0620b94d2_mask.png


In [None]:
# Please dont run this line if you would like to use the original size of input images.
desired_size=(640, 640)

In [None]:
# === Load and Process Ground Truth Masks ===
# This dictionary will store binary masks where pixel > 0 is treated as True
ground_truth_masks = {}

for idx in range(len(train_label_paths)):
    # Read the label mask in grayscale
    gt_grayscale = cv2.imread(train_label_paths[idx], cv2.IMREAD_GRAYSCALE)

    # Resize the mask if desired_size is specified
    if desired_size is not None:
        gt_grayscale = cv2.resize(gt_grayscale, desired_size, interpolation=cv2.INTER_LINEAR)

    # Convert to binary mask (True where pixel > 0)
    ground_truth_masks[idx] = (gt_grayscale > 0)

# Optional: Print number of masks and preview a sample
print(f"Total ground truth masks loaded: {len(ground_truth_masks)}")
print("Example binary mask shape:", ground_truth_masks[0].shape)


Total ground truth masks loaded: 186
Example binary mask shape: (640, 640)


In [None]:
# === Load and Process Validation Ground Truth Masks ===
# This dictionary will store binary masks for validation data
ground_truth_masksv = {}

for idx in range(len(Val1_label_paths)):
    # Read the validation label mask in grayscale
    gt_grayscale = cv2.imread(Val1_label_paths[idx], cv2.IMREAD_GRAYSCALE)

    # Resize the mask if a desired size is specified
    if desired_size is not None:
        gt_grayscale = cv2.resize(gt_grayscale, desired_size, interpolation=cv2.INTER_LINEAR)

    # Convert to binary mask: True where pixel > 0
    ground_truth_masksv[idx] = (gt_grayscale > 0)

# Print summary
print(f"Total validation ground truth masks loaded: {len(ground_truth_masksv)}")
print("Example validation mask shape:", ground_truth_masksv[0].shape)

Total validation ground truth masks loaded: 39
Example validation mask shape: (640, 640)


In [None]:
model_type = 'vit_b'
checkpoint = 'sam_vit_b_01ec64.pth'
device = 'cuda:0'

In [None]:
model_type = 'vit_l'
checkpoint = 'sam_vit_l_0b3195.pth'
device = 'cuda:0'

In [None]:
model_type = "vit_h"
checkpoint = "sam_vit_h_4b8939.pth"
device = 'cuda:0'

In [None]:
# === Import Required SAM Modules ===
# Make sure the Segment Anything (SAM) package is installed and accessible
from segment_anything import SamPredictor, sam_model_registry
import torch

# === Configuration ===
# Set the model type: "vit_b", "vit_l", or "vit_h" depending on your .pth file
model_type = "vit_b"  # or "vit_l", "vit_h", etc.

# Path to the pretrained SAM checkpoint file (.pth)
checkpoint = "/content/sam_vit_b_01ec64.pth"  # <-- Update with actual path

# === Load Model ===
# Use the model registry to initialize the correct SAM architecture
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)

# Move model to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam_model.to(device)

# === Set Model to Training Mode ===
# Use `model.train()` when fine-tuning or training the model
# For inference, use `model.eval()` instead
sam_model.train()

print(f"SAM model ({model_type}) loaded on {device} and set to training mode.")

SAM model (vit_b) loaded on cuda and set to training mode.


In [None]:
from collections import defaultdict
from segment_anything.utils.transforms import ResizeLongestSide

# Preprocessed image data will be stored in this dictionary
transformed_data = defaultdict(dict)

# Transformer that resizes image while preserving aspect ratio
resize_transform = ResizeLongestSide(sam_model.image_encoder.img_size)

# === Image Preprocessing Loop ===
for idx in range(len(train_image_paths)):
    # Load image from path
    image = cv2.imread(train_image_paths[idx])

    # Resize if a fixed input size is specified (e.g., for training consistency)
    if desired_size is not None:
        image = cv2.resize(image, desired_size, interpolation=cv2.INTER_LINEAR)

    # Convert BGR (OpenCV default) to RGB (SAM model expects RGB)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Apply SAM’s resizing transformation to match its input constraints
    input_image_np = resize_transform.apply_image(image_rgb)

    # Convert NumPy array to torch tensor and add batch dimension
    input_image_tensor = torch.as_tensor(input_image_np, device=device)
    input_image_tensor = input_image_tensor.permute(2, 0, 1).contiguous()[None, :, :, :]  # Shape: [1, 3, H, W]

    # Preprocess using SAM model’s preprocessing method (normalization, padding, etc.)
    input_tensor = sam_model.preprocess(input_image_tensor)

    # Store processed data
    transformed_data[idx]['image'] = input_tensor                          # Preprocessed image tensor
    transformed_data[idx]['input_size'] = input_image_tensor.shape[-2:]   # Input tensor size (H, W)
    transformed_data[idx]['original_image_size'] = image_rgb.shape[:2]    # Original image size (H, W)

print(f"Processed {len(transformed_data)} training images for SAM input.")

Processed 186 training images for SAM input.


In [None]:
# === Training Hyperparameters ===
lr = 1e-5                     # Learning rate for optimizer
wd = 0                        # Weight decay (L2 regularization)
batch_size = 32              # Number of samples per batch
num_epochs = 5               # Total number of training epochs

# === Optimizer Setup ===
# Only the mask decoder parameters are being fine-tuned (others are frozen)
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=lr, weight_decay=wd)

# === Loss Function ===
# Binary Cross Entropy with logits is commonly used for binary segmentation
loss_fn = torch.nn.BCEWithLogitsLoss()

# === Device Setup ===
# Automatically use GPU if available, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Ground Truth Mask Keys ===
# These lists are used to index into your ground truth dictionaries
keys_train = list(ground_truth_masks.keys())
keys_valid = list(ground_truth_masksv.keys())

print(f"Using device: {device}")
print(f"Training on {len(keys_train)} images, validating on {len(keys_valid)} images")

Using device: cuda
Training on 186 images, validating on 39 images


In [None]:

from torch.utils.data import DataLoader

# === Validation DataLoader Setup ===
# Here we're using a list of file paths as the dataset, which will later need to be wrapped in a proper Dataset class
val_loader = DataLoader(Val1_image_paths, batch_size=batch_size, shuffle=False)

# === Basic Validation Dataset Checks ===

# Total number of validation examples
num_val_examples = len(Val1_image_paths)
print(f"Number of validation examples: {num_val_examples}")

# Number of items returned by val_loader.dataset (same as above since it's a list)
print(f"Number of examples in validation dataset (via DataLoader): {len(val_loader.dataset)}")

# Number of batches in the validation DataLoader
print(f"Number of batches in validation loader: {len(val_loader)}")

# === Safety Check ===
# Prevent training from continuing if validation data is empty
if num_val_examples == 0:
    raise ValueError("The validation dataset is empty. Please check your data paths.")

Number of validation examples: 39
Number of examples in validation dataset (via DataLoader): 39
Number of batches in validation loader: 2


In [None]:
torch.cuda.empty_cache()


In [None]:
# === Accuracy Calculation Function ===
def calculate_accuracy(predictions, targets):
    """
    Computes binary accuracy between predicted and ground truth masks.
    """
    binary_predictions = (predictions > 0.5).float()
    accuracy = (binary_predictions == targets).float().mean()
    return accuracy.item()

# === Batch Training Function ===
def train_on_batch(keys, batch_start, batch_end):
    """
    Trains the SAM mask decoder on a batch of images and masks.
    Returns batch loss and accuracy.
    """
    batch_losses = []
    batch_accuracies = []

    for k in keys[batch_start:batch_end]:
        # === Get input data and metadata
        input_image = transformed_data[k]['image'].to(device)
        input_size = transformed_data[k]['input_size']
        original_image_size = transformed_data[k]['original_image_size']

        # === Forward Pass (frozen encoders)
        with torch.no_grad():
            image_embedding = sam_model.image_encoder(input_image)
            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                points=None, boxes=None, masks=None
            )

        low_res_masks, _ = sam_model.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False
        )

        # === Resize prediction to original size
        upscaled_masks = sam_model.postprocess_masks(
            low_res_masks, input_size, original_image_size
        ).to(device)

        # === Resize ground truth mask to match output
        gt_np = ground_truth_masks[k].astype(np.uint8)
        resized_gt = cv2.resize(gt_np, upscaled_masks.shape[-2:][::-1], interpolation=cv2.INTER_NEAREST)
        gt_binary_mask = torch.tensor(resized_gt, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)

        # === Compute loss and update weights
        loss = loss_fn(upscaled_masks, gt_binary_mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # === Record metrics
        batch_losses.append(loss.item())
        batch_accuracies.append(calculate_accuracy(torch.sigmoid(upscaled_masks), gt_binary_mask))

    return batch_losses, batch_accuracies

# === Epoch Training Loop ===
losses, accuracies = [], []

for epoch in range(num_epochs):
    epoch_losses = []
    epoch_accuracies = []

    print(f"\n--- EPOCH {epoch + 1}/{num_epochs} ---")

    for batch_start in range(0, len(keys), batch_size):
        batch_end = min(batch_start + batch_size, len(keys))
        batch_losses, batch_accuracies = train_on_batch(keys, batch_start, batch_end)

        batch_loss = mean(batch_losses)
        batch_accuracy = mean(batch_accuracies)
        epoch_losses.append(batch_loss)
        epoch_accuracies.extend(batch_accuracies)

        print(f'Batch [{batch_start + 1}–{batch_end}] | Loss: {batch_loss:.6f} | Accuracy: {batch_accuracy:.4f}')

    # === End of Epoch ===
    mean_train_loss = mean(epoch_losses)
    mean_train_accuracy = mean(epoch_accuracies)
    losses.append(mean_train_loss)
    accuracies.append(mean_train_accuracy)

    print(f'\nEpoch {epoch + 1} Summary:')
    print(f'➤ Mean Training Loss: {mean_train_loss:.6f}')
    print(f'➤ Mean Training Accuracy: {mean_train_accuracy:.4f}')

    # Clear cache to manage memory
    torch.cuda.empty_cache()


--- EPOCH 1/5 ---
Batch [1–32] | Loss: 0.983773 | Accuracy: 0.9191
Batch [33–64] | Loss: 0.037950 | Accuracy: 0.9905
Batch [65–96] | Loss: 0.042169 | Accuracy: 0.9853
Batch [97–128] | Loss: 0.031729 | Accuracy: 0.9930
Batch [129–160] | Loss: 0.026552 | Accuracy: 0.9918
Batch [161–186] | Loss: 0.015162 | Accuracy: 0.9956

Epoch 1 Summary:
➤ Mean Training Loss: 0.189556
➤ Mean Training Accuracy: 0.9787

--- EPOCH 2/5 ---
Batch [1–32] | Loss: 0.015650 | Accuracy: 0.9955
Batch [33–64] | Loss: 0.014300 | Accuracy: 0.9957
Batch [65–96] | Loss: 0.017627 | Accuracy: 0.9941
Batch [97–128] | Loss: 0.014464 | Accuracy: 0.9953
Batch [129–160] | Loss: 0.019018 | Accuracy: 0.9925
Batch [161–186] | Loss: 0.012470 | Accuracy: 0.9957

Epoch 2 Summary:
➤ Mean Training Loss: 0.015588
➤ Mean Training Accuracy: 0.9948

--- EPOCH 3/5 ---
Batch [1–32] | Loss: 0.012553 | Accuracy: 0.9959
Batch [33–64] | Loss: 0.011836 | Accuracy: 0.9959
Batch [65–96] | Loss: 0.014567 | Accuracy: 0.9948
Batch [97–128] | Loss:

In [None]:
print(type(ground_truth_masksv))
print(ground_truth_masksv.keys() if isinstance(ground_truth_masksv, dict) else len(ground_truth_masksv))

<class 'dict'>
dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38])


In [None]:

# Set the paths to your test images and labels
test_image_dir = "/content/drive/MyDrive/canal_project/27_11_2025/test/images"  # <-- Update this path
test_label_dir = "/content/drive/MyDrive/canal_project/27_11_2025/test/masks"   # <-- Update this path

# === Load and Sort Test Image Paths ===
# Collect all test image files (e.g., .jpg)
all_test_image_paths = sorted(glob(os.path.join(test_image_dir, "*.jpg")))
test_total_images = len(all_test_image_paths)
print(f"Total Number of Test Images: {test_total_images}")

# === Load and Sort Test Label Paths ===
# Collect all test label files (e.g., .png masks)
all_test_label_paths = sorted(glob(os.path.join(test_label_dir, "*.png")))
test_total_labels = len(all_test_label_paths)
print(f"Total Number of Test Labels: {test_total_labels}")

# === Match Image and Label Paths ===
# These lists can now be used for DataLoader or evaluation
Test_image_paths = all_test_image_paths[:test_total_images]
Test_label_paths = all_test_label_paths[:test_total_labels]

# Optional: Print a few samples to verify
print("Sample test image path:", Test_image_paths[0] if Test_image_paths else "No images found")
print("Sample test label path:", Test_label_paths[0] if Test_label_paths else "No labels found")

Total Number of Test Images: 40
Total Number of Test Labels: 40
Sample test image path: /content/drive/MyDrive/canal_project/27_11_2025/test/images/image_capture_2025-04-06_12-11-49_jpg.rf.826be63e91371c207d97f6472d9d495b.jpg
Sample test label path: /content/drive/MyDrive/canal_project/27_11_2025/test/masks/image_capture_2025-04-06_12-11-49_jpg.rf.826be63e91371c207d97f6472d9d495b_mask.png


In [None]:

# Dictionary to hold ground truth binary masks for test data
ground_truth_test_masks = {}

# === Load and Process Each Test Mask ===
for idx in range(len(Test_label_paths)):
    # Read label image in color (3-channel); expected mask is in the red channel
    gt_color = cv2.imread(Test_label_paths[idx])

    # Extract the red channel only and convert to binary mask
    # Note: OpenCV loads in BGR, so red is at index 2
    binary_mask = (gt_color[:, :, 2] > 0).astype(np.float32)

    # Resize if specified
    if desired_size is not None:
        binary_mask = cv2.resize(binary_mask, desired_size, interpolation=cv2.INTER_NEAREST)

    # Store in dictionary
    ground_truth_test_masks[idx] = binary_mask

print(f"Loaded {len(ground_truth_test_masks)} ground truth test masks.")

Loaded 40 ground truth test masks.


In [None]:

# === Inference with SAM Predictor on Test Set ===
masks_tuned_list = {}   # Stores predicted binary masks
images_tuned_list = {}  # Stores input images used during inference

for idx in range(len(Test_image_paths)):
    # === Load and Preprocess Image ===
    image = cv2.imread(Test_image_paths[idx])
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    if desired_size is not None:
        image_rgb = cv2.resize(image_rgb, desired_size, interpolation=cv2.INTER_LINEAR)

    # === Set the image for SAM predictor ===
    predictor_tuned.set_image(image_rgb)

    # === Predict segmentation mask ===
    masks_tuned, _, _ = predictor_tuned.predict(
        point_coords=None,
        box=None,
        multimask_output=False,  # Only get the most confident mask
    )

    # === Extract and post-process the first predicted mask ===
    mask_np = masks_tuned[0, :, :]                 # Select first mask
    binary_mask = (mask_np > 0).astype(np.float32) # Convert to float binary mask

    # === Store results ===
    images_tuned_list[idx] = image_rgb
    masks_tuned_list[idx] = binary_mask

print(f"Inference complete on {len(Test_image_paths)} test images.")


In [None]:

import matplotlib.pyplot as plt
import numpy as np

# === Grid Configuration ===
n_images = len(images_tuned_list)
n_cols = 4  # Number of images per row
n_rows = (n_images // n_cols) + (n_images % n_cols > 0)  # Auto-calculate rows

# Create a figure with subplots
fig, axs = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))

# If axs is 1D (e.g., only 1 row), convert to 2D for consistency
axs = np.atleast_2d(axs)

# === Iterate and Plot ===
for i in range(n_rows):
    for j in range(n_cols):
        index = i * n_cols + j
        ax = axs[i, j]

        if index < n_images:
            # Display the RGB image
            ax.imshow(images_tuned_list[index], interpolation='none')

            # Generate a blue mask overlay (R=0, G=0, B=1) for binary mask = 1
            mask = masks_tuned_list[index]
            blue_mask_rgb = np.zeros((*mask.shape, 3), dtype=np.float32)
            blue_mask_rgb[..., 2] = mask  # Blue channel

            # Overlay the mask with transparency
            ax.imshow(blue_mask_rgb, alpha=0.5)

        # Remove axes ticks
        ax.axis('off')

# === Final Layout ===
plt.subplots_adjust(wspace=0.03, hspace=0.03)
plt.tight_layout()
plt.show()

In [None]:
import torch
import numpy as np
from sklearn.metrics import auc, roc_curve

# === Binary Metrics for One Prediction ===
def binary_segmentation_metrics(predictions, targets):
    """
    Computes binary segmentation metrics for a single predicted mask vs ground truth.
    Inputs:
        predictions (numpy array): predicted mask, float32, range [0,1] or binary
        targets (numpy array): ground truth mask, binary (0 or 1)
    Returns:
        Tuple of metrics: accuracy, precision, recall, F1-score, IoU, kappa, FP, FN, TP, TN, dice
    """
    # Flatten and convert to binary
    predictions = predictions.squeeze()
    targets = targets.squeeze()

    predictions_binary = (predictions > 0.5).astype(int)
    targets_binary = targets.astype(int)

    # Confusion matrix components
    TP = np.sum((predictions_binary == 1) & (targets_binary == 1))
    FP = np.sum((predictions_binary == 1) & (targets_binary == 0))
    FN = np.sum((predictions_binary == 0) & (targets_binary == 1))
    TN = np.sum((predictions_binary == 0) & (targets_binary == 0))

    # Metrics with small epsilon to avoid division by zero
    eps = 1e-5
    accuracy = (TP + TN + eps) / (TP + FP + FN + TN + eps)
    precision = (TP + eps) / (TP + FP + eps)
    recall = (TP + eps) / (TP + FN + eps)
    f_score = 2 * (precision * recall) / (precision + recall + eps)
    dice = (2 * TP + eps) / (2 * TP + FP + FN + eps)
    iou = (TP + eps) / (TP + FP + FN + eps)

    # Cohen’s kappa
    total = TP + FP + FN + TN
    p_o = (TP + TN) / total
    p_e = ((TP + FP) * (TP + FN) + (FN + TN) * (FP + TN)) / (total ** 2)
    kappa = (p_o - p_e) / (1 - p_e + eps)

    return accuracy, precision, recall, f_score, iou, kappa, FP, FN, TP, TN, dice

# === Average Metrics Across Dataset ===
def calculate_average_metrics(predictions_list, targets_list):
    """
    Computes average binary segmentation metrics across a dataset.
    Inputs:
        predictions_list: dictionary or list of predicted masks
        targets_list: dictionary or list of ground truth masks
    Returns:
        Dictionary of averaged metrics
    """
    num_masks = len(predictions_list)

    total_metrics = {
        'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f_score': 0.0,
        'iou': 0.0, 'kappa': 0.0, 'FP': 0, 'FN': 0, 'MAR': 0.0, 'FAR': 0.0, 'dice': 0.0
    }

    for i in range(num_masks):
        pred = predictions_list[i]
        gt = targets_list[i]
        metrics = binary_segmentation_metrics(pred, gt)

        # Accumulate each metric
        for metric_name, value in zip(total_metrics.keys(), metrics):
            total_metrics[metric_name] += value

        # Add False Negative Rate (Missed Alarm Rate, MAR) and False Alarm Rate (FAR)
        TP, TN, FP, FN = metrics[8], metrics[9], metrics[6], metrics[7]
        total_metrics['MAR'] += FN / (FN + TP + 1e-5)
        total_metrics['FAR'] += FP / (FP + TN + 1e-5)

    # Compute mean for each metric
    avg_metrics = {k: v / num_masks for k, v in total_metrics.items()}

    return avg_metrics

# === Example Usage ===
# Evaluate the SAM predictions vs. ground truth test masks
avg_metrics = calculate_average_metrics(masks_tuned_list, ground_truth_test_masks)

# Print results
print("\n=== Average Metrics on Test Set ===")
for metric_name, value in avg_metrics.items():
    print(f"{metric_name.upper():<8}: {value:.4f}")