In [72]:
from pystac_client import Client
from odc.stac import load
import geopandas as gpd
import pandas as pd
import numpy as np
import xarray as xr
import folium

from sklearn.ensemble import RandomForestClassifier

import odc.geo.xr  # noqa: F401

In [2]:
# STAC Catalog URL
catalog = "https://earth-search.aws.element84.com/v1"

# Create a STAC Client
client = Client.open(catalog)

In [None]:
# Location is north of Kuching, in lat/lon order
ll = (1.57247,110.18403)
ur = (1.82405,110.56703)
bbox = (ll[1], ll[0], ur[1], ur[0])

# Three months of data
datetime = "2024-07/2024-09"

In [None]:
# Get the training data
data_url = "https://raw.githubusercontent.com/nick-murray/coastTrain/main/data/coastTrain_v1_0.geojson"

# gdf = gpd.read_file(data_url, bbox=bbox)

gdf.explore(column="Ecosys_Typ", legend=True)

In [None]:
# Search for Sentinel-2 data
items = client.search(
    collections=["sentinel-2-c1-l2a"],
    bbox=bbox,
    datetime=datetime,
    query={"eo:cloud_cover": {"lt": 50}},
).item_collection()

print(f"Found {len(items)} items")

In [None]:
# Load the data into an xarray Dataset
data = load(
    items,
    measurements=["red", "green", "blue", "nir08", "swir16", "scl"],
    bbox=bbox,
    chunks={"x": 2048, "y": 2048},
    groupby="solar_day",
)

data

In [None]:
# Mask out clouds and scale values

# Apply Sentinel-2 cloud mask
# 1: defective, 3: shadow, 9: high confidence cloud, 10: thin cirrus
mask_flags = [1, 3, 9, 10]

cloud_mask = ~data.scl.isin(mask_flags)
masked = data.where(cloud_mask)

# Apply scaling
scaled = (masked.where(masked != 0) * 0.0001).clip(0, 1)

# Add some indices
scaled["ndvi"] = (scaled.nir08 - scaled.red) / (scaled.nir08 + scaled.red)
scaled["ndwi"] = (scaled.green - scaled.nir08) / (scaled.green + scaled.nir08)

scaled


In [None]:
# Visualise one date, to make sure it looks good. This example doesn't look very good
# which highlights the difficulty of workingw with Sentinel-2!

scaled.isel(time=0).odc.explore(vmin=0, vmax=0.3)

In [None]:
# Create a median composite, which should get rid of most of the remaining clouds

median = scaled.median("time").compute()

median

In [None]:
median.odc.explore(vmin=0, vmax=0.3)

In [None]:
# Get observed values onto the training data
training = gdf.to_crs(median.odc.geobox.crs)
training_da = training.assign(x=training.geometry.x, y=training.geometry.y).to_xarray()

training_values = (
    median.sel(training_da[["x", "y"]], method="nearest").squeeze().compute().to_pandas()
)

training_array = pd.concat([training["Class"], training_values], axis=1)
training_array = training_array.drop(
    columns=[
        "y",
        "x",
        "spatial_ref",
    ]
)
# Drop rows where there are any NaNs
training_array = training_array.dropna()

training_array.head()

In [64]:
# Create a random forest classifier and fit it to the training data

# The training data is everything after the first column
training_data = np.array(training_array)[:, 1:]

# The classes are the first column
values = np.array(training_array)[:, 0]

classifier = RandomForestClassifier()

model = classifier.fit(training_data, values)

In [69]:
# Now run a prediction using the model

stacked_arrays = median.to_array().stack(dims=["y", "x"]).transpose()
predicted = model.predict(stacked_arrays)
array = predicted.reshape(len(median.y), len(median.x))

predicted_da = xr.DataArray(
    array, coords={"x": masked.x, "y": masked.y}, dims=["y", "x"]
)


In [None]:
# Put it all on a single interactive map

center = [np.mean([ll[0], ur[0]]), np.mean([ll[1], ur[1]])]
m = folium.Map(location=center, zoom_start=11)

median.odc.to_rgba(vmin=0, vmax=0.3).odc.add_to(m, name="Median Composite")
predicted_da.odc.add_to(m, name="Predicted")
gdf.explore(m=m, column="Ecosys_Typ", legend=True, name="Training Data")

# Layer control
folium.LayerControl().add_to(m)

m

## Look at the results, do they make sense?

Next steps are to fine tune the data. Perhaps download the points for this
region of interest as well as the RGB image and add and remove points until
there is a more representative training dataset.