In [None]:
%load_ext autoreload
%autoreload 2

# A pronounced loss of tree crown foliation results in multi-year growth reduction

We train model to predict the growth of trees during a given period of 4-6 years based on the following features:

| feature            | description                                  | level |
|--------------------|----------------------------------------------|-------|
| diameter_end       | Diameter at the end of the period            | tree  |
| defoliation_max    | Maximum defoliation of the growth period     | tree  |
| defoliation_min    | Minimum defoliation of the growth period     | tree  |
| defoliation_mean   | Mean defoliation of the growth period        | tree  |
| defoliation_median | Median defoliation of the growth period      | tree  |
| social_class_min   | Minimum social class of the growth period    | tree  |
| plot_latitude      | Latitude of the plot                         | plot  |
| plot_longitude     | Longitude of the plot                        | plot  |
| plot_slope         | Slope of the plot                            | plot  |
| plot_orientation   | Orientation of the plot                      | plot  |
| plot_altitude      | Altitude of the plot                         | plot  |
| dep_ph             | Deposition pH                                | plot  |
| dep_cond           | Deposition conductivity                      | plot  |
| dep_k              | Deposition potassium (K)                     | plot  |
| dep_ca             | Deposition calcium (Ca)                      | plot  |
| dep_mg             | Deposition magnesium (Mg)                    | plot  |
| dep_na             | Deposition sodium (Na)                       | plot  |
| dep_n_nh4          | Deposition ammonium (NH4)                    | plot  |
| dep_cl             | Deposition chloride (Cl)                     | plot  |
| dep_n_no3          | Deposition nitrate (NO3)                     | plot  |
| dep_s_so4          | Deposition sulfate (SO4)                     | plot  |
| dep_alk            | Deposition alkalinity                        | plot  |
| dep_n_tot          | Deposition total nitrogen (N)                | plot  |
| dep_doc            | Deposition dissolved organic carbon (DOC)    | plot  |
| dep_al             | Deposition aluminium (Al)                    | plot  |
| dep_mn             | Deposition manganese (Mn)                    | plot  |
| dep_fe             | Deposition iron (Fe)                         | plot  |
| dep_p_po4          | Deposition phosphate (PO4)                   | plot  |
| dep_cu             | Deposition copper (Cu)                       | plot  |
| dep_zn             | Deposition zinc (Zn)                         | plot  |
| dep_hg             | Deposition mercury (Hg)                      | plot  |
| dep_pb             | Deposition lead (Pb)                         | plot  |
| dep_co             | Deposition cobalt (Co)                       | plot  |
| dep_ni             | Deposition nickel (Ni)                       | plot  |
| dep_cd             | Deposition cadmium (Cd)                      | plot  |
| dep_s_tot          | Deposition total sulfur (S)                  | plot  |
| dep_c_tot          | Deposition total carbon (C)                  | plot  |
| dep_n_org          | Deposition organic nitrogen (N)              | plot  |
| dep_p_tot          | Deposition total phosphorus (P)              | plot  |
| dep_cr             | Deposition chromium (Cr)                     | plot  |
| dep_n_no2          | Deposition nitrite (NO2)                     | plot  |
| ss_ph              | Soil solution pH                             | plot  |
| ss_cond            | Soil solution conductivity                   | plot  |
| ss_k               | Soil solution potassium (K)                  | plot  |
| ss_ca              | Soil solution calcium (Ca)                   | plot  |
| ss_mg              | Soil solution magnesium (Mg)                 | plot  |
| ss_n_no3           | Soil solution nitrate (NO3)                  | plot  |
| ss_s_so4           | Soil solution sulfate (SO4)                  | plot  |
| ss_alk             | Soil solution alkalinity                     | plot  |
| ss_al              | Soil solution aluminium (Al)                 | plot  |
| ss_doc             | Soil solution dissolved organic carbon (DOC) | plot  |
| ss_na              | Soil solution sodium (Na)                    | plot  |
| ss_n_nh4           | Soil solution ammonium (NH4)                 | plot  |
| ss_cl              | Soil solution chloride (Cl)                  | plot  |
| ss_n_tot           | Soil solution total nitrogen (N)             | plot  |
| ss_fe              | Soil solution iron (Fe)                      | plot  |
| ss_mn              | Soil solution manganese (Mn)                 | plot  |
| ss_al_labile       | Soil solution labile aluminium (Al)          | plot  |
| ss_p               | Soil solution phosphorus (P)                 | plot  |
| ss_cr              | Soil solution chromium (Cr)                  | plot  |
| ss_ni              | Soil solution nickel (Ni)                    | plot  |
| ss_zn              | Soil solution zinc (Zn)                      | plot  |
| ss_cu              | Soil solution copper (Cu)                    | plot  |
| ss_pb              | Soil solution lead (Pb)                      | plot  |
| ss_cd              | Soil solution cadmium (Cd)                   | plot  |
| ss_si              | Soil solution silicon (Si)                   | plot  |
| soph_avg_sdi       | Average species diversity index              | plot  |
| soph_avg_age       | Average age of the trees                     | plot  |
| soph_avg_temp      | Average temperature                          | plot  |
| soph_avg_precip    | Average precipitation                        | plot  |

In [None]:
## Use the code below to produce the above table in markdown format

# from config import FEATURES_DESCRIPTION
# import polars as pl
#
# with pl.Config(
#     tbl_formatting="MARKDOWN",
#     tbl_hide_column_data_types=True,
#     tbl_rows=-1,
#     tbl_width_chars=200,
#     fmt_str_lengths=200
# ) as cfg:
#     print(
#         pl.from_dicts(
#             [
#                 {**{"feature": feature}, **descr}
#                 for feature, descr in FEATURES_DESCRIPTION.items()
#             ]
#         )
#     )

Here we import key libraries used in the notebook (run `uv sync` and select the corresponding kernel if some of them are missing).

In [None]:
# Main libraries
import shap
import numpy as np

# Plotting
import matplotlib.pyplot as plt
import seaborn as sns

# We will use Polars for data manipulation
import polars as pl

# Casting types from time to time to have a better autocompletion
from typing import cast

from models import (
    train_and_explain,
    optimize_hyperparameters,
    ExperimentResults,
    Species,
)

We then optimize hyperparameters and train models for every species. We may also select here the type of grouping used by the K-fold validation.

In [None]:
group_col = "tree_id"
all_species: list[Species] = ["spruce", "pine", "beech", "oak"]

all_results: dict[str, ExperimentResults] = {}

for species in all_species:
    best_params, best_value = optimize_hyperparameters(species, group_col=group_col)
    all_results[species] = train_and_explain(species, best_params, group_col=group_col)

## Evaluate model performance for each species

Here we compute the mean and standard deviation of the $R^2$ (coefficient of determination) for all species and folds.

In [None]:
for species, results in all_results.items():
    perf = pl.from_dicts(results.performances).select(
        pl.first().cum_count().alias("fold"), "test_r2", "train_r2"
    )

    r2_mean = perf.select("test_r2").mean().item()
    r2_std = perf.select("test_r2").std().item()

    print(f"Results for {species}")
    print(perf.select(pl.selectors.contains("r2")).describe(percentiles=None))

## Plot/period-wise R2 test scores

Here we compute plot/period-wise R2 test scores to ensure that models do not merely predict the plot average.

In [None]:
from sklearn.metrics import r2_score

min_samples = 25
max_plots = 3

scores = {}
for species, results in all_results.items():
    num_plotted = 0

    scores[species] = []
    for fold in range(5):
        # Get the results for the current fold
        _, y_true, y_pred = results.get_data(fold, "test")
        indices = results.get_indices(fold, "test")

        # Add a unique identifier for the period
        metadata = (
            results.metadata.with_columns(
                pl.concat_str(pl.col("period_start"), pl.col("period_end"))
                .hash()
                .alias("period_id")
            )
            .with_row_index()
            .filter(pl.col("index").is_in(indices))
        )

        # Compute the R2 score for every unique (plot_id, period_id) pair with at least min_samples samples
        plot_period_pairs = (
            metadata.group_by("plot_id", "period_id")
            .count()
            .filter(pl.col("count") > min_samples)
        )

        for plot_id, period_id in (
            plot_period_pairs.select("plot_id", "period_id").unique().iter_rows()
        ):
            filter_expr = (metadata["plot_id"] == plot_id) & (
                metadata["period_id"] == period_id
            )

            r2 = r2_score(
                y_true.filter(filter_expr),
                y_pred.filter(filter_expr),
            )

            scores[species].append(r2)

            if num_plotted < max_plots and r2 > 0.5:
                print(f"Decision plot for {species} fold {fold} with R2={r2:.3f}")

                # Get the indices of the current plot-period pair
                plot_period_indices = filter_expr.arg_true().to_numpy()

                # Plot the decision plot
                shap.decision_plot(
                    results.explainers[fold].expected_value,
                    results.shap_values[fold][plot_period_indices].values,
                    feature_names=results.features,
                )

                num_plotted += 1

    scores[species] = pl.Series(species, scores[species])


for species, r2_scores in scores.items():
    print(f"Results for {species}")
    print(f" - {r2_scores.len()} plot-period pairs")
    print(f" - R2: {r2_scores.mean():.3f} Â± {r2_scores.std():.3f}")
    print()

## Feature importance by species

Here we show how feature importance varies by species.

In [None]:
features = results.X.columns
top_n = 15

feature_importances = (
    pl.from_dicts(
        [
            {
                "species": species,
                "fold": fold,
                **dict(
                    zip(
                        features,
                        np.absolute(results.shap_values[fold].values).mean(axis=0),
                    )
                ),
            }
            for species, results in all_results.items()
            for fold in range(5)
        ]
    )
    .unpivot(
        on=pl.selectors.exclude("species", "fold"),  # type: ignore
        index=["species", "fold"],
        variable_name="feature",
        value_name="shap",
    )
    .with_columns(pl.col("shap").mean().over("feature").alias("importance"))
)

# Define order of features
feature_order = (
    feature_importances.group_by("feature")
    .agg(pl.col("importance").mean().alias("importance"))
    .sort("importance", descending=True)["feature"]
    .to_list()
)

sns.catplot(
    feature_importances,
    x="shap",
    y="feature",
    hue="species",
    kind="bar",
    order=feature_order[:top_n],
)
plt.xlabel("Mean SHAP absolute value")
plt.ylabel("Feature")

In [None]:
# Plot the top N features by species
for species, results in all_results.items():
    plt.figure()
    ax = shap.plots.bar(
        results.shap_values[4],
        max_display=10,
        ax=plt.gca(),
        show=False,
    )
    ax.set_title(f"Feature importance for {species.capitalize()}")

## Feature interactions

We can compute the Shapley interactions using `get_shap_interactions`, which returns a tensor `(# samples, # features, # features)`, where each slice along the first axis is a symmetric matrix of interaction values, whose each row sums to the Shapley value for this feature. The diagonal entries represent the "main effect" attributed to that feature, whereas other entires represent the first-order interactions with every other feature.

In [None]:
from explain import plot_interaction_matrix

top_n_features = (
    feature_importances.select("feature", "importance")
    .unique()
    .sort("importance", descending=True)
    .head(20)["feature"]
    .to_list()
)

# Loop over the species
for species in all_species:
    results = all_results[species]

    # Plot the mean absolute interaction values for the selected feature as a heatmap
    plt.figure(figsize=(10, 8))
    ax = plt.gca()

    interactions = plot_interaction_matrix(results, top_n=top_n_features, ax=ax)
    plt.tight_layout()
    plt.title(f"Interactions for {species}")

    plt.savefig(f"figures/{species}-interactions-mean.png")

In [None]:
from scipy.optimize import curve_fit

feature = "defoliation_mean"
fit_curve = True
fold = 0

# One plot per species (4 in total)
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Set global xlim and ylim based on the data
xlim = (
    min(
        cast(float, results.X[results.get_indices(fold, "all"), feature].min())
        for results in all_results.values()
    ),
    max(
        cast(float, results.X[results.get_indices(fold, "all"), feature].max())
        for results in all_results.values()
    ),
)

ylim = (
    min(
        cast(float, results.shap_values[fold][:, feature].values.min())
        for results in all_results.values()
    ),
    max(
        cast(float, results.shap_values[fold][:, feature].values.max())
        for results in all_results.values()
    ),
)

for species, ax in zip(all_species, axes.flatten()):
    results = all_results[species]

    # Get the SHAP values for the selected feature
    shapley_values = cast(np.ndarray, results.shap_values[fold][:, feature].values)
    feature_values = results.X[results.get_indices(fold, "all"), feature].to_numpy()
    growth_values = results.y_true[results.get_indices(fold, "all")].to_numpy()

    # Order dataset by feature values
    order_idx = np.argsort(feature_values)
    feature_values = feature_values[order_idx]
    shapley_values = shapley_values[order_idx]
    growth_values = growth_values[order_idx]

    # Draw the line that indicates no effect
    ax.axhline(0, color="grey", linestyle="--")
    ax.text(feature_values.max(), 0.005, "No effect", color="grey", ha="right")

    ax.scatter(
        # Add a bit of noise to the x-axis to avoid overlapping points
        feature_values + np.random.normal(0, 0.5, len(feature_values)),
        shapley_values,
        s=10,
        alpha=0.4,
    )

    if fit_curve:
        # Fit a power law with vertical offset
        def func(x, a, b, c):
            return a * x**b + c

        # Get the SHAP values for the selected feature
        popt, _ = curve_fit(func, feature_values, shapley_values)

        # Plot the curve
        x = np.linspace(feature_values.min(), feature_values.max(), 100)
        ax.plot(
            x,
            func(x, *popt),
            color="red",
            label=f"y = {popt[0]:.2e} x^{popt[1]:.2f} + {popt[2]:.2f}",
        )

    ax.set_title(species.capitalize())
    ax.set_xlabel(feature)
    ax.set_ylabel("SHAP value")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    # ax.legend()


if fit_curve:
    fig.suptitle(f"Dependence plots for {feature} with power law fit (fold {fold})")
else:
    fig.suptitle(f"Dependence plots for {feature} (fold {fold})")

plt.tight_layout()

# Species-specific plots

Hereafter we investigate a specific species and feature.

In [None]:
species: Species = "spruce"
results = all_results[species]

In [None]:
feature = "defoliation_mean"

results = all_results[species]

# Fetch interactions
interactions = plot_interaction_matrix(results, no_plotting=True)

# Construct a DataFrame of the interaction values
feature_idx = results.X.columns.index(feature)

df_interactions = pl.from_numpy(
    interactions[:, feature_idx, :],
    schema=results.X.columns,
)

# Show the top-10 interactions for the selected feature
top_interacting_feature = (
    df_interactions.unpivot(
        pl.selectors.exclude(feature),
        variable_name="feature",
        value_name="interaction",
    )
    .group_by("feature")
    .agg(pl.col("interaction").abs().mean())
    .sort("interaction", descending=True)
    .select("feature")
    .item(0, 0)
)


ax = shap.plots.scatter(
    results.shap_values[fold][:, feature],
    color=results.shap_values[fold][:, top_interacting_feature],
    alpha=0.6,
)

In [None]:
from explain import plot_ceteris_paribus_profile

fold = 0

X, _, y_pred = results.get_data(fold, "test")

# Plot 4 profiles for the selected feature
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

y_vec = y_pred.to_numpy()

for i, ax in enumerate(axes.flat):
    if i == 0:
        # Take 5 growth rates in the 5th percentile
        low, high = np.min(y_vec), np.percentile(y_vec, 5)
    elif i == 1:
        # Take 5 growth rates in the [20, 30] percentile
        low, high = (
            np.percentile(y_vec, 20),
            np.percentile(y_vec, 35),
        )
    elif i == 2:
        # Take 5 growth rates in the [70, 80] percentile
        low, high = (
            np.percentile(y_vec, 70),
            np.percentile(y_vec, 80),
        )
    else:
        # Take 5 growth rates in the 95th percentile
        low, high = np.percentile(y_vec, 95), np.max(y_vec)

    plot_period_indices = np.random.choice(
        np.argwhere((y_vec >= low) & (y_vec < high)).flatten(),
        5,
    )
    title = f"Growth rates between [{low:.2f}, {high:.2f}]"

    for idx in plot_period_indices:
        feature_range, y_pred = plot_ceteris_paribus_profile(
            results.estimators[fold], X, idx, feature, ax=ax
        )
    ax.set_title(title)
    # ax.set_ylim([0.2, 0.8])

plt.tight_layout()

## Explain a subset of the data

We explain only data within:

- A given climatic area (`boreal`, `high_altitude`, `dry`).

In [None]:
from data import load_data

split = "all"
min_samples = 250

# Load the data and the corresponding results
data = load_data(species)
results = all_results[species]

# Map each category to a Polar expression to retrieve the corresponding indices
category_idx_expr = {
    "boreal": pl.arg_where(pl.col("plot_latitude") >= 500000),
    "temperate": pl.arg_where(pl.col("plot_latitude") < 500000),
    "high_altitude": pl.arg_where(pl.col("plot_altitude") >= 1000),
    "low_altitude": pl.arg_where(pl.col("plot_altitude") <= 200),
    "flat": pl.arg_where(pl.col("plot_slope") <= 5),
    "hilly": pl.arg_where(pl.col("plot_slope") > 20),
    "dry": pl.arg_where(pl.col("soph_avg_precip") < 400),
    "humid": pl.arg_where(pl.col("soph_avg_precip") > 800),
}

shap_values_all = []
categories = []

for label, filter in category_idx_expr.items():
    indices = data.select(filter).to_series().to_numpy()

    if len(indices) < min_samples:
        print(f"Not enough samples for {label}")
        continue

    features = results.X[indices, :].to_pandas()
    shap_values = cast(np.ndarray, results.shap_values[fold][indices].values)

    for idx in range(shap_values.shape[0]):
        shap_values_all.append(shap_values[idx, :])

    categories.extend([label] * shap_values.shape[0])

# Plot the feature importances by category
df = (
    pl.from_numpy(
        np.absolute(np.stack(shap_values_all, axis=1).T), schema=results.features
    )
    .with_columns(pl.Series("condition", categories))
    .unpivot(
        on=pl.selectors.exclude("condition"),
        variable_name="feature",
        value_name="importance",
        index=["condition"],
    )
    .filter(pl.col("condition").is_in(["boreal", "high_altitude", "dry", "temperate"]))
)

sns.catplot(
    df,
    x="importance",
    y="feature",
    hue="condition",
    kind="bar",
    order=feature_order[:10],
)
plt.xlabel("Feature importance (mean absolute SHAP value)")
plt.ylabel("Feature")
plt.title(f"Feature importances by condition (species = {species.capitalize()})")

In [None]:
# Plot dependence plots for the selected feature by category
feature = "defoliation_mean"
interacting = "defoliation_max"

# 2x2 grid for the 4 categories
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.flatten()

# Define x- and y-axis limits based on global min and max values
x_min = data[feature].min()
x_max = data[feature].max()
y_min = np.min(results.shap_values[fold][:, feature].values)
y_max = np.max(results.shap_values[fold][:, feature].values)

for idx, conditions in enumerate(["boreal", "hilly", "dry", "temperate"]):
    indices = data.select(category_idx_expr[conditions]).to_series().to_numpy()

    if len(indices) < 25:
        print(f"Not enough samples for {conditions}")
        continue

    ax = axes[idx]

    ax.scatter(
        results.X[indices, feature].to_numpy(),
        results.shap_values[fold][indices, feature].values,
        # color=results.shap_values[fold][:, top_interacting_feature],
        alpha=0.2,
        c=results.y_pred[fold][indices],
    )

    ax.set_title(conditions.capitalize())
    ax.set_xlabel(feature)
    ax.set_ylabel("SHAP values")
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)

plt.tight_layout()

In [None]:
# Visualize Shapley values using a PCA
from sklearn.decomposition import PCA

# Perform PCA for fold 0
fold = 0
pca = PCA(n_components=2)

# Normalize the Shapley values
X_shap = results.shap_values[fold].values
X_shap = np.nan_to_num((X_shap - X_shap.mean(axis=0)) / X_shap.std(axis=0))

# Train the PCA model
X_pca = pca.fit_transform(X_shap)

In [None]:
# Scatter plot of the PCA with the top-5 features projected
plt.figure(figsize=(10, 8))
plt.scatter(X_pca[:, 0], X_pca[:, 1], alpha=0.1)

for label, expr in category_idx_expr.items():
    if label not in ["boreal", "dry", "high_altitude"]:
        continue

    indices = data.select(expr).to_series().to_numpy()

    plt.scatter(
        X_pca[indices, 0],
        X_pca[indices, 1],
        label=label,
        alpha=0.6,
    )

plt.legend()

# Compute the top features
feature_importances = np.absolute(results.shap_values[fold].values).mean(axis=0)
top_features = np.argsort(feature_importances)[::-1][:10]
print("Top features:", features.columns[top_features].to_list())

# Project the top-5 features
X_features = (
    400 * pca.transform(np.eye(X_shap.shape[1])) * feature_importances[:, np.newaxis]
)
for idx in top_features:
    plt.arrow(0, 0, X_features[idx, 0], X_features[idx, 1], color="black", width=0.05)
    plt.text(
        X_features[idx, 0] + 0.1,
        X_features[idx, 1] + 0.1,
        features.columns[idx],
        fontsize=12,
    )