# DINOv3 Lake Detection - SIMPLE Inference

**Super simple workflow:**
1. Load your complete trained model
2. Point it at a new satellite image  
3. Get lake predictions!

**No need to redefine model architecture - everything is saved!**

In [None]:
# Step 1: Install packages and setup
print("üì¶ Installing packages...")
!pip install torch torchvision transformers
!pip install rasterio geopandas opencv-python matplotlib

import torch
import numpy as np
import matplotlib.pyplot as plt
import rasterio
import cv2
import geopandas as gpd
from rasterio.features import geometry_mask
import torchvision.transforms as transforms
import os
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Using device: {device}")

In [None]:
# Step 2: Set your file paths (ONLY THING YOU NEED TO UPDATE!)

# Your saved complete model
MODEL_PATH = '/content/drive/MyDrive/superlakes/models/dinov3_lake_detection_complete.pth'

# New image to process
NEW_IMAGE_PATH = '/content/drive/MyDrive/superlakes/new_satellite_image.tif'

# Where to save results
OUTPUT_DIR = '/content/drive/MyDrive/superlakes/simple_results/'

print(f"üìÅ Model: {MODEL_PATH}")
print(f"üñºÔ∏è Image: {NEW_IMAGE_PATH}")
print(f"üíæ Output: {OUTPUT_DIR}")

In [None]:
# Step 3: Load your trained model (SUPER SIMPLE!)
print("üîÑ Loading your trained model...")

# Load everything we saved
checkpoint = torch.load(MODEL_PATH, map_location=device)

# Extract the complete model
model = checkpoint['complete_model'].to(device)
model.eval()  # Set to evaluation mode

# Get training configuration
patch_size = checkpoint['patch_size']
stride = checkpoint['stride']
shapefile_path = checkpoint['shapefile_path']

print(f"‚úÖ Model loaded successfully!")
print(f"   Training patch size: {patch_size}")
print(f"   Training stride: {stride}")
print(f"   Boundary shapefile: {os.path.basename(shapefile_path)}")
print(f"   Training info: {checkpoint['training_info']}")

In [None]:
# Step 4: Simple helper functions

def create_boundary_mask(image_path, shapefile_path):
    """Create glacier boundary mask"""
    if not os.path.exists(shapefile_path):
        print(f"‚ö†Ô∏è Shapefile not found, processing entire image")
        return None
    
    # Load shapefile and image info
    shapefile = gpd.read_file(shapefile_path)
    
    with rasterio.open(image_path) as src:
        image_crs = src.crs
        image_transform = src.transform
        image_shape = (src.height, src.width)
    
    # Reproject if needed
    if shapefile.crs != image_crs:
        shapefile = shapefile.to_crs(image_crs)
    
    # Create mask
    boundary_mask = ~geometry_mask(
        shapefile.geometry,
        transform=image_transform,
        invert=False,
        out_shape=image_shape
    )
    
    pixels_inside = boundary_mask.sum()
    percentage = pixels_inside / boundary_mask.size * 100
    print(f"   Glacier area: {pixels_inside:,} pixels ({percentage:.1f}%)")
    
    return boundary_mask.astype(np.uint8)

def setup_transforms():
    """Image preprocessing for DINOv3"""
    return transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

print("‚úÖ Helper functions ready!")

In [None]:
# Step 5: THE MAIN FUNCTION - Detect lakes in new image

def detect_lakes(image_path, model, patch_size, stride, shapefile_path=None):
    """Detect lakes in a satellite image - SIMPLE VERSION!"""
    
    print(f"üñºÔ∏è Processing: {os.path.basename(image_path)}")
    
    # Load image
    with rasterio.open(image_path) as src:
        image = src.read()
        image = np.transpose(image, (1, 2, 0))
        image_rgb = image[:,:,:3].astype(np.uint8)
        profile = src.profile.copy()
    
    height, width = image_rgb.shape[:2]
    print(f"   Image size: {width} x {height}")
    
    # Create boundary mask
    boundary_mask = None
    if shapefile_path:
        boundary_mask = create_boundary_mask(image_path, shapefile_path)
    
    # Setup for processing
    full_mask = np.zeros((height, width), dtype=np.float32)
    count_mask = np.zeros((height, width), dtype=np.float32)
    transform = setup_transforms()
    
    patches_processed = 0
    patches_skipped = 0
    
    print(f"   Processing with {patch_size}x{patch_size} patches...")
    
    # Process patches
    with torch.no_grad():
        for y in range(0, height - patch_size + 1, stride):
            for x in range(0, width - patch_size + 1, stride):
                
                # Check boundary constraint
                if boundary_mask is not None:
                    center_y = y + patch_size // 2
                    center_x = x + patch_size // 2
                    if boundary_mask[center_y, center_x] == 0:
                        patches_skipped += 1
                        continue
                
                # Extract and process patch
                small_patch = image_rgb[y:y+patch_size, x:x+patch_size, :3]
                patch_224 = cv2.resize(small_patch, (224, 224))
                patch_tensor = transform(patch_224).unsqueeze(0).to(device)
                
                # Model prediction
                pred_224 = model(patch_tensor).squeeze().cpu().numpy()
                pred_small = cv2.resize(pred_224, (patch_size, patch_size))
                
                # Accumulate
                full_mask[y:y+patch_size, x:x+patch_size] += pred_small
                count_mask[y:y+patch_size, x:x+patch_size] += 1
                
                patches_processed += 1
                if patches_processed % 1000 == 0:
                    print(f"      {patches_processed} patches processed...")
    
    # Average overlapping predictions
    final_mask = np.divide(full_mask, count_mask, out=np.zeros_like(full_mask), where=count_mask!=0)
    
    # Apply boundary constraint
    if boundary_mask is not None:
        final_mask = final_mask * boundary_mask
    
    # Statistics
    water_pixels = (final_mask > 0.5).sum()
    coverage = water_pixels / final_mask.size * 100
    
    print(f"   ‚úÖ Complete! {patches_processed} patches processed, {patches_skipped} skipped")
    print(f"   üìä Water detected: {water_pixels:,} pixels ({coverage:.2f}%)")
    
    return final_mask, water_pixels, coverage, profile

print("‚úÖ Main detection function ready!")

In [None]:
# Step 6: RUN LAKE DETECTION! üöÄ

print("üöÄ Starting lake detection...")
print("=" * 50)

# Check if files exist
if not os.path.exists(NEW_IMAGE_PATH):
    print(f"‚ùå Image not found: {NEW_IMAGE_PATH}")
else:
    # Run detection
    predicted_mask, water_pixels, coverage, profile = detect_lakes(
        image_path=NEW_IMAGE_PATH,
        model=model,
        patch_size=patch_size,  # From loaded model config
        stride=stride,          # From loaded model config  
        shapefile_path=shapefile_path  # From loaded model config
    )
    
    print("=" * 50)
    print("üéâ DETECTION COMPLETE!")
    print(f"‚úÖ Found {water_pixels:,} water pixels ({coverage:.2f}% coverage)")

In [None]:
# Step 7: Save results and create visualization

print("üíæ Saving results...")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Save prediction as GeoTIFF
image_name = os.path.splitext(os.path.basename(NEW_IMAGE_PATH))[0]
output_path = os.path.join(OUTPUT_DIR, f"{image_name}_lakes.tif")

profile.update({
    'dtype': rasterio.float32,
    'count': 1,
    'nodata': 0
})

with rasterio.open(output_path, 'w', **profile) as dst:
    dst.write(predicted_mask.astype('float32'), 1)

print(f"‚úÖ Saved prediction: {output_path}")

# Create visualization
print("üé® Creating visualization...")

# Load original image for display
with rasterio.open(NEW_IMAGE_PATH) as src:
    image = src.read()
    image = np.transpose(image, (1, 2, 0))
    image_rgb = image[:,:,:3].astype(np.uint8)

# Create 3-panel plot
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original image
axes[0].imshow(image_rgb)
axes[0].set_title('Original Satellite Image', fontsize=14)
axes[0].axis('off')

# Probability map
im1 = axes[1].imshow(predicted_mask, cmap='Blues', vmin=0, vmax=1)
axes[1].set_title('Lake Probability\n(0=No Water, 1=Water)', fontsize=14)
axes[1].axis('off')
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

# Binary detection overlay
axes[2].imshow(image_rgb)
binary_mask = predicted_mask > 0.5
axes[2].imshow(binary_mask, cmap='Reds', alpha=0.6)
axes[2].set_title('Detected Lakes (Red)', fontsize=14)
axes[2].axis('off')

plt.tight_layout()

# Save plot
plot_path = os.path.join(OUTPUT_DIR, f"{image_name}_visualization.png")
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ Saved visualization: {plot_path}")

# Final summary
print("\nüìä FINAL SUMMARY:")
print(f"   üñºÔ∏è Image processed: {os.path.basename(NEW_IMAGE_PATH)}")
print(f"   üíß Water pixels: {water_pixels:,}")
print(f"   üìä Coverage: {coverage:.2f}%")
print(f"   üéØ Confidence range: {predicted_mask.min():.3f} to {predicted_mask.max():.3f}")
print(f"   üíæ Results saved to: {OUTPUT_DIR}")
print("\nüéâ All done!")

In [None]:
# Step 8: (Optional) Process multiple images at once

print("üìÅ OPTIONAL: Batch process multiple images")
print("Uncomment the code below to process a folder of images:")

"""
# Batch processing example
import glob

# Directory with many images
batch_input_dir = '/content/drive/MyDrive/superlakes/many_images/'
batch_output_dir = '/content/drive/MyDrive/superlakes/batch_results/'

# Find all images
image_files = glob.glob(os.path.join(batch_input_dir, '*.tif'))
print(f"Found {len(image_files)} images to process")

# Process each one
results = []
for i, img_path in enumerate(image_files, 1):
    print(f"\n--- {i}/{len(image_files)}: {os.path.basename(img_path)} ---")
    
    try:
        # Detect lakes
        pred_mask, water_pix, cov, prof = detect_lakes(
            img_path, model, patch_size, stride, shapefile_path
        )
        
        # Save result
        name = os.path.splitext(os.path.basename(img_path))[0]
        out_path = os.path.join(batch_output_dir, f"{name}_lakes.tif")
        
        os.makedirs(batch_output_dir, exist_ok=True)
        prof.update({'dtype': rasterio.float32, 'count': 1, 'nodata': 0})
        
        with rasterio.open(out_path, 'w', **prof) as dst:
            dst.write(pred_mask.astype('float32'), 1)
        
        results.append({
            'image': name,
            'water_pixels': water_pix,
            'coverage_percent': cov,
            'status': 'success'
        })
        
    except Exception as e:
        print(f"‚ùå Error: {e}")
        results.append({
            'image': os.path.basename(img_path),
            'error': str(e),
            'status': 'failed'
        })

# Save summary
import pandas as pd
df = pd.DataFrame(results)
summary_path = os.path.join(batch_output_dir, 'batch_summary.csv')
df.to_csv(summary_path, index=False)

print(f"\nüéâ Batch complete! {len(df[df['status'] == 'success'])}/{len(df)} successful")
print(f"Summary: {summary_path}")
"""

print("\nüí° Usage Summary:")
print("1. Update paths in Step 2")
print("2. Run all cells")
print("3. Get results!")
print("\n‚ú® That's it - super simple! ‚ú®")