# Machine Learning using Sentinel-2 Data

This example uses training data from the
[Coast Train](https://github.com/nick-murray/coastTrain) dataset
along with Sentinel-2 data to demonstrate how to use a
machine learning classifier, in this case, Random Forest, to
assign a class to each pixel.

This notebook combines lessons from previous notebooks into
a comprehensive worked example.

## Getting started

First we load the required Python libraries and tools.

In [None]:
import os
import dask.config

import folium
import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
from dask.distributed import Client, LocalCluster
from datacube import Datacube
from datacube.utils.masking import valid_data_mask
from ipyleaflet import basemaps
from odc.algo import mask_cleanup
from odc.geo.geom import point
from odc.stac import configure_s3_access
from sklearn.ensemble import RandomForestClassifier
from urllib.parse import urlparse

## Study site configuration

Here we establish the STAC catalog we're using as well as a
spatial and temporal extent. This can be anywhere, but this location
near Kuching was chosen due to the training data having several
classes available.

In [None]:
start_date = "2025-01"
end_date = "2025-06"

# Location is around Pulau Sepanjang, Indonesia
coords = -7.12, 115.82

aoi_point = point(coords[1], coords[0], crs="EPSG:4326")
area = aoi_point.buffer(0.1).boundingbox

area.explore(tiles=basemaps.Esri.WorldImagery)

## Configure our environment.

This cell sets up Dask, which we use for parallel computing, and configures
AWS credentials for "unsigned" (public) data access.


In [None]:
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

# Configure S3 access for datacube
configure_s3_access(no_sign_request=True)

# Set up the Datacube
dc = Datacube(app="Sentinel-2_MachineLearning")

In [None]:
# Set up Dask
cluster = LocalCluster(
    n_workers=2,
    threads_per_worker=2,
    memory_limit='10GB'
)

dashboard_url = cluster.dashboard_link
port = urlparse(dashboard_url).port

jupyterhub_user = os.environ.get('JUPYTERHUB_USER')
dask.config.set(**{
    "distributed.dashboard.link": f"/user/{jupyterhub_user}/proxy/{port}/status"
})

client = Client(cluster)
client

## Training data

Next up we gather training data. This could be any geospatial point dataset
with a column that is numeric, for the class.

If you'd like to explore the structure of this data, you can run `gdf.head()`
to see the first few rows. The `explore()` function with the `column` argument
will show the data on the map, and change the colour based on that column.

In [None]:
# Get the training data
data_url = "data/NusaTenggaraBarat_tdata.geojson"
gdf = gpd.read_file(data_url, bbox=tuple(area))

# Alternately, use your updated data if you have it.
# gdf = gpd.read_file("data/training_revised.geojson", bbox=bbox)

gdf.explore(
    column="Ecosys_Typ",
    legend=True,
    tiles=basemaps.Esri.WorldImagery,
    style_kwds={"radius": 5},
)

## Find and load Sentinel-2 data

Here we search for Sentinel-2 scenes over our study area and use
Dask to lazy-load them. We're only loading the red, green, blue, nir and swir
bands, along with the scene classification (scl) band.

In [None]:
# Find and load Sentinel-2 datasets
sentinel2_datasets = dc.find_datasets(
    product=["s2_l2a"],
    time=(start_date, end_date),
    longitude=(area.left, area.right),
    latitude=(area.bottom, area.top),
    cloud_cover=(0, 80),
)

print(f"Found {len(sentinel2_datasets)} Sentinel-2 datasets")

ds = dc.load(
    datasets=sentinel2_datasets,
    longitude=(area.left, area.right),
    latitude=(area.bottom, area.top),
    output_crs="EPSG:6933",
    resolution=10,
    measurements=["red", "green", "blue", "nir08", "swir16", "swir22", "scl"],
    group_by="solar_day",
    dask_chunks={"time": 1, "x": 3200, "y": 3200},
    resampling={
        "*": "cubic",
        "scl": "nearest",
    },
    driver="rio",
)

# Mask Sentinel-2 data
# 3 is cloud shadow, 8 is medium probability cloud, 9 is high probability cloud
cloud_mask = ds.scl.isin([3, 8, 9])
cloud_mask = mask_cleanup(cloud_mask, (("dilation", 10), ("erosion", 5)))
valid_data = valid_data_mask(ds)
mask = cloud_mask | ~valid_data

data_masked = ds.where(~mask).drop_vars("scl")

# Scale Sentinel-2 data so that values are between 0 and 1
data = (data_masked * 0.0001).clip(0, 1)  # Scale Sentinel-2 data to match Landsat scale

# Create some indices
# MNDWI for water
data["mndwi"] = (data.green - data.swir16) / (data.green + data.swir16)

# NDVI for vegetation
data["ndvi"] = (data.nir08 - data.red) / (data.nir08 + data.red)

# Natural log of blue/green for shallow water definition
data["ln_bg"] = np.log(data.blue / data.green)

# Modified form of MVI (Baloloy et al. 2020) for distinguishing between mangroves and other vegetation types
data["mvi"] = (((data.nir08 - data.green) / (data.swir16 - data.green)) * 0.1).clip(
    -1, 1
)

data

## Data preparation

Now that we have data, we need to clean it up, masking out clouds
and scaling values to between 0-1, which are the valid reflectance
values.

We add a couple of indices too, which will help the machine learning
algorithm.

Note that we still have a lazy-loaded array, and haven't transferred
any data over the network.

In [None]:
# Visualise one date, to make sure it looks good.
# This example shows empty areas where we've masked out clouds.

# This process of loading should take less than a minute
data.isel(time=0).odc.explore(vmin=0, vmax=0.3, tiles=basemaps.CartoDB.DarkMatter)

## Create a cloud-free composite

The final data preparation step involves creating a temporal
median of the data bands. Here we use `compute()` to process
the data and bring it into memory.

We preview the data in the second cell below.

In [None]:
# Create a median composite, which should get rid of most of the remaining clouds
# Note that this will take a few minutes to complete

median = data.median("time").compute()

median

In [None]:
median.odc.explore(vmin=0, vmax=0.3, tiles=basemaps.CartoDB.DarkMatter)

## Prepare training data array

This next step involves extracting observed values from the satellite data
and combining them with our point data, resulting in something like this:

`class, red, green, blue ...`

This structure is then fed into the machine learning classifier.

In [None]:
# First transform the training points to the same CRS as the data
training = gdf.to_crs(median.odc.geobox.crs)

# Next get the X and Y values out of the point geometries
training_da = training.assign(x=training.geometry.x, y=training.geometry.y).to_xarray()

# Now we can use the x and y values (lon, lat) to extract values from the median composite
training_values = (
    median.sel(training_da[["x", "y"]], method="nearest")
    .squeeze()
    .compute()
    .to_pandas()
)

# Join the training data with the extracted values and remove unnecessary columns
training_array = pd.concat([training["Class"], training_values], axis=1)
training_array = training_array.drop(
    columns=[
        "y",
        "x",
        "spatial_ref",
    ]
)

# Drop rows where there was no data available
training_array = training_array.dropna()

# Preview our resulting training array
training_array.head()

## Create a classifier and fit a model

We pass in simple numpy arrays to the classifier, one has the
observations (the values of the red, green, blue and so on)
while the other has the classes.

In [None]:
# The classes are the first column
classes = np.array(training_array)[:, 0]

# The observation data is everything after the first column
observations = np.array(training_array)[:, 1:]

# Create a model...
classifier = RandomForestClassifier(class_weight="balanced")

# ...and fit it to the data
model = classifier.fit(observations, classes)

## Prediction

Next we predict. Again, we need a simple numpy array, this time
just with the observations. This needs to be in long array where
the x dimension is the observation values and the y is each cell
in the original raster.

In [None]:
# Convert to a stacked array of observations
stacked_arrays = median.to_array().stack(dims=["y", "x"]).transpose()

# Replace any NaN values with 0
stacked_arrays = stacked_arrays.fillna(0)

# Predict the classes
predicted = model.predict(stacked_arrays)

# Reshape back to the original 2D array
array = predicted.reshape(len(median.y), len(median.x))

# Convert to an xarray again, because it's easier to work with
predicted_da = xr.DataArray(array, coords={"y": data.y, "x": data.x}, dims=["y", "x"])

## Visualise our results

Here we're visualising the results along with the RGB image
and the original training data points. We're doing this using
a Python library called Folium, which wraps up the Leaflet
JavaScript library.

In [None]:
# Put it all on a single interactive map
m = folium.Map(
    location=tuple(aoi_point.coords[0][::-1]),
    zoom_start=11,
    tiles=basemaps.Esri.WorldImagery,
)

# RGB for the median
median.odc.to_rgba(vmin=0, vmax=0.3).odc.add_to(m, name="Median Composite")

# Categorical for the predicted classes and for the training data
# Note that Alex couldn't find a way to use the colormap here, so colours are random!
predicted_da.odc.add_to(m, name="Predicted", opacity=0.7)

gdf.explore(
    m=m,
    column="Ecosys_Typ",
    name="Training Data",
    style_kwds={
        "radius": 5,
        "stroke": "white",
        "opacity": 0.7,
        "color": "white",
        "stroke-weight": 1,
    },
    categorical=True,
)

# Layer control
folium.LayerControl().add_to(m)

m

In [None]:
# # Export the RGB median to use as a base for updating training data
# median[["red", "green", "blue"]].odc.to_rgba(vmin=0, vmax=0.3).odc.write_cog(
#     data_url.replace("_tddata.geojson", "_median.tif"), overwrite=True
# )

# # Export the training data, to update with the more data
# gdf[["Class", "Ecosys_Typ", "geometry"]].to_file(
#     data_url.replace("_tddata.geojson", "_revised.geojson"), driver="GeoJSON"
# )

## Conclusions

Do the results make sense?

What are some of the limitations of the visualisation?

### Next steps and opportunities

The obvious next step is to fine tune the data. Perhaps download the points for this
region of interest as well as the RGB image and add and remove points until
there is a more representative training dataset.