# DINOv3 Lake Detection - Inference Notebook for Google Colab

Use your trained model to detect lakes in new satellite images

**What this notebook does:**
1. Loads your trained DINOv3 + U-Net model
2. Processes new satellite images to detect lakes
3. Applies boundary constraints (glacier areas only)
4. Saves results and creates visualizations

**Requirements:**
- Your trained model file (saved from training notebook)
- New satellite images to process
- Optional: Boundary shapefile for glacier areas

In [None]:
# Step 1: Install required packages
print("üì¶ Installing required packages...")
!pip install torch torchvision transformers
!pip install rasterio geopandas opencv-python
!pip install scikit-image matplotlib

print("‚úÖ All packages installed!")

In [None]:
# Step 2: Import all necessary libraries
print("üìö Importing libraries...")

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import rasterio
import cv2
import geopandas as gpd
from rasterio.features import geometry_mask
from transformers import AutoModel
import torchvision.transforms as transforms
import os
from google.colab import drive

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è Using device: {device}")

print("‚úÖ All libraries imported!")

In [None]:
# Step 3: Mount Google Drive to access your saved model and images
print("üíæ Mounting Google Drive...")
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted!")

# Set your file paths here (UPDATE THESE PATHS)
MODEL_PATH = '/content/drive/MyDrive/superlakes/models/dinov3_lake_inference_ready.pth'
SHAPEFILE_PATH = '/content/drive/MyDrive/superlakes/vectors/clip_by_glacier.shp'

# Path to new image you want to process
NEW_IMAGE_PATH = '/content/drive/MyDrive/superlakes/new_satellite_image.tif'

# Where to save results
OUTPUT_DIR = '/content/drive/MyDrive/superlakes/results/'

print(f"üìÅ Model path: {MODEL_PATH}")
print(f"üó∫Ô∏è Shapefile path: {SHAPEFILE_PATH}")
print(f"üñºÔ∏è New image path: {NEW_IMAGE_PATH}")
print(f"üíæ Output directory: {OUTPUT_DIR}")

In [None]:
# Step 4: Define the model architecture (same as training)
print("üèóÔ∏è Setting up model architecture...")

class DynamicUNetDecoder(nn.Module):
    """
    This is the 'brain' that converts DINOv3 features into lake predictions
    Think of it as: DINOv3 sees the image ‚Üí Decoder decides what's water
    """
    def __init__(self, feature_dim=768, num_classes=1, target_size=224):
        super(DynamicUNetDecoder, self).__init__()
        self.target_size = target_size
        
        # Series of layers that gradually convert features to water/no-water decisions
        self.conv1 = nn.Conv2d(feature_dim, 512, kernel_size=3, padding=1)  # 768 ‚Üí 512 features
        self.conv2 = nn.Conv2d(512, 256, kernel_size=3, padding=1)          # 512 ‚Üí 256 features  
        self.conv3 = nn.Conv2d(256, 128, kernel_size=3, padding=1)          # 256 ‚Üí 128 features
        self.conv4 = nn.Conv2d(128, 64, kernel_size=3, padding=1)           # 128 ‚Üí 64 features
        self.final = nn.Conv2d(64, num_classes, kernel_size=1)              # 64 ‚Üí 1 (water probability)
        
        self.relu = nn.ReLU(inplace=True)    # Activation function
        self.sigmoid = nn.Sigmoid()          # Converts output to 0-1 probability
    
    def forward(self, x):
        # Process features through each layer
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.final(x)
        
        # Resize to exact target size (224x224)
        x = nn.functional.interpolate(x, size=(self.target_size, self.target_size), 
                                    mode='bilinear', align_corners=False)
        
        # Convert to probabilities (0 = no water, 1 = definitely water)
        x = self.sigmoid(x)
        return x

class DINOv3UNet(nn.Module):
    """
    Complete model: DINOv3 (sees images) + U-Net (decides what's water)
    This is the same architecture you trained in your main notebook
    """
    def __init__(self, dinov3_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m"):
        super(DINOv3UNet, self).__init__()
        
        # Load DINOv3 - this is the 'eyes' that extract features from satellite images
        print(f"   Loading DINOv3 model: {dinov3_model_name}")
        self.dinov3 = AutoModel.from_pretrained(dinov3_model_name)
        
        # Freeze DINOv3 (we don't change its weights, just use its features)
        for param in self.dinov3.parameters():
            param.requires_grad = False
            
        # Add our trained decoder (this is what we actually trained)
        self.decoder = DynamicUNetDecoder(feature_dim=768, target_size=224)
        
        print("   ‚úÖ Model architecture created!")
    
    def forward(self, x):
        """
        Forward pass: Image ‚Üí DINOv3 features ‚Üí U-Net decoder ‚Üí Water probability
        """
        with torch.no_grad():  # Don't compute gradients for DINOv3 (saves memory)
            # Get features from DINOv3
            features = self.dinov3(x).last_hidden_state
            
            # Remove the first token (CLS token) - we only want patch features
            patch_features = features[:, 1:]
            
            batch_size, num_patches, feature_dim = patch_features.shape
            
            # Figure out spatial arrangement (DINOv3 outputs patches in a sequence)
            h = int(num_patches**0.5)  # Assume roughly square arrangement
            w = h
            
            # Handle case where patches don't form perfect square
            if h * w != num_patches:
                needed_patches = h * w
                if needed_patches > num_patches:
                    # Pad with zeros if we need more patches
                    padding = torch.zeros(batch_size, needed_patches - num_patches, 
                                        feature_dim, device=patch_features.device)
                    patch_features = torch.cat([patch_features, padding], dim=1)
                else:
                    # Truncate if we have too many patches
                    patch_features = patch_features[:, :needed_patches]
            
            # Reshape from sequence to 2D feature map
            feature_map = patch_features.reshape(batch_size, h, w, feature_dim)
            feature_map = feature_map.permute(0, 3, 1, 2)  # Change to (batch, features, height, width)
        
        # Generate water probability mask using our trained decoder
        water_mask = self.decoder(feature_map)
        return water_mask

print("‚úÖ Model architecture defined!")

In [None]:
# Step 5: Load your trained model
print("üîÑ Loading your trained model...")

# Check if model file exists
if not os.path.exists(MODEL_PATH):
    print(f"‚ùå Model file not found: {MODEL_PATH}")
    print("   Please check the path and make sure you saved the model correctly!")
else:
    # Load the saved model information
    print(f"   Loading from: {MODEL_PATH}")
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    
    # Create the model architecture
    model = DINOv3UNet().to(device)
    
    # Load the trained weights into the model
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Set to evaluation mode (important for inference)
    model.eval()
    
    # Show what configuration was used during training
    if 'model_config' in checkpoint:
        config = checkpoint['model_config']
        print(f"   ‚úÖ Model loaded successfully!")
        print(f"   Training configuration:")
        print(f"     - Patch size: {config.get('patch_size', 'unknown')}")
        print(f"     - Stride: {config.get('stride', 'unknown')}")
        print(f"     - DINOv3 model: {config.get('dinov3_model', 'unknown')}")
    else:
        print(f"   ‚úÖ Model loaded successfully (basic version)!")

print("‚úÖ Model ready for inference!")

In [None]:
# Step 6: Create helper functions for processing images
print("üõ†Ô∏è Setting up helper functions...")

def create_boundary_mask(image_path, shapefile_path):
    """
    Create a mask from your shapefile to only analyze glacier areas
    This is the same boundary constraint you used during training
    """
    print(f"   üìê Creating boundary mask from shapefile...")
    
    if not os.path.exists(shapefile_path):
        print(f"   ‚ö†Ô∏è Shapefile not found: {shapefile_path}")
        print("   Will process entire image without boundary constraint")
        return None
    
    # Load the shapefile (your glacier boundary)
    shapefile = gpd.read_file(shapefile_path)
    print(f"   Found {len(shapefile)} polygon(s) in shapefile")
    
    # Load image to get its coordinate system and dimensions
    with rasterio.open(image_path) as src:
        image_crs = src.crs
        image_transform = src.transform
        image_shape = (src.height, src.width)
    
    print(f"   Image CRS: {image_crs}")
    print(f"   Shapefile CRS: {shapefile.crs}")
    
    # Make sure shapefile and image use same coordinate system
    if shapefile.crs != image_crs:
        print(f"   üîÑ Reprojecting shapefile to match image...")
        shapefile = shapefile.to_crs(image_crs)
    
    # Create binary mask: 1 = inside glacier boundary, 0 = outside
    boundary_mask = ~geometry_mask(
        shapefile.geometry,
        transform=image_transform,
        invert=False,
        out_shape=image_shape
    )
    
    pixels_inside = boundary_mask.sum()
    total_pixels = boundary_mask.size
    percentage = pixels_inside / total_pixels * 100
    
    print(f"   ‚úÖ Boundary mask created!")
    print(f"   Glacier area: {pixels_inside:,} pixels ({percentage:.1f}% of image)")
    
    return boundary_mask.astype(np.uint8)

def setup_image_transform():
    """
    Create the same image preprocessing used during training
    This normalizes images the way DINOv3 expects them
    """
    return transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        # These normalization values are standard for DINOv3
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

print("‚úÖ Helper functions ready!")

In [None]:
# Step 7: Main prediction function
print("üéØ Setting up main prediction function...")

def predict_lakes_in_image(image_path, shapefile_path=None, patch_size=16, stride=8, save_result=True):
    """
    Apply your trained model to detect lakes in a new satellite image
    
    Parameters:
    - image_path: Path to satellite image
    - shapefile_path: Path to boundary shapefile (optional)
    - patch_size: Size of patches to process (use same as training)
    - stride: Step size between patches (use same as training)
    - save_result: Whether to save the result
    """
    
    print(f"üñºÔ∏è Processing image: {os.path.basename(image_path)}")
    
    # Step 7a: Load the satellite image
    print("   üìÇ Loading satellite image...")
    with rasterio.open(image_path) as src:
        image = src.read()  # Read all bands
        image = np.transpose(image, (1, 2, 0))  # Change from (bands, height, width) to (height, width, bands)
        
        # Take only RGB channels (first 3 bands)
        image_rgb = image[:,:,:3].astype(np.uint8)
        
        # Save image info for later (to save results with same projection)
        profile = src.profile.copy()
    
    height, width = image_rgb.shape[:2]
    print(f"   Image size: {width} x {height} pixels")
    print(f"   Using RGB channels: {image_rgb.shape}")
    
    # Step 7b: Create boundary mask if shapefile provided
    boundary_mask = None
    if shapefile_path:
        boundary_mask = create_boundary_mask(image_path, shapefile_path)
    
    # Step 7c: Set up for patch processing
    print(f"   üîÑ Processing with {patch_size}x{patch_size} patches, stride {stride}...")
    
    # Create arrays to accumulate predictions
    full_mask = np.zeros((height, width), dtype=np.float32)  # Final prediction
    count_mask = np.zeros((height, width), dtype=np.float32)  # Count of overlapping patches
    
    # Set up image preprocessing
    transform = setup_image_transform()
    
    patches_processed = 0
    patches_skipped = 0
    
    # Step 7d: Process image patch by patch
    with torch.no_grad():  # Don't compute gradients (saves memory)
        
        # Loop through all possible patch positions
        for y in range(0, height - patch_size + 1, stride):
            for x in range(0, width - patch_size + 1, stride):
                
                # Check if patch center is inside boundary (if boundary provided)
                if boundary_mask is not None:
                    center_y = y + patch_size // 2
                    center_x = x + patch_size // 2
                    
                    if boundary_mask[center_y, center_x] == 0:
                        patches_skipped += 1
                        continue  # Skip patches outside glacier boundary
                
                # Extract small patch from image
                small_patch = image_rgb[y:y+patch_size, x:x+patch_size, :3]
                
                # Resize to 224x224 (what DINOv3 expects)
                patch_224 = cv2.resize(small_patch, (224, 224))
                
                # Preprocess patch (normalize, convert to tensor)
                patch_tensor = transform(patch_224).unsqueeze(0).to(device)
                
                # Run model prediction (gets 224x224 output)
                pred_224 = model(patch_tensor).squeeze().cpu().numpy()
                
                # Resize prediction back to original patch size
                pred_small = cv2.resize(pred_224, (patch_size, patch_size))
                
                # Add prediction to full image (accumulate overlapping patches)
                full_mask[y:y+patch_size, x:x+patch_size] += pred_small
                count_mask[y:y+patch_size, x:x+patch_size] += 1
                
                patches_processed += 1
                
                # Show progress every 1000 patches
                if patches_processed % 1000 == 0:
                    print(f"      Processed {patches_processed} patches...")
    
    print(f"   ‚úÖ Patch processing complete!")
    print(f"   Processed: {patches_processed} patches")
    print(f"   Skipped: {patches_skipped} patches (outside boundary)")
    
    # Step 7e: Average overlapping predictions
    print("   üßÆ Averaging overlapping predictions...")
    final_mask = np.divide(full_mask, count_mask, out=np.zeros_like(full_mask), where=count_mask!=0)
    
    # Apply boundary mask to final result
    if boundary_mask is not None:
        final_mask = final_mask * boundary_mask
    
    # Step 7f: Calculate statistics
    water_pixels = (final_mask > 0.5).sum()  # Count pixels with >50% water probability
    total_pixels = final_mask.size
    coverage_percent = water_pixels / total_pixels * 100
    
    if boundary_mask is not None:
        analysis_pixels = boundary_mask.sum()
        coverage_of_analysis_area = water_pixels / analysis_pixels * 100
        print(f"   üìä Results:")
        print(f"     Water pixels: {water_pixels:,}")
        print(f"     Coverage of total image: {coverage_percent:.2f}%")
        print(f"     Coverage of analysis area: {coverage_of_analysis_area:.2f}%")
    else:
        print(f"   üìä Results:")
        print(f"     Water pixels: {water_pixels:,}")
        print(f"     Coverage: {coverage_percent:.2f}%")
    
    # Step 7g: Save result if requested
    if save_result:
        # Create output filename
        image_name = os.path.splitext(os.path.basename(image_path))[0]
        output_filename = f"{image_name}_lake_prediction.tif"
        output_path = os.path.join(OUTPUT_DIR, output_filename)
        
        # Create output directory if needed
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        
        # Update profile for single-band float output
        profile.update({
            'dtype': rasterio.float32,
            'count': 1,
            'nodata': 0
        })
        
        # Save as GeoTIFF
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(final_mask.astype('float32'), 1)
        
        print(f"   üíæ Saved result: {output_path}")
    
    return final_mask, water_pixels, coverage_percent

print("‚úÖ Prediction function ready!")

In [None]:
# Step 8: Visualize results
print("üé® Setting up visualization function...")

def visualize_prediction(image_path, predicted_mask, save_plot=True):
    """
    Create a nice visualization comparing original image with prediction
    """
    print("   üñºÔ∏è Creating visualization...")
    
    # Load original image for display
    with rasterio.open(image_path) as src:
        image = src.read()
        image = np.transpose(image, (1, 2, 0))
        image_rgb = image[:,:,:3].astype(np.uint8)
    
    # Create the plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Original image
    axes[0].imshow(image_rgb)
    axes[0].set_title('Original Satellite Image', fontsize=14)
    axes[0].axis('off')
    
    # Prediction probabilities
    im1 = axes[1].imshow(predicted_mask, cmap='Blues', vmin=0, vmax=1)
    axes[1].set_title('Lake Probability\n(0=No Water, 1=Definitely Water)', fontsize=14)
    axes[1].axis('off')
    plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    
    # Binary prediction overlay
    axes[2].imshow(image_rgb)
    binary_mask = predicted_mask > 0.5
    axes[2].imshow(binary_mask, cmap='Reds', alpha=0.6)
    axes[2].set_title('Detected Lakes (Red Overlay)\nThreshold: 50% Confidence', fontsize=14)
    axes[2].axis('off')
    
    plt.tight_layout()
    
    if save_plot:
        image_name = os.path.splitext(os.path.basename(image_path))[0]
        plot_filename = f"{image_name}_visualization.png"
        plot_path = os.path.join(OUTPUT_DIR, plot_filename)
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        print(f"   üíæ Saved visualization: {plot_path}")
    
    plt.show()
    
    # Print some statistics
    water_pixels = (predicted_mask > 0.5).sum()
    total_pixels = predicted_mask.size
    coverage = water_pixels / total_pixels * 100
    
    print(f"   üìä Summary:")
    print(f"     - Water pixels detected: {water_pixels:,}")
    print(f"     - Coverage: {coverage:.2f}%")
    print(f"     - Confidence range: {predicted_mask.min():.3f} to {predicted_mask.max():.3f}")

print("‚úÖ Visualization function ready!")

In [None]:
# Step 9: RUN THE INFERENCE! üöÄ
print("üöÄ Starting lake detection inference...")
print("=" * 60)

# Check if input image exists
if not os.path.exists(NEW_IMAGE_PATH):
    print(f"‚ùå Image not found: {NEW_IMAGE_PATH}")
    print("   Please update NEW_IMAGE_PATH with the correct path to your satellite image")
else:
    # Run the prediction
    predicted_mask, water_pixels, coverage = predict_lakes_in_image(
        image_path=NEW_IMAGE_PATH,
        shapefile_path=SHAPEFILE_PATH,  # Set to None if you don't want boundary constraint
        patch_size=16,                  # Same as training
        stride=8,                       # Same as training
        save_result=True
    )
    
    # Show the results
    visualize_prediction(NEW_IMAGE_PATH, predicted_mask, save_plot=True)
    
    print("=" * 60)
    print("üéâ INFERENCE COMPLETE!")
    print(f"‚úÖ Successfully detected {water_pixels:,} water pixels ({coverage:.2f}% coverage)")
    print(f"üíæ Results saved to: {OUTPUT_DIR}")

In [None]:
# Step 10: (Optional) Process multiple images
print("üìÅ Optional: Process multiple images at once")
print("Uncomment and modify the code below to process a folder of images")

"""
# Example: Process all .tif files in a directory
input_directory = '/content/drive/MyDrive/superlakes/many_images/'
output_directory = '/content/drive/MyDrive/superlakes/batch_results/'

# Get list of all TIFF files
import glob
image_files = glob.glob(os.path.join(input_directory, '*.tif'))

print(f"Found {len(image_files)} images to process")

# Process each image
results = []
for i, image_path in enumerate(image_files, 1):
    print(f"\n--- Processing {i}/{len(image_files)}: {os.path.basename(image_path)} ---")
    
    try:
        # Predict lakes
        pred_mask, water_pix, coverage = predict_lakes_in_image(
            image_path=image_path,
            shapefile_path=SHAPEFILE_PATH,
            patch_size=16,
            stride=8,
            save_result=True
        )
        
        # Record results
        results.append({
            'image': os.path.basename(image_path),
            'water_pixels': water_pix,
            'coverage_percent': coverage,
            'status': 'success'
        })
        
    except Exception as e:
        print(f"‚ùå Error processing {os.path.basename(image_path)}: {e}")
        results.append({
            'image': os.path.basename(image_path),
            'error': str(e),
            'status': 'failed'
        })

# Save summary
import pandas as pd
df = pd.DataFrame(results)
summary_path = os.path.join(output_directory, 'batch_processing_summary.csv')
df.to_csv(summary_path, index=False)

print(f"\nüéâ Batch processing complete!")
print(f"Successfully processed: {len(df[df['status'] == 'success'])}/{len(df)} images")
print(f"Summary saved: {summary_path}")
"""

print("‚úÖ Inference notebook ready!")
print("\nüí° To use this notebook:")
print("1. Update the file paths in Step 3")
print("2. Run all cells in order")
print("3. Your results will be saved to Google Drive")
print("4. Check the visualization to see how well it worked!")