# SAM2 Evaluation Pipeline - Colab Notebook

This notebook provides an interactive way to run the SAM2 evaluation pipeline

**Workflow:**
1.  **Setup:** Clone repository, install dependencies, and install the required SAM2 library.
2.  **Configuration:** Set parameters for the pipeline (model, data paths, etc.).
3.  **Data Preparation:** Generate the `degradation_map.json` (assumes image data exists).
4.  **(Optional) Visualization:** Inspect sample images and masks.
5.  **Run Pipeline:** Execute the evaluation using the configured settings.
6.  **View Results:** Load and display the output CSV.

In [None]:
# CELL TO INSERT (after cell 1, before cell 2)

# --- OPTIONAL: Mount Google Drive ---
# Uncomment and run this cell if your data directory (or parts of it)
# resides on Google Drive instead of being in the Git repo.
# from google.colab import drive
# drive.mount('/content/drive')

# --- Optional: Create symlink if needed ---
# If your data is on Drive, e.g., at /content/drive/MyDrive/SAM2_data
# and your code expects it at ./data, you might create a symlink:
# DRIVE_DATA_PATH = '/content/drive/MyDrive/SAM2_data' # <-- Adjust this path
# PROJECT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data')
# if IN_COLAB and not os.path.exists(PROJECT_DATA_PATH):
#     print(f"Linking {DRIVE_DATA_PATH} to {PROJECT_DATA_PATH}...")
#     !ln -s "$DRIVE_DATA_PATH" "$PROJECT_DATA_PATH"
# else:
#    print("Skipping symlink creation (not in Colab or data path exists).")

## 1. Setup Environment

In [None]:
# Check if running in Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

# Base directory for the project
# If in Colab, clone the repo. Otherwise, assume we are running from the repo root.
import os

if IN_COLAB:
    print('Running in Colab, cloning repository...')
    # !!! IMPORTANT: Replace the URL below with your actual repository URL !!!
    # If your repository is private, use a PAT (Personal Access Token) in the URL:
    # !git clone https://<YOUR_GITHUB_TOKEN>@github.com/<YOUR-USERNAME>/SAM2_analysis.git
    !git clone https://github.com/<YOUR-ORG-OR-USERNAME>/SAM2_analysis.git # <-- EDIT THIS LINE
    %cd SAM2_analysis
    PROJECT_ROOT = '/content/SAM2_analysis'
else:
    print('Running locally, assuming current directory is project root.')
    # Find the project root assuming this notebook is in the root
    PROJECT_ROOT = os.path.abspath('.')
    # Verify by checking for a known file/directory
    if not os.path.exists(os.path.join(PROJECT_ROOT, 'main.py')):
        print(f'Warning: Could not confirm project root at {PROJECT_ROOT}')

print(f'Project Root: {PROJECT_ROOT}')
os.chdir(PROJECT_ROOT) # Ensure we are in the project root directory

In [None]:
# Install dependencies from requirements.txt
print('\nInstalling dependencies...')
%pip install -r requirements.txt

In [None]:
# Install the SAM2 library
# Assumes the sam2 code is located in 'external/sam2' within the project
print('\nInstalling SAM2 library...')
SAM2_DIR = os.path.join(PROJECT_ROOT, 'external/sam2')

if not os.path.exists(SAM2_DIR):
    print(f'Error: SAM2 directory not found at {SAM2_DIR}')
    print('Please ensure you have cloned the SAM2 repository into external/sam2')
    # Optional: Add command to clone it if missing
    # print('Attempting to clone SAM2...')
    # !git clone <SAM2_REPO_URL> external/sam2 # <-- Add SAM2 repo URL if desired
else:
    # Editable install. The bash $SAM2_DIR expands correctly in a shell context. - makes code directly importable
    %pip install -e {SAM2_DIR}

# --- OPTIONAL (only if SAM2_DIR is missing): ---------
if not os.path.exists(SAM2_DIR):
    print('Cloning official SAM2 repo …')
    !git clone https://github.com/facebookresearch/sam2.git "$SAM2_DIR"
    %pip install -e "$SAM2_DIR"
# -----------------------------------------------------

## 2. Configuration

In [None]:
# --- Pipeline Configuration ---
# Mimic the structure of sam2_eval_config.json

config = {
    "pipeline_name": "sam2_eval",
    "description": "Evaluate SAM2 auto-mask generator on data map (Colab)",

    # --- Data Configuration ---
    # Path to the generated data map (relative to project root)
    "data_path": "data/degradation_map.json",

    # Base directory where image files referenced in data_path are located
    # IMPORTANT: Filepaths within degradation_map.json (e.g., "images/gt_img/1.jpg")
    # are treated as relative to this 'image_base_dir'.
    "image_base_dir": "data", # Results in absolute path like /content/SAM2_analysis/data

    # --- Model Configuration ---
    # Hugging Face identifier for the SAM2 model
    # Examples: 'facebook/sam2-hiera-tiny', 'facebook/sam2-hiera-small',
    #           'facebook/sam2-hiera-base', 'facebook/sam2-hiera-large'
    "model_hf_id": "facebook/sam2-hiera-tiny", # Use a smaller model for faster testing

    # --- Mask Generator Configuration ---
    # Parameters passed to SAM2AutomaticMaskGenerator
    # See SAM2 library documentation for all options
    "generator_config": {
        "points_per_side": 16,       # Lower for faster processing
        "pred_iou_thresh": 0.80,     # Default: 0.88
        "stability_score_thresh": 0.90, # Default: 0.95
        "crop_n_layers": 0,          # Default: 0 (no cropping)
        "min_mask_region_area": 10   # Default: 0
    },

    # --- Evaluation Metric Configuration ---
   #"iou_threshold": 0.5,        # No longer directly used by pipeline func, keep/remove as needed for analysis
    "bf1_tolerance": 2,          # Tolerance in pixels for Boundary F1 score

    # --- Output Configuration ---
    # Path template for the results CSV file (relative to project root)
    # The pipeline function will add a timestamp.
    "output_path": "output/results_colab.csv" # Combine dir and prefix into a template
}

# Make directories/paths absolute for clarity later
config['data_path'] = os.path.join(PROJECT_ROOT, config['data_path'])
config['image_base_dir'] = os.path.join(PROJECT_ROOT, config['image_base_dir'])
# --- Make the output_path absolute ---
config['output_path'] = os.path.join(PROJECT_ROOT, config['output_path'])
# --- Remove the old output_dir line ---
# config['output_dir'] = os.path.join(PROJECT_ROOT, config['output_dir']) # REMOVED

print("Configuration set:")
import json
print(json.dumps(config, indent=2))

## 3. Data Preparation

**IMPORTANT:** This section assumes the necessary image and annotation files are already present in the `data/images/gt_img/` directory and any corresponding degraded images are in `data/images/img_degraded/` within your Colab environment or mounted drive.

The `build_local_map.py` script will scan these directories to create the `degradation_map.json`.

If the source images/annotations are not present, you need to:
1. Generate or place the original images and annotations in `data/images/gt_img/`.
2. (Optional) Generate degraded images using `code_degradation.py` or place them manually into the correct subdirectories within `data/images/img_degraded/`.
3. Upload/sync the populated `data/images/` structure to Colab or Google Drive.

In [None]:
import os # Ensure os is imported if running cells independently

# Ensure data directories exist (though the script expects content within them)
DATA_DIR = os.path.join(PROJECT_ROOT, 'data')
IMAGES_GT_DIR = os.path.join(DATA_DIR, 'images', 'gt_img')
IMAGES_DEGRADED_DIR = os.path.join(DATA_DIR, 'images', 'img_degraded')
OUTPUT_DIR = os.path.dirname(config['output_path']) # absolute path

os.makedirs(IMAGES_GT_DIR, exist_ok=True) # Create base dirs if they don't exist
os.makedirs(IMAGES_DEGRADED_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f'Checking for required image directories:')
print(f'- Ground Truth Images: {IMAGES_GT_DIR}')
print(f'- Degraded Images: {IMAGES_DEGRADED_DIR}')
print(f'Expected output map path: {config["data_path"]}')

# Check if the essential input directories for build_local_map.py exist
if not os.path.exists(IMAGES_GT_DIR) or not os.listdir(IMAGES_GT_DIR):
     print(f"Warning: Ground truth image directory '{IMAGES_GT_DIR}' does not exist or is empty. "
           f"The 'build_local_map.py' script will likely fail or produce an empty map.")
# Note: img_degraded might be optional depending on use case, so we don't warn if it's missing/empty

# Run the script to generate the degradation_map.json
print('\nRunning script to generate degradation_map.json...')
# Assumes build_local_map.py reads from ../images/gt_img and ../images/img_degraded relative to its own location
!python data/data_scripts/build_local_map.py

# Verify the map was created
data_map_path = config['data_path'] # Use absolute path from config
if os.path.exists(data_map_path):
    # Optionally check if the map is non-empty
    try:
        with open(data_map_path, 'r') as f:
            map_content = json.load(f)
        if map_content:
             print(f'Successfully generated non-empty {data_map_path}')
        else:
             print(f'Successfully generated {data_map_path}, but it appears to be empty.')
    except Exception as e:
         print(f'Successfully generated {data_map_path}, but could not verify content: {e}')
else:
    print(f'Error: {data_map_path} was not generated. Check data availability and script output.')
    # Add more detailed error checking if the script provides specific logs

## 4. (Optional) Visualize Data Sample

In [None]:
import json
import random
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_util # Import for RLE decoding
import os # Ensure os is imported

def visualize_sample(data_map_path, image_base_dir):
    """Loads the data map, picks a random image, and displays its versions and GT mask."""
    if not os.path.exists(data_map_path):
        print(f'Cannot visualize: {data_map_path} not found.')
        return

    with open(data_map_path, 'r') as f:
        try:
            data_map = json.load(f)
        except json.JSONDecodeError as e:
            print(f"Error reading data map JSON: {e}")
            return


    if not data_map:
        print('Cannot visualize: Data map is empty.')
        return

    image_id = random.choice(list(data_map.keys()))
    print(f'Visualizing sample for image_id: {image_id}')
    item_data = data_map[image_id]

    # Decode GT mask
    gt_rle = item_data.get('ground_truth_rle')
    gt_mask = None
    if gt_rle:
        try:
            # Handle potential string vs dict RLE formats if needed
            if isinstance(gt_rle, str): # If RLE is just the counts string
                 # Need size info - assume it's stored elsewhere or reconstruct
                 print("Warning: GT RLE is string, size info needed for decoding.")
                 # Example: Need to fetch item_data['height'], item_data['width']
                 # gt_rle_dict = {'size': [item_data['height'], item_data['width']], 'counts': gt_rle}
                 # gt_mask = mask_util.decode(gt_rle_dict)
            elif isinstance(gt_rle, dict):
                 gt_mask = mask_util.decode(gt_rle)
            else:
                 print(f"Warning: Unexpected GT RLE format: {type(gt_rle)}")

        except Exception as e:
            print(f'  Could not decode GT RLE: {e}')


    # Count versions - needs robust handling of structure
    num_versions = 0
    versions_to_plot = []
    base_img_path = image_base_dir # Already absolute

    if 'versions' in item_data:
         for degradation_type, levels_or_data in item_data['versions'].items():
              if isinstance(levels_or_data, dict) and 'filepath' in levels_or_data: # e.g., 'original'
                   filepath = levels_or_data['filepath']
                   level = levels_or_data.get('level', 'N/A')
                   title = f'{degradation_type}\n(Level: {level})'
                   # Construct absolute path carefully based on structure
                   abs_path = os.path.join(base_img_path, filepath) # Assumes filepath is relative to base_img_dir
                   versions_to_plot.append({'title': title, 'path': abs_path})
                   num_versions += 1
              elif isinstance(levels_or_data, dict): # Nested levels like {'1': {...}, '2': {...}}
                   for level, version_data in levels_or_data.items():
                       if isinstance(version_data, dict) and 'filepath' in version_data:
                           filepath = version_data['filepath']
                           level_val = version_data.get('level', level) # Use nested level if available
                           title = f'{degradation_type}_{level}\n(Level: {level_val})'
                           # Construct absolute path
                           abs_path = os.path.join(base_img_path, filepath) # Assumes filepath relative to base
                           versions_to_plot.append({'title': title, 'path': abs_path})
                           num_versions += 1

    plot_cols = num_versions + (1 if gt_mask is not None else 0)
    if plot_cols == 0:
        print("No image versions or GT mask found to plot.")
        return

    fig, axes = plt.subplots(1, max(1, plot_cols), figsize=(5 * max(1, plot_cols), 5))
    if plot_cols == 1:
        axes = [axes] # Make it iterable

    plot_idx = 0

    # Display versions
    for version_info in versions_to_plot:
        img_path = version_info['path']
        title = version_info['title']
        try:
            img = Image.open(img_path).convert('RGB')
            axes[plot_idx].imshow(img)
            axes[plot_idx].set_title(title)
        except FileNotFoundError:
            print(f'  Image not found: {img_path}')
            axes[plot_idx].set_title(f'{title}\n(Not Found)')
        except Exception as e:
             print(f"Error loading image {img_path}: {e}")
             axes[plot_idx].set_title(f'{title}\n(Load Error)')
        finally:
            axes[plot_idx].axis('off')
            plot_idx += 1


    # Display GT mask
    if gt_mask is not None:
        if plot_idx < len(axes): # Ensure we don't go out of bounds
            axes[plot_idx].imshow(gt_mask, cmap='gray')
            axes[plot_idx].set_title('Ground Truth Mask')
            axes[plot_idx].axis('off')
        else:
             print("Warning: Not enough subplot axes allocated for GT mask.")

    # Hide unused axes
    for i in range(plot_idx + (1 if gt_mask is not None else 0), len(axes)):
        axes[i].axis('off')


    plt.tight_layout()
    plt.show()

# --- Run visualization ---
data_map_path = config['data_path']
image_base_dir = config['image_base_dir']
if os.path.exists(data_map_path):
    visualize_sample(data_map_path, image_base_dir)
else:
    print(f"Skipping visualization because data map not found: {data_map_path}")


In [None]:
# CELL TO INSERT (before cell under "## 5. Run Pipeline")

# --- OPTIONAL: Hugging Face Login ---
# Uncomment and run this cell if the model you specified in the config
# ("model_hf_id") is private and requires authentication.
# Replace YOUR_HF_TOKEN with your actual Hugging Face access token.
# from huggingface_hub import login
# login(token="YOUR_HF_TOKEN") # Or use !huggingface-cli login --token YOUR_HF_TOKEN

# Or, if using notebook_login:
# from huggingface_hub import notebook_login
# notebook_login() # Prompts for token interactively

## 5. Run Pipeline

In [None]:
import sys
import os # Ensure os is imported

# Ensure project root is in path for imports
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Import the main pipeline function
try:
    from sam2_eval_pipeline import run_evaluation_pipeline
    print('Imported run_evaluation_pipeline successfully.')
except ImportError as e:
    print(f'Error importing pipeline function: {e}')
    print('Ensure installation steps completed correctly and you are in the project root.')
    run_evaluation_pipeline = None # Prevent further errors
except Exception as e:
     print(f"An unexpected error occurred during import: {e}")
     run_evaluation_pipeline = None


# Execute the pipeline
results_df = None # Initialize as None, as the function doesn't return the df
if run_evaluation_pipeline:
    print('\nStarting evaluation pipeline...') 
    try:
        # --- Corrected Call: Pass the entire config dictionary ---
        run_evaluation_pipeline(config)

        # --- Update Log Message ---
        # Get the directory from the output_path template for the log message
        output_dir_for_log = os.path.dirname(config['output_path'])
        print(f'Pipeline finished. Results should be saved in {output_dir_for_log}')

        # Since the function doesn't return the DF, keep results_df as None
        # The next cell will load the CSV from the file.

    except Exception as e:
        print(f"An error occurred during pipeline execution: {e}")

        import traceback
        traceback.print_exc() # Print detailed traceback for debugging
else:
    print("Skipping pipeline execution due to import failure.") 



## 6. View Results

In [None]:
import pandas as pd
import os
import glob
import datetime # Needed if you want to parse timestamps, though not strictly required for just finding the latest

# --- Get the expected output pattern from the config ---
# config['output_path'] is like: /content/SAM2_analysis/output/results_colab.csv
output_path_template = config['output_path']
output_dir = os.path.dirname(output_path_template)
base_filename = os.path.basename(output_path_template)
filename_stem, filename_ext = os.path.splitext(base_filename) # e.g., "results_colab", ".csv"

# The pipeline likely creates filenames like: results_colab_YYYYMMDD_HHMMSS.csv
# We need to find the latest file matching this pattern
search_pattern = os.path.join(output_dir, f"{filename_stem}_*{filename_ext}") # e.g., /content/.../output/results_colab_*.csv

print(f"Searching for results files matching: {search_pattern}")

try:
    # Find all matching files
    result_files = glob.glob(search_pattern)

    if not result_files:
        print("No results files found matching the pattern.")
        print("Ensure the pipeline ran successfully and created an output file.")
    else:
        # Find the most recently modified file
        latest_file = max(result_files, key=os.path.getmtime)
        print(f"Loading latest results file: {latest_file}")

        # Load the CSV
        results_df_loaded = pd.read_csv(latest_file)
        print("\\nDisplaying loaded results DataFrame:")

        # Display using Colab's interactive table if available
        try:
            from google.colab.data_table import DataTable
            display(DataTable(results_df_loaded))
        except ImportError:
            display(results_df_loaded) # Fallback for non-Colab

except FileNotFoundError:
     print(f"Error: Output directory '{output_dir}' not found.")
except Exception as e:
     print(f"An error occurred while loading or displaying results: {e}")
