# GCP-Based Patch Matching for Orthomosaic Registration

This notebook performs patch matching to find ground control points (GCPs) in orthomosaics and uses them to register the orthos to the basemap.

## Approach:
1. Load GCPs from CSV file (WGS84 coordinates)
2. Extract patches from basemap centered on each GCP
3. Use template matching to find corresponding patches in orthomosaics
4. Compute 2D shift or affine transformation from matches
5. Apply transformation to register orthos to basemap
6. Evaluate accuracy improvement

## Inputs:
- **Basemap**: `TestsiteNewWest_Spexigeo_RTK.tiff`
- **GCPs**: `25-3288-CONTROL-NAD83-UTM10N-EGM2008.csv` (converted to WGS84)
- **Orthomosaics**: 
  - `outputs/orthomosaics/orthomosaic_no_gcps.tif`
  - `outputs/orthomosaics/orthomosaic_with_gcps.tif`

## Outputs:
- All outputs saved to `outputs/gcp_matching/`
- Patches extracted from basemap
- Matched GCP locations in orthos
- Registered orthomosaics
- Accuracy evaluation

## Setup: Install Dependencies

In [None]:
# Install required packages if needed
import sys
import subprocess

packages = [
    'rasterio',
    'numpy',
    'matplotlib',
    'opencv-python',
    'scipy',
    'utm',
    'pillow'
]

for package in packages:
    try:
        __import__(package.replace('-', '_'))
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])

print("✓ Dependencies installed")

## Step 1: Setup - Imports and Paths

In [None]:
import numpy as np
import rasterio
from rasterio.transform import xy
from rasterio.warp import transform as transform_coords
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
from scipy import ndimage
import json
import csv
import utm
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Setup paths
data_dir = Path("/Users/mauriciohessflores/Documents/Code/Data/New Westminster Oct _25")
output_dir = Path("outputs")

# Input files
basemap_path = data_dir / "Michael_RTK_orthos" / "TestsiteNewWest_Spexigeo_RTK.tiff"
gcp_csv_path = data_dir / "25-3288-CONTROL-NAD83-UTM10N-EGM2008.csv"
ortho_no_gcps_path = output_dir / "orthomosaics" / "orthomosaic_no_gcps.tif"
ortho_with_gcps_path = output_dir / "orthomosaics" / "orthomosaic_with_gcps.tif"

# Output directories
gcp_matching_dir = output_dir / "gcp_matching"
gcp_matching_dir.mkdir(parents=True, exist_ok=True)

patches_dir = gcp_matching_dir / "patches"
patches_dir.mkdir(exist_ok=True)

matches_dir = gcp_matching_dir / "matches"
matches_dir.mkdir(exist_ok=True)

registered_dir = gcp_matching_dir / "registered"
registered_dir.mkdir(exist_ok=True)

print(f"✓ Output directory: {gcp_matching_dir}")
print(f"  - Patches: {patches_dir}")
print(f"  - Matches: {matches_dir}")
print(f"  - Registered: {registered_dir}")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Update paths for Colab
data_dir = Path("/content/drive/MyDrive/Data/New Westminster Oct _25")
output_dir = Path("/content/drive/MyDrive/Code/MyCode/research-westminster_ground_truth_analysis/outputs")

# Update other paths accordingly
basemap_path = data_dir / "Michael_RTK_orthos" / "TestsiteNewWest_Spexigeo_RTK.tiff"
gcp_csv_path = data_dir / "25-3288-CONTROL-NAD83-UTM10N-EGM2008.csv"
ortho_no_gcps_path = output_dir / "orthomosaics" / "orthomosaic_no_gcps.tif"
ortho_with_gcps_path = output_dir / "orthomosaics" / "orthomosaic_with_gcps.tif"

print("✓ Colab paths configured")

## Step 2: Load GCPs from CSV and Convert to WGS84

In [3]:
# Load GCPs from CSV file and convert to WGS84
def load_gcps_from_csv(csv_path: Path) -> List[Dict]:
    """
    Load GCPs from CSV file and convert to WGS84.
    
    Returns list of GCP dictionaries with 'id', 'lat', 'lon', 'x_utm', 'y_utm' keys.
    """
    import csv
    gcps = []
    
    with open(csv_path, 'r') as f:
        # Try to read as DictReader first
        f.seek(0)
        sample = f.read(1024)
        f.seek(0)
        
        reader = csv.DictReader(f)
        fieldnames = reader.fieldnames
        
        if fieldnames:
            # Has headers - try to find relevant columns
            name_col = None
            x_col = None
            y_col = None
            
            # Common column name variations
            for col in fieldnames:
                col_lower = col.lower().strip()
                if 'name' in col_lower or 'id' in col_lower or 'point' in col_lower or 'label' in col_lower:
                    name_col = col
                elif 'east' in col_lower or ('x' in col_lower and 'utm' not in col_lower):
                    x_col = col
                elif 'north' in col_lower or ('y' in col_lower and 'utm' not in col_lower):
                    y_col = col
            
            if name_col and x_col and y_col:
                for row in reader:
                    try:
                        name = str(row[name_col]).strip()
                        x_val = float(row[x_col])
                        y_val = float(row[y_col])
                        
                        # Check if values are reasonable (UTM Zone 10N ranges)
                        # Easting: 100,000 - 999,999
                        # Northing: 0 - 10,000,000
                        if 100000 <= x_val <= 999999 and 0 <= y_val <= 10000000:
                            # x is easting, y is northing
                            lat, lon = utm.to_latlon(x_val, y_val, 10, 'N')
                        elif 100000 <= y_val <= 999999 and 0 <= x_val <= 10000000:
                            # y is easting, x is northing (swapped)
                            lat, lon = utm.to_latlon(y_val, x_val, 10, 'N')
                        else:
                            # Try both orderings
                            try:
                                lat, lon = utm.to_latlon(x_val, y_val, 10, 'N')
                            except:
                                lat, lon = utm.to_latlon(y_val, x_val, 10, 'N')
                        
                        gcps.append({
                            'id': name,
                            'lat': lat,
                            'lon': lon,
                            'x_utm': x_val,
                            'y_utm': y_val
                        })
                    except (ValueError, KeyError) as e:
                        print(f"⚠️  Skipping row: {e}")
                        continue
            else:
                print(f"⚠️  Could not find required columns. Found: {fieldnames}")
                print(f"   Looking for: name/id, x/easting, y/northing")
        else:
            # No headers - try positional format
            f.seek(0)
            reader = csv.reader(f)
            rows = list(reader)
            
            # Try to detect format by analyzing first few rows
            for row in rows:
                if len(row) < 3:
                    continue
                try:
                    # Try different column orders
                    name = str(row[0]).strip()
                    x_val = float(row[1])
                    y_val = float(row[2])
                    
                    # Try both orderings
                    try:
                        if 100000 <= x_val <= 999999:
                            lat, lon = utm.to_latlon(x_val, y_val, 10, 'N')
                        else:
                            lat, lon = utm.to_latlon(y_val, x_val, 10, 'N')
                    except:
                        continue
                    
                    gcps.append({
                        'id': name,
                        'lat': lat,
                        'lon': lon,
                        'x_utm': x_val,
                        'y_utm': y_val
                    })
                except (ValueError, IndexError) as e:
                    continue
    
    return gcps

# Load GCPs
gcps = load_gcps_from_csv(gcp_csv_path)
print(f"✓ Loaded {len(gcps)} GCPs from CSV")
if len(gcps) > 0:
    print(f"\nFirst few GCPs:")
    for gcp in gcps[:3]:
        print(f"  {gcp['id']}: lat={gcp['lat']:.6f}, lon={gcp['lon']:.6f}")
else:
    print(f"⚠️  No GCPs loaded! Check CSV format.")
    print(f"   CSV path: {gcp_csv_path}")
    if gcp_csv_path.exists():
        print(f"   File exists. Showing first few lines:")
        with open(gcp_csv_path, 'r') as f:
            for i, line in enumerate(f):
                if i < 5:
                    print(f"     {line.strip()}")
                else:
                    break


✓ Loaded 0 GCPs from CSV

First few GCPs:


## Step 3: Convert GCPs to Pixel Coordinates in Basemap

In [None]:
# Convert GCPs (WGS84) to pixel coordinates in basemap
def gcp_to_pixel_coords(gcp_lat: float, gcp_lon: float, raster_path: Path) -> Optional[Tuple[int, int]]:
    """
    Convert GCP lat/lon to pixel coordinates in raster.
    
    Returns (col, row) or None if outside bounds.
    """
    with rasterio.open(raster_path) as src:
        # Transform WGS84 to raster CRS
        x, y = transform_coords(
            'EPSG:4326',  # WGS84
            src.crs,
            [gcp_lon],
            [gcp_lat]
        )
        
        # Convert to pixel coordinates
        row, col = rasterio.transform.rowcol(src.transform, x[0], y[0])
        
        # Check if within bounds
        if 0 <= row < src.height and 0 <= col < src.width:
            return (col, row)
        else:
            return None

# Get basemap CRS and transform
with rasterio.open(basemap_path) as basemap_src:
    basemap_crs = basemap_src.crs
    basemap_transform = basemap_src.transform
    basemap_width = basemap_src.width
    basemap_height = basemap_src.height

print(f"Basemap CRS: {basemap_crs}")
print(f"Basemap dimensions: {basemap_width}x{basemap_height}")

# Convert all GCPs to pixel coordinates
gcp_pixel_coords = {}
for gcp in gcps:
    pixel_coords = gcp_to_pixel_coords(gcp['lat'], gcp['lon'], basemap_path)
    if pixel_coords:
        gcp_pixel_coords[gcp['id']] = {
            'gcp': gcp,
            'pixel_col': pixel_coords[0],
            'pixel_row': pixel_coords[1]
        }
    else:
        print(f"⚠️  GCP {gcp['id']} is outside basemap bounds")

print(f"\n✓ Found {len(gcp_pixel_coords)} GCPs within basemap bounds")
print(f"\nFirst few GCP pixel coordinates:")
for gcp_id, coords in list(gcp_pixel_coords.items())[:3]:
    print(f"  {gcp_id}: col={coords['pixel_col']}, row={coords['pixel_row']}")

## Step 4: Extract Patches from Basemap

In [None]:
# Extract patches from basemap centered on GCPs
def extract_patch(raster_path: Path, center_col: int, center_row: int, patch_size: int) -> Optional[np.ndarray]:
    """
    Extract a patch from raster centered on given pixel coordinates.
    
    Args:
        raster_path: Path to raster file
        center_col: Center column (x)
        center_row: Center row (y)
        patch_size: Size of patch (must be odd, e.g., 29, 39, 49)
    
    Returns:
        Patch array (H, W, C) or None if out of bounds
    """
    half_size = patch_size // 2
    
    with rasterio.open(raster_path) as src:
        # Calculate bounds
        col_start = max(0, center_col - half_size)
        col_end = min(src.width, center_col + half_size + 1)
        row_start = max(0, center_row - half_size)
        row_end = min(src.height, center_row + half_size + 1)
        
        # Check if patch would be out of bounds
        if col_end - col_start < patch_size or row_end - row_start < patch_size:
            return None
        
        # Read patch
        patch = src.read(
            window=rasterio.windows.Window(col_start, row_start, col_end - col_start, row_end - row_start)
        )
        
        # Transpose to (H, W, C) format
        if len(patch.shape) == 3:
            patch = np.transpose(patch, (1, 2, 0))
        
        # If single band, convert to 3-channel grayscale
        if len(patch.shape) == 2:
            patch = np.stack([patch, patch, patch], axis=-1)
        
        return patch

# Extract patches for different patch sizes
patch_sizes = [29, 39, 49, 59]  # Try different sizes
basemap_patches = {}

for patch_size in patch_sizes:
    basemap_patches[patch_size] = {}
    
    for gcp_id, coords in gcp_pixel_coords.items():
        patch = extract_patch(
            basemap_path,
            coords['pixel_col'],
            coords['pixel_row'],
            patch_size
        )
        
        if patch is not None:
            basemap_patches[patch_size][gcp_id] = patch
            
            # Save patch as image for visualization
            patch_path = patches_dir / f"basemap_{gcp_id}_{patch_size}x{patch_size}.png"
            plt.imsave(patch_path, patch.astype(np.uint8))
    
    print(f"✓ Extracted {len(basemap_patches[patch_size])} patches of size {patch_size}x{patch_size}")

print(f"\n✓ Patch extraction complete!")

## Step 5: Reproject Orthomosaics to Match Basemap CRS

In [6]:
from rasterio.warp import calculate_default_transform, reproject, Resampling, transform_bounds
from rasterio.transform import from_bounds
from rasterio.enums import Resampling as RasterioResampling
from affine import Affine

# Reproject orthos to match basemap CRS and resolution
def reproject_ortho_to_basemap(ortho_path: Path, basemap_path: Path, output_path: Path) -> Path:
    """
    Reproject orthomosaic to match basemap CRS and bounds.
    Uses manual transform construction to avoid CPLE_AppDefinedError.
    """
    if output_path.exists():
        print(f"  ✓ Already reprojected: {output_path}")
        return output_path
    
    with rasterio.open(basemap_path) as basemap_src:
        target_crs = basemap_src.crs
        target_bounds = basemap_src.bounds
        target_transform = basemap_src.transform
        target_width = basemap_src.width
        target_height = basemap_src.height
    
    with rasterio.open(ortho_path) as ortho_src:
        source_crs = ortho_src.crs
        source_bounds = ortho_src.bounds
        
        if source_crs == target_crs:
            print(f"  ✓ Already in target CRS")
            import shutil
            shutil.copy(ortho_path, output_path)
            return output_path
        
        # Transform source bounds to target CRS
        print(f"  Transforming source bounds to target CRS...")
        src_bounds_target_crs = transform_bounds(
            source_crs, target_crs,
            source_bounds.left, source_bounds.bottom,
            source_bounds.right, source_bounds.top
        )
        
        print(f"  Source bounds in target CRS: {src_bounds_target_crs}")
        
        # Get target pixel size
        target_pixel_size_x = abs(target_transform[0])
        target_pixel_size_y = abs(target_transform[4])
        
        # Use intersection of bounds
        output_left = max(src_bounds_target_crs[0], target_bounds.left)
        output_bottom = max(src_bounds_target_crs[1], target_bounds.bottom)
        output_right = min(src_bounds_target_crs[2], target_bounds.right)
        output_top = min(src_bounds_target_crs[3], target_bounds.top)
        
        print(f"  Output bounds (intersection): left={output_left:.2f}, bottom={output_bottom:.2f}, right={output_right:.2f}, top={output_top:.2f}")
        
        # Validate bounds
        if output_right <= output_left or output_top <= output_bottom:
            raise ValueError(f"Invalid output bounds: width={output_right-output_left}, height={output_top-output_bottom}")
        
        # Calculate dimensions using target pixel size
        width = int((output_right - output_left) / target_pixel_size_x)
        height = int((output_top - output_bottom) / target_pixel_size_y)
        
        # Validate dimensions
        if width <= 0 or height <= 0:
            raise ValueError(f"Invalid dimensions: width={width}, height={height}")
        
        # Create transform for output
        transform = Affine.translation(output_left, output_top) * Affine.scale(target_pixel_size_x, -target_pixel_size_y)
        
        print(f"  ✓ Transform calculated: {width}x{height} pixels")
        
        # Read source data
        source_data = ortho_src.read()
        source_count = ortho_src.count
        
        # Reproject
        reprojected_data = np.zeros((source_count, height, width), dtype=source_data.dtype)
        
        for band_idx in range(1, source_count + 1):
            reproject(
                source=rasterio.band(ortho_src, band_idx),
                destination=reprojected_data[band_idx - 1],
                src_transform=ortho_src.transform,
                src_crs=source_crs,
                dst_transform=transform,
                dst_crs=target_crs,
                resampling=Resampling.bilinear
            )
        
        # Save
        with rasterio.open(
            output_path,
            'w',
            driver='GTiff',
            height=height,
            width=width,
            count=source_count,
            dtype=reprojected_data.dtype,
            crs=target_crs,
            transform=transform,
            compress='lzw',
            BIGTIFF='YES',
            tiled=True,
            blockxsize=512,
            blockysize=512
        ) as dst:
            dst.write(reprojected_data)
    
    return output_path

# Check for existing reprojected files from test_matching notebook
existing_reprojected_dir = output_dir / "test_matching" / "reprojected"
reprojected_dir = gcp_matching_dir / "reprojected"
reprojected_dir.mkdir(exist_ok=True)

ortho_paths = {
    'no_gcps': ortho_no_gcps_path,
    'with_gcps': ortho_with_gcps_path
}

reprojected_paths = {}
for ortho_name, ortho_path in ortho_paths.items():
    if not ortho_path.exists():
        print(f"⚠️  Ortho not found: {ortho_path}")
        continue
    
    # Check for existing reprojected file from test_matching
    existing_reprojected = existing_reprojected_dir / f"{ortho_name}_reprojected.tif"
    if existing_reprojected.exists():
        print(f"\nFound existing reprojected file: {existing_reprojected}")
        # Copy to our directory
        import shutil
        reprojected_path = reprojected_dir / f"{ortho_name}_reprojected.tif"
        if not reprojected_path.exists():
            shutil.copy(existing_reprojected, reprojected_path)
            print(f"  ✓ Copied to: {reprojected_path}")
        else:
            print(f"  ✓ Already exists: {reprojected_path}")
        reprojected_paths[ortho_name] = reprojected_path
        continue
    
    # Otherwise, reproject
    print(f"\nReprojecting {ortho_name}...")
    reprojected_path = reproject_ortho_to_basemap(
        ortho_path,
        basemap_path,
        reprojected_dir / f"{ortho_name}_reprojected.tif"
    )
    reprojected_paths[ortho_name] = reprojected_path

print(f"\n✓ Reprojection complete!")


Reprojecting no_gcps...


CPLE_AppDefinedError: Too many points (10201 out of 10201) failed to transform, unable to compute output bounds.

## Step 6: Find GCP Patches in Orthomosaics Using Template Matching

In [None]:
# Find GCP patches in orthomosaics using template matching
def find_patch_in_ortho(
    template_patch: np.ndarray,
    ortho_path: Path,
    search_center_col: int,
    search_center_row: int,
    search_radius: int = 500  # Search within this radius (pixels)
) -> Optional[Tuple[int, int, float]]:
    """
    Find template patch in orthomosaic using template matching.
    
    Returns:
        (col, row, confidence) or None if not found
    """
    # Convert template to grayscale if needed
    if len(template_patch.shape) == 3:
        template_gray = cv2.cvtColor(template_patch.astype(np.uint8), cv2.COLOR_RGB2GRAY)
    else:
        template_gray = template_patch.astype(np.uint8)
    
    with rasterio.open(ortho_path) as ortho_src:
        # Define search window
        search_col_start = max(0, search_center_col - search_radius)
        search_col_end = min(ortho_src.width, search_center_col + search_radius)
        search_row_start = max(0, search_center_row - search_radius)
        search_row_end = min(ortho_src.height, search_center_row + search_radius)
        
        # Read search region
        search_window = rasterio.windows.Window(
            search_col_start,
            search_row_start,
            search_col_end - search_col_start,
            search_row_end - search_row_start
        )
        
        search_region = ortho_src.read(window=search_window)
        
        # Convert to (H, W, C) and then grayscale
        if len(search_region.shape) == 3:
            search_region = np.transpose(search_region, (1, 2, 0))
            if search_region.shape[2] == 1:
                search_gray = search_region[:, :, 0]
            else:
                search_gray = cv2.cvtColor(search_region.astype(np.uint8), cv2.COLOR_RGB2GRAY)
        else:
            search_gray = search_region
        
        # Normalize to uint8
        if search_gray.dtype != np.uint8:
            search_min = search_gray.min()
            search_max = search_gray.max()
            if search_max > search_min:
                search_gray = ((search_gray - search_min) / (search_max - search_min) * 255).astype(np.uint8)
            else:
                search_gray = np.zeros_like(search_gray, dtype=np.uint8)
        
        # Template matching
        result = cv2.matchTemplate(search_gray, template_gray, cv2.TM_CCOEFF_NORMED)
        
        # Find best match
        min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
        
        # Convert back to global coordinates
        match_col = search_col_start + max_loc[0] + template_gray.shape[1] // 2
        match_row = search_row_start + max_loc[1] + template_gray.shape[0] // 2
        
        # Return if confidence is high enough
        if max_val > 0.5:  # Threshold for match confidence
            return (match_col, match_row, float(max_val))
        else:
            return None

# Find GCPs in each orthomosaic
matching_results = {}

for ortho_name, reprojected_path in reprojected_paths.items():
    print(f"\n{'='*60}")
    print(f"Finding GCPs in {ortho_name} orthomosaic")
    print(f"{'='*60}")
    
    matching_results[ortho_name] = {}
    
    # Get expected pixel coordinates in ortho (same as basemap if properly aligned)
    with rasterio.open(reprojected_path) as ortho_src:
        ortho_transform = ortho_src.transform
    
    # Try different patch sizes
    best_patch_size = None
    best_matches = 0
    
    for patch_size in patch_sizes:
        matches_found = 0
        
        for gcp_id, coords in gcp_pixel_coords.items():
            if gcp_id not in basemap_patches[patch_size]:
                continue
            
            template = basemap_patches[patch_size][gcp_id]
            
            # Expected position in ortho (same as basemap if aligned)
            expected_col = coords['pixel_col']
            expected_row = coords['pixel_row']
            
            # Search for patch
            match = find_patch_in_ortho(
                template,
                reprojected_path,
                expected_col,
                expected_row,
                search_radius=500
            )
            
            if match:
                match_col, match_row, confidence = match
                matches_found += 1
                
                if gcp_id not in matching_results[ortho_name]:
                    matching_results[ortho_name][gcp_id] = {}
                
                matching_results[ortho_name][gcp_id][patch_size] = {
                    'expected_col': expected_col,
                    'expected_row': expected_row,
                    'matched_col': match_col,
                    'matched_row': match_row,
                    'offset_col': match_col - expected_col,
                    'offset_row': match_row - expected_row,
                    'confidence': confidence
                }
        
        print(f"  Patch size {patch_size}x{patch_size}: {matches_found}/{len(gcp_pixel_coords)} matches")
        
        if matches_found > best_matches:
            best_matches = matches_found
            best_patch_size = patch_size
    
    print(f"\n  ✓ Best patch size: {best_patch_size}x{best_patch_size} ({best_matches} matches)")

print(f"\n✓ Patch matching complete!")

## Step 7: Compute 2D Shift or Affine Transformation

In [None]:
from scipy.spatial.distance import cdist

# Compute transformation from matches
def compute_transformation(matches: Dict, use_affine: bool = False) -> Dict:
    """
    Compute 2D shift or affine transformation from GCP matches.
    
    Args:
        matches: Dictionary with GCP matches
        use_affine: If True, compute affine transformation; otherwise 2D shift
    
    Returns:
        Dictionary with transformation parameters
    """
    # Collect source and destination points
    src_points = []
    dst_points = []
    
    for gcp_id, match_data in matches.items():
        # Use the best patch size match
        best_patch_size = max(match_data.keys())
        match = match_data[best_patch_size]
        
        src_points.append([match['expected_col'], match['expected_row']])
        dst_points.append([match['matched_col'], match['matched_row']])
    
    src_points = np.array(src_points, dtype=np.float32)
    dst_points = np.array(dst_points, dtype=np.float32)
    
    if len(src_points) < 2:
        return {'type': 'insufficient_points', 'error': 'Need at least 2 matches'}
    
    if use_affine and len(src_points) >= 3:
        # Compute affine transformation (6 parameters)
        # Requires at least 3 points
        transform_matrix = cv2.getAffineTransform(
            src_points[:3],
            dst_points[:3]
        )
        
        # Apply to all points to compute error
        ones = np.ones((len(src_points), 1))
        src_homogeneous = np.hstack([src_points, ones])
        transformed = (transform_matrix @ src_homogeneous.T).T
        
        errors = dst_points - transformed
        rmse = float(np.sqrt(np.mean(errors**2)))
        
        return {
            'type': 'affine',
            'matrix': transform_matrix.tolist(),
            'rmse': rmse,
            'num_points': len(src_points)
        }
    else:
        # Compute 2D shift (mean offset)
        offsets = dst_points - src_points
        shift_x = float(np.mean(offsets[:, 0]))
        shift_y = float(np.mean(offsets[:, 1]))
        
        # Compute RMSE
        errors = offsets - np.array([shift_x, shift_y])
        rmse = float(np.sqrt(np.mean(errors**2)))
        
        return {
            'type': 'shift',
            'shift_x': shift_x,
            'shift_y': shift_y,
            'rmse': rmse,
            'num_points': len(src_points)
        }

# Compute transformations for each ortho
transformations = {}

for ortho_name in matching_results.keys():
    print(f"\n{'='*60}")
    print(f"Computing transformation for {ortho_name}")
    print(f"{'='*60}")
    
    # Try 2D shift first
    shift_result = compute_transformation(matching_results[ortho_name], use_affine=False)
    print(f"\n2D Shift:")
    print(f"  Shift X: {shift_result.get('shift_x', 'N/A'):.2f} px")
    print(f"  Shift Y: {shift_result.get('shift_y', 'N/A'):.2f} px")
    print(f"  RMSE: {shift_result.get('rmse', 'N/A'):.2f} px")
    print(f"  Points: {shift_result.get('num_points', 0)}")
    
    # Try affine if we have enough points
    if len(matching_results[ortho_name]) >= 3:
        affine_result = compute_transformation(matching_results[ortho_name], use_affine=True)
        print(f"\nAffine Transformation:")
        print(f"  RMSE: {affine_result.get('rmse', 'N/A'):.2f} px")
        print(f"  Points: {affine_result.get('num_points', 0)}")
        
        # Use the one with lower RMSE
        if affine_result.get('rmse', float('inf')) < shift_result.get('rmse', float('inf')):
            transformations[ortho_name] = affine_result
            print(f"  ✓ Using affine transformation (lower RMSE)")
        else:
            transformations[ortho_name] = shift_result
            print(f"  ✓ Using 2D shift (lower RMSE)")
    else:
        transformations[ortho_name] = shift_result
        print(f"  ✓ Using 2D shift (insufficient points for affine)")

# Save transformations
transformations_file = matches_dir / "transformations.json"
with open(transformations_file, 'w') as f:
    json.dump(transformations, f, indent=2)

print(f"\n✓ Transformations saved to: {transformations_file}")

## Step 8: Apply Transformation and Register Orthomosaics

In [None]:
# Apply transformation to orthomosaic
def apply_transformation(
    ortho_path: Path,
    transformation: Dict,
    output_path: Path,
    basemap_path: Path
) -> Path:
    """
    Apply transformation to register orthomosaic to basemap.
    """
    with rasterio.open(basemap_path) as basemap_src:
        target_width = basemap_src.width
        target_height = basemap_src.height
        target_transform = basemap_src.transform
        target_crs = basemap_src.crs
    
    with rasterio.open(ortho_path) as ortho_src:
        source_data = ortho_src.read()
        source_count = ortho_src.count
        
        # Apply transformation
        if transformation['type'] == 'shift':
            # Apply 2D shift using scipy
            shift_x = transformation['shift_x']
            shift_y = transformation['shift_y']
            
            registered_data = np.zeros((source_count, target_height, target_width), dtype=source_data.dtype)
            
            for band_idx in range(source_count):
                shifted = ndimage.shift(
                    source_data[band_idx],
                    (shift_y, shift_x),
                    mode='constant',
                    cval=0,
                    order=1
                )
                
                # Crop or pad to match target dimensions
                if shifted.shape[0] > target_height:
                    shifted = shifted[:target_height, :]
                elif shifted.shape[0] < target_height:
                    padded = np.zeros((target_height, shifted.shape[1]), dtype=shifted.dtype)
                    padded[:shifted.shape[0], :] = shifted
                    shifted = padded
                
                if shifted.shape[1] > target_width:
                    shifted = shifted[:, :target_width]
                elif shifted.shape[1] < target_width:
                    padded = np.zeros((target_height, target_width), dtype=shifted.dtype)
                    padded[:, :shifted.shape[1]] = shifted
                    shifted = padded
                
                registered_data[band_idx] = shifted
        
        elif transformation['type'] == 'affine':
            # Apply affine transformation
            transform_matrix = np.array(transformation['matrix'], dtype=np.float32)
            
            registered_data = np.zeros((source_count, target_height, target_width), dtype=source_data.dtype)
            
            # Convert to (H, W, C) for OpenCV
            if source_count == 1:
                img = source_data[0]
            else:
                img = np.transpose(source_data, (1, 2, 0))
            
            # Apply affine transform
            registered_img = cv2.warpAffine(
                img.astype(np.uint8),
                transform_matrix,
                (target_width, target_height),
                flags=cv2.INTER_LINEAR,
                borderMode=cv2.BORDER_CONSTANT,
                borderValue=0
            )
            
            # Convert back to (C, H, W)
            if source_count == 1:
                registered_data[0] = registered_img
            else:
                registered_data = np.transpose(registered_img, (2, 0, 1))
        
        # Save registered orthomosaic
        with rasterio.open(
            output_path,
            'w',
            driver='GTiff',
            height=target_height,
            width=target_width,
            count=source_count,
            dtype=registered_data.dtype,
            crs=target_crs,
            transform=target_transform,
            compress='lzw',
            BIGTIFF='YES',
            tiled=True
        ) as dst:
            dst.write(registered_data)
    
    return output_path

# Register orthos
registered_paths = {}

for ortho_name, transformation in transformations.items():
    if 'error' in transformation:
        print(f"⚠️  Skipping {ortho_name}: {transformation['error']}")
        continue
    
    print(f"\nRegistering {ortho_name}...")
    
    registered_path = apply_transformation(
        reprojected_paths[ortho_name],
        transformation,
        registered_dir / f"{ortho_name}_registered.tif",
        basemap_path
    )
    
    registered_paths[ortho_name] = registered_path
    print(f"  ✓ Saved: {registered_path}")

print(f"\n✓ Registration complete!")

## Step 9: Evaluate Accuracy Improvement

In [None]:
# Compare registered orthos to basemap
def evaluate_accuracy(ortho_path: Path, basemap_path: Path, gcps: List[Dict]) -> Dict:
    """
    Evaluate accuracy by comparing pixel values at GCP locations.
    """
    with rasterio.open(basemap_path) as basemap_src:
        basemap_data = basemap_src.read()
    
    with rasterio.open(ortho_path) as ortho_src:
        ortho_data = ortho_src.read()
    
    errors = []
    
    for gcp in gcps:
        pixel_coords = gcp_to_pixel_coords(gcp['lat'], gcp['lon'], basemap_path)
        if not pixel_coords:
            continue
        
        col, row = pixel_coords
        
        if 0 <= row < basemap_data.shape[1] and 0 <= col < basemap_data.shape[2]:
            basemap_pixel = basemap_data[:, row, col]
            
            if 0 <= row < ortho_data.shape[1] and 0 <= col < ortho_data.shape[2]:
                ortho_pixel = ortho_data[:, row, col]
                
                # Compute error (Euclidean distance in pixel space)
                error = np.sqrt(np.sum((basemap_pixel.astype(float) - ortho_pixel.astype(float))**2))
                errors.append(error)
    
    if errors:
        return {
            'mean_error': float(np.mean(errors)),
            'rmse': float(np.sqrt(np.mean(np.array(errors)**2))),
            'max_error': float(np.max(errors)),
            'num_points': len(errors)
        }
    else:
        return {'error': 'No valid GCPs found'}

# Evaluate accuracy for original and registered orthos
accuracy_results = {}

for ortho_name in reprojected_paths.keys():
    print(f"\n{'='*60}")
    print(f"Accuracy evaluation: {ortho_name}")
    print(f"{'='*60}")
    
    # Original (reprojected but not registered)
    original_accuracy = evaluate_accuracy(
        reprojected_paths[ortho_name],
        basemap_path,
        gcps
    )
    
    print(f"\nOriginal (reprojected):")
    print(f"  Mean error: {original_accuracy.get('mean_error', 'N/A'):.2f}")
    print(f"  RMSE: {original_accuracy.get('rmse', 'N/A'):.2f}")
    
    # Registered
    if ortho_name in registered_paths:
        registered_accuracy = evaluate_accuracy(
            registered_paths[ortho_name],
            basemap_path,
            gcps
        )
        
        print(f"\nRegistered:")
        print(f"  Mean error: {registered_accuracy.get('mean_error', 'N/A'):.2f}")
        print(f"  RMSE: {registered_accuracy.get('rmse', 'N/A'):.2f}")
        
        improvement = ((original_accuracy.get('rmse', 0) - registered_accuracy.get('rmse', 0)) / original_accuracy.get('rmse', 1)) * 100
        print(f"\n  Improvement: {improvement:.1f}%")
        
        accuracy_results[ortho_name] = {
            'original': original_accuracy,
            'registered': registered_accuracy,
            'improvement_percent': improvement
        }
    else:
        accuracy_results[ortho_name] = {
            'original': original_accuracy
        }

# Save results
results_file = gcp_matching_dir / "accuracy_results.json"
with open(results_file, 'w') as f:
    json.dump(accuracy_results, f, indent=2)

print(f"\n✓ Accuracy results saved to: {results_file}")