# DINOv3 + U-Net for Pixel-Level Lake Segmentation

This notebook shows how to:
1. Use DINOv3 as a feature extractor (frozen backbone)
2. Add a U-Net decoder for pixel-level segmentation
3. Train on your manual lake masks
4. Get precise lake boundaries (not just patch-level predictions)

**Key difference from previous approach:**
- Previous: 224x224 patch → single prediction ("has lake")
- This: 224x224 patch → 224x224 mask (pixel-level "which pixels are lake")

## Step 1: Setup and Imports

In [None]:
# Core libraries
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import rasterio

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# DINOv3
from transformers import Dinov2Model, Dinov2Config
from huggingface_hub import login

# Login (you already have this)
login(token="my-login-token")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


## Step 2: Load Your Data

In [2]:
# Load your satellite image and mask
image_path = "/Users/varyabazilova/Desktop/glacial_lakes/super_lakes/dinov3_tryout/test_data/2021-09-04_fcc_testclip.tif"
mask_path = "/Users/varyabazilova/Desktop/glacial_lakes/super_lakes/dinov3_tryout/test_data/lake_mask_testclip.tif"

with rasterio.open(image_path) as src:
    image = src.read()  # Shape: (channels, height, width)
    image = np.transpose(image, (1, 2, 0))  # Change to (height, width, channels)
    
with rasterio.open(mask_path) as src:
    mask = src.read(1)  # Read first band

print(f"Image shape: {image.shape}")
print(f"Mask shape: {mask.shape}")

# Convert to 0-255 range for neural networks
image_rgb = image[:,:,:3].astype(np.uint8)
mask_binary = (mask > 0).astype(np.float32)  # Binary mask for training

Image shape: (657, 653, 4)
Mask shape: (657, 653)


## Step 3: Create Training Dataset

In [3]:
class LakeSegmentationDataset(Dataset):
    """
    Dataset that creates image patches and corresponding mask patches
    for training pixel-level segmentation
    """
    def __init__(self, image, mask, patch_size=224, max_patches=100):
        self.image = image
        self.mask = mask
        self.patch_size = patch_size
        
        # Create patches
        self.patches = []
        self.mask_patches = []
        
        height, width = image.shape[:2]
        patch_count = 0
        
        # Sample patches across the image
        for y in range(0, height-patch_size, patch_size//2):
            for x in range(0, width-patch_size, patch_size//2):
                if patch_count >= max_patches:
                    break
                    
                img_patch = image[y:y+patch_size, x:x+patch_size]
                mask_patch = mask[y:y+patch_size, x:x+patch_size]
                
                if img_patch.shape[:2] == (patch_size, patch_size):
                    self.patches.append(img_patch)
                    self.mask_patches.append(mask_patch)
                    patch_count += 1
            
            if patch_count >= max_patches:
                break
        
        print(f"Created dataset with {len(self.patches)} patches")
        
        # Data transforms
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
        ])
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, idx):
        image_patch = self.patches[idx]
        mask_patch = self.mask_patches[idx]
        
        # Transform image
        image_tensor = self.transform(image_patch)
        
        # Convert mask to tensor
        mask_tensor = torch.from_numpy(mask_patch).float().unsqueeze(0)  # Add channel dimension
        
        return image_tensor, mask_tensor

# Create dataset
dataset = LakeSegmentationDataset(image_rgb, mask_binary, patch_size=224, max_patches=50)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

Created dataset with 16 patches


## Step 4: Define DINOv3 + U-Net Model

In [4]:
class UNetDecoder(nn.Module):
    """
    Simple U-Net decoder that takes DINOv3 features and outputs pixel-level predictions
    """
    def __init__(self, feature_dim=768, num_classes=1):
        super(UNetDecoder, self).__init__()
        
        # Decoder layers - upsample from 14x14 to 224x224
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(feature_dim, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )  # 14x14 -> 28x28
        
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )  # 28x28 -> 56x56
        
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )  # 56x56 -> 112x112
        
        self.decoder4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )  # 112x112 -> 224x224
        
        # Final prediction layer
        self.final = nn.Conv2d(64, num_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.decoder1(x)
        x = self.decoder2(x)
        x = self.decoder3(x)
        x = self.decoder4(x)
        x = self.final(x)
        x = self.sigmoid(x)
        return x


class DINOv3UNet(nn.Module):
    """
    Complete model: DINOv3 backbone + U-Net decoder for segmentation
    """
    def __init__(self):
        super(DINOv3UNet, self).__init__()
        
        # Load DINOv3 model (smaller version for learning)
        self.dinov3 = Dinov2Model.from_pretrained("facebook/dinov2-base")
        
        # Freeze DINOv3 parameters (use as feature extractor only)
        for param in self.dinov3.parameters():
            param.requires_grad = False
        
        # U-Net decoder
        self.decoder = UNetDecoder(feature_dim=768)  # DINOv2-base has 768 features
        
        print("Model created: DINOv3 (frozen) + U-Net decoder (trainable)")
    
    def forward(self, x):
        # Extract features with DINOv3
        with torch.no_grad():  # Don't compute gradients for DINOv3
            features = self.dinov3(x).last_hidden_state  # Shape: (batch, 257, 768)
            
            # Remove CLS token and reshape to spatial format
            patch_features = features[:, 1:]  # Remove first token (CLS), shape: (batch, 256, 768)
            
            # Reshape to 2D feature map: 256 patches = 16x16 grid for 224x224 input
            batch_size = patch_features.shape[0]
            feature_map = patch_features.reshape(batch_size, 16, 16, 768)
            feature_map = feature_map.permute(0, 3, 1, 2)  # (batch, 768, 16, 16)
        
        # Generate segmentation mask
        mask = self.decoder(feature_map)
        
        return mask

# Create model
model = DINOv3UNet().to(device)

# Loss and optimizer
criterion = nn.BCELoss()  # Binary cross-entropy for lake/no-lake
optimizer = optim.Adam(model.decoder.parameters(), lr=0.001)  # Only train decoder

print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Model created: DINOv3 (frozen) + U-Net decoder (trainable)
Model parameters: 5,090,177


## Step 5: Training Loop

In [5]:
# Training function
def train_model(model, dataloader, criterion, optimizer, num_epochs=10):
    model.train()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        for batch_idx, (images, masks) in enumerate(dataloader):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(images)
            
            # Calculate loss
            loss = criterion(outputs, masks)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if batch_idx % 5 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.4f}')
        
        avg_loss = running_loss / len(dataloader)
        print(f'Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}\n')

# Train the model
print("Starting training...")
train_model(model, dataloader, criterion, optimizer, num_epochs=5)
print("Training completed!")

Starting training...


ValueError: Using a target size (torch.Size([4, 1, 224, 224])) that is different to the input size (torch.Size([4, 1, 256, 256])) is deprecated. Please ensure they have the same size.

## Step 6: Test and Visualize Results

In [None]:
# Test the model on a few examples
model.eval()

# Get a batch of test data
test_images, test_masks = next(iter(dataloader))
test_images = test_images.to(device)

# Generate predictions
with torch.no_grad():
    predicted_masks = model(test_images)

# Move to CPU for visualization
test_images = test_images.cpu()
test_masks = test_masks.cpu()
predicted_masks = predicted_masks.cpu()

# Visualize results
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

for i in range(min(4, len(test_images))):
    # Original image (denormalize for display)
    img = test_images[i].permute(1, 2, 0)
    img = img * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
    img = torch.clamp(img, 0, 1)
    
    axes[0, i].imshow(img)
    axes[0, i].set_title(f'Input Image {i+1}')
    axes[0, i].axis('off')
    
    # Ground truth mask
    axes[1, i].imshow(test_masks[i].squeeze(), cmap='Blues')
    axes[1, i].set_title(f'Ground Truth {i+1}')
    axes[1, i].axis('off')
    
    # Predicted mask
    pred_mask = predicted_masks[i].squeeze()
    axes[2, i].imshow(pred_mask, cmap='Reds')
    axes[2, i].set_title(f'Prediction {i+1}')
    axes[2, i].axis('off')

plt.tight_layout()
plt.show()

# Calculate accuracy metrics
def calculate_metrics(pred_masks, true_masks, threshold=0.5):
    pred_binary = (pred_masks > threshold).float()
    
    intersection = (pred_binary * true_masks).sum()
    union = pred_binary.sum() + true_masks.sum() - intersection
    
    iou = intersection / (union + 1e-8)
    accuracy = ((pred_binary == true_masks).float().mean())
    
    return iou.item(), accuracy.item()

iou, accuracy = calculate_metrics(predicted_masks, test_masks)
print(f"\nModel Performance:")
print(f"IoU (Intersection over Union): {iou:.3f}")
print(f"Pixel Accuracy: {accuracy:.3f}")

## Step 7: Apply to Full Image

In [None]:
def predict_full_image(model, image, patch_size=224, stride=112):
    """
    Apply the trained model to a full large image using sliding window
    """
    model.eval()
    height, width = image.shape[:2]
    
    # Create output mask
    full_mask = np.zeros((height, width), dtype=np.float32)
    count_mask = np.zeros((height, width), dtype=np.float32)
    
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    print("Applying model to full image...")
    
    patches_processed = 0
    with torch.no_grad():
        for y in range(0, height - patch_size + 1, stride):
            for x in range(0, width - patch_size + 1, stride):
                # Extract patch
                patch = image[y:y+patch_size, x:x+patch_size]
                
                # Transform and predict
                patch_tensor = transform(patch).unsqueeze(0).to(device)
                pred_mask = model(patch_tensor).squeeze().cpu().numpy()
                
                # Add to full mask
                full_mask[y:y+patch_size, x:x+patch_size] += pred_mask
                count_mask[y:y+patch_size, x:x+patch_size] += 1
                
                patches_processed += 1
                if patches_processed % 20 == 0:
                    print(f"  Processed {patches_processed} patches...")
    
    # Average overlapping predictions
    full_mask = np.divide(full_mask, count_mask, out=np.zeros_like(full_mask), where=count_mask!=0)
    
    return full_mask

# Apply to a smaller region first (full image might be too large)
# Take a 1000x1000 subset for demonstration
subset_size = 1000
y_start, x_start = 1000, 1000  # Adjust these coordinates
image_subset = image_rgb[y_start:y_start+subset_size, x_start:x_start+subset_size]
mask_subset = mask_binary[y_start:y_start+subset_size, x_start:x_start+subset_size]

# Predict on subset
predicted_full_mask = predict_full_image(model, image_subset, patch_size=224, stride=112)

# Visualize full prediction
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(image_subset)
axes[0].set_title('Original Image (Subset)')
axes[0].axis('off')

axes[1].imshow(mask_subset, cmap='Blues')
axes[1].set_title('Ground Truth Mask')
axes[1].axis('off')

axes[2].imshow(predicted_full_mask, cmap='Reds')
axes[2].set_title('DINOv3+UNet Prediction')
axes[2].axis('off')

plt.tight_layout()
plt.show()

print("\nPixel-level segmentation complete!")
print("This approach gives you precise lake boundaries, not just patch-level predictions.")

## Summary

**What this notebook achieved:**

1. **True pixel-level segmentation** - each pixel gets classified as lake/not lake
2. **DINOv3 features** - powerful satellite-trained representations 
3. **U-Net decoder** - specialized for precise boundary detection
4. **Training on your data** - learns from your manual lake masks

**Key advantages over patch classification:**
- Precise lake boundaries (not rectangular patches)
- Pixel-level accuracy
- Better for tracking lake changes over time
- Scalable to any image size

**Next steps:**
- Train on more data for better accuracy
- Apply to multiple time periods to track changes
- Fine-tune hyperparameters
- Use larger DINOv3 models for better features