# GLMM Analysis

**Important 🚨:** If only the demo data for the most common nine species is used, the results produced for all 52 species will naturally be different from the display items in the publication.


## Setup


In [None]:
import sys

sys.path.insert(0, "../../src")
from imports import *
from datetime import datetime

init_notebook()

from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler

In [None]:
def fx(df, change):
    if change in ["warmer", "cooler"]:
        values = df.query("change_temp_all == @change")["pval_temp"]
        mean = values.mean()
    else:
        values = df.query("change_spei_all == @change")["pval_spei"]
        mean = values.mean()

    # print(f"Mean p-value for {change}: {mean:.2f}")
    ax, fig = plt.subplots(figsize=(3, 3))
    sns.histplot(values, kde=True, bins=20)
    # Add legend
    plt.legend([f"{change} mean: {mean:.2f}"])
    plt.show()


def get_var_and_val(df_tmp, response_var, var_pval_threshold=0.05, group_threshold=0.6):

    pvar = f"pval_{response_var}"
    # df_tmp = df_tmp[df_tmp[pvar] < var_pval_threshold]
    # display(df_tmp)

    df_tmp = (
        df_tmp[f"response_{response_var}"]
        .value_counts(normalize=True)
        .sort_values(ascending=False)
    )
    ipattern = df_tmp.index[0]
    ivalue = df_tmp.values[0]
    if ivalue < group_threshold:
        ipattern = "ns"

    return ipattern, ivalue


def plot_pattern_dist(df_in, var_in, dir_patterns=None):
    # Sum up runs per group
    df_in = (
        df_in.groupby(var_in).agg(group_size_rel=("group_size", "sum")).reset_index()
    )
    # Take percentage and turn into int
    df_in["group_size_rel"] = (
        df_in["group_size_rel"] / df_in["group_size_rel"].sum() * 100
    )
    df_in["group_size_rel"] = df_in["group_size_rel"].round(0)
    df_in["group_size_rel"] = df_in["group_size_rel"].astype(int)
    # Sort by percentage
    df_in = df_in.sort_values("group_size_rel", ascending=True)
    # Add percentage to label
    df_in["change_simple"] = df_in[var_in]
    df_in[var_in] = df_in[var_in] + " (" + df_in["group_size_rel"].astype(str) + "%)"
    # Plot it
    df_in = df_in.sort_values("group_size_rel", ascending=False).reset_index(drop=True)
    # display(df_in)
    return df_in
    # df_in.plot(kind="barh", x=var_in, y="group_size_rel", color="grey", legend=False)
    # plt.xlabel("Share of all runs (%)")
    # plt.ylabel("")
    # plt.tight_layout()
    # plt.savefig(f"{dir_patterns}/change_counts_before_merging_unclear_{var_in}.png")
    # plt.show()
    # plt.close()


def plot_bars_dataset_pattern(
    patterns_merged,
    all_dfs,
    all_or_top9="all",
    color_temp="#77422C",
    color_spei="#D1A289",
    color_rest="lightgrey",
    color_wd="#B2182B",  # Original: "#B2182B"
    color_ww="#2166AC",  # Original: "#2166AC"
    color_other="lightgrey",
    color_cd="lightgrey",  # Original: "#EF8A62",
    color_cw="lightgrey",  # Original: "#67A9CF",
    ytick_labels=None,
    left_ylim=60,
    base_fontsize=12,
    filepath=None,
):
    # Plot
    fig, axs = plt.subplots(1, 2, figsize=(16, 5))
    axs = axs.flatten()

    # ! Position of temp and cwb in the dataframe
    if all_or_top9 == "top9":
        pos_temp = 5
        pos_spei = 6
    elif all_or_top9 == "all":
        pos_temp = 5
        pos_spei = 3

    ax_dataset_boxplot(
        axs[0],
        all_dfs,
        all_dfs.columns[2:].tolist(),
        base_fontsize,
        pos_spei=pos_spei,
        pos_temp=pos_temp,
        color_spei=color_spei,
        color_temp=color_temp,
        color_rest=color_rest,
        all_or_top9=all_or_top9,
    )
    axs[0].set_xlim(0, left_ylim)

    # ax_dataset_boxplot(
    #     axs[0],
    #     all_dfs,
    #     imps,
    #     base_fontsize,
    #     color_spei=color_spei,
    #     color_temp=color_temp,
    #     color_rest=color_rest,
    #     all_or_top9=all_or_top9,
    # )

    # Barplot for patterns
    sns.barplot(
        data=patterns_merged,
        x="group_size_rel",
        y="change_simple",
        hue="change_simple",
        palette=[
            color_temp,
            color_temp,
            color_temp,
            color_spei,
            color_spei,
            color_spei,
            color_wd,
            color_ww,
            color_other,
            color_cd,
            color_cw,
        ],
        orient="h",
        height=0.5,
        dodge=False,
        edgecolor="black",
        # hue="response_spei",
        # palette=[
        #     "#B2182B",
        #     "#2166AC",
        #     "grey",
        # ],
        ax=axs[1],
    )

    # Add values to end of bars
    for i in range(len(patterns_merged)):
        axs[1].text(
            patterns_merged.loc[i, "group_size_rel"] + 1,
            i + -0.1,
            # f"{patterns_merged.loc[i, 'group_size_rel']} %",
            f"{patterns_merged.loc[i, 'group_size_rel']} % (sign: {patterns_merged.loc[i, 'perc_sign'].round(0).astype(int)}%)",
            va="center",
            fontsize=base_fontsize * 0.9,
        )

    # Add horizontal lines
    axs[1].axhline(2.35, color="black", linewidth=1)
    axs[1].axhline(5.35, color="black", linewidth=1)

    # Add text
    axs[1].text(
        95,
        2,
        "Temperature\nanomaly",
        ha="right",
        fontweight="bold",
        fontsize=base_fontsize * 1,
    )
    axs[1].text(
        95,
        5,
        "CWB anomaly",
        ha="right",
        fontweight="bold",
        fontsize=base_fontsize * 1,
    )
    axs[1].text(
        95,
        10,
        "Combined",
        ha="right",
        fontweight="bold",
        fontsize=base_fontsize * 1,
    )

    # Add labels
    axs[1].set_xlabel(
        "Model frequency (%)",
        fontweight="bold",
        labelpad=10,
        fontsize=base_fontsize * 1.2,
    )
    axs[1].set_ylabel(
        # "Climatic conditions before 2$^{\\text{nd}}$ visit",
        "Short-term climatic condition\npromoting mortality",
        labelpad=10,
        fontweight="bold",
        fontsize=base_fontsize * 1.2,
    )

    # Fix y-ticks
    if ytick_labels is not None:
        axs[1].set_yticklabels(ytick_labels, fontsize=base_fontsize * 1)

    axs[1].set_yticks(range(len(patterns_merged["change_simple"])))
    axs[1].tick_params(axis="y", which="both", left=False)

    # Fix axis limits
    axs[1].set_xlim(0, 100)
    axs[1].set_ylim(10.4, -0.4)

    # Remove top and right axis
    axs[1].spines["top"].set_visible(False)
    axs[1].spines["right"].set_visible(False)

    # Add letters
    letters = ["A", "B"]
    for i, ax in enumerate(axs):
        ax.text(
            -0.5,
            0.99,
            letters[i],
            transform=ax.transAxes,
            fontsize=base_fontsize * 1.3,
            fontweight="bold",
        )

    # Fix layout
    # Fix x-tick size
    axs[0].tick_params(axis="x", which="both", labelsize=base_fontsize * 0.8)
    axs[1].tick_params(axis="x", which="both", labelsize=base_fontsize * 0.8)
    plt.tight_layout(w_pad=2, h_pad=1)

    if filepath is not None:
        plt.savefig(filepath, dpi=600, bbox_inches="tight")
    plt.show()
    plt.close()


def get_ns_per_pattern(df_in, pattern):
    if pattern == "warmer" or pattern == "cooler":
        var_all = "change_temp_all"
        var_sign = "change_temp"
    elif pattern == "wetter" or pattern == "drier":
        var_all = "change_spei_all"
        var_sign = "change_spei"
    else:
        var_all = "change_both_all"
        var_sign = "change_both_sign"

    # Remove NAs in the pattern column to avoid confusion with "ns" as relating to variable was not contained in the model and "ns" as relating to no significant pvalue
    df_in = df_in.dropna(subset=[var_all])

    # Get NA percentage
    df_in = df_in.query(f"{var_all} == '{pattern}'")[var_sign].value_counts(
        normalize=True
    )["ns"]

    # Return
    return df_in


def ax_performance(ax, roc_or_auc, all_or_top9):

    if roc_or_auc == "roc":
        metric = "ROC AUC"
        file_name = "roc_auc"
    elif roc_or_auc == "pr":
        metric = "PR AUC"
        file_name = "pr_auc"
    else:
        raise ValueError("roc_or_auc must be 'roc' or 'pr'")

    df_comp_perf = []

    for i, row in i_success.iterrows():

        # Check if roc auc files are there
        path_rf = (
            f"{path_prefix}/{row.model}/{row.species}/rf_performance/{file_name}.csv"
        )
        path_glmm = (
            f"{path_prefix}/{row.model}/{row.species}/{path_suffix}/{file_name}.csv"
        )

        if not os.path.isfile(path_glmm):
            print(f"GLMM file missing for {row.species} - {row.model}")
            continue
        if not os.path.isfile(path_rf):
            raise ValueError(f"RF file missing: {path_rf}")

        # Attach files
        irf = pd.read_csv(path_rf)
        irf["species"] = row.species
        irf["model"] = "rf"
        irf["run"] = row.model

        iglmm = pd.read_csv(path_glmm)
        iglmm["species"] = row.species
        iglmm["model"] = "glmm"
        iglmm["run"] = row.model

        df_comp_perf.append(irf)
        df_comp_perf.append(iglmm)
        display(path_glmm)

    # Get data
    df_plot = pd.concat(df_comp_perf)

    if all_or_top9 == "top9":
        df_plot = df_plot[df_plot["species"].isin(top9)]

    if roc_or_auc == "roc":
        df_plot = df_plot[df_plot["test_mean"] > roc_threshold]

    # Get mean and std for each model
    metric_table = df_plot.groupby(["model"]).agg(
        {"test_mean": "mean", "test_sd": "mean"}
    )
    metric_rf = metric_table.loc["rf"]
    metric_rf = f"RF: {metric_rf.test_mean:.2f} ± {metric_rf.test_sd:.2f}"
    metric_glmm = metric_table.loc["glmm"]
    metric_glmm = f"GLMM: {metric_glmm.test_mean:.2f} ± {metric_glmm.test_sd:.2f}"

    # Replace model names with metrics
    df_plot["model"] = df_plot["model"].replace({"rf": metric_rf, "glmm": metric_glmm})

    sns.boxplot(
        data=df_plot,
        y="species",
        x="test_mean",
        hue="model",  # Different colors for each model
        palette="Set2",  # Change color palette as desired
        ax=ax,
        # more space between species
        width=0.7,
    )

    # Add titles and labels to the given ax
    ax.set_xlabel(f"{metric} on Test Set", fontsize=10, fontweight="bold")
    if roc_or_auc == "roc":
        ax.set_ylabel("Species", fontsize=10, fontweight="bold")
    else:
        ax.set_ylabel("")
        ax.yaxis.tick_right()

    # Rotate x-axis labels
    ax.tick_params(axis="x", labelrotation=0, labelsize=10)

    # Show legend on the ax
    pos = "upper right"  # if roc_or_auc == "pr" else "upper left"
    ax.legend(title=None, loc=pos, fontsize=8, handlelength=0.8)

    # Verbose
    # print("Comparison of random forest and logistic glmm models:")
    # print(metric_rf)
    # print(metric_glmm)

    return ax

## Settings


In [None]:
# Get todays analysis folder
from datetime import datetime

dir_today = f"./model_analysis/{datetime.now().strftime('%Y-%m-%d')}/glmm_analysis"
os.makedirs(dir_today, exist_ok=True)
dir_today

In [None]:
# Settings
all_or_top9 = "all"
pval_threshold = 0.05  # For defining significant response
roc_threshold = 0.6  # For defining successful models
min_group_percentage = 0.6  # For aggregating response of same feature
ns_for_insignificant = False  # Wording for insignificant p-values

# Get paths
path_prefix = "./model_runs/all_runs"
if not os.path.exists(path_prefix):
    path_prefix = "/Volumes/SAMSUNG 1TB/all_runs"
    if not os.path.exists(path_prefix):
        raise ValueError(f"Path '{path_prefix}' does not exist. Please check the path.")
path_suffix = "glmm/"

print(f"GLMM path: '{path_prefix}/RUN/SPECIES/{path_suffix}'")

## List of Species and Models


In [None]:
final_species = get_species_with_models("list")

top9 = final_species[:9]

base_dir = path_prefix
models_dir = os.listdir(base_dir)
models_dir = [m for m in models_dir if not m.startswith(".")]
models_dir = sorted(models_dir)

# Merge species and model lists into one df
models_species = list(itertools.product(models_dir, final_species))
df_in = pd.DataFrame(models_species, columns=["model", "species"])
df_in

## Single Runs


In [None]:
# ispecies = "Abies grandis"
# imodel = "run_51"

# glmm_run_per_species_and_model(
#     ispecies,
#     imodel,
#     verbose=True,
#     path_prefix=path_prefix,
#     path_suffix=path_suffix,
#     return_all=False,
#     skip_if_exists=False,
# )

# ! Loop
# for i, row in tqdm(df_in.reset_index(drop=True).iterrows(), total=len(df_in)):

#     glmm_run_per_species_and_model(
#         ispecies=row.species,
#         imodel=row.model,
#         path_prefix=path_prefix,
#         path_suffix=path_suffix,
#         return_all=False,
#         verbose=False,
#         skip_if_exists=True,
#     )
#     # clear_output()

## All runs


In [None]:
# from random_forest_utils import glmm_wrapper_loop

# # Run glmm fitting in parallel:
# out = run_mp(
#     glmm_wrapper_loop,
#     arg_list=split_df_into_list_of_group_or_ns(df_in, "model"),
#     num_cores=10,
#     progress_bar=True,
#     verbose=False,
#     path_prefix=path_prefix,
#     path_suffix=path_suffix,
#     return_all=False,
#     skip_if_exists=True,
# )

In [None]:
# ! osascript -e 'tell app "System Events" to shut down'

## Check GLMM Runs


In [None]:
# Check which models have NO model created
i_missing = []
i_success = []
for i, row in df_in.reset_index(drop=True).iterrows():
    # Check for summary file because it is the last that should be saved per run
    path_glmm = f"{path_prefix}/{row.model}/{row.species}/{path_suffix}/summary.csv"
    if not os.path.isfile(path_glmm):
        i_missing.append(row)
    else:
        i_success.append(row)

if i_missing.__len__() > 0:
    print(f"Missing {i_missing.__len__()} models")
    i_missing = pd.concat(i_missing)
    i_missing = pd.DataFrame(
        {"model": i_missing.model.values, "species": i_missing.species.values}
    )
    print(
        f"Missing models: {i_missing.shape[0]} from {i_missing.model.nunique()} seeds and {i_missing.species.nunique()} species"
    )
else:
    print("All models are present")

i_success = pd.concat(i_success)
i_success = pd.DataFrame(
    {"model": i_success.model.values, "species": i_success.species.values}
)

print(
    f"Successful models: {i_success.shape[0]} from {i_success.model.nunique()} seeds and {i_success.species.nunique()} species"
)

In [None]:
# # Loop over missing models to see which cause errors
# verbose = True
# for i, row in i_missing.reset_index(drop=True).iterrows():
#     if verbose:
#         print(f"🟡 Species: {row.species}\t | Model: {row.model}")

#     if "61" in row.model:
#         print("Skipping: Alnus incana 61")
#         continue

#     glmm_wrapper(
#         ispecies=row.species,
#         imodel=row.model,
#         base_dir=path_prefix,
#         return_all=False,
#         verbose=False,
#     )
#     clear_output()
#     if verbose:
#         print(f"🟡 Species: {row.species}\t | Model: {row.model}")

### Calculate Model Performance


In [None]:
# Get model directories
df_available = glob.glob(f"{path_prefix}/run_*/*/glmm/y_train_pred.csv")
df_available = pd.DataFrame(df_available, columns=["file"])
df_available["species"] = df_available["file"].str.split("/").str[-3]
df_available["model"] = df_available["file"].str.split("/").str[-4]
df_available["base_dir"] = df_available["file"].str.split("/run").str[0] + "/"

# Check for missing runs
print(" --- The following species do not have their 50 seed runs yet: ---")
print(f" - Species found: {df_available['species'].nunique()}")
print(f" - Seeds found: {df_available['model'].nunique()}")

# Reduce to missing runs
df_todo = []
for i, row in df_available.iterrows():
    # Check if file exists
    ifile = row.file.replace(
        "y_train_pred.csv",
        "classification_metrics_fixed_threshold.csv",
    )
    if not os.path.isfile(ifile):
        df_todo.append(row)
df_todo = pd.DataFrame(df_todo)

In [None]:
# Calculate model performance
from random_forest_utils import calculate_glmm_performance

if df_todo.shape[0] == 0:
    print("✅ All model have been run!")
else:
    run_mp(
        calculate_glmm_performance,
        split_df_into_list_of_group_or_ns(df_todo, 10),
        skip_if_csv_exists=False,
        progress_bar=True,
        num_cores=10,
    )

In [None]:
roc_paths = glob.glob(f"{path_prefix}/run_*/**/glmm/roc_auc.csv")
pr_paths = glob.glob(f"{path_prefix}/run_*/**/glmm/pr_auc.csv")
clf_paths = glob.glob(
    f"{path_prefix}/run_*/**/glmm/classification_metrics_fixed_threshold.csv"
)

df_mean = get_metrics_for_all_models_and_species(
    "mean",
    roc_threshold,
    roc_paths,
    pr_paths,
    clf_paths,
)

df_sd = get_metrics_for_all_models_and_species(
    "sd",
    roc_threshold,
    roc_paths,
    pr_paths,
    clf_paths,
)

df_mean.to_csv(f"{dir_today}/glmm_model_performance_summary.csv", index=False)
df_sd.to_csv(f"{dir_today}/glmm_model_performance_summary_std.csv", index=False)

In [None]:
#
from matplotlib.lines import Line2D

# ! Comparing performances of final models across different analyses
# Load and compare final model performances
files = [
    sorted(
        glob.glob(
            "./model_analysis/*/pattern_analysis/by_mk/roc_0.6-min_group_share_0.6/tables/model_performance_summary.csv"
        )
    )[-1],
    f"{dir_today}/glmm_model_performance_summary.csv",
]
analysis_type = ["RF", "GLMM"]

files = pd.DataFrame(
    {
        "file": files,
        "analysis_type": analysis_type,
    }
)

df = []
for i, row in files.iterrows():
    tmp = pd.read_csv(row["file"])
    tmp = tmp.query("Species in ['Mean', 'SD']")
    tmp["analysis_type"] = row["analysis_type"]
    tmp = tmp.rename(
        columns={
            "Train Roc Auc": "Train ROC-AUC",
            "Test Roc Auc": "Test ROC-AUC",
            "Train Pr Auc": "Train PR-AUC",
            "Test Pr Auc": "Test PR-AUC",
        }
    )
    df.append(tmp)

df = pd.concat(df, ignore_index=True)

# Filter only 'Mean' rows
df_mean = df[df["Species"] == "Mean"]

# Metrics and styles
metrics = ["ROC-AUC", "PR-AUC", "Precision", "Recall", "F1"]
linestyles = {"Train": "--", "Test": "-"}
colors = {
    "ROC-AUC": "tab:blue",
    "PR-AUC": "tab:orange",
    "Precision": "tab:green",
    "Recall": "tab:red",
    "F1": "tab:purple",
}
markers = {
    "ROC-AUC": "o",
    "PR-AUC": "s",
    "Precision": "D",
    "Recall": "^",
    "F1": "v",
}

# Melt to long format
train_cols = [f"Train {m}" for m in metrics]
train_cols = [f"Train {m}" for m in metrics]
test_cols = [f"Test {m}" for m in metrics]

df_train = df_mean[["analysis_type"] + train_cols].melt(
    id_vars="analysis_type", var_name="Metric", value_name="Score"
)
df_train["Set"] = "Train"
df_train["Metric"] = df_train["Metric"].str.replace("Train ", "")

df_test = df_mean[["analysis_type"] + test_cols].melt(
    id_vars="analysis_type", var_name="Metric", value_name="Score"
)
df_test["Set"] = "Test"
df_test["Metric"] = df_test["Metric"].str.replace("Test ", "")

df_long = pd.concat([df_train, df_test])

# Melt SD data
df_sd = df[df["Species"] == "SD"]

df_train_sd = df_sd[["analysis_type"] + train_cols].melt(
    id_vars="analysis_type", var_name="Metric", value_name="SD"
)
df_train_sd["Set"] = "Train"
df_train_sd["Metric"] = df_train_sd["Metric"].str.replace("Train ", "")

df_test_sd = df_sd[["analysis_type"] + test_cols].melt(
    id_vars="analysis_type", var_name="Metric", value_name="SD"
)
df_test_sd["Set"] = "Test"
df_test_sd["Metric"] = df_test_sd["Metric"].str.replace("Test ", "")

df_sd_long = pd.concat([df_train_sd, df_test_sd])

# Merge SDs with main data
df_long = pd.merge(
    df_long, df_sd_long, on=["analysis_type", "Metric", "Set"], how="left"
)

# Plot with error bars
fig, ax = plt.subplots(figsize=(5, 5))
for metric in metrics:
    for set_type in ["Train", "Test"]:
        subset = df_long[(df_long["Metric"] == metric) & (df_long["Set"] == set_type)]
        ax.plot(
            subset["analysis_type"],
            subset["Score"],
            color=colors[metric],
            linestyle=linestyles[set_type],
            marker=markers[metric],
            label=f"{metric} ({set_type})",  # Not used for legend directly
            markersize=8,
        )
        # # With Error Bars
        # ax.errorbar(
        #     subset["analysis_type"],
        #     subset["Score"],
        #     yerr=subset["SD"],
        #     color=colors[metric],
        #     linestyle=linestyles[set_type],
        #     marker=markers[metric],
        #     markersize=8,
        #     capsize=4,
        # )

# Labels and style
ax.set_title("Train vs Test Performance between RF and GLMM Models")
ax.set_xlabel("Analysis Type")
ax.set_ylabel("Score")
ax.set_ylim(0, 1.05)
ax.grid(True)
plt.tight_layout()

# Custom legend
metric_handles = [
    Line2D(
        [0],
        [0],
        color=colors[m],
        marker=markers[m],
        linestyle="",
        markersize=8,
        label=m,
    )
    for m in metrics
]

set_handles = [
    Line2D([0], [0], color="gray", linestyle=ls, lw=2, label=st)
    for st, ls in linestyles.items()
]

ax.legend(
    handles=metric_handles + set_handles,
    title="",
    loc="center left",
    bbox_to_anchor=(1.01, 0.5),
    frameon=False,
    borderaxespad=0,
)

plt.savefig(
    f"{dir_today}/glmm_rf_model_performance_comparison.png",
    dpi=300,
    bbox_inches="tight",
)

plt.show()

## Compare RF and GLMM performance


In [None]:
# # Create figure
# fig, axs = plt.subplots(1, 2, figsize=(10, 10))
# axs = axs.flatten()
# ax_performance(axs[0], "roc", "all")
# ax_performance(axs[1], "pr", "all")
# fig.tight_layout()
# save_path = f"{dir_today}/performance_comparison-all_species.png"
# fig.savefig(save_path, dpi=300, bbox_inches="tight")

In [None]:
# # Create figure
# fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# axs = axs.flatten()
# ax_performance(axs[0], "roc", "top9")
# ax_performance(axs[1], "pr", "top9")
# fig.tight_layout()
# # save_path = f"{dir_today}/performance_comparison-top9_species.png"
# # fig.savefig(save_path, dpi=300, bbox_inches="tight")
# plt.show()

## Variable importance comparison


In [None]:
# Load and concat all results

df_all = []
for ispecies in final_species:
    for imodel in models_dir:

        path = f"{path_prefix}/{imodel}/{ispecies}/{path_suffix}/summary.csv"

        # Load if file exists
        if os.path.isfile(path):
            idf = pd.read_csv(
                path,
                index_col=0,
            )[["Estimate", "P-val", "2.5_ci", "97.5_ci", "SE"]]

            # Generalize columns names based on category
            idf["dataset"] = idf.index.map(get_category_from_var_wrapper)
            idf = idf.reset_index().rename({"index": "variable"}, axis=1)

            # Add info
            idf["species"] = ispecies
            idf["model"] = imodel
            idf = move_vars_to_front(idf, ["species", "model", "dataset", "variable"])
            df_all.append(idf)
        else:
            pass
            print(f"Missing: {ispecies} - {imodel}")


# Load data
df_vimp = pd.concat(df_all)
df_vimp = df_vimp[["species", "model", "dataset", "Estimate"]]

# Take estimate as importance proxy by taking the absolute value
df_vimp["Estimate"] = df_vimp["Estimate"].abs() * 100

# Take the mean per model-species-dataset combination (e.g. there are two temperature variables when linear and quad. are included)
df_vimp = df_vimp.groupby(["species", "model", "dataset"]).mean().reset_index()

# # Make wide df_vimp
df_vimp = df_vimp.pivot(
    index=["species", "model"], columns="dataset", values="Estimate"
).reset_index()

# Drop intercept columns
df_vimp = df_vimp.drop("Intercept", axis=1)

# For each row, divide the estimate by the sum of the estimates for that row
df_vimp.iloc[:, 2:] = (
    df_vimp.iloc[:, 2:].div(df_vimp.iloc[:, 2:].sum(axis=1), axis=0) * 100
)

# Display
df_vimp.head(5)

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
ax_dataset_boxplot(
    ax,
    df_vimp,
    df_vimp.columns[2:].tolist(),
    base_fontsize=11,
    pos_spei=4,
    pos_temp=5,
    all_or_top9="top9",
)
ax.set_xlim(0, 100)

## Investigate climate effect


### At model-level


In [None]:
df_speitemp_list = []

for i, row in i_success.iterrows():

    # Check if roc auc files are there
    path_glmm = f"{path_prefix}/{row.model}/{row.species}/{path_suffix}/summary.csv"
    path_perf = f"{path_prefix}/{row.model}/{row.species}/{path_suffix}/roc_auc.csv"

    if not os.path.isfile(path_glmm) or not os.path.isfile(path_perf):
        print(f" - 🚨 GLMM files incomplete for {row.species} - {row.model}")
        continue

    # Load data
    iper = pd.read_csv(path_perf)
    isry = pd.read_csv(path_glmm, index_col=0)
    var_spei = glmm_get_spei_var(isry.index)
    var_temp = glmm_get_temp_var(isry.index)

    if var_spei is None:
        esti_spei = None
        sign_spei = None
        dire_spei = None
        pval_spei = None
        change_spei = None
    else:
        esti_spei = isry.loc[[var_spei]]["Estimate"].values[0]
        sign_spei = isry.loc[[var_spei]]["Sig"].values[0]
        dire_spei = np.sign(isry.loc[[var_spei]]["Estimate"].values[0])
        pval_spei = isry.loc[[var_spei]]["P-val"].values[0]

        if pval_spei < pval_threshold:
            if dire_spei > 0:
                change_spei = "wetter"
            else:
                change_spei = "drier"
        else:
            change_spei = "ns"

    if var_temp is None:
        esti_temp = None
        sign_temp = None
        dire_temp = None
        pval_temp = None
        change_temp = None
    else:
        esti_temp = isry.loc[[var_temp]]["Estimate"].values[0]
        sign_temp = isry.loc[[var_temp]]["Sig"].values[0]
        dire_temp = np.sign(isry.loc[[var_temp]]["Estimate"].values[0])
        pval_temp = isry.loc[[var_temp]]["P-val"].values[0]

        if pval_temp < pval_threshold:
            if dire_temp > 0:
                change_temp = "warmer"
            else:
                change_temp = "cooler"
        else:
            change_temp = "ns"

    # Attach temp and spei information
    df_speitemp_list.append(
        pd.DataFrame(
            {
                "species": row.species,
                "run": row.model.split(" -")[0].split("_")[1],
                "spei": var_spei,
                "sign_spei": sign_spei,
                "dire_spei": dire_spei,
                "pval_spei": pval_spei,
                "change_spei": change_spei,
                "temp": var_temp,
                "sign_temp": sign_temp,
                "dire_temp": dire_temp,
                "pval_temp": pval_temp,
                "change_temp": change_temp,
                "test_roc_auc": iper["test_mean"].values[0],
            },
            index=[0],
        )
    )

# Get dictionary
dict_spei = {
    +1: "drier",
    -1: "wetter",
}

dict_temp = {
    +1: "warmer",
    -1: "cooler",
}

dict_ns = {
    "ns_ns": "ns",
    "ns_wetter": "ns",
    "ns_drier": "ns",
    "warmer_ns": "ns",
    "cooler_ns": "ns",
    "warmer_wetter": "warmer_wetter",
    "warmer_drier": "warmer_drier",
    "cooler_wetter": "cooler_wetter",
    "cooler_drier": "cooler_drier",
}

In [None]:
# Unlist data
df_pattern_per_model = pd.concat(df_speitemp_list)
df_pattern_per_model["change_spei_all"] = df_pattern_per_model["dire_spei"].map(
    dict_spei
)
df_pattern_per_model["change_temp_all"] = df_pattern_per_model["dire_temp"].map(
    dict_temp
)
df_pattern_per_model["change_both_all"] = (
    df_pattern_per_model["change_temp_all"]
    + "_"
    + df_pattern_per_model["change_spei_all"]
)
df_pattern_per_model["change_both_sign"] = (
    df_pattern_per_model["change_temp"] + "_" + df_pattern_per_model["change_spei"]
)
df_pattern_per_model["change_both_sign"] = df_pattern_per_model["change_both_sign"].map(
    dict_ns
)


for all_or_top9 in ["all", "top9"]:
    if all_or_top9 == "top9":
        text = "Top 9 species"
        df_tmp = df_pattern_per_model.query("species in @top9").copy()
    else:
        df_tmp = df_pattern_per_model.copy()
        text = "All species"

    display(f"----- 🚨 {text} ------")

    print("--- Change patterns all ---")
    display(df_tmp.change_temp_all.value_counts(normalize=True).sort_values().round(2))
    display(df_tmp.change_spei_all.value_counts(normalize=True).sort_values().round(2))
    display(df_tmp.change_both_all.value_counts(normalize=True).sort_values().round(2))

    print(f"--- Change patterns significant at p = {pval_threshold}---")
    display(df_tmp.change_temp.value_counts(normalize=True).sort_values().round(2))
    display(df_tmp.change_spei.value_counts(normalize=True).sort_values().round(2))
    print(f"- Both with ns")
    display(df_tmp.change_both_sign.value_counts(normalize=True).sort_values().round(2))
    print(f"- Both without ns")
    display(
        df_tmp.query("change_both_sign != 'ns'")
        .change_both_sign.value_counts(normalize=True)
        .sort_values()
        .round(2)
    )

    # fx(df_tmp, "wetter")
    # fx(df_tmp, "drier")
    # fx(df_tmp, "warmer")
    # fx(df_tmp, "cooler")

In [None]:
keep_nonsignificant = True

if keep_nonsignificant:
    tmp_suffix = "_all"
else:
    tmp_suffix = ""

df_res_lm = (
    df_pattern_per_model.sort_values("species")
    .rename(
        columns={
            f"change_spei{tmp_suffix}": "response_spei",
            f"change_temp{tmp_suffix}": "response_temp",
        }
    )
    .query("test_roc_auc > @roc_threshold")
    .reset_index(drop=True)
)

# display(df_res_lm)

# Group by species and spei/temp variables
df_res_lm_group = (
    df_res_lm.dropna(subset=["response_spei", "response_temp"])
    # .query("dire_spei == 0 or dire_temp == 0")
    .groupby(["species", "spei", "temp"])
)

df_list = []

i = 0
ispecies = ""

for group in df_res_lm_group.groups:

    # Increment group counter
    if ispecies == group[0]:
        group_counter = group_counter + 1
    else:
        group_counter = 1
        ispecies = group[0]

    # Get group
    df_group = df_res_lm_group.get_group(group)
    # Get group size
    group_size = df_group.shape[0]
    # Get pattern percentages
    spei_pattern, spei_value = get_var_and_val(df_group, "spei", min_group_percentage)
    temp_pattern, temp_value = get_var_and_val(df_group, "temp", min_group_percentage)

    idf = pd.DataFrame(
        {
            "species": group[0],
            "spei": group[1],
            "temp": group[2],
            "group": group_counter,
            "group_size": group_size,
            "response_spei": spei_pattern,
            "response_temp": temp_pattern,
            "perc_spei": spei_value,
            "perc_temp": temp_value,
        },
        index=[0],
    )

    df_list.append(idf)
    # display(idf)
    # display(df_group)


# pd.concat(df_list)

In [None]:
# Load data
df_patterns = pd.concat(df_list)
df_patterns["change"] = (
    df_patterns["response_temp"] + "_" + df_patterns["response_spei"]
)

# If ns is in change, then it is ns
df_patterns["change"] = df_patterns["change"].apply(lambda x: "ns" if "ns" in x else x)
df_patterns.head(3)

In [None]:
# For all species
pattern_both = plot_pattern_dist(df_patterns, "change")
pattern_spei = plot_pattern_dist(df_patterns, "response_spei")
pattern_temp = plot_pattern_dist(df_patterns, "response_temp")

patterns_merged_allspecies = pd.concat(
    [
        pattern_temp.replace({"ns": "ns (temp)"}),
        pattern_spei.replace({"ns": "ns (spei)"}),
        pattern_both,
    ],
    axis=0,
).reset_index(drop=True)

# For all species
pattern_both = plot_pattern_dist(df_patterns.query("species in @top9"), "change")
pattern_spei = plot_pattern_dist(df_patterns.query("species in @top9"), "response_spei")
pattern_temp = plot_pattern_dist(df_patterns.query("species in @top9"), "response_temp")

patterns_merged_top9 = pd.concat(
    [
        pattern_temp.replace({"ns": "ns (temp)"}),
        pattern_spei.replace({"ns": "ns (spei)"}),
        pattern_both,
    ],
    axis=0,
).reset_index(drop=True)

In [None]:
df_models_with_grouped_response = pd.merge(
    df_res_lm[["species", "spei", "temp", "run", "pval_spei", "pval_temp"]],
    df_patterns[
        [
            "species",
            "spei",
            "temp",
            "group_size",
            "response_temp",
            "response_spei",
            "change",
        ]
    ],
    how="left",
    on=["species", "spei", "temp"],
)

# Keep original pval for checking later on
df_models_with_grouped_response["pval_spei_org"] = (
    df_models_with_grouped_response["pval_spei"].copy().round(3)
)
df_models_with_grouped_response["pval_temp_org"] = (
    df_models_with_grouped_response["pval_temp"].copy().round(3)
)

# Attach pval for both spei and temp
df_models_with_grouped_response["pval_spei"] = df_models_with_grouped_response[
    "pval_spei"
].fillna(1)
df_models_with_grouped_response["pval_spei"] = (
    df_models_with_grouped_response["pval_spei"] < pval_threshold
)
df_models_with_grouped_response["pval_temp"] = df_models_with_grouped_response[
    "pval_temp"
].fillna(1)

df_models_with_grouped_response["pval_temp"] = (
    df_models_with_grouped_response["pval_temp"] < pval_threshold
)

df_models_with_grouped_response["pval_both"] = (
    df_models_with_grouped_response["pval_spei"] == True
) & (df_models_with_grouped_response["pval_temp"] == True)

# Attach spei_temp pair
df_models_with_grouped_response["spei_temp"] = (
    df_models_with_grouped_response["spei"]
    + "-"
    + df_models_with_grouped_response["temp"]
)

df_models_with_grouped_response.sort_values(
    ["species", "group_size", "spei_temp"], ascending=[True, False, True]
).head(10)

In [None]:
#
# ! Assess significance of features
list_all = []
list_top9 = []

for pattern in [
    "warmer",
    "cooler",
    "wetter",
    "drier",
    "warmer_wetter",
    "warmer_drier",
    "cooler_wetter",
    "cooler_drier",
    "ns (temp)",
    "ns (spei)",
    "ns",
]:
    if pattern == "warmer" or pattern == "cooler" or pattern == "ns (temp)":
        var_all = "response_temp"
        var_sign = "pval_temp"
        search_pattern = pattern
    elif pattern == "wetter" or pattern == "drier" or pattern == "ns (spei)":
        var_all = "response_spei"
        var_sign = "pval_spei"
        search_pattern = pattern
    else:
        var_all = "change"
        var_sign = "pval_both"
        search_pattern = pattern

    if "ns" in pattern:
        search_pattern = "ns"

    # All
    xxx = df_models_with_grouped_response.query(f"{var_all} == '{search_pattern}'")
    xxx = xxx[var_sign].value_counts(normalize=True).sort_index()
    xxx["response"] = var_all
    xxx["pattern"] = pattern
    xxx = xxx.to_frame().T
    list_all.append(xxx)

    # Top9
    xxx = df_models_with_grouped_response.query(
        f"{var_all} == '{search_pattern}' and species in @top9"
    )
    xxx = xxx[var_sign].value_counts(normalize=True).sort_index()
    xxx["response"] = var_all
    xxx["pattern"] = pattern
    xxx = xxx.to_frame().T
    list_top9.append(xxx)


df_significance_all_species = pd.concat(list_all).reset_index(drop=True)
df_significance_all_species = df_significance_all_species.drop(
    columns=[
        "response",
        # "pval_spei",
    ]
).rename(columns={False: "ns", True: "sig"})

df_significance_top9 = pd.concat(list_top9).reset_index(drop=True)
df_significance_top9 = df_significance_top9.drop(
    columns=[
        "response",
        # "pval_spei",
    ]
).rename(columns={False: "ns", True: "sig"})

# Show
display(df_significance_all_species)
display(df_significance_top9)

In [None]:
# ! Get final df for plotting
for all_or_top9 in ["all", "top9"]:
    print(f" --- {all_or_top9} ---")

    df_tmp = df_vimp.copy()
    if all_or_top9 == "top9":
        df_tmp = df_tmp[df_tmp["species"].isin(top9)]

    if all_or_top9 == "all":
        df_plot = pd.merge(
            patterns_merged_allspecies,
            df_significance_all_species,
            left_on="change_simple",
            right_on="pattern",
            how="left",
        )
        left_ylim = 60
        ytick_labels = [
            "Warmer".title(),
            "Cooler".title(),
            "Other".title(),
            "Drier".title(),
            "Wetter".title(),
            "Other".title(),
            "Warmer + Drier".title(),
            "Warmer + Wetter".title(),
            "Cooler + Drier".title(),
            "Cooler + Wetter".title(),
            "Other".title(),
        ]
    else:
        df_plot = pd.merge(
            patterns_merged_top9,
            df_significance_top9,
            left_on="change_simple",
            right_on="pattern",
            how="left",
        )
        left_ylim = 60

    # Split group size into ns and sig
    df_plot["perc_sign"] = df_plot["group_size_rel"] * (1 - df_plot["ns"])
    df_plot["perc_ns"] = df_plot["group_size_rel"] * df_plot["ns"]

    # Fix the perc_xxx columns where ns was NA. FIll with group_size_rel
    df_plot["perc_sign"] = df_plot["perc_sign"].fillna(df_plot["group_size_rel"])
    df_plot["perc_ns"] = df_plot["perc_ns"].fillna(0)

    # Switch ns for unclear
    df_plot["change_simple"] = df_plot["change_simple"].replace(
        {"ns": "unclear", "ns (spei)": "unclear (spei)", "ns (temp)": "unclear (temp)"}
    )

    # Show
    display(
        df_plot[
            [
                "change_simple",
                "group_size_rel",
                "pattern",
                "ns",
                "sig",
                "perc_sign",
                "perc_ns",
            ]
        ]
    )

    # ! Make Plot
    color_temp = "#77422C"
    color_spei = "#D1A289"
    color_rest = "lightgrey"
    color_wd = sns.color_palette("Reds", 3)[-1]
    color_ww = sns.color_palette("Blues", 3)[-1]
    color_other = "lightgrey"

    plot_bars_dataset_pattern(
        df_plot,
        df_tmp,
        all_or_top9=all_or_top9,
        color_temp=color_temp,
        color_spei=color_spei,
        color_rest=color_rest,
        color_wd=color_wd,
        color_ww=color_ww,
        color_other=color_other,
        filepath=f"{dir_today}/{all_or_top9}.png",
        ytick_labels=ytick_labels,
        left_ylim=left_ylim,
    )

    # ! Importance distribution
    df_tmp = df_tmp.drop(columns=["species", "model"]).mean()
    imp_stand = (
        df_tmp["Light Competition"]
        + df_tmp["Species Competition"]
        + df_tmp["Stand Structure"]
        + df_tmp["Tree Size"]
    ).round(0)

    imp_climate = (df_tmp["Temperature"] + df_tmp["SPEI"]).round(2)
    imp_soil = (df_tmp["Soil Fertility"] + df_tmp["Soil Water Conditions"]).round(0)
    imp_ndvi = (df_tmp["NDVI"]).round(0)

    display(df_tmp.round(2).sort_values(ascending=False))
    print(f"Sum of Stand-describing variables: {imp_stand}")
    print(f"Sum of Climate-describing variables: {imp_climate}")
    print(f"Sum of Soil-describing variables: {imp_soil}")
    print(f"Sum of NDVI variables: {imp_ndvi}\n\n ")

In [None]:
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# User input
fig_top = f"{dir_today}/top9.png"
fig_bot = f"{dir_today}/all.png"
output_file = f"{dir_today}/si_summary_importance.png"

# Panel label positions in figure coordinates
label_letters = ["A", "B", "C", "D"]
label_x_coords = [0.025, 0.515, 0.025, 0.515]
label_y_coords = [0.865, 0.865, 0.445, 0.445]

# Square figure size
fig_size = 8  # inches
fig = plt.figure(figsize=(fig_size, fig_size))
gs = GridSpec(nrows=2, ncols=1, height_ratios=[1, 1], figure=fig)
axs = [fig.add_subplot(gs[i, 0]) for i in range(2)]

# Plot images
for ax, img_path in zip(axs, [fig_top, fig_bot]):
    img = mpimg.imread(img_path)
    ax.imshow(img)
    ax.axis("off")

# Add labels using fixed (x, y) in figure coordinates
for label, x, y in zip(label_letters, label_x_coords, label_y_coords):
    fig.text(
        x,
        y,
        label,
        fontsize=12,
        fontweight="bold",
        ha="left",
        va="top",
        bbox=dict(facecolor="white", edgecolor="none", pad=2),
    )

# Save and show
plt.tight_layout(h_pad=0)
plt.savefig(output_file, dpi=300, bbox_inches="tight")
plt.show()
print(f"Figure saved to {output_file}")

---
