In [1]:
!pip install rasterio
!pip install geopandas
!pip install tqdm
!pip install shapely

Collecting rasterio
  Downloading rasterio-1.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading rasterio-1.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.3/22.3 MB[0m [31m116.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl (11 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1.2 cligj-0.7.2 rasterio-1.4.3


In [18]:
import os
import cv2
import rasterio
import numpy as np
import seaborn as sns

from PIL import Image
from tqdm import tqdm
from pathlib import Path
from tifffile import imread
from rasterio.windows import Window
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader

from scipy.ndimage import gaussian_filter

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering

In [19]:
def load_orthomosaic_tiff(tiff_path, max_size=2048):
    """
    Load orthomosaic TIFF file with proper handling
    """
    print(f"Loading orthomosaic: {tiff_path}")
    
    try:
        # Method 1: Using rasterio (better for geoTIFF)
        with rasterio.open(tiff_path) as src:
            image = src.read()
            image = np.transpose(image, (1, 2, 0))  # (H, W, C)
            print(f"Loaded with rasterio - Shape: {image.shape}, dtype: {image.dtype}")
    except:
        # Method 2: Using tifffile
        try:
            image = imread(tiff_path)
            if len(image.shape) == 3 and image.shape[0] in [3, 4]:
                image = np.transpose(image, (1, 2, 0))  # (C, H, W) -> (H, W, C)
            print(f"Loaded with tifffile - Shape: {image.shape}, dtype: {image.dtype}")
        except:
            # Method 3: Using PIL
            image = np.array(Image.open(tiff_path))
            print(f"Loaded with PIL - Shape: {image.shape}, dtype: {image.dtype}")
    
    # Handle different channel configurations
    if image.shape[-1] > 3:
        print(f"Keeping first 3 channels from {image.shape[-1]} channels")
        image = image[:, :, :3]
    
    # Convert to uint8 if needed
    if image.dtype == np.uint16:
        image = (image / 256).astype(np.uint8)
    elif image.dtype == np.float32 or image.dtype == np.float64:
        image = (image * 255).astype(np.uint8)
    
    # Resize if too large (for memory efficiency)
    h, w = image.shape[:2]
    if max(h, w) > max_size:
        scale = max_size / max(h, w)
        new_h, new_w = int(h * scale), int(w * scale)
        image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
        print(f"Resized to: {image.shape}")
    
    return image

In [12]:
# ============================================================================
# DINOv2 Feature Extractor
# ============================================================================

class DINOv2FeatureExtractor:
    """Extract features using DINOv2 foundation model"""
    def __init__(self, model_size='small', device='cuda'):
        """
        Args:
            model_size: 'small', 'base', 'large', or 'giant'
            device: 'cuda' or 'cpu'
        """
        self.device = device
        self.model_size = model_size
        
        print(f"\n{'='*60}")
        print(f"Loading DINOv2 ({model_size}) model...")
        print(f"{'='*60}")
        
        # Load DINOv2 model
        model_map = {
            'small': 'dinov2_vits14',
            'base': 'dinov2_vitb14',
            'large': 'dinov2_vitl14',
            'giant': 'dinov2_vitg14'
        }
        
        self.model = torch.hub.load('facebookresearch/dinov2', model_map[model_size])
        self.model = self.model.to(device)
        self.model.eval()
        
        # Get feature dimension
        self.feature_dim = self.model.embed_dim
        
        print(f"✓ Model loaded: {model_map[model_size]}")
        print(f"  Feature dimension: {self.feature_dim}")
        print(f"  Patch size: 14x14")
    
    def preprocess_image(self, image):
        """Preprocess image for DINOv2"""
        # Convert to float and normalize
        if image.dtype == np.uint8:
            image = image.astype(np.float32) / 255.0
        
        # Handle grayscale
        if len(image.shape) == 2:
            image = np.stack([image] * 3, axis=-1)
        
        # ImageNet normalization
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = (image - mean) / std
        
        # Convert to tensor (C, H, W)
        image_tensor = torch.FloatTensor(image).permute(2, 0, 1)
        
        return image_tensor
    
    def extract_features(self, image, stride=14):
        """
        Extract dense features from image
        
        Args:
            image: Input image (H, W, C) or (H, W)
            stride: Stride for feature extraction (default: 14, same as patch size)
        
        Returns:
            features: Dense feature map (H', W', feature_dim)
            feature_spatial_size: (H', W')
        """
        print(f"\n{'='*60}")
        print("Extracting DINOv2 Features")
        print(f"{'='*60}")
        
        h, w = image.shape[:2]
        print(f"Image size: {h} x {w}")
        
        # Preprocess
        image_tensor = self.preprocess_image(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            # Get patch features
            features = self.model.forward_features(image_tensor)
            
            # Extract patch tokens (remove CLS token)
            patch_features = features['x_norm_patchtokens']  # (1, N_patches, feature_dim)
            
            # Calculate spatial dimensions
            patch_h = h // 14
            patch_w = w // 14
            
            # Reshape to spatial grid
            patch_features = patch_features.squeeze(0)  # (N_patches, feature_dim)
            patch_features = patch_features.reshape(patch_h, patch_w, self.feature_dim)
            
            print(f"✓ Features extracted")
            print(f"  Feature map size: {patch_h} x {patch_w} x {self.feature_dim}")
            print(f"  Reduction factor: {14}x")
        
        return patch_features.cpu().numpy(), (patch_h, patch_w)

In [13]:
# ============================================================================
# Spectral Clustering on DINOv2 Features
# ============================================================================

def spectral_clustering_segmentation(features, n_clusters=5, 
                                    sigma=0.5, normalize=True):
    """
    Perform spectral clustering on DINOv2 features for segmentation
    
    Args:
        features: Feature map (H, W, feature_dim)
        n_clusters: Number of segments
        sigma: Gaussian kernel bandwidth
        normalize: Normalize features before clustering
    
    Returns:
        segmentation_map: Segmentation labels (H, W)
    """
    print(f"\n{'='*60}")
    print(f"Spectral Clustering (n_clusters={n_clusters})")
    print(f"{'='*60}")
    
    h, w, d = features.shape
    
    # Reshape features
    features_flat = features.reshape(-1, d)
    
    # Normalize features
    if normalize:
        features_flat = features_flat / (np.linalg.norm(features_flat, axis=1, keepdims=True) + 1e-8)
    
    print(f"Computing affinity matrix...")
    
    # Compute affinity matrix using spatial and feature similarity
    from sklearn.metrics.pairwise import euclidean_distances
    from sklearn.cluster import SpectralClustering
    
    # For large images, use approximate spectral clustering
    n_samples = features_flat.shape[0]
    
    if n_samples > 10000:
        print(f"  Large image ({n_samples} pixels), using approximate clustering...")
        # Subsample for affinity computation
        indices = np.random.choice(n_samples, size=min(5000, n_samples), replace=False)
        features_sample = features_flat[indices]
        
        # Compute affinity on subsample
        affinity_sample = np.exp(-euclidean_distances(features_sample, features_sample) / (2 * sigma**2))
        
        # Cluster subsample
        sc = SpectralClustering(n_clusters=n_clusters, affinity='precomputed', 
                               n_init=10, assign_labels='kmeans')
        labels_sample = sc.fit_predict(affinity_sample)
        
        # Assign all pixels to nearest cluster center
        from sklearn.neighbors import KNeighborsClassifier
        knn = KNeighborsClassifier(n_neighbors=1)
        knn.fit(features_sample, labels_sample)
        labels = knn.predict(features_flat)
    else:
        # Full spectral clustering for smaller images
        sc = SpectralClustering(n_clusters=n_clusters, affinity='nearest_neighbors',
                               n_neighbors=10, n_init=10, assign_labels='kmeans')
        labels = sc.fit_predict(features_flat)
    
    # Reshape to image
    segmentation_map = labels.reshape(h, w)
    
    print(f"✓ Clustering complete")
    for label in np.unique(labels):
        count = np.sum(labels == label)
        print(f"  Segment {label}: {count} pixels ({count/len(labels)*100:.1f}%)")
    
    return segmentation_map


In [None]:
# ============================================================================
# K-Nearest Neighbors on DINOv2 Features
# ============================================================================

def knn_segmentation(features, n_clusters=5, normalize=True):
    """
    Simple K-Means clustering on DINOv2 features
    
    Args:
        features: Feature map (H, W, feature_dim)
        n_clusters: Number of segments
        normalize: Normalize features before clustering
    
    Returns:
        segmentation_map: Segmentation labels (H, W)
    """
    print(f"\n{'='*60}")
    print(f"K-Means Clustering (n_clusters={n_clusters})")
    print(f"{'='*60}")
    
    h, w, d = features.shape
    
    # Reshape features
    features_flat = features.reshape(-1, d)
    
    # Normalize features
    if normalize:
        features_flat = features_flat / (np.linalg.norm(features_flat, axis=1, keepdims=True) + 1e-8)
    
    from sklearn.cluster import MiniBatchKMeans
    
    print(f"Running K-Means...")
    kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=42, 
                            batch_size=1000, n_init=10)
    labels = kmeans.fit_predict(features_flat)
    
    # Reshape to image
    segmentation_map = labels.reshape(h, w)
    
    print(f"✓ Clustering complete")
    for label in np.unique(labels):
        count = np.sum(labels == label)
        print(f"  Segment {label}: {count} pixels ({count/len(labels)*100:.1f}%)")
    
    return segmentation_map

In [14]:
# ============================================================================
# Post-processing: Upsampling and Refinement
# ============================================================================

def upsample_segmentation(segmentation_map, original_size, method='nearest'):
    """
    Upsample segmentation map to original image size
    
    Args:
        segmentation_map: Low-res segmentation (H', W')
        original_size: Target size (H, W)
        method: 'nearest', 'bilinear', or 'guided'
    
    Returns:
        upsampled_map: Full resolution segmentation (H, W)
    """
    print(f"\n{'='*60}")
    print("Upsampling Segmentation")
    print(f"{'='*60}")
    
    h_target, w_target = original_size
    h_low, w_low = segmentation_map.shape
    
    print(f"Upsampling: {h_low}x{w_low} -> {h_target}x{w_target}")
    
    if method == 'nearest':
        # Simple nearest neighbor upsampling
        upsampled = cv2.resize(segmentation_map, (w_target, h_target), 
                              interpolation=cv2.INTER_NEAREST)
    
    elif method == 'bilinear':
        # Bilinear then round
        upsampled = cv2.resize(segmentation_map.astype(np.float32), 
                              (w_target, h_target), 
                              interpolation=cv2.INTER_LINEAR)
        upsampled = np.round(upsampled).astype(np.int32)
    
    print(f"✓ Upsampling complete")
    
    return upsampled


def refine_segmentation(segmentation_map, image, iterations=3):
    """
    Refine segmentation using guided filtering
    
    Args:
        segmentation_map: Segmentation labels
        image: Original image for guidance
        iterations: Number of refinement iterations
    
    Returns:
        refined_map: Refined segmentation
    """
    print(f"\n{'='*60}")
    print(f"Refining Segmentation ({iterations} iterations)")
    print(f"{'='*60}")
    
    # Convert to grayscale if needed
    if len(image.shape) == 3:
        guide = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
    else:
        guide = (image * 255).astype(np.uint8)
    
    refined = segmentation_map.copy()
    
    for i in range(iterations):
        # Apply bilateral filter to each label separately
        refined_float = refined.astype(np.float32)
        refined_float = cv2.bilateralFilter(refined_float, d=9, 
                                           sigmaColor=75, sigmaSpace=75)
        refined = np.round(refined_float).astype(np.int32)
    
    print(f"✓ Refinement complete")
    
    return refined


In [15]:
# ============================================================================
# Visualization
# ============================================================================

def visualize_segmentation(image, segmentation_map, save_path=None):
    """Visualize segmentation results with DINOv2"""
    
    n_clusters = len(np.unique(segmentation_map))
    
    # Create colormap
    colors = plt.cm.tab20(np.linspace(0, 1, n_clusters))
    cmap = ListedColormap(colors)
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Original image
    if len(image.shape) == 2:
        axes[0].imshow(image, cmap='gray')
    else:
        if image.dtype == np.float32 or image.dtype == np.float64:
            axes[0].imshow(np.clip(image, 0, 1))
        else:
            axes[0].imshow(image)
    axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    # Segmentation map
    im = axes[1].imshow(segmentation_map, cmap=cmap, interpolation='nearest')
    axes[1].set_title(f'DINOv2 Segmentation ({n_clusters} segments)', 
                     fontsize=14, fontweight='bold')
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    
    # Overlay
    if len(image.shape) == 2:
        overlay = np.stack([image]*3, axis=-1)
    else:
        overlay = image.copy()
    
    if overlay.dtype != np.float32 and overlay.dtype != np.float64:
        overlay = overlay.astype(np.float32) / 255.0
    
    # Create colored overlay
    seg_colored = cmap(segmentation_map / n_clusters)[:, :, :3]
    overlay_blend = 0.6 * overlay + 0.4 * seg_colored
    
    axes[2].imshow(np.clip(overlay_blend, 0, 1))
    axes[2].set_title('Overlay', fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"✓ Visualization saved to {save_path}")
    
    plt.show()


def visualize_features(features, method='pca', save_path=None):
    """Visualize DINOv2 features using dimensionality reduction"""
    
    print(f"\n{'='*60}")
    print(f"Visualizing Features ({method.upper()})")
    print(f"{'='*60}")
    
    h, w, d = features.shape
    features_flat = features.reshape(-1, d)
    
    if method == 'pca':
        from sklearn.decomposition import PCA
        pca = PCA(n_components=3)
        features_reduced = pca.fit_transform(features_flat)
        print(f"  Explained variance: {pca.explained_variance_ratio_.sum():.2%}")
    elif method == 'tsne':
        from sklearn.manifold import TSNE
        tsne = TSNE(n_components=3, random_state=42)
        features_reduced = tsne.fit_transform(features_flat)
    
    # Reshape and normalize to [0, 1]
    features_rgb = features_reduced.reshape(h, w, 3)
    features_rgb = (features_rgb - features_rgb.min()) / (features_rgb.max() - features_rgb.min())
    
    plt.figure(figsize=(10, 10))
    plt.imshow(features_rgb)
    plt.title(f'DINOv2 Features ({method.upper()} visualization)', 
             fontsize=14, fontweight='bold')
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.show()
    
    print(f"✓ Feature visualization complete")

In [16]:
# ============================================================================
# Complete Pipeline with DINOv2
# ============================================================================

def dinov2_segmentation_pipeline(image, n_clusters=5, 
                                model_size='small',
                                refine=True,
                                device='cuda'):
    """
    Complete unsupervised segmentation pipeline using DINOv2
    
    Args:
        image: Input drone image (H, W, C) or (H, W)
        n_clusters: Number of segments
        model_size: 'small', 'base', 'large', or 'giant'
        clustering_method: 'kmeans', 'spectral', or 'hierarchical'
        refine: Apply post-processing refinement
        device: 'cuda' or 'cpu'
    
    Returns:
        segmentation_map: Pixel-wise segmentation labels
        features: DINOv2 features
        extractor: DINOv2 feature extractor
    """
    
    print("\n" + "="*60)
    print("DINOv2 UNSUPERVISED SEGMENTATION PIPELINE")
    print("="*60)
    print(f"Clustering method: {clustering_method.upper()}")
    print(f"Number of clusters: {n_clusters}")
    
    # Ensure image is in correct format
    if image.dtype == np.uint8:
        image_display = image.astype(np.float32) / 255.0
    else:
        image_display = image.copy()
    
    # Step 1: Extract DINOv2 features
    print("\n[1/4] Extracting DINOv2 features...")
    extractor = DINOv2FeatureExtractor(model_size=model_size, device=device)
    features, feature_size = extractor.extract_features(image)
    
    # Step 2: Cluster features
    print(f"\n[2/4] Clustering features...")
    segmentation_low = knn_segmentation(features, n_clusters=n_clusters)

    
    # Step 3: Upsample to original resolution
    print(f"\n[3/4] Upsampling to original resolution...")
    segmentation_map = upsample_segmentation(segmentation_low, image.shape[:2])
    
    # Step 4: Optional refinement
    if refine:
        print(f"\n[4/4] Refining segmentation...")
        segmentation_map = refine_segmentation(segmentation_map, image_display)
    else:
        print(f"\n[4/4] Skipping refinement")
    
    print("\n" + "="*60)
    print("✓ PIPELINE COMPLETE!")
    print("="*60)
    
    return segmentation_map, features, extractor

In [None]:
if __name__ == "__main__":
    print("DINOv2-based Unsupervised Drone Image Segmentation")
    print("\nThis pipeline:")
    print("1. Extracts dense DINOv2 features (no training needed!)")
    print("2. Clusters features using spatial-aware methods")
    print("3. Produces high-quality segmentation")
    print("\nAdvantages over traditional methods:")
    print("✓ No training required - uses pre-trained DINOv2")
    print("✓ Better semantic understanding")
    print("✓ Robust to lighting/viewpoint changes")
    print("✓ Works on diverse drone imagery")
    print("\nExample usage:")

    import cv2
    
    # Load drone image
    tiff_path = 
    image = load_orthomosaic_tiff(tiff_path, max_size=512)
    
    # Run DINOv2 segmentation
    seg_map, features, extractor = dinov2_segmentation_pipeline(
        image=image,
        n_clusters=6,
        model_size='small',  # 'small', 'base', 'large', 'giant'
        clustering_method='hierarchical',  # 'kmeans', 'spectral', 'hierarchical'
        refine=True
    )
    
    # Visualize
    visualize_segmentation(image, seg_map)
    visualize_features(features, method='pca')