# Orthomosaic Registration Pipeline Demo

This notebook demonstrates the complete hierarchical orthomosaic registration pipeline:
1. Downloading basemap from H3 cells
2. Preprocessing and overlap computation
3. Feature matching at multiple scales
4. Transformation computation and application
5. Final registered orthomosaic


## Setup and Imports


In [None]:
import sys
from pathlib import Path
import json
import xml.etree.ElementTree as ET
import numpy as np
import matplotlib.pyplot as plt
import rasterio
from rasterio.plot import show
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path to import modules
notebook_dir = Path.cwd()
repo_root = notebook_dir.parent
sys.path.insert(0, str(repo_root))

from defaults import DEFAULT_SCALES, DEFAULT_ALGORITHMS, DEFAULT_MATCHER
from basemap_downloader import (
    download_basemap, h3_cells_to_bbox, load_h3_cells_from_file,
    parse_bbox_string
)
from preprocessing import ImagePreprocessor
from matching import match_lightglue, visualize_matches, create_mask, LIGHTGLUE_AVAILABLE
from transformations import (
    load_matches, remove_gross_outliers, compute_2d_shift, compute_homography
)
from register_orthomosaic import OrthomosaicRegistration

print(f"Working directory: {notebook_dir}")
print(f"Repository root: {repo_root}")
print(f"Default scales: {DEFAULT_SCALES}")
print(f"Default matcher: {DEFAULT_MATCHER}")


## Step 1: Parse H3 Cells and Download Basemap


In [None]:
# Parse H3 cells from XML file
def parse_h3_cells_from_xml(xml_path: Path) -> list:
    """Extract H3 cell IDs from Word XML format."""
    tree = ET.parse(xml_path)
    root = tree.getroot()
    
    # Find all text elements containing H3 cells
    h3_cells = []
    for elem in root.iter():
        if elem.text:
            text = elem.text.strip()
            # H3 cells are typically 15-character hex strings
            if len(text) == 15 and all(c in '0123456789abcdef' for c in text.lower()):
                h3_cells.append(text)
    
    return list(set(h3_cells))  # Remove duplicates

# Load H3 cells
h3_xml_path = repo_root / "inputs" / "qualicum_beach" / "h3_cells.xml"
h3_cells = parse_h3_cells_from_xml(h3_xml_path)
print(f"Found {len(h3_cells)} unique H3 cells:")
for i, cell in enumerate(h3_cells[:5]):  # Show first 5
    print(f"  {i+1}. {cell}")
if len(h3_cells) > 5:
    print(f"  ... and {len(h3_cells) - 5} more")


In [None]:
# Convert H3 cells to bounding box
try:
    bbox = h3_cells_to_bbox(h3_cells)
    min_lat, min_lon, max_lat, max_lon = bbox
    print(f"Bounding box: ({min_lat:.6f}, {min_lon:.6f}, {max_lat:.6f}, {max_lon:.6f})")
    print(f"  Latitude range: {max_lat - min_lat:.6f} degrees")
    print(f"  Longitude range: {max_lon - min_lon:.6f} degrees")
except ImportError as e:
    print(f"Error: {e}")
    print("Installing h3 package...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "h3"])
    import importlib
    import basemap_downloader
    importlib.reload(basemap_downloader)
    from basemap_downloader import h3_cells_to_bbox
    bbox = h3_cells_to_bbox(h3_cells)


In [None]:
# Download basemap
output_dir = repo_root / "outputs" / "notebook_demo"
output_dir.mkdir(parents=True, exist_ok=True)

basemap_path = output_dir / "downloaded_basemap_esri.tif"

if not basemap_path.exists():
    print("Downloading basemap from ESRI World Imagery...")
    downloaded_path = download_basemap(
        bbox=bbox,
        output_path=str(basemap_path),
        source="esri",
        target_resolution=0.5  # 0.5 meters per pixel
    )
    print(f"✓ Basemap downloaded to: {downloaded_path}")
else:
    print(f"✓ Basemap already exists: {basemap_path}")
    downloaded_path = str(basemap_path)


In [None]:
# Visualize downloaded basemap
with rasterio.open(downloaded_path) as src:
    fig, ax = plt.subplots(1, 1, figsize=(12, 10))
    show(src, ax=ax, title="Downloaded Basemap (ESRI World Imagery)")
    plt.tight_layout()
    plt.show()
    
    print(f"Basemap info:")
    print(f"  Size: {src.width} x {src.height} pixels")
    print(f"  CRS: {src.crs}")
    print(f"  Bounds: {src.bounds}")


## Step 2: Load Source Orthomosaic and Initialize Preprocessor


In [None]:
# Load source orthomosaic
source_path = repo_root / "inputs" / "qualicum_beach" / "orthomosaic_no_gcps.tif"

print(f"Source orthomosaic: {source_path}")
print(f"  Exists: {source_path.exists()}")

# Initialize preprocessor
preprocessor = ImagePreprocessor(
    source_path=str(source_path),
    target_path=downloaded_path,
    output_dir=output_dir
)

# Log metadata
preprocessor.log_metadata()


In [None]:
# Visualize source orthomosaic
with rasterio.open(source_path) as src:
    fig, ax = plt.subplots(1, 1, figsize=(12, 10))
    show(src, ax=ax, title="Source Orthomosaic")
    plt.tight_layout()
    plt.show()
    
    print(f"Source orthomosaic info:")
    print(f"  Size: {src.width} x {src.height} pixels")
    print(f"  CRS: {src.crs}")
    print(f"  Bounds: {src.bounds}")


## Step 3: Preprocessing - Create Resolution Pyramid


In [None]:
# Define scales for hierarchical registration (using defaults from defaults module)
scales = DEFAULT_SCALES.copy()
print(f"Processing scales: {scales}")

# Create resolution pyramid
for scale in scales:
    print(f"\n--- Scale {scale:.3f} ---")
    
    # Load downsampled images
    source_img, target_img = preprocessor.load_downsampled(scale)
    print(f"  Source shape: {source_img.shape}")
    print(f"  Target shape: {target_img.shape}")
    
    # Compute overlap region
    overlap_info = preprocessor.compute_overlap_region(scale)
    if overlap_info:
        print(f"  Overlap region: {overlap_info['source']}")
        
        # Crop to overlap
        source_overlap, target_overlap = preprocessor.crop_to_overlap(
            source_img, target_img, overlap_info
        )
        print(f"  Source overlap shape: {source_overlap.shape}")
        print(f"  Target overlap shape: {target_overlap.shape}")
    else:
        print(f"  ⚠ No overlap found at scale {scale}")


In [None]:
# Visualize overlap regions at different scales
fig, axes = plt.subplots(2, len(scales), figsize=(20, 10))

for idx, scale in enumerate(scales):
    source_img, target_img = preprocessor.load_downsampled(scale)
    overlap_info = preprocessor.compute_overlap_region(scale)
    
    if overlap_info:
        source_overlap, target_overlap = preprocessor.crop_to_overlap(
            source_img, target_img, overlap_info
        )
        
        axes[0, idx].imshow(source_overlap, cmap='gray')
        axes[0, idx].set_title(f"Source Overlap\nScale {scale:.3f}")
        axes[0, idx].axis('off')
        
        axes[1, idx].imshow(target_overlap, cmap='gray')
        axes[1, idx].set_title(f"Target Overlap\nScale {scale:.3f}")
        axes[1, idx].axis('off')

plt.tight_layout()
plt.show()


## Step 4: Feature Matching at Multiple Scales


In [None]:
# Check if LightGlue is available
if not LIGHTGLUE_AVAILABLE:
    print("⚠ LightGlue not available. Install with: pip install lightglue")
    print("Falling back to SIFT matching...")
else:
    print("✓ LightGlue is available")


In [None]:
# Match features at each scale
matches_by_scale = {}
match_visualizations = {}

for scale in scales:
    print(f"\n{'='*60}")
    print(f"Matching at scale {scale:.3f}")
    print(f"{'='*60}")
    
    # Load downsampled images
    source_img, target_img = preprocessor.load_downsampled(scale)
    
    # Compute overlap
    overlap_info = preprocessor.compute_overlap_region(scale)
    if not overlap_info:
        print(f"  ⚠ No overlap at scale {scale}, skipping...")
        continue
    
    # Crop to overlap
    source_overlap, target_overlap = preprocessor.crop_to_overlap(
        source_img, target_img, overlap_info
    )
    
    # Create masks
    source_mask = create_mask(source_overlap)
    target_mask = create_mask(target_overlap)
    
    # Compute matches using LightGlue
    pixel_resolution = 0.02 / scale  # meters per pixel
    print(f"  Pixel resolution: {pixel_resolution:.4f} m/pixel")
    
    if LIGHTGLUE_AVAILABLE:
        matches_result = match_lightglue(
            source_overlap, target_overlap, source_mask, target_mask,
            use_tiles=True,
            pixel_resolution_meters=pixel_resolution
        )
    else:
        from matching import match_sift
        matches_result = match_sift(source_overlap, target_overlap, source_mask, target_mask)
    
    if matches_result and 'matches' in matches_result and len(matches_result['matches']) > 0:
        num_matches = len(matches_result['matches'])
        print(f"  ✓ Found {num_matches} matches")
        matches_by_scale[scale] = matches_result
        
        # Visualize matches
        viz_path = output_dir / f"matches_scale{scale:.3f}.png"
        visualize_matches(
            source_overlap, target_overlap, matches_result, viz_path,
            source_name=f"source_scale{scale:.3f}",
            target_name=f"target_scale{scale:.3f}",
            skip_json=True
        )
        match_visualizations[scale] = viz_path
        print(f"  ✓ Saved visualization: {viz_path.name}")
    else:
        print(f"  ✗ No matches found at scale {scale}")


In [None]:
# Display match visualizations
fig, axes = plt.subplots(1, len(matches_by_scale), figsize=(20, 5))
if len(matches_by_scale) == 1:
    axes = [axes]

for idx, (scale, viz_path) in enumerate(match_visualizations.items()):
    img = plt.imread(viz_path)
    axes[idx].imshow(img)
    axes[idx].set_title(f"Matches at Scale {scale:.3f}\n({len(matches_by_scale[scale]['matches'])} matches)")
    axes[idx].axis('off')

plt.tight_layout()
plt.show()


In [None]:
# Run the complete registration pipeline
print("Running full hierarchical registration pipeline...")
print(f"Source: {source_path}")
print(f"Target: {downloaded_path}")
print(f"Output: {output_dir}")
print(f"Scales: {scales}")

# Create registration instance (using defaults from defaults module)
registration = OrthomosaicRegistration(
    source_path=str(source_path),
    target_path=downloaded_path,
    output_dir=str(output_dir),
    scales=scales,  # Uses DEFAULT_SCALES
    matcher=DEFAULT_MATCHER,
    transform_types={scale: algo for scale, algo in zip(DEFAULT_SCALES, DEFAULT_ALGORITHMS)},
    debug_level='high'  # Get all intermediate files
)

# Run registration
result = registration.register()

if result:
    print(f"\n✓ Registration complete!")
    print(f"Final output: {result}")
else:
    print("\n✗ Registration failed")


## Step 6: Visualize Results


In [None]:
# Load and visualize final registered orthomosaic
if result and result.exists():
    with rasterio.open(result) as src:
        fig, axes = plt.subplots(1, 2, figsize=(20, 10))
        
        # Original source
        with rasterio.open(source_path) as orig:
            show(orig, ax=axes[0], title="Original Source Orthomosaic")
        
        # Registered result
        show(src, ax=axes[1], title="Registered Orthomosaic")
        
        plt.tight_layout()
        plt.show()
        
        print(f"Registered orthomosaic info:")
        print(f"  Size: {src.width} x {src.height} pixels")
        print(f"  CRS: {src.crs}")
        print(f"  Bounds: {src.bounds}")
else:
    print("Final registered orthomosaic not found")


In [None]:
# Display error histograms if available
matching_dir = output_dir / "matching_and_transformations"
if matching_dir.exists():
    import glob
    histograms = sorted(glob.glob(str(matching_dir / "error_histogram_scale*.png")))
    
    if histograms:
        fig, axes = plt.subplots(1, len(histograms), figsize=(20, 5))
        if len(histograms) == 1:
            axes = [axes]
        
        for idx, hist_path in enumerate(histograms):
            img = plt.imread(hist_path)
            axes[idx].imshow(img)
            axes[idx].set_title(Path(hist_path).stem)
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.show()
