# SAM 2 Parameter Tuning for Lake Detection

This notebook shows how to tune SAM 2 automatic mask generation parameters using ground truth lake data.

**Workflow:**
1. Load RGB satellite image
2. Load ground truth lake mask (binary raster: 1=lake, 0=not lake)
3. Test different SAM 2 parameter combinations
4. Find optimal parameters using IoU, precision, and recall metrics

In [None]:
# Install required packages (run this in Colab)
!pip install torchgeo segment-anything-2
!pip install rasterio matplotlib scikit-learn

In [None]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import rasterio
import torch
from sklearn.metrics import jaccard_score, precision_score, recall_score
import pandas as pd

# SAM 2 imports
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Configuration - UPDATE THESE PATHS
RGB_IMAGE_PATH = "2021-09-04_rgb.tif"  # Your RGB satellite image
GROUND_TRUTH_PATH = "2021-09-04_lake_mask.tif"  # Your binary lake mask (1=lake, 0=not lake)

# SAM 2 model paths (adjust if different)
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"

print(f"RGB image: {RGB_IMAGE_PATH}")
print(f"Ground truth: {GROUND_TRUTH_PATH}")

In [None]:
# Load RGB satellite image
print("Loading RGB satellite image...")
with rasterio.open(RGB_IMAGE_PATH) as src:
    rgb_data = src.read([1, 2, 3])  # Read RGB bands
    rgb_data = rgb_data.transpose(1, 2, 0)  # Shape: (height, width, 3)
    
# Normalize to 0-255 if needed
if rgb_data.max() > 255:
    rgb_image = ((rgb_data / rgb_data.max()) * 255).astype(np.uint8)
else:
    rgb_image = rgb_data.astype(np.uint8)

print(f"RGB image shape: {rgb_image.shape}")
print(f"RGB value range: {rgb_image.min()} - {rgb_image.max()}")

# Visualize RGB image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image)
plt.title('RGB Satellite Image')
plt.axis('off')
plt.show()

In [None]:
# Load ground truth lake mask
print("Loading ground truth lake mask...")
with rasterio.open(GROUND_TRUTH_PATH) as src:
    ground_truth = src.read(1).astype(bool)  # Convert to boolean

print(f"Ground truth shape: {ground_truth.shape}")
print(f"Lake pixels: {ground_truth.sum():,} ({ground_truth.mean()*100:.2f}% of image)")

# Check if dimensions match
if ground_truth.shape != rgb_image.shape[:2]:
    print(f"‚ö†Ô∏è  WARNING: Dimension mismatch!")
    print(f"RGB: {rgb_image.shape[:2]}, Ground truth: {ground_truth.shape}")
else:
    print("‚úÖ Dimensions match!")

# Visualize ground truth
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.imshow(rgb_image)
plt.title('RGB Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(ground_truth, cmap='Blues')
plt.title('Ground Truth Lakes (Blue = Lake)')
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Initialize SAM 2 model
print("Loading SAM 2 model...")
sam2 = build_sam2(MODEL_CFG, SAM2_CHECKPOINT, device=device, apply_postprocessing=False)
print("‚úÖ SAM 2 model loaded successfully!")

In [None]:
# Define parameter combinations to test
param_configs = [
    # Vary points_per_side
    {"points_per_side": 16, "pred_iou_thresh": 0.8, "min_mask_region_area": 100, "stability_score_thresh": 0.95},
    {"points_per_side": 32, "pred_iou_thresh": 0.8, "min_mask_region_area": 100, "stability_score_thresh": 0.95},
    {"points_per_side": 64, "pred_iou_thresh": 0.8, "min_mask_region_area": 100, "stability_score_thresh": 0.95},
    
    # Vary pred_iou_thresh
    {"points_per_side": 32, "pred_iou_thresh": 0.7, "min_mask_region_area": 100, "stability_score_thresh": 0.95},
    {"points_per_side": 32, "pred_iou_thresh": 0.85, "min_mask_region_area": 100, "stability_score_thresh": 0.95},
    {"points_per_side": 32, "pred_iou_thresh": 0.9, "min_mask_region_area": 100, "stability_score_thresh": 0.95},
    
    # Vary min_mask_region_area
    {"points_per_side": 32, "pred_iou_thresh": 0.8, "min_mask_region_area": 50, "stability_score_thresh": 0.95},
    {"points_per_side": 32, "pred_iou_thresh": 0.8, "min_mask_region_area": 200, "stability_score_thresh": 0.95},
    {"points_per_side": 32, "pred_iou_thresh": 0.8, "min_mask_region_area": 500, "stability_score_thresh": 0.95},
    
    # Vary stability_score_thresh
    {"points_per_side": 32, "pred_iou_thresh": 0.8, "min_mask_region_area": 100, "stability_score_thresh": 0.9},
    {"points_per_side": 32, "pred_iou_thresh": 0.8, "min_mask_region_area": 100, "stability_score_thresh": 0.98},
]

print(f"Will test {len(param_configs)} parameter combinations")

In [None]:
# Run parameter tuning
print("Starting parameter tuning...\n")

results = []

for i, config in enumerate(param_configs):
    print(f"Testing config {i+1}/{len(param_configs)}:")
    print(f"  {config}")
    
    try:
        # Create mask generator with current config
        mask_generator = SAM2AutomaticMaskGenerator(sam2, **config)
        
        # Generate masks on RGB image
        masks = mask_generator.generate(rgb_image)
        
        # Combine all masks into single binary mask
        combined_mask = np.zeros_like(ground_truth, dtype=bool)
        for mask in masks:
            combined_mask |= mask['segmentation']
        
        # Calculate metrics
        iou = jaccard_score(ground_truth.flatten(), combined_mask.flatten())
        
        # Calculate precision and recall
        precision = precision_score(ground_truth.flatten(), combined_mask.flatten(), zero_division=0)
        recall = recall_score(ground_truth.flatten(), combined_mask.flatten(), zero_division=0)
        
        # F1 score
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        # Store results
        result = {
            'config_id': i,
            'points_per_side': config['points_per_side'],
            'pred_iou_thresh': config['pred_iou_thresh'],
            'min_mask_region_area': config['min_mask_region_area'],
            'stability_score_thresh': config['stability_score_thresh'],
            'iou': iou,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'num_masks': len(masks),
            'predicted_lake_pixels': combined_mask.sum()
        }
        results.append(result)
        
        print(f"  ‚úÖ IoU: {iou:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}")
        print(f"     Masks: {len(masks)}, Lake pixels: {combined_mask.sum():,}\n")
        
    except Exception as e:
        print(f"  ‚ùå Error: {e}\n")
        continue

print("‚úÖ Parameter tuning complete!")

In [None]:
# Analyze results
if results:
    df_results = pd.DataFrame(results)
    
    # Sort by IoU (best first)
    df_results = df_results.sort_values('iou', ascending=False)
    
    print("üèÜ TOP 5 CONFIGURATIONS (by IoU):")
    print("=" * 80)
    
    top_5 = df_results.head(5)
    for idx, row in top_5.iterrows():
        print(f"Rank {list(top_5.index).index(idx) + 1}:")
        print(f"  Points/side: {row['points_per_side']}, IoU thresh: {row['pred_iou_thresh']}, ")
        print(f"  Min area: {row['min_mask_region_area']}, Stability: {row['stability_score_thresh']}")
        print(f"  üìä IoU: {row['iou']:.3f}, Precision: {row['precision']:.3f}, Recall: {row['recall']:.3f}, F1: {row['f1']:.3f}")
        print(f"  üéØ Masks: {row['num_masks']}, Lake pixels: {row['predicted_lake_pixels']:,}")
        print()
    
    # Best configuration
    best_config = df_results.iloc[0]
    print(f"üéØ BEST CONFIGURATION:")
    print(f"   points_per_side: {best_config['points_per_side']}")
    print(f"   pred_iou_thresh: {best_config['pred_iou_thresh']}")
    print(f"   min_mask_region_area: {best_config['min_mask_region_area']}")
    print(f"   stability_score_thresh: {best_config['stability_score_thresh']}")
    print(f"   Best IoU: {best_config['iou']:.3f}")
    
else:
    print("‚ùå No successful results to analyze")

In [None]:
# Visualize results
if results:
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # IoU vs points_per_side
    points_data = df_results.groupby('points_per_side')['iou'].max()
    axes[0,0].bar(points_data.index, points_data.values)
    axes[0,0].set_title('Best IoU vs Points Per Side')
    axes[0,0].set_xlabel('Points Per Side')
    axes[0,0].set_ylabel('Best IoU')
    
    # IoU vs pred_iou_thresh
    iou_data = df_results.groupby('pred_iou_thresh')['iou'].max()
    axes[0,1].bar(iou_data.index, iou_data.values)
    axes[0,1].set_title('Best IoU vs Prediction IoU Threshold')
    axes[0,1].set_xlabel('Prediction IoU Threshold')
    axes[0,1].set_ylabel('Best IoU')
    
    # IoU vs min_mask_region_area
    area_data = df_results.groupby('min_mask_region_area')['iou'].max()
    axes[1,0].bar(area_data.index, area_data.values)
    axes[1,0].set_title('Best IoU vs Min Mask Region Area')
    axes[1,0].set_xlabel('Min Mask Region Area')
    axes[1,0].set_ylabel('Best IoU')
    
    # Precision vs Recall scatter
    scatter = axes[1,1].scatter(df_results['recall'], df_results['precision'], 
                               c=df_results['iou'], cmap='viridis', s=50)
    axes[1,1].set_title('Precision vs Recall (colored by IoU)')
    axes[1,1].set_xlabel('Recall')
    axes[1,1].set_ylabel('Precision')
    plt.colorbar(scatter, ax=axes[1,1], label='IoU')
    
    plt.tight_layout()
    plt.show()
    
    # Save results
    df_results.to_csv('sam2_parameter_tuning_results.csv', index=False)
    print("üìÅ Results saved to: sam2_parameter_tuning_results.csv")

In [None]:
# Generate masks with best configuration for visualization
if results:
    print("Generating final result with best configuration...")
    
    best_params = {
        'points_per_side': int(best_config['points_per_side']),
        'pred_iou_thresh': float(best_config['pred_iou_thresh']),
        'min_mask_region_area': int(best_config['min_mask_region_area']),
        'stability_score_thresh': float(best_config['stability_score_thresh'])
    }
    
    # Create mask generator with best config
    best_mask_generator = SAM2AutomaticMaskGenerator(sam2, **best_params)
    
    # Generate final masks
    final_masks = best_mask_generator.generate(rgb_image)
    
    # Combine masks
    final_combined_mask = np.zeros_like(ground_truth, dtype=bool)
    for mask in final_masks:
        final_combined_mask |= mask['segmentation']
    
    # Visualize final result
    fig, axes = plt.subplots(2, 2, figsize=(16, 16))
    
    # Original RGB
    axes[0,0].imshow(rgb_image)
    axes[0,0].set_title('Original RGB Image')
    axes[0,0].axis('off')
    
    # Ground truth
    axes[0,1].imshow(rgb_image)
    axes[0,1].imshow(ground_truth, alpha=0.6, cmap='Blues')
    axes[0,1].set_title('Ground Truth Lakes')
    axes[0,1].axis('off')
    
    # SAM 2 result
    axes[1,0].imshow(rgb_image)
    axes[1,0].imshow(final_combined_mask, alpha=0.6, cmap='Reds')
    axes[1,0].set_title(f'SAM 2 Result (IoU: {best_config["iou"]:.3f})')
    axes[1,0].axis('off')
    
    # Comparison
    comparison = np.zeros((*ground_truth.shape, 3), dtype=np.uint8)
    comparison[ground_truth & final_combined_mask] = [0, 255, 0]      # True positive (green)
    comparison[ground_truth & ~final_combined_mask] = [0, 0, 255]     # False negative (blue)
    comparison[~ground_truth & final_combined_mask] = [255, 0, 0]     # False positive (red)
    
    axes[1,1].imshow(rgb_image)
    axes[1,1].imshow(comparison, alpha=0.6)
    axes[1,1].set_title('Comparison: Green=Correct, Blue=Missed, Red=False+')
    axes[1,1].axis('off')
    
    plt.tight_layout()
    plt.savefig('sam2_best_result.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nüéâ Best configuration achieved:")
    print(f"   IoU: {best_config['iou']:.3f}")
    print(f"   Precision: {best_config['precision']:.3f}")
    print(f"   Recall: {best_config['recall']:.3f}")
    print(f"   F1 Score: {best_config['f1']:.3f}")
    print(f"üìÅ Visualization saved to: sam2_best_result.png")

In [None]:
# Code to use the best configuration on new images
print("\nüöÄ CODE TO USE BEST CONFIGURATION ON NEW IMAGES:")
print("=" * 60)

if results:
    code_template = f"""
# Use this configuration for new images:
best_mask_generator = SAM2AutomaticMaskGenerator(
    sam2,
    points_per_side={int(best_config['points_per_side'])},
    pred_iou_thresh={float(best_config['pred_iou_thresh'])},
    min_mask_region_area={int(best_config['min_mask_region_area'])},
    stability_score_thresh={float(best_config['stability_score_thresh'])}
)

# Apply to new image:
new_masks = best_mask_generator.generate(new_rgb_image)

# Combine masks:
lake_mask = np.zeros((new_rgb_image.shape[0], new_rgb_image.shape[1]), dtype=bool)
for mask in new_masks:
    lake_mask |= mask['segmentation']
"""
    print(code_template)
else:
    print("No successful results to provide code template")