<a href="https://colab.research.google.com/github/tahmidjamal12/231N_Final_Code/blob/main/sam/evaluation_sam/sam_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install -q git+https://github.com/facebookresearch/segment-anything.git
!pip install -q opencv-python h5py scikit-learn tqdm

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone


In [None]:
import h5py
import numpy as np

# Open the HDF5 file
with h5py.File('/content/drive/MyDrive/sam_eval/dataset/annotation/1048806_annotation.hdf5', 'r') as f:

    # Print the keys (top-level groups/datasets)
    print("Keys in the HDF5 file:")
    print(list(f.keys()))
    print(f['segments'])

Keys in the HDF5 file:
['segments']
<HDF5 group "/segments" (3 members)>


In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 0) Mount Drive (if not already done) and install dependencies
# ─────────────────────────────────────────────────────────────────────────────
from google.colab import drive
drive.mount('/content/drive')

# If SAM is not installed yet, run:
# !pip install git+https://github.com/facebookresearch/segment-anything.git
# !pip install h5py opencv-python matplotlib tqdm scipy

# ─────────────────────────────────────────────────────────────────────────────
# 1) Imports
# ─────────────────────────────────────────────────────────────────────────────
import os
import cv2
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment

import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

# ─────────────────────────────────────────────────────────────────────────────
# 2) Paths and SAM Setup
# ─────────────────────────────────────────────────────────────────────────────
base_path = "/content/drive/MyDrive/sam_eval"
checkpoint_path = f"{base_path}/checkpoints/sam_vit_h_4b8939.pth"

# Load SAM (ViT-H) onto GPU
sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path).to("cuda")
mask_generator = SamAutomaticMaskGenerator(sam)

# Directories for images, HDF5 annotations, and where to write results
image_dir = f"{base_path}/dataset/images"
anno_dir  = f"{base_path}/dataset/annotation"

# HDF5 file where we will dump (RGB, GT masks, SAM masks) for each image
output_h5 = f"{base_path}/results/sam_results.h5"
os.makedirs(os.path.dirname(output_h5), exist_ok=True)

# Visualization output folder (for plot_segments_autoseg)
vis_folder = f"{base_path}/results/vis_sam"
os.makedirs(vis_folder, exist_ok=True)

# Gather sorted list of image IDs (filename minus "_rgb.png")
image_ids = [fname[:-8] for fname in os.listdir(image_dir) if fname.endswith("_rgb.png")]
image_ids.sort()

# ─────────────────────────────────────────────────────────────────────────────
# 3) IoU and AP/AR functions (exactly as in evaluate_autoseg.py)
# ─────────────────────────────────────────────────────────────────────────────
def batched_iou(x, y=None):
    """
    x: (N_gt, H, W) binary masks
    y: (N_pred, H, W) binary masks (if None, uses x for self-IoU)
    returns: (N_gt, N_pred) IoU matrix
    """
    if y is None:
        y = x
    xp = x[:, None]      # shape (N_gt, 1, H, W)
    yp = y[None]         # shape (1, N_pred, H, W)
    intersection = (xp & yp).sum(axis=(-1, -2))
    union = (xp | yp).sum(axis=(-1, -2))
    # Avoid division by zero: whenever union == 0, set union = 1 so IoU = 0
    union = np.where(union == 0, 1, union)
    return intersection / union

def evaluate_AP_AR_single_image(pred_segments, gt_segments):
    """
    pred_segments: (N_pred, H, W) binary masks
    gt_segments:   (N_gt,   H, W) binary masks
    Returns a dict with:
      'AP': float, 'AR': float,
      'assignments': [gt_inds, pred_inds],
      'iou_mat': np.array((N_gt, N_pred)),
      'thresholds': np.array of IoU thresholds.
    """
    iou_mat = batched_iou(gt_segments, pred_segments)  # shape (N_gt, N_pred)
    gt_inds, pred_inds = linear_sum_assignment(1.0 - iou_mat)
    ious = iou_mat[gt_inds, pred_inds]

    num_gt   = gt_segments.shape[0]
    num_pred = pred_segments.shape[0]

    precisions = []
    recalls    = []
    thresholds = np.arange(start=0.50, stop=0.95, step=0.05)

    for t in thresholds:
        tp = np.count_nonzero(ious >= t)
        precisions.append(tp / num_pred if num_pred > 0 else 0.0)
        recalls.append(tp / num_gt   if num_gt   > 0 else 0.0)

    return {
        'AP': np.mean(precisions),
        'AR': np.mean(recalls),
        'assignments': [gt_inds, pred_inds],
        'iou_mat': iou_mat,
        'thresholds': thresholds
    }

# ─────────────────────────────────────────────────────────────────────────────
# 4) plot_segments_autoseg (exactly as you provided)
# ─────────────────────────────────────────────────────────────────────────────
import matplotlib
def plot_segments_autoseg(data_path: str, out_dir: str):
    """
    Reads an HDF5 file with groups:
      img0/, img1/, ..., each containing:
        image_rgb     (H, W, 3)
        segments_gt   (n_gt, H, W)
        segments_pred (n_pred, H, W)
    For each image group, it makes a single PNG showing:
      Row 1: RGB image
      Row 2: GT masks (one column per GT)
      Row 3: Predicted masks (aligned to GT)
    Saves to out_dir/<idx>.png
    """
    with h5py.File(data_path, 'r') as f:
        for idx, img_key in enumerate(f.keys()):
            img_group = f[img_key]
            rgb       = img_group['image_rgb'][:]        # (H, W, 3)
            gt_masks  = img_group['segments_gt'][:]      # (n_gt, H, W)
            pred_masks= img_group['segments_pred'][:].astype(np.uint8)  # (n_pred, H, W)

            # Compute assignments via AP/AR helper
            assignments = evaluate_AP_AR_single_image(pred_masks, gt_masks)['assignments']
            gt_inds, pred_inds = assignments
            n_gt = gt_masks.shape[0]

            # Create a 3×n_gt grid
            fig, axes = plt.subplots(3, n_gt, figsize=(4 * n_gt, 12))
            if n_gt == 1:
                axes = axes.reshape(3, 1)

            for i in range(n_gt):
                # Row 1: RGB image
                axes[0, i].imshow(rgb)
                axes[0, i].set_title("RGB Image")
                axes[0, i].axis('off')

                # Row 2: GT mask i
                axes[1, i].imshow(gt_masks[i], cmap='gray')
                axes[1, i].set_title("GT Segment")
                axes[1, i].axis('off')

                # Row 3: corresponding predicted mask (if matched)
                if i in gt_inds:
                    vis_idx = np.where(gt_inds == i)[0][0]
                    pred_idx = pred_inds[vis_idx]
                    axes[2, i].imshow(pred_masks[pred_idx], cmap='gray')
                axes[2, i].set_title("Predicted Segment")
                axes[2, i].axis('off')

            plt.tight_layout(rect=[0, 0, 1, 0.96])
            out_file = os.path.join(out_dir, f"{idx}.png")
            fig.savefig(out_file, bbox_inches='tight')
            plt.close(fig)
            print(f"Saved visualization to {out_file}")

# ─────────────────────────────────────────────────────────────────────────────
# 5) Build a new HDF5 ("sam_results.h5") containing:
#     /img0/
#       image_rgb     (H, W, 3)
#       segments_gt   (n_gt, H, W)
#       segments_pred (n_pred, H, W)
#     /img1/  ...
# ─────────────────────────────────────────────────────────────────────────────
with h5py.File(output_h5, 'w') as h5f:
    for idx, img_id in enumerate(tqdm(image_ids, desc="Writing SAM results to HDF5")):
        # 5.1 Load RGB image
        img_path = os.path.join(image_dir, f"{img_id}_rgb.png")
        img_bgr  = cv2.imread(img_path)
        image_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

        H, W = image_rgb.shape[:2]

        # 5.2 Load GT masks
        h5_path = os.path.join(anno_dir, f"{img_id}_annotation.hdf5")
        with h5py.File(h5_path, 'r') as hf:
            gt_group = hf['segments']
            gt_list = []
            for key in sorted(gt_group.keys(), key=lambda x: int(x)):
                arr = np.array(gt_group[key][:], dtype=np.uint8)
                gt_list.append((arr > 0).astype(np.uint8))
        if gt_list:
            gt_stack = np.stack(gt_list, axis=0)  # (n_gt, H, W)
        else:
            gt_stack = np.zeros((0, H, W), dtype=np.uint8)

        # ─────────────────────────────────────────────────────────────────────────────
        # 5.3) Run SAM to generate predicted masks (max 4 per image)
        # ─────────────────────────────────────────────────────────────────────────────
        sam_outputs = mask_generator.generate(image_rgb)
        # Each element of sam_outputs is a dict with keys like:
        #   'segmentation'   → (H, W) binary mask (bool or 0/1)
        #   'predicted_iou'  → float
        #   'stability_score'→ float
        #   …etc.

        # Sort by predicted_iou (descending) so highest‐quality masks come first
        sam_outputs_sorted = sorted(
            sam_outputs,
            key=lambda x: x.get("predicted_iou", 0.0),
            reverse=True
        )

        # Keep only the top 4 (or fewer if <4 exist)
        top4 = sam_outputs_sorted[:4]

        # Build pred_stack from those top‐4 masks
        pred_list = [m["segmentation"].astype(np.uint8) for m in top4]
        if pred_list:
            pred_stack = np.stack(pred_list, axis=0)  # shape = (≤4, H, W)
        else:
            pred_stack = np.zeros((0, H, W), dtype=np.uint8)


        # 5.4 Create group "/img{idx}" and write datasets
        grp = h5f.create_group(f"img{idx}")
        grp.create_dataset("image_rgb", data=image_rgb,       compression="gzip")
        grp.create_dataset("segments_gt", data=gt_stack,      compression="gzip")
        grp.create_dataset("segments_pred", data=pred_stack,  compression="gzip")

    print(f"Done writing all images to '{output_h5}'")

# ─────────────────────────────────────────────────────────────────────────────
# 6) Use plot_segments_autoseg to save visualizations into vis_folder
# ─────────────────────────────────────────────────────────────────────────────
print("Generating visualizations from HDF5 ...")
plot_segments_autoseg(output_h5, vis_folder)

# ─────────────────────────────────────────────────────────────────────────────
# 7) Compute and print summary metrics (mean AP, mean AR, mean IoU)
# ─────────────────────────────────────────────────────────────────────────────
all_APs = []
all_ARs = []
all_IoU_vals = []

with h5py.File(output_h5, 'r') as f:
    for img_key in f.keys():
        img_grp = f[img_key]
        gt_stack   = img_grp['segments_gt'][:].astype(np.uint8)
        pred_stack = img_grp['segments_pred'][:].astype(np.uint8)

        res = evaluate_AP_AR_single_image(pred_stack, gt_stack)
        all_APs.append(res['AP'])
        all_ARs.append(res['AR'])

        # Extract matched IoUs from the IoU‐matrix
        iou_mat = res['iou_mat']
        if iou_mat.size > 0:
            gt_inds, pred_inds = linear_sum_assignment(1.0 - iou_mat)
            matched = iou_mat[gt_inds, pred_inds]
            all_IoU_vals.extend(matched.tolist())

mean_AP = np.mean(all_APs) if all_APs else 0.0
mean_AR = np.mean(all_ARs) if all_ARs else 0.0
mean_IoU = np.nanmean(all_IoU_vals) if all_IoU_vals else 0.0

print("\n=== FINAL SUMMARY METRICS ===")
print(f"Number of images: {len(image_ids)}")
print(f"Mean Average Precision (mAP):  {mean_AP:.4f}")
print(f"Mean Average Recall   (mAR):  {mean_AR:.4f}")
print(f"Mean IoU across matches:    {mean_IoU:.4f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Writing SAM results to HDF5: 100%|██████████| 95/95 [03:08<00:00,  1.98s/it]


Done writing all images to '/content/drive/MyDrive/sam_eval/results/sam_results.h5'
Generating visualizations from HDF5 ...
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/0.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/1.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/2.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/3.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/4.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/5.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/6.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/7.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/8.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/9.png
Saved visualization to /content/drive/MyDrive/sam_eval/results/vis_sam/10.png
Saved visualization to /cont

In [None]:
%cd MyDrive/sam_eval/results/
!pwd

[Errno 2] No such file or directory: 'MyDrive/sam_eval/results/'
/content/drive/.shortcut-targets-by-id/1TkGNfiPiYZ6IIA9l9g3-JL1J49XbhYaP/sam_eval/results
/content/drive/.shortcut-targets-by-id/1TkGNfiPiYZ6IIA9l9g3-JL1J49XbhYaP/sam_eval/results


In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt

def inspect_h5(file_path):
    """
    Open an HDF5 file, list its structure, and display the first mask if one exists.
    """
    try:
        with h5py.File(file_path, "r") as f:
            print(f"Opened file: {file_path}\n")

            # 1) List top‐level groups/datasets
            print("Top-level keys:")
            for key in f.keys():
                print(f"  - {key}")

            # 2) For each group, list its children (and shape/dtype)
            for key in f.keys():
                item = f[key]
                if isinstance(item, h5py.Group):
                    print(f"\nGroup '{key}' contains:")
                    for subkey in item.keys():
                        dset = item[subkey]
                        print(f"    - {subkey}  (shape: {dset.shape}, dtype: {dset.dtype})")
                else:
                    print(f"\nDataset '{key}': shape = {item.shape}, dtype = {item.dtype}")

            # 3) If there is a dataset literally named "masks", plot the first one:
            if "masks" in f:
                masks = f["masks"][:]
                if masks.ndim >= 2:
                    first_mask = masks[0]
                    plt.figure(figsize=(4,4))
                    plt.imshow(first_mask, cmap="gray")
                    plt.title("First Mask")
                    plt.axis("off")
                    plt.show()
                else:
                    print("\n'Masks' dataset exists but does not have 2+ dims.")
            else:
                print("\nNo 'masks' dataset found at top level.")

    except FileNotFoundError:
        print(f"Error: File not found → {file_path}")
    except OSError as e:
        print(f"Error opening file: {e}")

# ── UPDATE this path to wherever your sam_result.h5 actually is: ──
# If you uploaded directly via Colab’s Files pane, it’s probably:
#     "/content/sam_result.h5"
# If you mounted Drive, it might be:
#     "/content/drive/MyDrive/<your-folder>/sam_result.h5"
file_path = "/content/drive/MyDrive/sam_eval/results/sam_results.h5"

inspect_h5(file_path)


Opened file: /content/drive/MyDrive/sam_eval/results/sam_results.h5

Top-level keys:
  - img0
  - img1
  - img10
  - img11
  - img12
  - img13
  - img14
  - img15
  - img16
  - img17
  - img18
  - img19
  - img2
  - img20
  - img21
  - img22
  - img23
  - img24
  - img25
  - img26
  - img27
  - img28
  - img29
  - img3
  - img30
  - img31
  - img32
  - img33
  - img34
  - img35
  - img36
  - img37
  - img38
  - img39
  - img4
  - img40
  - img41
  - img42
  - img43
  - img44
  - img45
  - img46
  - img47
  - img48
  - img49
  - img5
  - img50
  - img51
  - img52
  - img53
  - img54
  - img55
  - img56
  - img57
  - img58
  - img59
  - img6
  - img60
  - img61
  - img62
  - img63
  - img64
  - img65
  - img66
  - img67
  - img68
  - img69
  - img7
  - img70
  - img71
  - img72
  - img73
  - img74
  - img75
  - img76
  - img77
  - img78
  - img79
  - img8
  - img80
  - img81
  - img82
  - img83
  - img84
  - img85
  - img86
  - img87
  - img88
  - img89
  - img9
  - img90
  - img91
  - i