In [None]:
%load_ext autoreload
%autoreload 2

Here we import key libraries used in the notebook (run `uv sync` and select the corresponding kernel if some of them are missing).

In [None]:
# Main libraries
import shap
import numpy as np
import os

# Plotting
import matplotlib.pyplot as plt
import seaborn as sns

# We will use Polars for data manipulation
import polars as pl
import polars.selectors as cs

# Casting types from time to time to have a better autocompletion
from typing import cast

from models import train_and_explain, ExperimentResults, Species, ModelType, ALL_SPECIES
from analysis import summarize_performance
from config import Ablation

# Use caching for various results
if not os.path.exists("./cache"):
    os.makedirs("./cache")

We then optimize hyperparameters and train models for every species. We may also select here the type of grouping used by the K-fold validation.

Now configure the model to be used and the grouping column for K-fold validation.

## Feature importance by species

Here we show how feature importance varies by species.

In [None]:
all_results: dict[Species, ExperimentResults] = {}

group_col = "tree_id"  # "tree_id" or "plot_id"
model_type: ModelType = "lgbm"  # "lgbm" or "lasso"
ablation: Ablation = "all"  # "all", "tree-level-only", "plot-level-only", "no-defoliation", "max-defoliation"

for species in ALL_SPECIES:
    all_results[species] = train_and_explain(
        species, model_type=model_type, group_by=group_col, ablation=ablation
    )

summarize_performance(
    all_results, ablation=ablation, model_type=model_type, group_col=group_col
)

In [None]:
min_shap = 0.015
max_rank = 3

feature_importances = (
    pl.from_dicts(
        [
            {
                "species": species,
                "fold": fold,
                **dict(
                    zip(
                        results.features,
                        np.absolute(results.shap_values[fold].values).mean(axis=0),
                    )
                ),
            }
            for species, results in all_results.items()
            for fold in range(5)
        ]
    )
    .unpivot(
        on=cs.exclude("species", "fold"),  # type: ignore
        index=["species", "fold"],
        variable_name="feature",
        value_name="shap",
    )
    .with_columns(rank=pl.col("shap").rank(descending=True).over("species"))
    .with_columns(
        shap_max=pl.col("shap").max().over("feature"),
        min_rank=pl.col("rank").min().over("feature"),
    )
)

feature_importances.write_parquet(
    f"./cache/feature_importances-{ablation}-{model_type}-{group_col}.parquet",
)

data = feature_importances.filter(
    (pl.col("shap_max") > min_shap) | (pl.col("min_rank") <= max_rank)
).sort("species", pl.col("rank").mean().over("feature"), descending=[False, False])

g = sns.catplot(data, x="shap", y="feature", hue="species", kind="bar")

plt.xlabel("Mean SHAP absolute value")
plt.ylabel("Feature")

if model_type == "lgbm":
    plt.title(f"Feature importance for GBDT (grouped by {group_col})")
elif model_type == "lasso":
    plt.title(f"Feature importance for Lasso (grouped by {group_col})")

fig = plt.gcf()
plt.savefig(
    f"./figures/importance-{model_type}-{group_col}-{ablation}.pdf",
    bbox_inches="tight",
)

In [None]:
# Variation of feature importance with and without defoliation
top_n = 3

if os.path.exists(
    f"./cache/feature_importances-no-defoliation-{model_type}-{group_col}.parquet"
):
    importances_comparison = (
        feature_importances.join(
            pl.read_parquet(
                f"./cache/feature_importances-no-defoliation-{model_type}-{group_col}.parquet"
            ),
            on=["species", "feature"],
            suffix="-no-defoliation",
        )
        .group_by("species", "feature")
        .agg(
            pl.col("rank").mean().cast(pl.Int32).alias("rank-all-features"),
            pl.col("rank-no-defoliation")
            .mean()
            .cast(pl.Int32)
            .alias("rank-no-defoliation"),
            pl.col("shap").mean().alias("shap-all-features"),
            pl.col("shap-no-defoliation").mean().alias("shap-no-defoliation"),
        )
        .with_columns(
            shap_delta=pl.col("shap-no-defoliation") - pl.col("shap-all-features"),
            rank_delta=pl.col("rank-all-features") - pl.col("rank-no-defoliation"),
        )
        .filter(
            (
                pl.col("shap_delta").rank("dense", descending=True).over("species")
                <= top_n
            )
            & (pl.col("shap_delta") > 0)
        )
        .sort(["species", "shap_delta"], descending=[False, True])
    )

    with pl.Config() as cfg:
        cfg.set_tbl_formatting("ASCII_MARKDOWN")
        cfg.set_float_precision(3)
        cfg.set_tbl_rows(100)
        cfg.set_tbl_hide_column_data_types(True)

        print(importances_comparison)
        importances_comparison.write_clipboard(float_precision=3)

In [None]:
# Plot the top N features by species for a given fold
fold = 0  # Change this to plot for a different fold
for species, results in all_results.items():
    plt.figure()
    ax = shap.plots.bar(
        results.shap_values[fold],
        max_display=10,
        ax=plt.gca(),
        show=False,
    )
    ax.set_title(f"Feature importance for {species.capitalize()}")

## Feature dependence plots

These plots show how the Shapley value varies as a function of feature value.

In [None]:
from explain import plot_dependence

# Produce dependence plots for some key (species, feature) pairs
fold = 0  # Change this to plot for a different fold

# Spruce: dependence on deposition of phosphate
ax = plot_dependence(all_results["spruce"], feature="dep_pb")
# Set a vertical line at 0.6 mg/l
# ax.axvline(0.6, color="red", linestyle="--", label="0.6 mg/l threshold")
ax.legend()
ax.set_title("Spruce: Dependence of growth rate on lead deposition")
ax.set_xlabel("Lead deposition [Âµg/l]")
ax.set_ylabel("SHAP value")


# Pine: dependence on iron deposition
ax = plot_dependence(all_results["pine"], feature="dep_fe", fold=fold, alpha=0.4)
ax.set_title("Pine: Dependence of growth rate on iron deposition")
ax.set_xlabel("Iron deposition [mg/l]")
ax.set_ylabel("SHAP value")

# Spruce: dependence on plot orientation
ax = plot_dependence(
    all_results["spruce"],
    feature="plot_orientation",
    fold=fold,
    alpha=0.4,
)
ax.set_title("Spruce: Dependence of growth rate on plot orientation")
ax.set_xlabel("Plot orientation")
ax.set_ylabel("SHAP value")
ax.set_xticks(np.arange(9))  # 8 orientations + flat
ax.set_xticklabels(
    [
        "North",
        "North-east",
        "East",
        "South-east",
        "South",
        "South-west",
        "West",
        "North-west",
        "Flat",
    ],
    rotation=90,
)  # Adjust labels for better readability
ax.set_xlim(-0.5, 8.5)

# Oak: dependence on deposition pH
# ax = plot_dependence(
#     all_results["oak"], feature="dep_n_no3", fold=fold, show_interaction=True, alpha=0.4
# )
# ax.set_title("Oak: Dependence of growth rate on deposition of nitrate")
# ax.set_xlabel("Deposition of nitrate [mg/l]")
# ax.set_ylabel("SHAP value")

# Beech: dependence on plot slope
ax = plot_dependence(
    all_results["beech"],
    feature="plot_slope",
    fold=fold,
    alpha=0.4,
)
ax.set_title("Beech: Dependence of growth rate on plot slope")
ax.set_xlabel("Plot slope")
ax.set_ylabel("SHAP value")

In [None]:
# One plot per species (4 in total)
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

for species, ax in zip(ALL_SPECIES, axes.flatten()):
    plot_dependence(
        all_results[species],
        feature="dep_n_tot",
        fold=fold,
        alpha=0.4,
        ax=ax,
        ylim=(-0.05, 0.07),
    )

In [None]:
num_bins = 20

# One plot per species (4 in total)
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

for species, ax in zip(ALL_SPECIES, axes.flatten()):
    results = all_results[species]

    for feature in ["dep_n_nh4", "dep_n_no3", "dep_n_total"]:
        indices = np.arange(results.X.shape[0])

        if feature == "dep_n_total":
            shap_values = np.concatenate(
                [
                    results.shap_values[fold][:, "dep_n_nh4"].values  # type: ignore
                    for fold in range(results.num_folds)
                ]
            ) + np.concatenate(
                [
                    results.shap_values[fold][:, "dep_n_no3"].values  # type: ignore
                    for fold in range(results.num_folds)
                ]
            )
            feature_values = np.concatenate(
                [
                    results.shap_values[fold][:, "dep_n_nh4"].data  # type: ignore
                    for fold in range(results.num_folds)
                ]
            ) + np.concatenate(
                [
                    results.shap_values[fold][:, "dep_n_no3"].data  # type: ignore
                    for fold in range(results.num_folds)
                ]
            )
        else:
            shap_values = np.concatenate(
                [
                    results.shap_values[fold][:, feature].values  # type: ignore
                    for fold in range(results.num_folds)
                ]
            )
            feature_values = np.concatenate(
                [
                    results.shap_values[fold][:, feature].data  # type: ignore
                    for fold in range(results.num_folds)
                ]
            )

        # Bin the feature values and compute mean SHAP value per bin
        x_min, x_max = np.nanmin(feature_values), np.nanmax(feature_values)
        bins = np.linspace(x_min, x_max, num_bins + 1)
        bin_idx = np.digitize(feature_values, bins) - 1
        bin_idx = np.clip(bin_idx, 0, num_bins - 1)  # Ensure indices are within range

        n_per_bin = np.bincount(bin_idx, minlength=num_bins)
        x_mean = (
            np.bincount(bin_idx, weights=feature_values, minlength=num_bins) / n_per_bin
        )
        y_mean = (
            np.bincount(bin_idx, weights=shap_values, minlength=num_bins) / n_per_bin
        )

        q_low = np.full(num_bins, np.nan)
        q_high = np.full(num_bins, np.nan)

        for b in range(num_bins):
            idx = bin_idx == b
            if idx.sum() >= 3:
                q_low[b] = np.percentile(shap_values[idx], 5)
                q_high[b] = np.percentile(shap_values[idx], 95)

        ax.plot(x_mean, y_mean, label=feature)
        ax.fill_between(x_mean, q_low, q_high, alpha=0.2)

        # Overlaid inset axes for histogram with the same x-axis limits
        ax2 = ax.inset_axes(
            bounds=(0, 0, 1.0, 0.2),
            zorder=0,
            sharex=ax,
            frame_on=False,
        )

        # Remove xticks/yticks from the inset axes
        ax2.tick_params(
            axis="x", which="both", bottom=False, top=False, labelbottom=False
        )
        ax2.tick_params(
            axis="y",
            which="both",
            left=False,
            right=False,
            labelleft=False,
            labelright=False,
        )

        # Overlaid histogram of point density
        valid_indices = ~np.isnan(feature_values)  # type: ignore
        sns.histplot(
            x=feature_values[valid_indices],
            legend=False,
            ax=ax2,
            bins=50,
            stat="density",
            color="grey",
            alpha=0.3,
            edgecolor=None,
        )

    ax.set_title(species.capitalize())
    ax.set_xlabel("Deposition of total nitrogen [kg N / ha / year]")
    ax.set_ylabel("SHAP value")
    ax.set_xlim(0, 50)
    ax.set_ylim(-0.05, 0.10)

fig.suptitle("Dependence of growth rate on total nitrogen deposition")

plt.tight_layout()

In [None]:
feature = "defoliation_mean"  # Change this to the feature you want to plot

fit_curve = False
fold = 0

# One plot per species (4 in total)
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

for species, ax in zip(ALL_SPECIES, axes.flatten()):
    plot_dependence(
        all_results[species],
        feature=feature,
        ax=ax,
        alpha=0.1,
        xlim=(0, 100),
        ylim=(-0.15, 0.15),
        # color=all_results[species].shap_values[fold][:, "dep_ph"].values
    )
    ax.set_title(species.capitalize())
    ax.set_xlabel("Mean defoliation [%]")
    ax.set_ylabel("SHAP value")

    fig.suptitle("Dependence of growth rate on mean defoliation")

plt.tight_layout()

In [None]:
from matplotlib.colors import TwoSlopeNorm
import matplotlib.cm as cm

# Overlay both mean and max defoliation on the same plot
fold = 0

# One plot per species (4 in total)
fig, axes = plt.subplots(2, 2, figsize=(12, 9))

handles, labels = [], []
for species, ax in zip(ALL_SPECIES, axes.flatten()):
    data = all_results[species].get_data(fold=fold, split="all")[0]

    # Draw x = y line and a label
    ax.plot(
        [0, 100],
        [0, 100],
        color="grey",
        linestyle="--",
        linewidth=1,
    )
    ax.text(
        5,
        5,
        "y = x",
        color="grey",
        fontsize=10,
        ha="left",
        va="top",
    )

    # Plot mean vs max defoliation colored by Shapley values of both
    shap_values = (
        all_results[species].shap_values[fold][:, "defoliation_median"].values
        + all_results[species].shap_values[fold][:, "defoliation_max"].values
    )

    # Use coolwarm colormap for better visibility
    norm = TwoSlopeNorm(vmin=-0.30, vcenter=0, vmax=0.2)
    cmap = plt.get_cmap("coolwarm")
    colors = cmap(norm(shap_values))

    ax.scatter(
        data["defoliation_median"].to_numpy(),
        data["defoliation_max"].to_numpy(),
        # label="Mean defoliation",
        c=shap_values,
        alpha=0.8,
        s=15,
        cmap=cmap,
    )

    ax.set_xlabel("Median defoliation")
    ax.set_ylabel("Max defoliation")
    ax.set_xlim(0, 100)
    ax.set_ylim(0, 100)

    # Compute R2 between mean and max defoliation
    r2 = (
        np.corrcoef(
            data["defoliation_median"].to_numpy(),
            data["defoliation_max"].to_numpy(),
        )[0, 1]
        ** 2
    )

    # Set the title with R2 value
    ax.set_title(
        f"{species.capitalize()} (RÂ² = {r2:.2f})",
        fontsize=12,
    )

# Display colorbar
fig.colorbar(
    cm.ScalarMappable(norm=norm, cmap=cmap),
    ax=axes,
    # Draw the colorbar on the right side of the plot
    cax=fig.add_axes([1.01, 0.15, 0.02, 0.7]),
    label="Sum of SHAP values of median and max defoliation",
    orientation="vertical",
    use_gridspec=True,
)

# Set font size for the legend
fig.legend(handles, labels, fontsize=12)
fig.suptitle(
    f"Relation between median and max defoliation (fold {fold})",
    fontsize=16,
)

plt.tight_layout()

# Species-specific plots

Hereafter we investigate a specific species and feature.

In [None]:
species: Species = "oak"
results = all_results[species]

## Feature interactions

We visualize the interaction between features by plotting two features against one each other.

In [None]:
feature = "defoliation_mean"  # Feature to plot
interacting = "dep_n_tot"

# Construct a DataFrame of the interaction values
feature_idx = results.X.columns.index(feature)

# Set axes
fig, ax = plt.figure(figsize=(6, 4)), plt.gca()

# Plot dependence plot with specific interaction
shap.plots.scatter(
    results.shap_values[fold][:, feature],
    color=results.shap_values[fold][:, interacting],
    # color=results.shap_values[fold],
    alpha=0.4,
    ax=ax,
    show=False,
)
plt.xlabel("Average defoliation [%]")
plt.ylabel("SHAP value")

# Set label of colorbar
ax.collections[0].colorbar.set_label("Deposition of nitrogen [mg / l]")

In [None]:
from explain import plot_ceteris_paribus_profile

fold = 0

X, _, y_pred = results.get_data(fold, "test")

# Plot 4 profiles for the selected feature
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

y_vec = y_pred.to_numpy()

for i, ax in enumerate(axes.flat):
    if i == 0:
        # Take 5 growth rates in the 5th percentile
        low, high = np.min(y_vec), np.percentile(y_vec, 5)
    elif i == 1:
        # Take 5 growth rates in the [20, 30] percentile
        low, high = (
            np.percentile(y_vec, 20),
            np.percentile(y_vec, 35),
        )
    elif i == 2:
        # Take 5 growth rates in the [70, 80] percentile
        low, high = (
            np.percentile(y_vec, 70),
            np.percentile(y_vec, 80),
        )
    else:
        # Take 5 growth rates in the 95th percentile
        low, high = np.percentile(y_vec, 95), np.max(y_vec)

    plot_period_indices = np.random.choice(
        np.argwhere((y_vec >= low) & (y_vec < high)).flatten(),
        5,
    )
    title = f"Growth rates between [{low:.2f}, {high:.2f}]"

    for idx in plot_period_indices:
        feature_range, y_pred = plot_ceteris_paribus_profile(
            results.estimators[fold], X, idx, feature, ax=ax
        )
    ax.set_title(title)
    # ax.set_ylim([0.2, 0.8])

plt.tight_layout()

In [None]:
# Use Shapley values to cluster the data

In [None]:
# Visualize Shapley values using t-SNE
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE

# Build a dataframe of shapley values across all species
df_shap = []
species_col = []
plot_id_col = []
for species, results in all_results.items():
    for fold in range(5):
        df_shap.append(results.shap_values[fold].values)
        species_col.extend([species] * len(df_shap[-1]))
        plot_id_col.extend(results.metadata["plot_id"].to_numpy())

# Build a dataframe of SHAP values, downsampled to 20,000 samples
df_shap = (
    pl.from_numpy(np.concatenate(df_shap, axis=0), schema=results.features)
    .with_columns(pl.Series("species", species_col), pl.Series("plot_id", plot_id_col))
    .select("species", "plot_id", pl.exclude("species", "plot_id"))
).sample(n=20000)

# Standardize the data
scaler = StandardScaler()
X_shap = np.nan_to_num(
    scaler.fit_transform(df_shap.select(pl.exclude("species")).to_numpy())
)

# Train the t-SNE model
tsne = TSNE(n_components=2, perplexity=100, early_exaggeration=20)
X_tsne = tsne.fit_transform(X_shap)

In [None]:
# Use seaborn to generate a scatter plot of the t-SNE results
ax = sns.scatterplot(
    data=df_shap.with_columns(
        pl.Series("tsne_x", X_tsne[:, 0]), pl.Series("tsne_y", X_tsne[:, 1])
    ).to_pandas(),
    x="tsne_x",
    y="tsne_y",
    hue="species",
    alpha=0.5,
    palette="muted",
    s=5,
    legend="brief",
)

ax.set_xlabel("Component 1")
ax.set_ylabel("Component 2")

plt.tight_layout()

In [None]:
# Produce a zoom in a specific region
x_bounds = [0, 40]
y_bounds = [-40, 0]

# Create 2 subplots of the zoomed region: one for the species and one for the plot_id

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

for ax, col in zip(axes, ["species", "plot_id"]):
    sns.scatterplot(
        data=df_shap.with_columns(
            pl.Series("tsne_x", X_tsne[:, 0]), pl.Series("tsne_y", X_tsne[:, 1])
        ).to_pandas(),
        x="tsne_x",
        y="tsne_y",
        hue=col,
        alpha=0.9,
        palette="muted" if col == "species" else "dark",
        legend=False,
        s=10,
        ax=ax,
    )

    ax.set_xlim(x_bounds)
    ax.set_ylim(y_bounds)

    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")

    plt.tight_layout()

## Feature interactions

We can compute the Shapley interactions using `get_shap_interactions`, which returns a tensor `(# samples, # features, # features)`, where each slice along the first axis is a symmetric matrix of interaction values, whose each row sums to the Shapley value for this feature. The diagonal entries represent the "main effect" attributed to that feature, whereas other entires represent the first-order interactions with every other feature.

In [None]:
from explain import compute_interaction_matrix

top_n_features = (
    feature_importances.select(
        "feature", pl.col("shap").mean().over("feature").alias("importance")
    )
    .unique()
    .sort("importance", descending=True)
    .head(20)["feature"]
    .to_list()
)

interactions = {}
indices = {}

# Loop over the species
for species in ALL_SPECIES:
    results = all_results[species]

    # Plot the mean absolute interaction values for the selected feature as a heatmap
    plt.figure(figsize=(10, 8))
    ax = plt.gca()

    interactions[species], indices[species] = compute_interaction_matrix(
        results, top_n=top_n_features, ax=ax, vmax=0.006
    )

    plt.tight_layout()
    plt.title(f"Interactions for {species}")

    plt.savefig(f"figures/{species}-interactions-mean.png")

In [None]:
import networkx as nx

# Find all pair of features with a significant interaction
species = "oak"
cutoff = 0.002

results = all_results[species]
interactions_matrix = interactions[species]

interactions_matrix = np.absolute(interactions_matrix).mean(axis=0)
adjacency_matrix = np.triu(interactions_matrix, k=1)
adjacency_matrix[adjacency_matrix < cutoff] = 0.0

# Build graph from adjacency matrix
G = cast(
    nx.Graph,
    nx.from_numpy_array(
        adjacency_matrix,
        edge_attr="interaction",
        nodelist=results.features,
    ),
)

# Trim nodes without any connection
G.remove_nodes_from(list(nx.isolates(G)))

# Display graph with a circular layout
plt.figure(figsize=(12, 12))
pos = nx.circular_layout(G)

# Draw edges with a width proportional to the interaction strength
nx.draw_networkx_edges(
    G, pos, width=[G[u][v]["interaction"] * 1000 for u, v in G.edges()]
)

# Draw labels for the edges up in scientific notation with 2 decimal places
edge_labels = {
    k: f"{v:.2e}" for k, v in nx.get_edge_attributes(G, "interaction").items()
}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)

# Draw labels for the nodes while avoiding overlap
nx.draw_networkx_labels(
    G,
    pos,
    font_size=12,
    font_color="black",
    bbox=dict(facecolor="lightblue", boxstyle="round,pad=0.5,rounding_size=0.5"),
)
plt.title(f"Interaction graph for {species}")