# Machine Learning using Sentinel-2 Data

This example uses training data from the
[Coast Train](https://github.com/nick-murray/coastTrain) dataset
along with Sentinel-2 data to demonstrate how to use a
machine learning classifier, in this case, Random Forest, to
assign a class to each pixel.

This notebook combines lessons from previous notebooks into
a comprehensive worked example.

## Getting started

First we load the required Python libraries and tools.

In [1]:
from pystac_client import Client
from dask.distributed import Client as DaskClient
from odc.stac import load, configure_rio
import geopandas as gpd
import pandas as pd
import numpy as np
import xarray as xr
import folium
from ipyleaflet import basemaps

from sklearn.ensemble import RandomForestClassifier

## Study site configuration

Here we establish the STAC catalog we're using as well as a
spatial and temporal extent. This can be anywhere, but this location
near Kuching was chosen due to the training data having several
classes available.

In [2]:
# STAC Catalog URL
catalog = "https://earth-search.aws.element84.com/v1"

# Create a STAC Client
client = Client.open(catalog)

In [18]:
# Location is east of Zamboanga City, Philippines.
ll = (6.87774,122.07482)  # lat/lon order
ur = (7.00761,122.28968)
bbox = (ll[1], ll[0], ur[1], ur[0])

# Four months of data
datetime = "2024-06/2024-09"

## Configure our environment.

This cell sets up Dask, which we use for parallel computing, and configures
AWS credentials for "unsigned" (public) data access.


In [None]:
# Create local dask cluster to improve data load time.
dask_client = DaskClient(n_workers=2, threads_per_worker=16, memory_limit='16GB')

# We set up Rasterio to optimise loading data
_ = configure_rio(cloud_defaults=True)

dask_client

## Training data

Next up we gather training data. This could be any geospatial point dataset
with a column that is numeric, for the class.

If you'd like to explore the structure of this data, you can run `gdf.head()`
to see the first few rows. The `explore()` function with the `column` argument
will show the data on the map, and change the colour based on that column.

In [None]:
# Get the training data
data_url = "https://raw.githubusercontent.com/nick-murray/coastTrain/main/data/coastTrain_v1_0.geojson"
gdf = gpd.read_file(data_url, bbox=bbox)

# Alternately, use your updated data if you have it.
# gdf = gpd.read_file("data/training_revised.geojson", bbox=bbox)

gdf.explore(column="Ecosys_Typ", legend=True, tiles=basemaps.Esri.WorldImagery, style_kwds={"radius":5})

## Find and load Sentinel-2 data

Here we search for Sentinel-2 scenes over our study area and use
Dask to lazy-load them. We're only loading the red, green, blue, nir and swir
bands, along with the scene classification (scl) band.

In [None]:
# Search for Sentinel-2 data
items = client.search(
    collections=["sentinel-2-c1-l2a"],
    bbox=bbox,
    datetime=datetime,
    query={"eo:cloud_cover": {"lt": 90}},  # Remove scenes completely if they have a cloud cover of 90% or more
).item_collection()

print(f"Found {len(items)} items")

In [None]:
# Load the data into an xarray Dataset
data = load(
    items,
    measurements=["red", "green", "blue", "nir08", "swir16", "scl"],
    bbox=bbox,
    chunks={"x": 2048, "y": 2048},
    groupby="solar_day",
    skip_failures=True,
)

data

## Data preparation

Now that we have data, we need to clean it up, masking out clouds
and scaling values to between 0-1, which are the valid reflectance
values.

We add a couple of indices too, which will help the machine learning
algorithm.

Note that we still have a lazy-loaded array, and haven't transferred
any data over the network.

In [None]:
# Mask out clouds and scale values

# Apply Sentinel-2 cloud mask
# 1: defective, 3: shadow, 8: low confidence cloud, 9: high confidence cloud, 10: thin cirrus
mask_flags = [1, 3, 8, 9, 10]

cloud_mask = ~data.scl.isin(mask_flags)
masked = data.where(cloud_mask)

# Apply scaling and clip to valid data, from 0 to 1
scaled = (masked.where(masked != 0) * 0.0001).clip(0, 1)

# NDVI for vegetation density
scaled["ndvi"] = (scaled.nir08 - scaled.red) / (scaled.nir08 + scaled.red)
# NDWI for distinguishing between vegetation and water
scaled["ndwi"] = (scaled.green - scaled.nir08) / (scaled.green + scaled.nir08)
# Modified form of MVI (Baloloy et al. 2020) for distinguishing between mangroves and other vegetation types
scaled["mvi"] = (((scaled.nir08 - scaled.green) / (scaled.swir16 - scaled.green)) * 0.1).clip(-1,1)

scaled

In [None]:
# Visualise one date, to make sure it looks good.
# This example shows empty areas where we've masked out clouds.

# This process of loading should take less than a minute
scaled.isel(time=0).odc.explore(vmin=0, vmax=0.3, tiles=basemaps.CartoDB.DarkMatter)

## Create a cloud-free composite

The final data preparation step involves creating a temporal
median of the data bands. Here we use `compute()` to process
the data and bring it into memory.

We preview the data in the second cell below.

In [None]:
# Create a median composite, which should get rid of most of the remaining clouds
# Note that this will take a few minutes to complete

median = scaled.median("time").compute()

median

In [None]:
median.odc.explore(vmin=0, vmax=0.3, tiles=basemaps.CartoDB.DarkMatter)

## Prepare training data array

This next step involves extracting observed values from the satellite data
and combining them with our point data, resulting in something like this:

`class, red, green, blue ...`

This structure is then fed into the machine learning classifier.

In [None]:
# First transform the training points to the same CRS as the data
training = gdf.to_crs(median.odc.geobox.crs)

# Next get the X and Y values out of the point geometries
training_da = training.assign(x=training.geometry.x, y=training.geometry.y).to_xarray()

# Now we can use the x and y values (lon, lat) to extract values from the median composite
training_values = (
    median.sel(training_da[["x", "y"]], method="nearest").squeeze().compute().to_pandas()
)

# Join the training data with the extracted values and remove unnecessary columns
training_array = pd.concat([training["Class"], training_values], axis=1)
training_array = training_array.drop(
    columns=[
        "y",
        "x",
        "spatial_ref",
    ]
)

# Drop rows where there was no data available
training_array = training_array.dropna()

# Preview our resulting training array
training_array.head()

## Create a classifier and fit a model

We pass in simple numpy arrays to the classifier, one has the
observations (the values of the red, green, blue and so on)
while the other has the classes.

In [25]:
# The classes are the first column
classes = np.array(training_array)[:, 0]

# The observation data is everything after the first column
observations = np.array(training_array)[:, 1:]

# Create a model...
classifier = RandomForestClassifier(class_weight='balanced')

# ...and fit it to the data
model = classifier.fit(observations, classes)

## Prediction

Next we predict. Again, we need a simple numpy array, this time
just with the observations. This needs to be in long array where
the x dimension is the observation values and the y is each cell
in the original raster.

In [26]:
# Convert to a stacked array of observations
stacked_arrays = median.to_array().stack(dims=["y", "x"]).transpose()

# Replace any NaN values with 0
stacked_arrays = stacked_arrays.fillna(0)

# Predict the classes
predicted = model.predict(stacked_arrays)

# Reshape back to the original 2D array
array = predicted.reshape(len(median.y), len(median.x))

# Convert to an xarray again, because it's easier to work with
predicted_da = xr.DataArray(
    array, coords={"y": masked.y, "x": masked.x}, dims=["y", "x"]
)

## Visualise our results

Here we're visualising the results along with the RGB image
and the original training data points. We're doing this using
a Python library called Folium, which wraps up the Leaflet
JavaScript library.

In [None]:
# Put it all on a single interactive map
center = [np.mean([ll[0], ur[0]]), np.mean([ll[1], ur[1]])]
m = folium.Map(location=center, zoom_start=11, tiles=basemaps.Esri.WorldImagery)

# RGB for the median
median.odc.to_rgba(vmin=0, vmax=0.3).odc.add_to(m, name="Median Composite")

# Categorical for the predicted classes and for the training data
predicted_da.odc.add_to(m, name="Predicted")
gdf.explore(m=m, column="Class", name="Training Data", style_kwds={"radius":5, "color":"white", "fillOpacity": 0.9})

# Layer control
folium.LayerControl().add_to(m)

m

In [43]:
# Export the RGB median to use as a base for updating training data
# median[["red", "green", "blue"]].odc.to_rgba(vmin=0, vmax=0.3).odc.write_cog("median.tif", overwrite=True)

# Export the training data, to update with the more data
# gdf[["Class", "Ecosys_Typ", "geometry"]].to_file("data/training_revised.geojson", driver="GeoJSON")

## Conclusions

Do the results make sense?

What are some of the limitations of the visualisation?

### Next steps and opportunities

The obvious next step is to fine tune the data. Perhaps download the points for this
region of interest as well as the RGB image and add and remove points until
there is a more representative training dataset.