### Feature Extraction
* **Products used:** 
[dem_cop_30](https://explorer.digitalearth.africa/products/s2_l2a), [s2_l2a](https://explorer.digitalearth.africa/products/dem_cop_90), [dem_srtm](https://explorer.digitalearth.africa/products/dem_srtm), [dem_srtm_deriv](https://explorer.digitalearth.africa/products/dem_srtm_deriv)

## Background:

Training data extraction plays a crucial role in training machine learning models. The process involves extracting relevant feature layers from a geospatial dataset based on predefined geometries or regions of interest. This enables the creation of accurate and reliable classification models for various applications such as land cover mapping, crop monitoring, and environmental analysis.

To facilitate this task, the open-data-cube provides a powerful function called "collect_training_data." This function is part of the deafrica_tools.classification script and is specifically designed to extract training data from the open-data-cube using geometries defined within a GeoJSON file. The GeoJSON file contains the spatial boundaries or polygons that delineate the regions of interest for which training data needs to be extracted.

## Description:

This notebook focuses on the extraction of training data (feature layers) from the open-data-cube using geometries defined within a GeoJSON file. It follows a step-by-step approach to guide users in utilizing the "collect_training_data" function effectively. The goal is to enable users to extract the appropriate training data for their specific use case.

The main steps in this notebook are as follows:

1. **Previewing the Training Data:** The notebook starts by plotting the polygons from the training data on a basemap. This visualization provides users with a visual representation of the regions of interest for which training data will be extracted.

2. **Defining the Feature Layer Function:** Next, a feature layer function is defined. This function specifies the set of feature layers to be extracted from the open-data-cube. These layers are carefully selected based on their relevance to the classification task at hand.

3. **Extracting Training Data:** The "collect_training_data" function is then employed to extract the training data from the datacube. It utilizes the predefined geometries from the GeoJSON file and retrieves the corresponding feature layers. This step ensures that the extracted data aligns precisely with the defined regions of interest.

4. **Exporting Training Data:** Finally, the extracted training data is exported and saved to disk. This facilitates its subsequent use in other scripts or machine learning workflows for training classification models.

By following the steps outlined in this notebook, users can leverage the "collect_training_data" function to efficiently extract training data from the open-data-cube. 

## Getting started
To run this analysis, run all the cells in the notebook, starting with the "Load packages" cell.

### Load packages

In [1]:
%matplotlib inline
import io
import os
import math
import datacube
import warnings
from pathlib import Path
import rioxarray
import rasterio
import numpy as np
import pandas as pd
import xarray as xr
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt

from datacube.testutils.io import rio_slurp_xarray


from deafrica_tools.datahandling import load_ard
from deafrica_tools.plotting import map_shapefile
from deafrica_tools.bandindices import calculate_indices
from classification import collect_training_data


## Analysis parameters
 * path: The path to the input vector file from which we will extract training data. A default geojson is provided.
 * field: This is the name of column in your shapefile attribute table that contains the class labels. The class labels must be integers

In [2]:
# Specify a prefix to identify the area of interest in the saved outputs
# By assigning the desired prefix, you can easily identify the outputs associated with the specific area of interest.
SKYCOMISH_HUC8_ID = "17110009"
prefix = SKYCOMISH_HUC8_ID
field = "class_id"
path = f"/data/{prefix}_training_samples.geojson"

# Load input data shapefile
training_points = gpd.read_file(path)
training_points.head()

Unnamed: 0,class_id,class_name,geometry
0,0,Non-wetland,POINT (-121.37522 47.86603)
1,0,Non-wetland,POINT (-121.21535 47.74082)
2,3,Freshwater Emergent Wetland,POINT (-121.26909 47.64271)
3,0,Non-wetland,POINT (-121.17138 47.65332)
4,0,Non-wetland,POINT (-121.13877 47.73572)


In [3]:
# Set a flag to convert to polygons:
use_polygons = False

if use_polygons:
    # Convert from lat,lon to EPSG:6933 (projection in metres)
    training_points = training_points.to_crs("EPSG:6933")

    # Buffer geometry to get a square - only if trying to sample multiple pixels
    buffer_radius_m = 10
    training_points.geometry = training_points.geometry.buffer(
        buffer_radius_m, cap_style=3
    )

#### Plot on interactive map 

In [5]:
points = training_points
training_points.explore(
    tiles="https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}",
    attr="Imagery @2022 Landsat/Copernicus, Map data @2022 Google",
    popup=True,
    cmap="viridis",
    style_kwds=dict(radius=5, color="red", fillOpacity=0.8, fillColor="red", weight=3),
)

In [6]:
# dem_file = "../../../data/processed/17110009_Skykomish_HE_DEM_3m.tif"
dem_file = Path("/data/terrain_attributes/dem.tif")
CRS = rioxarray.open_rasterio(dem_file, chunks=True).squeeze(drop=True).odc.crs

In [7]:
from pathlib import Path
import rioxarray as rx


def load_training_data():
    return xr.merge(
        [
            rx.open_rasterio(f, chunks=True).squeeze(drop=True).to_dataset(name=f.stem)
            for f in Path("/data/terrain_attributes").glob("*.tif")
        ]
    )


training_xr = load_training_data().rio.write_crs(CRS)

In [8]:
import pandas as pd
import xarray as xr


def sample_xr(xr_ds: xr.Dataset, points: gpd.GeoDataFrame):
    points_proj = points.to_crs(xr_ds.odc.crs)
    pts_da = points_proj.assign(
        x=points_proj.geometry.x, y=points_proj.geometry.y
    ).to_xarray()

    # a dataframe or series (for a single point)
    pt_values_i = (
        xr_ds.sel(pts_da[["x", "y"]], method="nearest").squeeze().compute().to_pandas()
    )

    if isinstance(pt_values_i, pd.Series):
        pt_values_i = pt_values_i.to_frame().transpose()
        pt_values_i.index = points.index

    return pd.concat([points_proj, pt_values_i], axis=1)


pd_training_features = sample_xr(training_xr, training_points)
pd_training_features

Unnamed: 0,class_id,class_name,geometry,x,y,Planform_curvature_90.0m,Curvature_450.0m,Curvature_270.0m,TPI_450.0m,Slope_90.0m,...,TPI_90.0m,TPI_270.0m,TWI,Curvature_90.0m,dem,Elevation,Profile_curvature_450.0m,mrvbf,Slope_270.0m,spatial_ref
0,0,Non-wetland,POINT (-11711043.627 5433848.061),-11711045.0,5433845.0,-9.034044,0.236544,0.180498,2.600647,82.447052,...,2.595947,1.074829,-0.602729,-0.615586,976.126648,976.126648,-3.421940,0,80.998924,0
1,0,Non-wetland,POINT (-11695618.383 5423039.778),-11695615.0,5423035.0,0.118914,-0.339459,-0.448346,7.162598,72.145485,...,0.860352,4.122314,2.700622,-0.160886,1125.915283,1125.915283,-0.516363,0,73.678680,0
2,3,Freshwater Emergent Wetland,POINT (-11700802.947 5414552.963),-11700805.0,5414555.0,0.723456,-0.359268,0.374796,-15.088562,17.400249,...,-0.573608,-9.183105,4.848710,0.025842,591.859009,591.859009,-3.201056,2,61.709118,0
3,0,Non-wetland,POINT (-11691375.627 5415470.963),-11691375.0,5415475.0,5.250413,0.499063,0.232729,-2.092651,76.622498,...,-0.416748,-1.945190,2.179628,-0.112407,1102.812378,1102.812378,-3.500315,0,75.425468,0
4,0,Non-wetland,POINT (-11688229.595 5422598.953),-11688225.0,5422595.0,-0.234889,0.220975,0.127447,-13.449036,67.774742,...,-0.431335,-4.176697,3.072900,-0.263022,891.196655,891.196655,0.905681,0,69.929688,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3995,0,Non-wetland,POINT (-11720260.551 5417154.306),-11720265.0,5417155.0,4.837530,0.566681,0.015662,0.844116,79.333138,...,-0.920532,1.438721,-0.218655,0.028020,1483.663818,1483.663818,-0.098083,0,79.030769,0
3996,5,Lake,POINT (-11692545.117 5413039.214),-11692545.0,5413035.0,-0.009208,-0.473309,0.372760,-21.504028,1.227523,...,0.008301,-9.408569,3.036996,-0.022927,1381.800171,1381.800171,-0.988381,3,16.410891,0
3997,0,Non-wetland,POINT (-11737191.404 5445188.107),-11737195.0,5445185.0,-1.549930,0.941419,0.405933,24.865601,69.629227,...,4.210693,14.229004,-1.586322,-0.032829,1436.465454,1436.465454,-2.684710,0,74.242851,0
3998,0,Non-wetland,POINT (-11710617.607 5409718.002),-11710615.0,5409715.0,-5.099864,0.209731,0.156389,4.527100,75.367500,...,-0.675049,-0.203003,0.665912,-0.014237,1267.716431,1267.716431,2.411661,0,73.168686,0


### Export training features

In [16]:
# set the name and location of the output file
# output_file = "results/training_features.txt"
output_file = f"/data/results/{prefix}_training_features.csv"
# Export files to disk
pd_training_features.to_csv(output_file, header=True, index=False, sep=" ")

In [17]:
# create geopandas dataframe
gpd_training_features = gpd.GeoDataFrame(
    pd_training_features,
    geometry="geometry",
)

#####  Add a column for binary (wetland/non-wetland) classification
This block ensures that both binary and multi-class classification labels are properly set up from the original `class_id` field:

- If `class_id` contains only two values (0 and 1), it is assumed to be binary, and is renamed to `class_id_binary`.
- If `class_id` includes additional wetland types, a new binary column `class_id_binary` is created:
  - `1` for any wetland type (i.e., values not equal to 0)
  - `0` for non-wetland (i.e., value equal to 0)
  - The original column is then renamed to `class_id_type` for use in multi-class classification.

In [18]:
# Check if unique values in 'class_id' are only 0 and 1
unique_values = gpd_training_features["class_id"].unique()
if len(unique_values) == 2 and set(unique_values) == {0, 1}:
    # Replace 'class_id' with 'class_id_binary'
    gpd_training_features.rename(columns={"class_id": "class_id_binary"}, inplace=True)
else:
    # Create 'class_id_binary' column based on condition
    gpd_training_features["class_id_binary"] = gpd_training_features["class_id"].apply(
        lambda x: 1 if x != 0 else 0
    )
    gpd_training_features.rename(columns={"class_id": "class_id_type"}, inplace=True)

# Insert the new column at the second position
gpd_training_features.insert(
    0, "class_id_binary", gpd_training_features.pop("class_id_binary")
)
print(gpd_training_features.columns)

Index(['class_id_binary', 'class_id_type', 'class_name', 'geometry', 'x', 'y',
       'Planform_curvature_90.0m', 'Curvature_450.0m', 'Curvature_270.0m',
       'TPI_450.0m', 'Slope_90.0m', 'Planform_curvature_270.0m',
       'Profile_curvature_270.0m', 'DTW', 'mrrtf', 'Profile_curvature_90.0m',
       'percent_slope', 'Slope_450.0m', 'Planform_curvature_450.0m',
       'TPI_90.0m', 'TPI_270.0m', 'TWI', 'Curvature_90.0m', 'dem', 'Elevation',
       'Profile_curvature_450.0m', 'mrvbf', 'Slope_270.0m', 'spatial_ref'],
      dtype='object')


In [12]:
# Replace non-zero values in the 'class_id' column with 1
gpd_training_features["class_id_binary"] = gpd_training_features["class_id_type"].apply(
    lambda x: 1 if x != 0 else 0
)
# Insert the new column at the second position
gpd_training_features.insert(
    1, "class_id_binary", gpd_training_features.pop("class_id_binary")
)
gpd_training_features

Unnamed: 0,class_id_type,class_id_binary,class_name,geometry,x,y,Planform_curvature_90.0m,Curvature_450.0m,Curvature_270.0m,TPI_450.0m,...,TPI_90.0m,TPI_270.0m,TWI,Curvature_90.0m,dem,Elevation,Profile_curvature_450.0m,mrvbf,Slope_270.0m,spatial_ref
0,0,0,Non-wetland,POINT (-11711043.627 5433848.061),-11711045.0,5433845.0,-9.034044,0.236544,0.180498,2.600647,...,2.595947,1.074829,-0.602729,-0.615586,976.126648,976.126648,-3.421940,0,80.998924,0
1,0,0,Non-wetland,POINT (-11695618.383 5423039.778),-11695615.0,5423035.0,0.118914,-0.339459,-0.448346,7.162598,...,0.860352,4.122314,2.700622,-0.160886,1125.915283,1125.915283,-0.516363,0,73.678680,0
2,3,1,Freshwater Emergent Wetland,POINT (-11700802.947 5414552.963),-11700805.0,5414555.0,0.723456,-0.359268,0.374796,-15.088562,...,-0.573608,-9.183105,4.848710,0.025842,591.859009,591.859009,-3.201056,2,61.709118,0
3,0,0,Non-wetland,POINT (-11691375.627 5415470.963),-11691375.0,5415475.0,5.250413,0.499063,0.232729,-2.092651,...,-0.416748,-1.945190,2.179628,-0.112407,1102.812378,1102.812378,-3.500315,0,75.425468,0
4,0,0,Non-wetland,POINT (-11688229.595 5422598.953),-11688225.0,5422595.0,-0.234889,0.220975,0.127447,-13.449036,...,-0.431335,-4.176697,3.072900,-0.263022,891.196655,891.196655,0.905681,0,69.929688,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3995,0,0,Non-wetland,POINT (-11720260.551 5417154.306),-11720265.0,5417155.0,4.837530,0.566681,0.015662,0.844116,...,-0.920532,1.438721,-0.218655,0.028020,1483.663818,1483.663818,-0.098083,0,79.030769,0
3996,5,1,Lake,POINT (-11692545.117 5413039.214),-11692545.0,5413035.0,-0.009208,-0.473309,0.372760,-21.504028,...,0.008301,-9.408569,3.036996,-0.022927,1381.800171,1381.800171,-0.988381,3,16.410891,0
3997,0,0,Non-wetland,POINT (-11737191.404 5445188.107),-11737195.0,5445185.0,-1.549930,0.941419,0.405933,24.865601,...,4.210693,14.229004,-1.586322,-0.032829,1436.465454,1436.465454,-2.684710,0,74.242851,0
3998,0,0,Non-wetland,POINT (-11710617.607 5409718.002),-11710615.0,5409715.0,-5.099864,0.209731,0.156389,4.527100,...,-0.675049,-0.203003,0.665912,-0.014237,1267.716431,1267.716431,2.411661,0,73.168686,0


In [19]:
# save as geojson file
results_dir = Path("results")
os.makedirs(results_dir, exist_ok=True)
geojson_file = f"/data/results/{prefix}_training_features.geojson"
gpd_training_features.to_file(geojson_file, driver="GeoJSON")

***

## Additional information

**License:** The code in this notebook is licensed under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0). 
Digital Earth Africa data is licensed under the [Creative Commons by Attribution 4.0](https://creativecommons.org/licenses/by/4.0/) license.

**Contact:** If you need assistance, please post a question on the [Open Data Cube Slack channel](http://slack.opendatacube.org/) or on the [GIS Stack Exchange](https://gis.stackexchange.com/questions/ask?tags=open-data-cube) using the `open-data-cube` tag (you can view previously asked questions [here](https://gis.stackexchange.com/questions/tagged/open-data-cube)).
If you would like to report an issue with this notebook, you can file one on [Github](https://github.com/digitalearthafrica/deafrica-sandbox-notebooks).

**Compatible datacube version:** 

In [14]:
from datetime import datetime

datetime.today().strftime("%Y-%m-%d")

'2025-07-09'