# DEA Intertidal Elevation

This notebook demonstrates how to run and customise a DEA Intertidal Elevation analysis.

The notebook is adapted from the [Geoscience Australia example](https://github.com/GeoscienceAustralia/dea-intertidal/blob/rbt/notebooks/Intertidal_elevation.ipynb).

## Getting started

### Load packages

In [None]:
import os
import dask.config

import matplotlib.pyplot as plt
import xarray as xr
from dask.distributed import Client, LocalCluster
from datacube import Datacube
from datacube.utils.masking import create_mask_value, valid_data_mask
from eo_tides.eo import pixel_tides
from intertidal.elevation import (
    clean_edge_pixels,
    ds_to_flat,
    flat_to_ds,
    pixel_dem,
    pixel_rolling_median,
    pixel_uncertainty,
)
from intertidal.io import extract_geobox
from ipyleaflet import basemaps
from odc.geo.geom import point
from odc.stac import configure_s3_access
from urllib.parse import urlparse

## Configure the environment

In [None]:
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

if "AWS_NO_SIGN_REQUEST" in os.environ:
    del os.environ["AWS_NO_SIGN_REQUEST"]

configure_s3_access(requester_pays=True)

dc = Datacube()

In [None]:
# Set up Dask
cluster = LocalCluster(
    n_workers=2,
    threads_per_worker=2,
    memory_limit='10GB'
)

dashboard_url = cluster.dashboard_link
port = urlparse(dashboard_url).port

jupyterhub_user = os.environ.get('JUPYTERHUB_USER')
dask.config.set(**{
    "distributed.dashboard.link": f"/user/{jupyterhub_user}/proxy/{port}/status"
})

client = Client(cluster)
client

## Setup


### Set analysis parameters

In [None]:
# Intertidal Elevation variables
start_date = "2018"  # Start date for analysis
end_date = "2020"  # End date for analysis
resolution = 20  # Spatial resolution used for output files
crs = "EPSG:6933"  # Coordinate Reference System (CRS) to use for output files
min_freq = 0.01  # Minimum wetness freq required for pixel to be included in analysis
max_freq = 0.99  # Maximum wetness freq required for pixel to be included in analysis
min_correlation = 0.15  # Minimum correlation between water index and tide height
ndwi_thresh = 0.1  # Threshold used to identify dry/wet transition
include_s2 = True  # Include Sentinel-2 data in the analysis?
include_ls = True  # Include Landsat data in the analysis?
tide_model = "EOT20"  # Tide model to use in analysis
tide_model_dir = "~/jovyan/data/coastlines/tide_models/"
max_cloud_cover = (
    60  # Maximum cloud cover percentage for datasets to be included in analysis
)

#### Set study area

In [None]:
study_area = "testing"

coords = -7.135908786754243, 114.5000202988149

aoi_point = point(coords[1], coords[0], crs="EPSG:4326")
area = aoi_point.buffer(0.02).boundingbox

geobox = extract_geobox(geom=area.polygon, resolution=resolution, crs=crs)
geobox.explore(tiles=basemaps.Esri.WorldImagery)

## Load data

In [None]:
categories_to_mask_landsat = {
    "cloud": "high_confidence",
    "cloud_shadow": "high_confidence",
}

# Find and load Landsat datasets
landsat_datasets = dc.find_datasets(
    product=["ls8_c2l2_sr", "ls9_c2l2_sr"],
    time=(start_date, end_date),
    like=geobox.compat,
    cloud_cover=(0, max_cloud_cover),
    collection_category="T1",  # Only include T1 quality data
)

print(f"Found {len(landsat_datasets)} Landsat datasets")

ls_ds = dc.load(
    datasets=landsat_datasets,
    like=geobox.compat,
    measurements=["red", "green", "blue", "nir08", "swir16", "swir22", "qa_pixel"],
    group_by="solar_day",
    dask_chunks={"time": 1, "x": 3200, "y": 3200},
    resampling={
        "*": "cubic",
        "qa_pixel": "nearest",
    },
    # skip_broken_datasets=True,
    driver="rio",
)

# Mask Landsat data
mask_value, _ = create_mask_value(
    landsat_datasets[0].product.measurements["qa_pixel"].flags_definition,
    **categories_to_mask_landsat,
)

cloud_mask = (ls_ds.qa_pixel & mask_value) != 0
valid_data = valid_data_mask(ls_ds)
mask = cloud_mask | ~valid_data

ls_ds_masked = ls_ds.where(~mask).drop_vars("qa_pixel")

# Scale Landsat data, so that values are between 0 and 1
ls_ds_scaled = (ls_ds_masked * 0.0000275 - 0.2).clip(0, 1)

ls_ds_scaled

In [None]:
# ls_ds_scaled[["red", "green", "blue"]].isel(time=slice(0, 6)).to_array().plot.imshow(
#     col="time", col_wrap=3, size=4, vmin=0, vmax=0.3
# )

In [None]:
# Find and load Sentinel-2 datasets
sentinel2_datasets = dc.find_datasets(
    product=["s2_l2a"],
    time=(start_date, end_date),
    like=geobox.compat,
    cloud_cover=(0, max_cloud_cover),
)

print(f"Found {len(sentinel2_datasets)} Sentinel-2 datasets")

s2_ds = dc.load(
    datasets=sentinel2_datasets,
    like=geobox.compat,
    measurements=["red", "green", "blue", "nir08", "swir16", "swir22", "scl"],
    group_by="solar_day",
    dask_chunks={"time": 1, "x": 3200, "y": 3200},
    resampling={
        "*": "cubic",
        "scl": "nearest",
    },
    driver="rio",
)

# Mask Sentinel-2 data
# 3 is cloud shadow, 8 is medium probability cloud, 9 is high probability cloud
cloud_mask = s2_ds.scl.isin([3, 8, 9])
valid_data = valid_data_mask(s2_ds)
mask = cloud_mask | ~valid_data

s2_ds_masked = s2_ds.where(~mask).drop_vars("scl")

# Scale Sentinel-2 data so that values are between 0 and 1
s2_ds_scaled = (s2_ds_masked * 0.0001).clip(
    0, 1
)  # Scale Sentinel-2 data to match Landsat scale

s2_ds_scaled

In [None]:
# s2_ds_scaled[["red", "green", "blue"]].isel(time=slice(0, 6)).to_array().plot.imshow(
#     col="time", col_wrap=3, size=4, vmin=0, vmax=0.3
# )

In [None]:
# Combine Landsat and Sentinel-2 datasets
data = xr.concat([ls_ds_scaled, s2_ds_scaled], dim="time")

# Create NDWI
data["ndwi"] = (data.nir08 - data.swir16) / (data.nir08 + data.swir16)
data

In [None]:
%%time

# Load data into memory. Watch Dask for this, it'll do a bunch of work!
data = data.compute()

## Pixel-based tides

In [None]:
# Model tides into every pixel in the three-dimensional (x, y, time)
# satellite dataset. If `model` is "ensemble" this generate optimised
# tide modelling by combining the best local tide models.
tide_m = pixel_tides(
    data=data,
    model=tide_model,
    directory=tide_model_dir,
    ensemble_models=[
        "EOT20",
        "FES2012",
        "FES2014_extrapolated",
        "FES2022_extrapolated",
        "INATIDES",
        "GOT4.10",
        "GOT5.6_extrapolated",
        "TPXO10-atlas-v2-nc",
        "TPXO8-atlas-nc",
        "TPXO9-atlas-v5-nc",
    ],
    ensemble_func={"ensemble-mean-weightedtop3": lambda x: (4 - x["rank"]).clip(0, 3)},
)

In [None]:
# Set tide array pixels to nodata if the satellite data array pixels contain
# nodata. This ensures that we ignore any tide observations where we don't
# have matching satellite imagery
data["tide_m"] = tide_m.where(
    ~data.to_array().isel(variable=0).isnull().drop("variable")
)

## Pixel-based DEM creation

### Flatten array from 3D to 2D and drop pixels with no correlation with tide
Flatten array to only pixels with positive correlations between water observations and tide height. This greatly improves processing time by ensuring only a narrow strip of pixels along the coastline are analysed, rather than the entire x * y array:

In [None]:
flat_ds, freq, corr, clear = ds_to_flat(
    data,
    min_freq=min_freq,
    max_freq=max_freq,
    min_correlation=min_correlation,
    # valid_mask=topobathy_mask & coastal_mask,
)

# Reducing analysed pixels from 10240000 to 1002452 (9.79%)
# CPU times: user 3min, sys: 1min 25s, total: 4min 25s
# Wall time: 4min 13s

### Pixel-wise rolling median
This function performs a rolling median calculation along the tide heights of our satellite images. 
It breaks our tide range into `windows_n` individual rolling windows, each of which covers `windows_prop_tide` of the full tidal range. 
For each window, the function returns the median of all tide heights and NDWI index values within the window, and returns an array with a new "interval" dimension that summarises these values from low to high tide.

More windows (e.g. `windows_n=100`) produces detailed elevation maps that can capture small differences in intertidal morphology - at the expense of slower run times.
Fewer windows (e.g. `windows_n=50`) will run faster, but produce less smooth, less detailed elevation maps.

In [None]:
interval_ds = pixel_rolling_median(
    flat_ds,
    windows_n=100,
    window_prop_tide=0.15,
    max_workers=None,
    min_count=5,
)

### Model intertidal elevation and uncertainty

Test our workflow by plotting an example elevation extraction for a single pixel:

In [None]:
# from intertidal.elevation import pixel_dem_debug

# center = data.odc.geobox.center_pixel.coords

# x, y = center["x"].values[0], center["y"].values[0]

# interval_pixel, interval_smoothed_pixel = pixel_dem_debug(
#     x,
#     y,
#     flat_ds,
#     interval_ds,
#     interp_intervals=200,
#     smooth_radius=20,
#     min_periods=5,
#     plot_style="season"
# )

Now model our full elevation raster:

In [None]:
# Model elevation
flat_dem = pixel_dem(
    interval_ds,
    ndwi_thresh=ndwi_thresh,
    interp_intervals=200,
    smooth_radius=20,
    min_periods=5,
)

In [None]:
# Model uncertainty
low, high, uncertainty, misclassified = pixel_uncertainty(
    flat_ds, flat_dem, ndwi_thresh, method="mad"
)

# Add arrays to dataset
flat_dem[["elevation_low", "elevation_high", "elevation_uncertainty"]] = (
    low,
    high,
    uncertainty,
)

## Unstack outputs and export

In [None]:
# Combine QA layers with elevation layers
flat_combined = xr.combine_by_coords(
    [
        flat_dem,  # DEM data
        freq,  # Frequency
        corr,  # Correlation
        clear,  # Clear count
    ],
)

# Unstack elevation and uncertainty layers back into their original
# spatial dimensions
ds = flat_to_ds(flat_combined, data)

# Clean upper edge of intertidal zone in elevation layers
# (likely to be inaccurate edge pixels)
elevation_bands = [d for d in ds.data_vars if "elevation" in d]
ds[elevation_bands] = clean_edge_pixels(ds[elevation_bands])


In [None]:
fix, axes = plt.subplots(1, 5, figsize=(12, 3))
ds.elevation.plot.imshow(cmap="viridis", ax=axes[0])
ds.elevation_uncertainty.plot.imshow(cmap="inferno", vmin=0, vmax=0.5, ax=axes[1])
ds.qa_ndwi_corr.plot.imshow(cmap="RdBu", vmin=-0.7, vmax=0.7, ax=axes[2])
ds.qa_ndwi_freq.plot.imshow(cmap="Blues", vmin=0, vmax=1, ax=axes[3])
ds.qa_count_clear.plot.imshow(cmap="Greys", ax=axes[4])

In [None]:
ds.elevation.odc.explore(tiles=basemaps.Esri.WorldImagery)

## Close Dask client

In [None]:
client.shutdown()