# DEA Intertidal Elevation validation

This notebook calculates validation statistics for DEA Intertidal Elevation by comparing elevation values against external validation LiDAR and multibeam datasets.

<div class="alert alert-info">

**Note:** This is an experimental notebook containing preliminary validation results. These results will be updated upon publication of the DEA Intertidal Elevation scientific paper.

</div>

In [None]:
cd ../..

In [None]:
pip install -r dev-requirements.in --quiet

In [None]:
%load_ext autoreload
%autoreload 2

import warnings

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from rasterio.errors import RasterioIOError
from tqdm import tqdm

warnings.filterwarnings("ignore")

import datacube
import odc.geo.xr
from dea_tools.dask import create_local_dask_cluster
from dea_tools.datahandling import load_reproject
from dea_tools.spatial import xr_interpolate
from dea_tools.validation import eval_metrics
from odc.geo.geom import Geometry

from intertidal.validation import preprocess_validation


def add_substrate(gdf_lines, gdf_poly):
    substrate_dict = {
        "Sandy beach undiff": "Sandy beach",
        "Unclassified": "Other",
        "Mixed sandy shore undiff": "Sandy beach",
        "Mixed sand tidal flats undiff": "Tidal flats",
        "Sandy shore undiff": "Sandy beach",
        "Sandy tidal flats": "Tidal flats",
        "Flat boulder deposit (rock) undiff": "Rocky",
        "Pebble/cobble (rock shingle) beach": "Rocky",
        "Sloping hard rock shore": "Rocky",
        "Hard rocky shore platform": "Rocky",
        "Fine-medium sand beach": "Sandy beach",
        "Hard bedrock shore": "Rocky",
        "Tidal sediment flats (inferred from mangroves)": "Tidal flats",
        "Tidal flats (sediment undiff)": "Tidal flats",
        "Rocky shore platform (undiff)": "Rocky",
        "Sloping rocky shore (undiff)": "Rocky",
        "Boulder seawall": "Rocky",
        "Mixed sand and shell beach": "Sandy beach",
        "Muddy shore undiff": "Tidal flats",
        "Boulder or shingle-grade shore undiff": "Rocky",
        "Coarse sand beach": "Sandy beach",
        "Muddy tidal flats": "Tidal flats",
        "Sandy-mud tidal flats": "Tidal flats",
        "Soft `bedrock¿ shore platform": "Rocky",
        "Boulder (rock) beach": "Rocky",
        "Sandy beach with cobbles/pebbles (rock)": "Rocky",
        "Boulder/cobble (rock) beach": "Rocky",
        "Flat pebble/cobble deposit (rock) undiff": "Rocky",
        "Hard bedrock shore inferred": "Rocky",
    }

    # Ensure both GeoDataFrames use the same CRS
    gdf_lines = gdf_lines.to_crs(gdf_poly.crs)

    # Spatial join: which line features intersect which polygons
    joined = gpd.sjoin(gdf_lines, gdf_poly, how="inner", predicate="intersects")

    # Group by polygon index, get most common INTERTD1_V value
    most_common = (
        joined.groupby("index_right")["INTERTD1_V"]
        .agg(lambda x: x.mode().iloc[0] if not x.mode().empty else None)
        .rename("substrate")
        .map(substrate_dict)
    )

    # Join back to original polygons
    return gdf_poly.join(most_common)


dc = datacube.Datacube()

client = create_local_dask_cluster(return_client=True)

## Initial setup

In [None]:
# Read datum transformation geopackage
datum_gdf = gpd.read_file("/gdata1/data/tide_datums/AHD_to_MSL.gpkg", engine="pyogrio", use_arrow=True)

# Load polygons and set data paths
validation_sites_gdf = gpd.read_file("data/raw/validation_sites.geojson")
validation_sites_gdf["val_path"] = (
    "/gdata1/projects/coastal/intertidal/Elevation_data/Processed/" + validation_sites_gdf.year + "_combined.tif"
)

# Load smartline and add as column
smartline_gdf = gpd.read_file("/gdata1/data/smartline/Smartline.gpkg", engine="pyogrio", use_arrow=True)
validation_sites_gdf = add_substrate(smartline_gdf, validation_sites_gdf)

# Set up data paths
resampling = "average"
resolution = 10

## Run validation analysis for each polygon

In [None]:
outputs = []

for i, row in tqdm(validation_sites_gdf.iterrows(), total=len(validation_sites_gdf.index)):
    try:
        # Convert to Geometry and create GeoBox to load data into
        poly = Geometry(row.geometry, crs="EPSG:4326")
        poly_geobox = odc.geo.geobox.GeoBox.from_geopolygon(poly, crs="EPSG:3577", resolution=resolution)

        # Obtain tide metadata from intertidal datasets, and summarise
        # if multiple datasets are returned per polygon site
        dss = dc.find_datasets(
            product="ga_s2ls_intertidal_cyear_3",
            time=row.year,
            like=poly_geobox,
        )
        tr_vals = [i.metadata.search_fields["intertidal_tr"] for i in dss]
        otr_vals = [i.metadata.search_fields["intertidal_otr"] for i in dss]
        lat_vals = [i.metadata.search_fields["intertidal_lat"] for i in dss]
        hat_vals = [i.metadata.search_fields["intertidal_hat"] for i in dss]
        tr = np.max(tr_vals) if tr_vals else np.nan
        otr = np.max(otr_vals) if otr_vals else np.nan
        lat = np.min(lat_vals) if lat_vals else np.nan
        hat = np.max(hat_vals) if hat_vals else np.nan

        # Load elevation and uncertainty data from datacube
        modelled_ds = (
            dc.load(
                datasets=dss,
                like=poly_geobox,
                measurements=["elevation", "uncertainty"],
                dask_chunks={"x": 2048, "y": 2048},
                resampling=resampling,
                driver="rio",
            )
            .squeeze("time")
            .compute()
        )
        modelled_da = modelled_ds.elevation
        uncertainty_da = modelled_ds.uncertainty

        # modelled_ds = dc.load(
        #             product="nidem",
        #             like=poly_geobox,
        #             dask_chunks={"x": 2048, "y": 2048},
        #             resampling=resampling,
        #         ).squeeze("time").compute()
        # modelled_da = modelled_ds.nidem
        # uncertainty_da = modelled_ds.nidem

        # Skip polygon if no modelled data available
        if (~modelled_da.isnull()).sum().item() > 0:
            # Mask our data by our input polygon
            modelled_da = modelled_da.odc.mask(poly=poly)

            # Load validation data into polygon GeoBox
            validation_da = load_reproject(path=row.val_path, how=poly_geobox, resampling=resampling).compute()

            # Preprocess AHD data
            validation_m_ahd, modelled_m, uncertainty_m = preprocess_validation(
                validation_da,
                modelled_da,
                uncertainty_da,
                lat=lat,
                hat=hat,
            )

            # Interpolate AHD to MSL correction and apply to data
            ahd_to_msl = xr_interpolate(ds=validation_da, gdf=datum_gdf, columns=["ahd_to_msl"]).ahd_to_msl
            validation_m_msl = validation_m_ahd + ahd_to_msl.mean().item()

            output_df = pd.DataFrame({
                "i": i,
                "year": row.year,
                "tr": tr,
                "otr": otr,
                "lat": lat,
                "hat": hat,
                "validation_m_ahd": validation_m_ahd,
                "validation_m_msl": validation_m_msl,
                "modelled_m": modelled_m,
                "uncertainty_m": uncertainty_m,
                "substrate": row.substrate,
            })
            outputs.append(output_df)

    except (KeyError, RasterioIOError, IndexError, AssertionError):
        pass

# Combine and add additional columns
outputs_all_df = pd.concat(outputs)
outputs_all_df["uncertainty_perc"] = outputs_all_df.uncertainty_m / outputs_all_df.tr
outputs_all_df["category"] = pd.cut(
    outputs_all_df.tr,
    bins=(0, 2, 4, np.inf),
    labels=["microtidal", "mesotidal", "macrotidal"],
)

## Results

### Overall

In [None]:
# Plot and compare - heatmap
plt.figure(figsize=(5, 5))
lim_min, lim_max = np.nanpercentile(
    np.concatenate([outputs_all_df.validation_m_ahd, outputs_all_df.modelled_m]), [1, 99]
)
lim_min -= 0.5
lim_max += 0.5
plt.hexbin(
    x=outputs_all_df.validation_m_ahd,
    y=outputs_all_df.modelled_m,
    extent=(lim_min, lim_max, lim_min, lim_max),
    cmap="inferno",
)
plt.gca().set_facecolor("#0C0C0C")
plt.plot([lim_min, lim_max], [lim_min, lim_max], "--", c="white")
plt.margins(x=0, y=0)
plt.xlabel("LiDAR DEM (m AHD)")
plt.ylabel("Intertidal DEM (m MSL)")

# Accuracy statistics
print(eval_metrics(x=outputs_all_df.validation_m_ahd, y=outputs_all_df.modelled_m, round=3))
print(f"n                   {len(outputs_all_df.modelled_m)}")
print(f"area                {len(outputs_all_df.modelled_m) * (10 * 10) * 0.000001:.2f} km sq")

In [None]:
plt.gcf().savefig("DEAIntertidal_validation_all_ahd.png", dpi=200, bbox_inches="tight")

In [None]:
# Plot and compare - heatmap
plt.figure(figsize=(5, 5))
lim_min, lim_max = np.nanpercentile(
    np.concatenate([outputs_all_df.validation_m_msl, outputs_all_df.modelled_m]), [1, 99]
)
lim_min -= 0.5
lim_max += 0.5
plt.hexbin(
    x=outputs_all_df.validation_m_msl,
    y=outputs_all_df.modelled_m,
    extent=(lim_min, lim_max, lim_min, lim_max),
    cmap="inferno",
)
plt.gca().set_facecolor("#0C0C0C")
plt.plot([lim_min, lim_max], [lim_min, lim_max], "--", c="white")
plt.margins(x=0, y=0)
plt.xlabel("LiDAR DEM (m MSL)")
plt.ylabel("Intertidal DEM (m MSL)")

# Accuracy statistics
print(eval_metrics(x=outputs_all_df.validation_m_msl, y=outputs_all_df.modelled_m, round=3))
print(f"n                   {len(outputs_all_df.modelled_m)}")
print(f"area                {len(outputs_all_df.modelled_m) * (10 * 10) * 0.000001:.2f} km sq")

In [None]:
plt.savefig("DEAIntertidal_validation_all_msl.png", dpi=200, bbox_inches="tight")

#### Assorted plots

In [None]:
outputs_all_df.groupby("year").apply(lambda x: eval_metrics(x=x.validation_m_msl, y=x.modelled_m, round=3)).plot()

In [None]:
outputs_all_df.groupby(["year", "category"]).validation_m_msl.count().unstack(level=-1).plot.bar(stacked=True)

### By certainty

In [None]:
lim_min, lim_max = np.nanpercentile(
    np.concatenate([outputs_all_df.validation_m_msl, outputs_all_df.modelled_m]), [1, 99]
)
lim_min -= 0.1
lim_max += 0.1

fig, axes = plt.subplots(1, 2, figsize=(12, 5.5))

# Define a set of certainty ranges
cert_ranges = [(0, 0.3), (0.3, 10)]
scale_dict = [(-2.0, 1.0), (-2.0, 1.0)]
titles = ["Uncertainty (< ±0.30 m)", "Uncertainty (> ±0.30 m)"]

out = {}
for i, (min_thresh, max_thresh) in enumerate(cert_ranges):
    # outputs_subset_df = outputs_all_df.query(
    #     "(uncertainty_perc >= @min_thresh) & (uncertainty_perc < @max_thresh)"
    # )
    outputs_subset_df = outputs_all_df.query("(uncertainty_m >= @min_thresh) & (uncertainty_m < @max_thresh)")

    lim_min, lim_max = scale_dict[i]

    # Plot and compare - heatmap
    axes[i].hexbin(
        x=outputs_subset_df.validation_m_msl,
        y=outputs_subset_df.modelled_m,
        extent=(lim_min, lim_max, lim_min, lim_max),
        cmap="inferno",
        bins=100,
        vmin=0,
        vmax=50,
    )
    axes[i].plot([lim_min, lim_max], [lim_min, lim_max], "--", c="white")
    axes[i].margins(x=0, y=0)
    axes[i].set_title(titles[i])
    if i == 0:
        axes[i].set_ylabel("Intertidal DEM (m MSL)")
    elif i == 1:
        axes[i].set_xlabel("LiDAR DEM (m MSL)")

    # Accuracy statistics
    stats_df = eval_metrics(x=outputs_subset_df.validation_m_msl, y=outputs_subset_df.modelled_m, round=3)
    stats_df["n"] = len(outputs_subset_df.modelled_m)
    out[titles[i]] = stats_df

pd.DataFrame.from_records(
    out,
    index=["n", "Correlation", "R-squared", "RMSE", "MAE", "Bias"],
    columns=titles,
).round(2)

In [None]:
fig.savefig("DEAIntertidal_validation_uncertainty.png", dpi=200, bbox_inches="tight")

### By tide range class

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 3.6))

# Define a set of certainty ranges
cat_ranges = ["microtidal", "mesotidal", "macrotidal"]

scale_dict = {"microtidal": (-0.9, 0.6), "mesotidal": (-2, 1.2), "macrotidal": (-2, 2.5)}

out = {}
for i, cat in enumerate(cat_ranges):
    outputs_subset_df = outputs_all_df.query("category == @cat")

    lim_min, lim_max = scale_dict[cat]

    # Plot and compare - heatmap
    axes[i].hexbin(
        x=outputs_subset_df.validation_m_ahd,
        y=outputs_subset_df.modelled_m,
        extent=(lim_min, lim_max, lim_min, lim_max),
        cmap="inferno",
        bins=100,
        vmin=0,
        vmax=50,
    )
    axes[i].set_facecolor("#0C0C0C")
    axes[i].plot([lim_min, lim_max], [lim_min, lim_max], "--", c="white", linewidth=0.5)
    axes[i].margins(x=0, y=0)
    axes[i].set_title(cat)
    if i == 0:
        axes[i].set_ylabel("DEA Intertidal Elevation (m MSL)")
    elif i == 1:
        axes[i].set_xlabel("Validation DEM (m AHD)")

    # Accuracy statistics
    stats_df = eval_metrics(x=outputs_subset_df.validation_m_ahd, y=outputs_subset_df.modelled_m, round=3)
    stats_df["n"] = len(outputs_subset_df.modelled_m)
    out[cat] = stats_df


pd.DataFrame.from_records(
    out,
    index=["n", "Correlation", "R-squared", "RMSE", "MAE", "Bias"],
    columns=cat_ranges,
).round(2)

In [None]:
fig.savefig("DEAIntertidal_validation_micromesomacro_ahd.png", dpi=200, bbox_inches="tight")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 3.6))

# Define a set of certainty ranges
cat_ranges = ["microtidal", "mesotidal", "macrotidal"]

scale_dict = {"microtidal": (-0.9, 0.8), "mesotidal": (-2, 1.2), "macrotidal": (-2, 2.5)}

out = {}
for i, cat in enumerate(cat_ranges):
    outputs_subset_df = outputs_all_df.query("category == @cat")
    # )
    outputs_subset_df["modelled_m"] = outputs_subset_df.modelled_m.replace(-9999, np.nan)

    lim_min, lim_max = scale_dict[cat]

    # Plot and compare - heatmap
    axes[i].hexbin(
        x=outputs_subset_df.validation_m_msl,
        y=outputs_subset_df.modelled_m,
        extent=(lim_min, lim_max, lim_min, lim_max),
        cmap="inferno",
        bins=100,
        vmin=0,
        vmax=50,
    )
    axes[i].set_facecolor("#0C0C0C")
    axes[i].plot([lim_min, lim_max], [lim_min, lim_max], "--", c="white", linewidth=0.5)
    axes[i].margins(x=0, y=0)
    axes[i].set_title(cat)
    if i == 0:
        axes[i].set_ylabel("DEA Intertidal Elevation (m MSL)")
    elif i == 1:
        axes[i].set_xlabel("Validation DEM (m MSL)")

    # Accuracy statistics
    stats_df = eval_metrics(x=outputs_subset_df.validation_m_msl, y=outputs_subset_df.modelled_m, round=3)
    stats_df["n"] = len(outputs_subset_df.modelled_m)
    out[cat] = stats_df


pd.DataFrame.from_records(
    out,
    index=["n", "Correlation", "R-squared", "RMSE", "MAE", "Bias"],
    columns=cat_ranges,
).round(2)

In [None]:
fig.savefig("DEAIntertidal_validation_micromesomacro_msl.png", dpi=200, bbox_inches="tight")

## By substrate

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 3.6))

# Define a set of certainty ranges
cat_ranges = ["Sandy beach", "Tidal flats", "Rocky"]

scale_dict = {"Sandy beach": (-2.0, 1.1), "Tidal flats": (-2.0, 1.1), "Rocky": (-1.5, 1.1)}

out = {}
for i, cat in enumerate(cat_ranges):
    outputs_subset_df = outputs_all_df.query("substrate == @cat")

    lim_min, lim_max = scale_dict[cat]

    # Plot and compare - heatmap
    axes[i].hexbin(
        x=outputs_subset_df.validation_m_msl,
        y=outputs_subset_df.modelled_m,
        extent=(lim_min, lim_max, lim_min, lim_max),
        cmap="inferno",
        bins=100,
        vmin=0,
        vmax=50,
    )
    axes[i].set_facecolor("#0C0C0C")
    axes[i].plot([lim_min, lim_max], [lim_min, lim_max], "--", c="white", linewidth=0.5)
    axes[i].margins(x=0, y=0)
    axes[i].set_title(cat)
    if i == 0:
        axes[i].set_ylabel("DEA Intertidal Elevation (m MSL)")
    elif i == 1:
        axes[i].set_xlabel("Validation DEM (m MSL)")

    # Accuracy statistics
    stats_df = eval_metrics(x=outputs_subset_df.validation_m_msl, y=outputs_subset_df.modelled_m, round=3)
    stats_df["n"] = len(outputs_subset_df.modelled_m)
    out[cat] = stats_df

pd.DataFrame.from_records(
    out,
    index=["n", "Correlation", "R-squared", "RMSE", "MAE", "Bias", "Regression slope"],
    columns=cat_ranges,
).round(2)

## 