In [96]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from logging import INFO, basicConfig, getLogger

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from xgboost import XGBRFRegressor
import pandas as pd

In [39]:
basicConfig(
    level=INFO,
    format="%(asctime)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = getLogger(__name__)

In [None]:
# Max worker
WORKERS = 16

# image paths
satellite_embedding = "satellite_embedding.tif"

# bands
bands = [f"A{no:02d}" for no in range(64)]

# Sample paths
samples = [
    dict(name="CHM", path="GEDI_GEDI02_A.fgb", unit="meter"),
    dict(name="treecover", path="GEDI_GEDI02_B.fgb", unit="%"),
    dict(name="AGB", path="GEDI_GEDI04_A.fgb", unit="Ton/Ha"),
]


In [None]:
# Function to extract
def extract_sample(x):
    sample_dict = samples[x]

    print(f"Extracting {sample_dict['name']} sample")

    path = sample_dict["path"]
    extract = gpd.read_file(path)
    coords = [coord for coord in zip(extract.geometry.x, extract.geometry.y)]

    with rio.open(satellite_embedding) as src:
        extract[bands] = [data for data in src.sample(coords)]

    del extract["geometry"]

    samples[x]["extract"] = extract


# Run concurrent
with ThreadPoolExecutor(WORKERS) as executor:
    jobs = [executor.submit(extract_sample, x) for x in range(len(samples))]
    as_completed(jobs)

Extracting CHM sample
Extracting treecover sample
Extracting AGB sample


In [None]:
plt.figure(figsize=(20, 60))


# Function to check correlation between predictor and labels per sample
def correlation(x):
    sample_dict = samples[x]
    name = sample_dict["name"]
    corr = sample_dict["extract"].corr()
    highest = corr[name].sort_values(ascending=False)
    top_ten = highest.iloc[1:11]
    print(top_ten)

    samples[x]["top_ten"] = list(top_ten.index)


# Run concurrent
[correlation(x) for x in range(len(samples))]

A03    0.324368
A04    0.300962
A62    0.261663
A08    0.251686
A52    0.239404
A50    0.237373
A58    0.233496
A56    0.220947
A26    0.190208
A30    0.181636
Name: CHM, dtype: float64
A03    0.206534
A04    0.198425
A50    0.182283
A30    0.167549
A56    0.165789
A08    0.165247
A26    0.163799
A47    0.159441
A58    0.156091
A52    0.149760
Name: treecover, dtype: float64
A03    0.389995
A50    0.288297
A52    0.288027
A04    0.271961
A58    0.265380
A47    0.256163
A62    0.244133
A56    0.238786
A24    0.205751
A08    0.202145
Name: AGB, dtype: float64


[None, None, None]

<Figure size 2000x6000 with 0 Axes>

In [95]:
# function to do modelling
def modelling(x):
    sample_dict = samples[x]
    name = sample_dict["name"]
    extract = sample_dict["extract"]
    predictors = sample_dict["top_ten"]

    # split train and test
    train, test = train_test_split(extract, test_size=0.3, random_state=1)

    # train model
    model = XGBRFRegressor(n_estimator=500)
    model.fit(train[predictors], train[name])

    # test the model
    test_prediction = model.predict(test[predictors])
    mae = mean_absolute_error(test[name], test_prediction)
    r2 = np.corrcoef(test[name], test_prediction)[0, 1]
    print(f"{name} MAE={round(mae, 3)} {sample_dict["unit"]}, R2={round(r2, 3)}")

    samples[x]["model"] = model

with ThreadPoolExecutor(WORKERS) as executor:
    jobs = [executor.submit(modelling, x) for x in range(len(samples))]
    for job in jobs:
        job.result()

AGB MAE=59.23 Ton/Ha, R2=0.545
treecover MAE=22.294 %, R2=0.388
CHM MAE=6.798 meter, R2=0.56


In [98]:
# load the image
with rio.open(satellite_embedding) as src:
    images = src.read()
    table_images = images.transpose(1, 2, 0)
    table_images = pd.DataFrame(table_images.reshape(-1, src.count), columns=bands)
    valid_mask = table_images[bands[0]] != src.nodata
    valid_table = table_images[valid_mask]

In [None]:
# Apply the model
def apply_model(x):
    sample_dict = samples[x]
    name = sample_dict["name"]
    extract = sample_dict["extract"]
    predictors = bands
    model = sample_dict["model"]

    prediction = model.predict(valid_table[predictors])
    table_images = table_images.loc[valid_mask, name] = prediction
    table_images = table_images.loc[~valid_mask, name] = -9999
