In [None]:
from pystac_client import Client
from odc.stac import load
from dea_tools.coastal import pixel_tides

from dask.distributed import Client as DaskClient

import numpy as np
from pathlib import Path
import odc.geo.xr  # noqa
import folium

In [None]:
# STAC Catalog URL
catalog = "https://earth-search.aws.element84.com/v1"

# Create a STAC Client
client = Client.open(catalog)

# Set up Dask
dask_client = DaskClient(n_workers=4, threads_per_worker=4)
dask_client

In [None]:
# Find a location you're interested in on Google Maps and copy the coordinates
# by right-clicking on the map and clicking the coordinates

# These coords are in the order Y then X, or Latitude then Longitude
coords = 20.7748, 106.7785  # Near Haiphong
buffer = 0.1
bbox = (coords[1] - buffer, coords[0] - buffer, coords[1] + buffer, coords[0] + buffer)

datetime = "2023/2024"

# Tide data and config
home = Path("~")
tide_data_location = f"{home}/tide_models"

In [None]:
items = client.search(
    collections=["sentinel-2-c1-l2a"],
    bbox=bbox,
    datetime=datetime
).item_collection()

print(f"Found {len(items)} STAC items")

In [None]:
data = load(
    items,
    bands=["red", "green", "blue", "swir16", "cloud", "scl"],
    bbox=bbox,
    groupby="solar_day",
    chunks={"x": 2048, "y": 2048},
)

data

In [None]:
# cloud_mask = data.cloud > 25  # Percentage likelihood a pixel is cloud
NODATA = 0
CLOUD_SHADOW = 3
CLOUD_MEDIUM_PROBABILITY = 8
CLOUD_HIGH_PROBABILITY = 9

mask = data.scl.isin([NODATA, CLOUD_SHADOW, CLOUD_MEDIUM_PROBABILITY, CLOUD_HIGH_PROBABILITY])

masked = data.where(~mask, other=np.nan)
masked = masked.drop_vars("scl")

In [None]:
# We're doing a lowres tide model, as we are just going to filter out scenes
tides_lowres = pixel_tides(
    masked, resample=False, directory=tide_data_location, model="FES2022", dask_compute=True
)

In [None]:
lowest, highest = tides_lowres.quantile([0.3, 0.7]).values

low_scenes = tides_lowres.where(tides_lowres < lowest, drop=True)
high_scenes = tides_lowres.where(tides_lowres > highest, drop=True)

data_low = masked.sel(time=low_scenes.time)
data_high = masked.sel(time=high_scenes.time)

print(f"Found {len(data_low.time)} low tide days and {len(data_high.time)} high tide days out of {len(data.time)} days")

In [None]:
median_low = data_low.median("time").compute()
median_high = data_high.median("time").compute()

In [None]:
median = masked.median("time").compute()
median.odc.explore(vmin=1000, vmax=4000)

In [None]:
median_high.odc.explore(vmin=1000, vmax=4000)

In [None]:
median_low.odc.explore(vmin=1000, vmax=4000)

In [None]:
# Calculate MNDWI
median_low["mndwi"] = (median_low.green - median_low.swir16) / (median_low.green + median_low.swir16)
median_high["mndwi"] = (median_high.green - median_high.swir16) / (median_high.green + median_high.swir16)

In [None]:
# Plot LOW and HIGH MNDWI on the same map
m = folium.Map()

arguments = {
    "cmap": "RdBu",
    "vmin": -0.5,
    "vmax": 0.5,
}

# Plot each sample image with different colormap
median_low.mndwi.odc.add_to(m, name="low", **arguments)
median_high.mndwi.odc.add_to(m, name="high", **arguments)

folium.LayerControl().add_to(m)
m.fit_bounds(median_low.odc.map_bounds())
m