# Machine Learning using Sentinel-2 Data

This example uses training data from the
[Coast Train](https://github.com/nick-murray/coastTrain) dataset
along with Sentinel-2 data to demonstrate how to use a
machine learning classifier, in this case, Random Forest, to
assign a class to each pixel.

This notebook combines lessons from previous notebooks into
a comprehensive worked example.

## Getting started

First we load the required Python libraries and tools.

In [None]:
import os

import folium
import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
from datacube import Datacube
from ipyleaflet import basemaps
from odc.geo.geom import point
from odc.stac import configure_s3_access
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

## Study site configuration

Here we establish the STAC catalog we're using as well as a
spatial and temporal extent. This can be anywhere, but the location below, the island of Sepanjang, was chosen as an example.

In [None]:
# Choose a year
year = "2025"

# Location is the island of Sepanjang, Indonesia
aoi_coords = -7.115, 115.83

aoi_point = point(aoi_coords[1], aoi_coords[0], crs="EPSG:4326")
aoi = aoi_point.buffer(0.1).boundingbox

aoi.explore(tiles=basemaps.Esri.WorldImagery)

## Configure our environment.

This cell sets up Dask, which we use for parallel computing, and configures
AWS credentials for "unsigned" (public) data access.


In [None]:
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

# Configure S3 access for datacube
configure_s3_access(no_sign_request=True)

# Set up the Datacube
dc = Datacube(app="Sentinel-2_MachineLearning")

## Add a function to prepare data

In [None]:
def prepare_data(ds):
    # Scale Sentinel-2 data to reflectance
    data = (ds * 0.0001).clip(0, 1)

    # Add some indices to improve classification results
    # MNDWI for water
    data["mndwi"] = (data.green - data.swir16) / (data.green + data.swir16)

    # NDVI for vegetation
    data["ndvi"] = (data.nir08 - data.red) / (data.nir08 + data.red)

    # NDTI for additional distinction of shallow benthic types
    data["ndti"] = (data.red - data.green) / (data.red + data.green)

    # Natural log of blue/green for shallow water definition
    data["ln_bg"] = np.log(data.blue / data.green)

    # Modified form of MVI (Baloloy et al. 2020) for distinguishing between mangroves and other vegetation types
    data["mvi"] = (((data.nir08 - data.green) / (data.swir16 - data.green)) * 0.1).clip(
        -1, 1
    )

    # Remove the time dimension
    return data.squeeze()

## Training data

Next up we gather training data. This could be any geospatial point dataset
with a column that is numeric, for the class.

If you'd like to explore the structure of this data, you can run `gdf.head()`
to see the first few rows. The `explore()` function with the `column` argument
will show the data on the map, and change the colour based on that column.

In [None]:
# Get the training data
data_url = "./NusaTenggaraBarat_tdata.geojson"
gdf = gpd.read_file(data_url, bbox=tuple(aoi))

# Alternately, use your updated data if you have it.
# gdf = gpd.read_file("data/training_revised.geojson", bbox=bbox)
gdf.explore(
    column="Ecosys_Typ",
    legend=True,
    tiles=basemaps.Esri.WorldImagery,
    style_kwds={"radius": 5},
)

## Find and load Sentinel-2 data

Here we search for Sentinel-2 scenes over our study area and use
Dask to lazy-load them. We're only loading the red, green, blue, nir and swir
bands, along with the scene classification (scl) band.

In [None]:
# Load data from the Sentinel-2 GeoMAD
ds = dc.load(
    product="s2_l2a",
    longitude=(aoi.left, aoi.right),
    latitude=(aoi.bottom, aoi.top),
    time=(f"{year}-01-01/{year}-12-31"),
    # output_crs="EPSG:6933",
    # resolution=10,
    measurements=["blue", "green", "red", "nir08", "swir16", "swir22"],
    group_by="solar_day",
    dask_chunks={"time": 1, "x": 3200, "y": 3200},
    resampling={
        "*": "cubic",
        "scl": "nearest",
    },
    driver="rio",
)

data = prepare_data(data)

## Prepare training data array

This next step involves extracting observed values from the satellite data
and combining them with our cover class data, resulting in something like this:

`class, red, green, blue ...`

This structure establishes a set of classes and the spectral information associated with them. This is the groundwork for our machine learning classifier, which will identify (hidden) statistical relationships between spectral variables and cover classes.

In [None]:
# First transform the training points to the same CRS as the data
training = gdf.to_crs(data.odc.geobox.crs)

# Next get the X and Y values out of the point geometries
training_da = training.assign(x=training.geometry.x, y=training.geometry.y).to_xarray()

# Now we can use the x and y values (lon, lat) to extract values from the median composite
training_values = (
    data.sel(training_da[["x", "y"]], method="nearest")
    .squeeze()
    .compute()
    .to_pandas()
)

# Join the training data with the extracted values and remove unnecessary columns
training_array = pd.concat([training["Class"], training_values], axis=1)
training_array = training_array.drop(
    columns=[
        "y",
        "x",
        "spatial_ref",
    ]
)

# Drop rows where there was no data available
training_array = training_array.dropna()

# Preview our resulting training array
training_array.head()

## Create a classifier and fit a model

We pass in two numpy arrays to the classifier, one has the observations (reflectance values, vegetation indices, and so on) while the other has the cover classes.

First, we need to split our point data into training and test sets. The training set is used to train the machine learning model, while the test set is used to measure the accuracy of the classification. This is very important, as it helps with refining the model and communicating reliability of the outputs.

We can also produce an initial report of classification accuracy. This uses only the data in our training and testing arrays. The classifier, developed using the training subset, is applied to the testing subset. This provides us class-by-class measures of accuracy, showing which classes are more or less accurately classified. 

Note that interpreting these accuracy figures is not simple as simple as it seems, and it is easy to get artificially high values, especially with low sample numbers or spatially clustered samples.

In [None]:
# The classes are the first column
classes = np.array(training_array)[:, 0]
# The observation data is everything after the first column
observations = np.array(training_array)[:, 1:]

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    observations, classes, test_size=0.3, stratify=classes
)

X_train.shape, X_test.shape

# Create a model...
classifier = RandomForestClassifier(n_estimators=100, class_weight="balanced")
# Consider other hyperparameters - max_features, max_depth, min_samples_split, min_samples_leaf, bootstrap
# Consider also k-fold cross-validation at this stage

# ...and fit it to the training data
model = classifier.fit(X_train, y_train)

y_pred = model.predict(X_test)

class_names = gdf.groupby("Class")["Ecosys_Typ"].first().sort_index().values
print(
    f"\nClassification report:\n{classification_report(y_test, y_pred, target_names=class_names)}"
)

## Prediction

Next we apply our machine learning to the entire image to predict the class in each pixel. Again, we need a simple numpy array, this time
just with the observations. This needs to be in long array where
the x dimension is the observation values and the y is each cell
in the original raster.

In [None]:
# Convert to a stacked array of observations
stacked_arrays = data.to_array().stack(dims=["y", "x"]).transpose()

# Replace any NaN values with 0
stacked_arrays = stacked_arrays.fillna(0)

# Predict the classes using our trained model
predicted = model.predict(stacked_arrays)

# Reshape back to the original 2D array
array = predicted.reshape(len(data.y), len(data.x))

# Convert to an xarray again, because it's easier to work with
predicted_da = xr.DataArray(array, coords={"y": data.y, "x": data.x}, dims=["y", "x"])

## Visualise our results

Here we're visualising the results along with the RGB image
and the original training data points. We're doing this using
a Python library called Folium, which wraps up the Leaflet
JavaScript library.

In [None]:
# Put it all on a single interactive map
m = folium.Map(
    location=tuple(aoi_point.coords[0][::-1]),
    zoom_start=11,
    tiles=basemaps.Esri.WorldImagery,
)

# RGB for the median
data.odc.to_rgba(bands=["red", "green", "blue"], vmin=0, vmax=0.3).odc.add_to(
    m, name="Median Composite"
)

# Categorical for the predicted classes and for the training data
# Note that Alex couldn't find a way to use the colormap here, so colours are random!
predicted_da.astype(int).odc.add_to(m, name="Predicted", opacity=0.7)

gdf.explore(
    m=m,
    column="Ecosys_Typ",
    name="Training Data",
    style_kwds={
        "radius": 5,
        "stroke": "white",
        "opacity": 1,
        "color": "white",
        "stroke-weight": 0.8,
    },
    categorical=True,
)

# Layer control
folium.LayerControl().add_to(m)

m

## Test importance of model variables

Random Forests can provide some useful explanatory information. Random Forest models make a large number of decisions based on band/index values. Feature importance tells us which bands/indices contribute to the greatest number of decisions.

Feature importance can be useful in a number of ways. If efficiency is important, a model may be able to be simplified by removing low-importance features. Feature importance can help to refine flawed models, such as those that confuse classes often. 

Feature importance can also inform sensor design for related missions. For example, if we were planning to expand a project using UAV data, we might select the UAV sensor partly based on band importance.

It is also worth noting that feature importance depends partly upon the classes themselves. For instance, in this model, MNDWI has high importance. How would this change if we masked out land areas, and only classified benthic cover?

In [None]:
# Get the feature names (band names from training data) Note that order has to be maintained here
feature_names = list(data.data_vars)

importance_df = pd.DataFrame(
    {"Feature": feature_names, "Importance": model.feature_importances_}
).sort_values("Importance", ascending=False)

print(importance_df)

## Apply the model to a different area

We will now apply the same machine learning model to a different location. This can be done without retraining the model. The first cell below just defines functions to perform all the same pre-processing we did for our original site (retrieving S2 images, masking cloud, adding indices, producing median image).

In [None]:
# We can transfer the classification to the similar nearby island of Raas, using the same model
raas_coords = -7.135, 114.58

new_point = point(raas_coords[1], raas_coords[0], crs="EPSG:4326")
new_aoi = new_point.buffer(0.11).boundingbox
new_aoi.explore(tiles=basemaps.Esri.WorldImagery)

In [None]:
# Load data from the Sentinel-2 GeoMAD
new_ds = dc.load(
    product="s2_l2a",
    longitude=(new_aoi.left, new_aoi.right),
    latitude=(new_aoi.bottom, new_aoi.top),
    time=(f"{year}-01-01/{year}-12-31"),
    # output_crs="EPSG:6933",
    # resolution=10,
    measurements=["blue", "green", "red", "nir08", "swir16", "swir22"],
    group_by="solar_day",
    dask_chunks={"time": 1, "x": 3200, "y": 3200},
    resampling={
        "*": "cubic",
        "scl": "nearest",
    },
    driver="rio",
)

new_data = prepare_data(new_ds)
new_data

Now we define a new area. We can use the nearby island of Raas, which has similar land and benthic cover.

The next cell applies our machine learning model, developed using data in Sapeken, to Raas. This can be carried out without changing the model at all, and we can expect it to work fairly well. Spatially transferring a random forest model like this will work best only if the areas have similar land cover types, and the images are sourced from the same sensor and processed identically.

Some very advanced deep learning models can operate across domains - for instance, being generalisable across different satellite sensors.

In [None]:
# Convert to a stacked array of observations
stacked_arrays_new = new_data.to_array().stack(dims=["y", "x"]).transpose()

# Replace any NaN values with 0
stacked_arrays_new = stacked_arrays_new.fillna(0)

# Predict the classes
predicted_new = model.predict(stacked_arrays_new)

# Reshape back to the original 2D array
array_new = predicted_new.reshape(len(new_data.y), len(new_data.x))

# Convert to an xarray again, because it's easier to work with
predicted_da_new = xr.DataArray(
    array_new, coords={"y": new_data.y, "x": new_data.x}, dims=["y", "x"]
)

## Visualise the newly classified area

Now we can visualise the classification outputs for Raas. The classification symbology will be the same as the previous map. We can see that, even without modifying the model at all, it works quite well.

In [None]:
m = folium.Map(
    location=tuple(new_point.coords[0][::-1]),
    zoom_start=11,
    tiles=basemaps.Esri.WorldImagery,
)

# RGB for the median
new_data.odc.to_rgba(bands=["red", "green", "blue"], vmin=0, vmax=0.3).odc.add_to(
    m, name="Median Composite"
)

# Categorical for the predicted classes and for the training data
predicted_da_new.odc.add_to(m, name="Predicted_raas", opacity=0.7)

folium.LayerControl().add_to(m)

m


In [None]:
# # Export the RGB median to use as a base for updating training data
# median.to_array(dim='band').odc.write_cog(
#     data_url.replace("_tdata.geojson", "_median.tif"),
#     overwrite=True
# )
# # Export the training data, to update with the more data
# gdf[["Class", "Ecosys_Typ", "geometry"]].to_file(
#     data_url.replace("_tddata.geojson", "_revised.geojson"), driver="GeoJSON"
# )

## Conclusions

Do the results make sense?

Are there any issues with the training and testing data? If so, how could they be fixed?

There are limitations with the visualisation used here - consider exporting the outputs to view them in QGIS or similar

### Fine tuning

The obvious next step is to fine tune the data. There are a few ways you can do this:
- Adding more input features into the model (e.g. vegetation indices)
- Modifying model training data
- Training the model on data from both sites