In [1]:
import numpy as np
import time
import rasterio
from rasterio.windows import from_bounds
from rasterio.warp import reproject, Resampling
from skimage import exposure
from skimage.feature import SIFT, match_descriptors
from skimage.morphology import binary_erosion, footprint_rectangle

In [2]:
def _prepare_dem(img, nodata=-9999.0, min_elev=-5.0):
    """
    Returns:
      proc: float32 image in [0,1] for SIFT
      mask: True where invalid (nodata / nan / outlier)
    """
    img = img.astype(np.float32)

    mask = (img == nodata) | np.isnan(img) | (img < min_elev)
    if np.all(mask):
        raise ValueError("All pixels are masked (NoData/outliers) in the selected overlap.")

    valid = img[~mask]
    valid_mean = np.nanmean(valid)

    img_filled = img.copy()
    img_filled[mask] = valid_mean

    # Robust contrast stretch based on valid pixels only
    p2, p98 = np.percentile(valid, (2, 98))
    proc = exposure.rescale_intensity(
        img_filled, in_range=(p2, p98), out_range=(0.0, 1.0)
    ).astype(np.float32)

    return proc, mask


def _robust_translation_from_matches(src_xy, dst_xy, z=3.5):
    """
    Robust translation estimate using median + MAD filtering.

    src_xy, dst_xy: (N,2) arrays in (x,y) coords (col,row)
    Returns:
      tx_pix, ty_pix, keep_ratio
    """
    d = dst_xy - src_xy  # (N,2)
    med = np.median(d, axis=0)

    mad = np.median(np.abs(d - med), axis=0) + 1e-6
    scale = 1.4826 * mad

    keep = (np.abs(d - med) <= z * scale).all(axis=1)
    keep_ratio = float(np.mean(keep))

    if np.sum(keep) >= 3:
        d2 = d[keep]
        med2 = np.median(d2, axis=0)
        return float(med2[0]), float(med2[1]), keep_ratio

    # If too few kept, fall back to raw median
    return float(med[0]), float(med[1]), keep_ratio


def calculate_drift_sift_translation_only(
    path_ref,
    path_mov,
    nodata=-9999.0,
    min_elev=-5.0,
    max_ratio=0.8,
    erosion_iters=10,
    min_matches=8,
    debug=True,
):
    """
    Computes bulk drift between two DEMs (same CRS, e.g., EPSG:3413) using:
      1) overlap extraction
      2) resample mov DEM onto ref overlap grid
      3) SIFT matching
      4) robust translation (median + MAD), translation-only

    Convention:
      We estimate the shift that maps REF -> MOV:
        (x_mov, y_mov) â‰ˆ (x_ref + dE, y_ref + dN)

    Returns:
      tx_pix, ty_pix, dE_m, dN_m, match_count, keep_ratio

    Where:
      tx_pix, ty_pix are in pixel units on the ref overlap grid (x=col, y=row)
      dE_m, dN_m are in meters in the raster CRS (east, north)
    """
    with rasterio.open(path_ref) as src1, rasterio.open(path_mov) as src2:
        # Same-CRS assumption (you said EPSG:3413 for both)
        if src1.crs != src2.crs:
            raise ValueError(f"CRS mismatch: ref={src1.crs}, mov={src2.crs}")

        # Compute overlap bounds in map coordinates
        b1, b2 = src1.bounds, src2.bounds
        left   = max(b1.left,   b2.left)
        right  = min(b1.right,  b2.right)
        bottom = max(b1.bottom, b2.bottom)
        top    = min(b1.top,    b2.top)

        if (left >= right) or (bottom >= top):
            raise ValueError("No spatial overlap between the two DEMs.")

        # Read REF overlap (defines target grid)
        win1 = from_bounds(left, bottom, right, top, transform=src1.transform)
        win1 = win1.round_offsets().round_lengths()
        dem1 = src1.read(1, window=win1, boundless=False)
        trans1 = src1.window_transform(win1)

        # Read MOV overlap on its own grid
        win2 = from_bounds(left, bottom, right, top, transform=src2.transform)
        win2 = win2.round_offsets().round_lengths()
        dem2_raw = src2.read(1, window=win2, boundless=False)
        trans2 = src2.window_transform(win2)

        # Resample MOV overlap -> REF overlap grid
        dem2_on_1 = np.empty_like(dem1, dtype=np.float32)
        reproject(
            source=dem2_raw,
            destination=dem2_on_1,
            src_transform=trans2,
            src_crs=src2.crs,
            dst_transform=trans1,
            dst_crs=src1.crs,
            resampling=Resampling.bilinear,
            src_nodata=nodata,
            dst_nodata=nodata,
        )

    # Prepare + masks
    proc1, mask1 = _prepare_dem(dem1, nodata=nodata, min_elev=min_elev)
    proc2, mask2 = _prepare_dem(dem2_on_1, nodata=nodata, min_elev=min_elev)

    # Hard-mask invalid regions to suppress edge artifacts from filling
    proc1[mask1] = 0.0
    proc2[mask2] = 0.0

    # Optionally erode valid region in REF to avoid NoData boundaries
    valid1 = ~mask1
    if erosion_iters and erosion_iters > 0:
        # binary_erosion with a 3x3 footprint repeated erosion_iters times
        for _ in range(erosion_iters):
            valid1 = binary_erosion(valid1, footprint_rectangle((3, 3)))
        if debug:
            print(f"[DEBUG] erosion_iters={erosion_iters}, valid_ref% after erosion: {100*np.mean(valid1):.2f}%")

    if debug:
        h, w = proc1.shape
        print(f"[DEBUG] overlap size: (h={h}, w={w})")
        print(f"[DEBUG] nodata% ref: {100*np.mean(mask1):.2f}%, mov(resampled): {100*np.mean(mask2):.2f}%")
        print(f"[DEBUG] transform.a (xres): {trans1.a}, transform.e (yres): {trans1.e}")

    # SIFT
    sift1, sift2 = SIFT(), SIFT()
    sift1.detect_and_extract(proc1)
    sift2.detect_and_extract(proc2)

    if sift1.descriptors is None or sift2.descriptors is None:
        raise ValueError("SIFT failed: descriptors are None (overlap too flat / too masked / too small).")

    if debug:
        print(f"[DEBUG] keypoints: ref={len(sift1.keypoints)}, mov={len(sift2.keypoints)}")

    # Match with ratio + cross-check
    matches = match_descriptors(
        sift1.descriptors,
        sift2.descriptors,
        max_ratio=max_ratio,
        cross_check=True
    )

    if debug:
        print(f"[DEBUG] matches after ratio+crosscheck: {matches.shape[0]} (max_ratio={max_ratio})")

    if matches.shape[0] < 4:
        raise ValueError(f"Insufficient matches: {matches.shape[0]} (need >= 4).")

    # Coordinates: SIFT gives (row,col). Convert to (x,y)=(col,row)
    src_rc = sift1.keypoints[matches[:, 0]]
    dst_rc = sift2.keypoints[matches[:, 1]]
    src_xy = src_rc[:, ::-1]
    dst_xy = dst_rc[:, ::-1]

    # Filter matches to those whose REF keypoints are in the eroded valid area
    if erosion_iters and erosion_iters > 0:
        cols = np.round(src_xy[:, 0]).astype(int)
        rows = np.round(src_xy[:, 1]).astype(int)
        in_bounds = (rows >= 0) & (rows < valid1.shape[0]) & (cols >= 0) & (cols < valid1.shape[1])
        keep = np.zeros(len(rows), dtype=bool)
        keep[in_bounds] = valid1[rows[in_bounds], cols[in_bounds]]
        src_xy = src_xy[keep]
        dst_xy = dst_xy[keep]
        if debug:
            print(f"[DEBUG] matches after valid-area filter: {src_xy.shape[0]}")

    if src_xy.shape[0] < min_matches:
        raise ValueError(
            f"Too few matches after filtering: {src_xy.shape[0]} (min_matches={min_matches}). "
            f"Try reducing erosion_iters, increasing overlap, or switching to hillshade/gradient preprocessing."
        )

    # Translation-only robust estimate
    tx_pix, ty_pix, keep_ratio = _robust_translation_from_matches(src_xy, dst_xy, z=3.5)

    # Pixel -> map meters (EPSG:3413 is meters)
    # trans1.a is +x pixel size; trans1.e is typically negative for north-up rasters
    dE_m = tx_pix * trans1.a
    dN_m = ty_pix * trans1.e  # keep sign; positive means north, negative means south (usually)

    return tx_pix, ty_pix, dE_m, dN_m, int(src_xy.shape[0]), keep_ratio

In [3]:
img1 = 'C:/Users/yhqian/Downloads/QGIS_maps/PointClouds_2_Raster/ALS_L1B_20190410T164528_165720_4/ALS_L1B_20190410T164528_165720_4_1m_concave.tif'
img2 = 'C:/Users/yhqian/Downloads/QGIS_maps/PointClouds_2_Raster/ALS_L1B_20190410T174554_181213_3/ALS_L1B_20190410T174554_181213_3_1m_concave.tif'

In [4]:
# Example usage
t0 = time.time()

tx_pix, ty_pix, dE_m, dN_m, n_matches, keep_ratio = calculate_drift_sift_translation_only(
    img1, img2,
    nodata=-9999.0,
    min_elev=-50.0,
    max_ratio=0.8,
    erosion_iters=10,
    min_matches=8,
    debug=True
)
print("\n=== Drift (REF -> MOV) ===")
print("tx_pix, ty_pix:", tx_pix, ty_pix)
print("dE_m, dN_m (m):", dE_m, dN_m)
print("matches used:", n_matches)
print("keep_ratio (MAD):", keep_ratio)

t1 = time.time()
print("Time cost:", t1 - t0)

[DEBUG] erosion_iters=10, valid_ref% after erosion: 13.31%
[DEBUG] overlap size: (h=1007, w=3576)
[DEBUG] nodata% ref: 84.20%, mov(resampled): 46.75%
[DEBUG] transform.a (xres): 0.9999711068909402, transform.e (yres): -1.0001339622641554
[DEBUG] keypoints: ref=2509, mov=8899
[DEBUG] matches after ratio+crosscheck: 1252 (max_ratio=0.8)
[DEBUG] matches after valid-area filter: 1190

=== Drift (REF -> MOV) ===
tx_pix, ty_pix: -73.0 -13.0
dE_m, dN_m (m): -72.99789080303863 13.00174150943402
matches used: 1190
keep_ratio (MAD): 0.9907563025210084
Time cost: 8.718449354171753


In [5]:
1765.138731529644/12278.0

0.14376435343945626

In [6]:
203.48486794871602/12278.0

0.01657312819259782

In [5]:
-72.99789080303863/12278.0

-0.005945421958221097

In [6]:
13.00174150943402/12278.0

0.0010589462053619497