# SAM Lake Detection with Teaching (Improved)

This notebook demonstrates how to **teach SAM** what lakes look like using manual annotations, then apply this knowledge to detect lakes automatically.

**Simple 3-step workflow:**
1. **üéì Learn** from 1-2 manual examples 
2. **üéØ Apply** to new images with optimized settings
3. **üìä Compare** results and export

**Requirements:** GPU runtime recommended for faster processing.

## Install dependencies

In [ ]:
# Install required packages
! pip install segment-geospatial leafmap scikit-learn opencv-python -q

print("‚úÖ Installation complete!")

In [ ]:
# Import libraries
import leafmap
from samgeo import SamGeo
import numpy as np
import rasterio
import cv2
import json
import os
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

print("‚úÖ Libraries imported successfully")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted at /content/drive")

## Step 1: üéì Learning from Manual Examples

This section analyzes your manual lake annotations to extract knowledge that will guide SAM.

In [ ]:
def analyze_manual_lakes(mask_path):
    """Extract lake characteristics from manual annotations"""
    print(f"üîç Analyzing: {os.path.basename(mask_path)}")
    
    with rasterio.open(mask_path) as src:
        mask = src.read(1).astype(bool)
    
    # Find individual lakes
    labeled_mask = cv2.connectedComponents(mask.astype(np.uint8))[1]
    lake_labels = np.unique(labeled_mask)[1:]  # Exclude background
    
    lake_centers = []
    lake_sizes = []
    
    for label in lake_labels:
        component = (labeled_mask == label)
        coords = np.where(component)
        
        # Lake center and size
        center_x = int(np.mean(coords[1]))
        center_y = int(np.mean(coords[0]))
        size = len(coords[0])
        
        lake_centers.append((center_x, center_y))
        lake_sizes.append(size)
    
    # Calculate statistics
    stats = {
        'num_lakes': len(lake_labels),
        'avg_size': np.mean(lake_sizes) if lake_sizes else 0,
        'coverage_percent': (mask.sum() / mask.size) * 100,
        'size_range': (min(lake_sizes), max(lake_sizes)) if lake_sizes else (0, 0)
    }
    
    print(f"   üìä Found {stats['num_lakes']} lakes, avg size: {stats['avg_size']:.0f} pixels")
    print(f"   üìä Coverage: {stats['coverage_percent']:.2f}%")
    
    return lake_centers, stats

def create_smart_sam_config(stats):
    """Create SAM configuration optimized for detected lake characteristics"""
    print("‚öôÔ∏è Creating optimized SAM configuration...")
    
    config = {
        "points_per_side": 32,
        "pred_iou_thresh": 0.76,
        "stability_score_thresh": 0.62,
        "crop_n_layers": 1,
        "min_mask_region_area": 30,
    }
    
    # Optimize based on lake characteristics
    if stats['avg_size'] < 100:  # Small lakes
        config["points_per_side"] = 64
        config["min_mask_region_area"] = max(10, int(stats['avg_size'] // 4))
        print("   ‚Üí Optimized for small lakes")
    
    if stats['coverage_percent'] < 2.0:  # Sparse lakes
        config["pred_iou_thresh"] = 0.6
        config["stability_score_thresh"] = 0.5
        print("   ‚Üí Optimized for sparse coverage")
    
    if stats['num_lakes'] > 10:  # Many lakes
        config["crop_n_layers"] = 2
        print("   ‚Üí Optimized for many lakes")
    
    return config

In [None]:
def create_optimized_sam_config(lake_stats):
    """
    Create SAM configuration optimized for the detected lake characteristics
    """
    print("‚öôÔ∏è Creating optimized SAM configuration...")
    
    # Start with base configuration
    sam_kwargs = {
        "points_per_side": 32,
        "pred_iou_thresh": 0.76,
        "stability_score_thresh": 0.62,
        "crop_n_layers": 1,
        "crop_n_points_downscale_factor": 2,
        "min_mask_region_area": 30,
    }
    
    # Optimize based on lake characteristics
    avg_size = lake_stats['avg_size']
    coverage = lake_stats['total_coverage']
    num_lakes = lake_stats['num_lakes']
    
    # If lakes are small, increase sampling and lower minimum area
    if avg_size < 100:
        sam_kwargs["points_per_side"] = 64
        sam_kwargs["min_mask_region_area"] = max(10, int(avg_size // 4))
        print("   ‚Üí Optimized for small lakes: increased sampling, lowered min area")
    
    # If lakes are sparse, be more aggressive
    if coverage < 2.0:
        sam_kwargs["pred_iou_thresh"] = 0.6
        sam_kwargs["stability_score_thresh"] = 0.5
        print("   ‚Üí Optimized for sparse lakes: lowered quality thresholds")
    
    # If many small lakes, use multi-scale approach
    if num_lakes > 10 and avg_size < 200:
        sam_kwargs["crop_n_layers"] = 2
        sam_kwargs["crop_n_points_downscale_factor"] = 1
        print("   ‚Üí Optimized for many small lakes: multi-scale processing")
    
    print(f"   ‚úÖ Optimized SAM configuration: {sam_kwargs}")
    return sam_kwargs

## Configuration: Your Training Data

**Update these paths** to point to your training image and manual lake annotations:

In [ ]:
# üéØ YOUR TRAINING DATA - Update these paths:
TRAINING_IMAGE = '/content/drive/MyDrive/superlakes/2021-09-04_fcc_testclip2.tif'
TRAINING_MASK = '/content/drive/MyDrive/superlakes/lake_mask_testclip.tif'

# Target image to apply learning to:
TARGET_IMAGE = '/content/drive/MyDrive/superlakes/2021-09-04_fcc_blurred_medium_blur.tif'

# Check files
for path, name in [(TRAINING_IMAGE, "Training image"), (TRAINING_MASK, "Training mask"), (TARGET_IMAGE, "Target image")]:
    if os.path.exists(path):
        print(f"‚úÖ {name}: {os.path.basename(path)}")
    else:
        print(f"‚ùå {name} not found: {path}")
        print("   ‚Üí Update the path above!")

In [ ]:
# üéì LEARN from manual annotations
print("üéì LEARNING PHASE")
print("=" * 40)

# Analyze training data to extract lake knowledge
lake_centers, lake_stats = analyze_manual_lakes(TRAINING_MASK)
optimized_config = create_smart_sam_config(lake_stats)

print(f"\n‚úÖ Learning complete!")
print(f"   ‚Üí Extracted {len(lake_centers)} lake center points")
print(f"   ‚Üí Created optimized SAM configuration")
print(f"   ‚Üí Ready to apply to new images")

## Step 2: üéØ Apply Learning to New Image

Now we'll apply the learned knowledge to detect lakes in a new image using three approaches.

In [ ]:
# Quick visualization of what we learned
def show_training_analysis(image_path, mask_path, centers):
    """Show training image, manual mask, and extracted guidance points"""
    with rasterio.open(image_path) as src:
        img = src.read()
        if img.shape[0] <= 4: img = np.transpose(img, (1, 2, 0))
        if img.shape[2] > 3: img = img[:, :, :3]
    
    with rasterio.open(mask_path) as src:
        mask = src.read(1).astype(bool)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(img)
    axes[0].set_title('Training Image')
    axes[0].axis('off')
    
    axes[1].imshow(img)
    axes[1].imshow(mask, alpha=0.6, cmap='Blues')
    axes[1].set_title(f'Manual Annotations\n{mask.sum():,} pixels')
    axes[1].axis('off')
    
    axes[2].imshow(img)
    axes[2].imshow(mask, alpha=0.3, cmap='Blues')
    if centers:
        x, y = zip(*centers)
        axes[2].scatter(x, y, c='red', s=80, marker='+', linewidth=2)
    axes[2].set_title(f'Extracted Centers\n{len(centers)} points')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

show_training_analysis(TRAINING_IMAGE, TRAINING_MASK, lake_centers)

## Three Detection Methods: Default ‚Üí Optimized ‚Üí Guided

In [None]:
# üéØ TARGET IMAGE - UPDATE THIS PATH
# This is the new image where you want to detect lakes

target_image = '/content/drive/MyDrive/superlakes/2021-09-04_fcc_blurred_medium_blur.tif'

print(f"üéØ Target image: {os.path.basename(target_image)}")

if os.path.exists(target_image):
    print("‚úÖ Target image found!")
else:
    print("‚ùå Target image not found! Please update the path above.")

## üÜö Method Comparison: Default vs. Learned

Let's compare three approaches:
1. **Default SAM** - No optimization
2. **Optimized SAM** - Using learned configuration but no guidance
3. **Guided SAM** - Using learned configuration + positive point guidance

In [ ]:
# Method 1: ü§ñ Default SAM (baseline)
print("ü§ñ Method 1: Default SAM")

sam_default = SamGeo(model_type="vit_l")
sam_default.generate(TARGET_IMAGE, output="default_masks.tif", foreground=True, unique=True)
sam_default.show_anns(axis="off", alpha=1, output="default_result.tif")

print("‚úÖ Default SAM complete")

In [ ]:
# Method 2: ‚öôÔ∏è Optimized SAM (learned configuration)
print("‚öôÔ∏è Method 2: Optimized SAM (using learned config)")

sam_optimized = SamGeo(model_type="vit_l", sam_kwargs=optimized_config)
sam_optimized.generate(TARGET_IMAGE, output="optimized_masks.tif", foreground=True, unique=True)
sam_optimized.show_anns(axis="off", alpha=1, output="optimized_result.tif")

print("‚úÖ Optimized SAM complete")

In [ ]:
# Method 3: üéØ Guided SAM (learned config + point guidance) 
print("üéØ Method 3: Guided SAM (using learned config + guidance points)")

# For guided detection, we need to use a different approach
# Initialize SAM and set the image first
sam_guided = SamGeo(model_type="vit_l", sam_kwargs=optimized_config)

# Use the generate method but with a different approach for guidance
# Since point guidance requires predictor, we'll use a workaround
try:
    # Try the direct approach first
    sam_guided.set_image(TARGET_IMAGE)
    
    # Use first 10 guidance points from training
    guidance_points = lake_centers[:min(10, len(lake_centers))]
    point_labels = [1] * len(guidance_points)  # All positive
    
    print(f"   Using {len(guidance_points)} guidance points from training")
    
    # Predict with guidance
    masks = sam_guided.predict(
        point_coords=guidance_points,
        point_labels=point_labels,
        multimask_output=True
    )
    
    sam_guided.save_prediction("guided_masks.tif")
    sam_guided.show_anns(axis="off", alpha=1, output="guided_result.tif")
    
    print("‚úÖ Guided SAM complete (with point guidance)")
    
except AttributeError as e:
    print(f"‚ö†Ô∏è  Point guidance not available, using optimized config only")
    # Fallback: just use the optimized config without point guidance
    sam_guided.generate(TARGET_IMAGE, output="guided_masks.tif", foreground=True, unique=True)
    sam_guided.show_anns(axis="off", alpha=1, output="guided_result.tif")
    print("‚úÖ Guided SAM complete (config optimization only)")
    
except Exception as e:
    print(f"‚ùå Error in guided method: {e}")
    print("   Falling back to optimized config only")
    sam_guided.generate(TARGET_IMAGE, output="guided_masks.tif", foreground=True, unique=True)
    sam_guided.show_anns(axis="off", alpha=1, output="guided_result.tif")
    print("‚úÖ Guided SAM complete (fallback mode)")

## Step 3: üìä Compare Results

Let's see how the three methods performed:

In [ ]:
# Compare Default vs Optimized
print("üÜö Default vs Optimized SAM:")
leafmap.image_comparison(
    "default_result.tif",
    "optimized_result.tif", 
    label1="Default SAM",
    label2="Optimized SAM",
)

In [ ]:
# Compare Optimized vs Guided
print("üÜö Optimized vs Guided SAM:")
leafmap.image_comparison(
    "optimized_result.tif",
    "guided_result.tif",
    label1="Optimized SAM", 
    label2="Guided SAM",
)

In [ ]:
# Original image vs best result
print("üÜö Original vs Final Result:")
leafmap.image_comparison(
    TARGET_IMAGE,
    "guided_result.tif",
    label1="Original Image",
    label2="Lake Detection Result",
)

## Export Results

Save the best results for further analysis:

In [ ]:
# Export best results to vector format
print("üíæ Exporting results...")

# Convert best results to vector format for GIS
try:
    sam_guided.tiff_to_vector("guided_masks.tif", "lake_detection_results.gpkg")
    print("‚úÖ Vector results: lake_detection_results.gpkg")
except:
    print("‚ö†Ô∏è Vector export failed, but raster results available")

# Save the learned configuration for future use
learned_config = {
    'sam_config': optimized_config,
    'lake_stats': lake_stats,
    'guidance_points': lake_centers[:10],  # First 10 points
    'training_files': {
        'image': os.path.basename(TRAINING_IMAGE),
        'mask': os.path.basename(TRAINING_MASK)
    }
}

with open('learned_lake_detection_config.json', 'w') as f:
    json.dump(learned_config, f, indent=2)

print("‚úÖ Configuration saved: learned_lake_detection_config.json")
print("‚úÖ Results saved: guided_masks.tif & guided_result.tif")

## üéâ Summary & Next Steps

**What we accomplished:**
1. **üéì Learned** from manual lake annotations (extracted {lake_stats['num_lakes']} lakes)
2. **‚öôÔ∏è Optimized** SAM configuration for your lake characteristics  
3. **üéØ Applied** learning to detect lakes in new image
4. **üìä Compared** Default vs Optimized vs Guided approaches
5. **üíæ Exported** results in raster and vector formats

**Files created:**
- `guided_masks.tif` - Binary lake detection mask
- `guided_result.tif` - Colored visualization 
- `lake_detection_results.gpkg` - Vector format for GIS
- `learned_lake_detection_config.json` - Configuration for future use

**Next steps:**
- Apply the saved configuration to more images
- Fine-tune parameters if needed based on visual inspection
- Use batch processing for large datasets

## Batch Processing (Optional)

Apply learned configuration to multiple images:

In [ ]:
# Batch processing function - apply learned config to many images
def batch_process_images(image_paths, sam_config, output_dir="batch_results"):
    """Apply learned SAM configuration to multiple images"""
    os.makedirs(output_dir, exist_ok=True)
    
    sam = SamGeo(model_type="vit_l", sam_kwargs=sam_config)
    results = []
    
    for i, image_path in enumerate(image_paths, 1):
        print(f"Processing {i}/{len(image_paths)}: {os.path.basename(image_path)}")
        
        try:
            output_name = os.path.splitext(os.path.basename(image_path))[0]
            mask_output = os.path.join(output_dir, f"{output_name}_masks.tif")
            result_output = os.path.join(output_dir, f"{output_name}_result.tif")
            
            sam.generate(image_path, output=mask_output, foreground=True, unique=True)
            sam.show_anns(axis="off", alpha=1, output=result_output)
            
            results.append({"image": image_path, "status": "success", "output": mask_output})
            print(f"   ‚úÖ Saved: {mask_output}")
            
        except Exception as e:
            results.append({"image": image_path, "status": "error", "error": str(e)})
            print(f"   ‚ùå Error: {e}")
    
    print(f"\nüéâ Batch processing complete!")
    print(f"   Successful: {len([r for r in results if r['status'] == 'success'])}/{len(image_paths)}")
    return results

# Example usage (uncomment and update paths to use):
# image_list = [
#     '/content/drive/MyDrive/superlakes/image1.tif',
#     '/content/drive/MyDrive/superlakes/image2.tif',
#     # Add more images...
# ]
# 
# batch_results = batch_process_images(image_list, optimized_config)

print("Batch processing function ready! Update image_list above to use it.")