# Tree Mortality Predictions


In [None]:
import sys

sys.path.insert(0, "../../src")
from imports import *

init_notebook()

from IPython.display import clear_output  # For clearing the output of a cell
import json


# List available data
tmp = list_predictor_datasets(return_list=False)
display("--------")
print("\nList of available species and their percentages")
tmp = get_final_nfi_data_for_analysis(verbose=False).query(
    "tree_state_change in ['alive_alive', 'alive_dead']"
)
# Get normalized and non normalized counts
species = tmp["species_lat2"].value_counts()
species_norm = tmp["species_lat2"].value_counts(normalize=True)
for i in species.index:
    print(f"{i:25} {species[i]:<30} {species_norm[i]*100:.2f}%")

## Function Definition


In [None]:
def run_all(species, user_input, base_dir=None):
    # NFI file name
    # Folder prefix
    folder_suffix = ""  # None for no prefix
    user_input["subset_group"] = [species]
    user_input["predictor_datasets"] = [
        # "agroparistech_soil",
        # "apt",
        # "cs_disturbances",
        "digitalis_tmax",
        "digitalis_tmin",
        "digitalis_tmoy",
        # "edo",
        "esa_landcover_percentages",
        "forest_biodiversity",
        "forest_competition",
        "forest_gini",
        # "forest_health",
        # "forest_structure_idp",
        "france_dem",
        "human_activity",
        # "ndvi",
        # "nfi_site_information",
        # "safran",
        # "soil",
        "spei_minmean",
        "spei_trend",
        # "treecover",
        # "metrics_of_change",  # > Growth / Mortality / Recruitment / Logging -> Specify variables below
    ]

    # ! Get current directory
    if base_dir is None:
        user_input["current_dir"] = create_new_run_folder_treemort(
            user_input["subset_group"][0]
        )

    else:
        # Attach / if not present
        if base_dir[-1] != "/":
            base_dir += "/"

        user_input["current_dir"] = base_dir + species + "/"
        os.makedirs(user_input["current_dir"], exist_ok=True)

    current_dir = user_input["current_dir"]

    # Write to file
    file_path = f"{current_dir}/__user_input.txt"
    with open(file_path, "w") as file:
        for key, value in user_input.items():
            if isinstance(value, list):
                file.write(f"{key}:")
                for v in value:
                    file.write(f"\n - {v}")
                file.write("\n\n")
            else:
                file.write(f"{key}:\n - {value}\n\n")

    # Add a few quick_see_files
    write_txt(
        f"{current_dir}/_{user_input['subset_group'][0]}_{user_input['subset'][0]}.txt"
    )

    if user_input["do_smote_test_validation"]:
        write_txt(f"{user_input['current_dir']}/游뚿 SMOTE for RFE-CV val set.txt")

    if user_input["do_smote_test_final"]:
        write_txt(f"{user_input['current_dir']}/游뚿 SMOTE for final test set.txt")

    write_txt(
        f"{user_input['current_dir']}/游뚿 best model based on {user_input['best_model_decision']}.txt"
    )

    if user_input["description_file"] is not None:
        write_txt(f"{user_input['current_dir']}/_{user_input['description_file']}.txt")

    # write_txt(
    #     f"{current_dir}/_SUBSET-REGION_{user_input['region_subset']}-{user_input['region_subset_group']}.txt"
    # )
    ## Load Data

    # ! Load Cleaned Dataset
    df_raw = pd.read_feather(here("data/final/nfi/nfi_ready_for_analysis.feather"))
    ## Apriori Filter

    df_filter = df_raw.copy()
    print(f" - Initial shape of nfi data: {df_filter.shape}")

    # # Keep only alive_alive and alive_dead
    # df_filter = df_filter.query("tree_state_change in ['alive_alive', 'alive_dead']")

    # ## Filters following work by A. Taccoen
    # # No forest edges: plisi 1,2 indicate forest edges in the plot
    # df_filter = df_filter.query("plisi not in [1, 2]")

    # # No groves
    # df_filter = df_filter.query(
    #     "utip_1 != 'V' and utip_2 != 'V' and csa_1 != 2 and csa_2 != 2"
    # )

    # # No natural incidences
    # df_filter = df_filter.query(
    #     "nincid_1 in ['0', 'Missing'] and nincid_2 in ['0', 'Missing']"
    # )

    # # No broken [T] or fallen/windthrown [Z, A, 1] trees (or cut [6, 7])
    # df_filter = df_filter.query("veget5 not in ['Z', '1', 'T', 'A', '6', '7']")

    # # No sites with less than 10% of trees alive at first visit
    # df_filter = df_filter.query("share_alive > 0.1")

    # # No sites with less than 10% of trees DBH >= 7.5cm
    # df_filter = df_filter.query("share_larger75dbh > 0.1")
    # # df_filter = df_filter.query("share_smaller75dbh < 0.9")

    ## Verbose
    # filter_report("", df_raw, df_filter)
    ### Target

    # Load Tree Level Data
    df_subset = df_filter.copy()

    # Filter out trees that do not belong to the desired group
    # Check if subset variables are in the dataset
    for subset in user_input["subset"]:
        if subset not in df_subset.columns:
            raise KeyError(f"{subset} not in columns")

    df_before = df_subset.copy()

    for subset in user_input["subset"]:
        df_subset = df_subset[df_subset[subset].isin(user_input["subset_group"])]
        # print(
        #     f"Kept {len(df_subset)} trees based on {subset}: {user_input['subset_group']}"
        # )

    # Filter out irrelevant trees
    df_subset = df_subset.query(
        "tree_state_change == 'alive_alive' or tree_state_change == 'alive_dead'"
    )
    # Encode target
    df_subset["target"] = (
        df_subset["tree_state_change"]
        .copy()
        .apply(lambda x: 1 if x == "alive_dead" else 0)
    )

    # Clean df a bit
    df_subset = move_vars_to_front(df_subset, ["idp", "tree_id", "target"])

    # Keep target dataset separately
    df_target = df_subset[["idp", "tree_id", "target"]]
    # display(df_target)
    # display(df_target.target.value_counts())
    # display(df_target.target.value_counts(normalize=True))
    # df_target.target.value_counts(normalize=True).plot(kind="bar")
    # plt.show()

    # Break function if only alive trees
    if df_target.target.value_counts().shape[0] == 1:
        display(df_target.target.value_counts())
        print(f" - Skipping because too few dead trees")
        write_txt(f"{current_dir}/丘멆잺 too few dead trees.txt")
        return None

    # Break function if too little dead trees
    if df_target.target.value_counts()[1] < 35:
        display(df_target.target.value_counts())
        print(f" - Skipping because too few dead trees")
        write_txt(f"{current_dir}/丘멆잺 too few dead trees.txt")
        return None

    print(f" - Shape of target dataset: {df_target.shape}")

    ### ! Predictors ----------------------------------------------------------------
    # #! Initiate dictionary and df
    dict_preds = {}
    df_preds = df_subset.copy()[["idp", "tree_id"]]

    #! Tree Properties
    # Using df_subset from above to pick variables
    voi = ["htot_final", "c13_rel", "c13_1"]
    df_tree = df_subset[["idp", "tree_id"] + voi]
    df_preds = df_preds.merge(df_tree, on=["idp", "tree_id"], how="left")
    dict_preds = add_vars_to_dict("Tree", df_tree, dict_preds)

    #! Stand Properties
    # Using df_subset from above to pick variables
    df_stand = df_subset[["idp", "tree_id", "social_status"]]

    # Using separately calculated metrics
    df_stand = (
        df_stand.merge(
            attach_or_load_predictor_dataset("forest_competition"),
            on=["idp", "tree_id"],
            how="left",
        )
        .merge(
            attach_or_load_predictor_dataset("forest_biodiversity"),
            on=["idp"],
            how="left",
        )
        .merge(
            attach_or_load_predictor_dataset("forest_gini"),
            on=["idp"],
            how="left",
        )
    )
    df_preds = df_preds.merge(df_stand, on=["idp", "tree_id"], how="left")
    dict_preds = add_vars_to_dict("Stand", df_stand, dict_preds)

    #! Carrying Capacity
    df_cc = attach_or_load_predictor_dataset("forest_carrying_capacity")
    df_preds = df_preds.merge(df_cc, on="idp", how="left")
    dict_preds = add_vars_to_dict("Carrying Capacity", df_cc, dict_preds)

    #! Topography
    df_topo = attach_or_load_predictor_dataset("france_dem")
    # Keep only variables at 1000m resolution (we will use this as the main resolution)
    df_topo = df_topo[["idp"] + [var for var in df_topo.columns if "1000" in var]]
    # Remove dem1000_ and _mean from variable names
    df_topo.columns = ["idp"] + [
        var.replace("dem1000_", "").replace("_mean", "") for var in df_topo.columns[1:]
    ]
    # Attach to df_preds
    df_preds = df_preds.merge(df_topo, on="idp", how="left")
    # Save variables to dictionary
    dict_preds = add_vars_to_dict("Topography", df_topo, dict_preds)

    #! Soil Conditions
    df_soil = attach_or_load_predictor_dataset("agroparistech_soil")
    # Clean variable names
    df_soil.columns = [var.replace("soil_", "") for var in df_soil.columns]
    df_soil = df_soil.drop(columns=["first_year"])
    # Attach to df_preds
    df_preds = df_preds.merge(df_soil, on="idp", how="left")
    # Save variables to dictionary
    dict_preds = add_vars_to_dict("Soil", df_soil, dict_preds)

    #! Temperature
    drop_cols = ["idp", "first_year", "yrs_before_second_visit"]
    df_temp = pd.concat(
        [
            attach_or_load_predictor_dataset("digitalis_tmoy"),
            attach_or_load_predictor_dataset("digitalis_tmin").drop(columns=drop_cols),
            attach_or_load_predictor_dataset("digitalis_tmax").drop(columns=drop_cols),
        ],
        axis=1,
    )

    # > Note: Removed annual metrics because they are dominated by summer or winter temperatures anyways
    for col in df_temp.columns:
        if "ann" in col:
            df_temp = df_temp.drop(columns=col)
        elif "trend" in col:
            df_temp = df_temp.drop(columns=col)
        elif "slope" in col:
            df_temp = df_temp.drop(columns=col)
    display("游릮游릮游릮 CURRENTLY NOT USING Temperature TREND VARIABLES 游릮游릮游릮")

    # Attach to df_preds
    df_preds = df_preds.merge(df_temp, on="idp", how="left")

    # Save variables to dictionary
    dict_preds = add_vars_to_dict("Temperature", df_temp, dict_preds)

    #! SPEI
    df_spei = pd.merge(
        attach_or_load_predictor_dataset("spei_trend"),
        attach_or_load_predictor_dataset("spei_minmean"),
        on="idp",
        how="left",
    )

    df_spei = attach_or_load_predictor_dataset("spei_minmean")
    display("游릮游릮游릮 CURRENTLY NOT USING SPEI TREND VARIABLES 游릮游릮游릮")

    # Rename columns from numbers to months
    df_spei.columns = [
        var.replace("-1_", "-jan_")
        .replace("-2_", "-feb_")
        .replace("-3_", "-mar_")
        .replace("-4_", "-apr_")
        .replace("-5_", "-may_")
        .replace("-6_", "-jun_")
        .replace("-7_", "-jul_")
        .replace("-8_", "-aug_")
        .replace("-9_", "-sep_")
        .replace("-10_", "-oct_")
        .replace("-11_", "-nov_")
        .replace("-12_", "-dec_")
        .replace("-13_", "-ann_")
        for var in df_spei.columns
    ]

    # From df_spei select only variables with the following patterns:
    # > Note: Removed "ann" to focus on seasons
    spei_durations = [f"spei{i}-" for i in [1, 3, 6, 9, 12, 15, 18, 21, 24]]
    spei_months = [f"*-{i}_*" for i in ["feb", "may", "aug", "nov"]]
    spei_subset = match_variables(df_spei, spei_durations)
    spei_subset = match_variables(df_spei[spei_subset], spei_months)

    df_spei = df_spei[["idp"] + spei_subset]

    df_preds = df_preds.merge(df_spei, on="idp", how="left")
    dict_preds = add_vars_to_dict("SPEI", df_spei, dict_preds)

    #! Management
    df_human = attach_or_load_predictor_dataset("human_activity")
    df_preds = df_preds.merge(df_human, on="idp", how="left")
    dict_preds = add_vars_to_dict("Management", df_human, dict_preds)

    #! Landcover
    # df_lc = attach_or_load_predictor_dataset("esa_landcover_percentages")
    # df_preds = df_preds.merge(df_lc, on="idp", how="left")
    # dict_preds = add_vars_to_dict("Land Cover", df_lc, dict_preds)

    #! NDVI
    df_ndvi = attach_or_load_predictor_dataset("ndvi")
    df_preds = df_preds.merge(df_ndvi, on="idp", how="left")
    dict_preds = add_vars_to_dict("NDVI", df_ndvi, dict_preds)

    # ! Align direction of variables
    # Increasing distance to road should mean more management
    df_preds.dist_road = df_preds.dist_road.replace({0: 4, 1: 3, 3: 1, 4: 0})

    # ! Update dictionary --------------------------------------------------------------------------------
    dict_preds_org = dict_preds.copy()
    dict_preds_org
    dict_preds = dict_preds_org.copy()
    dict_preds.pop("Tree", None)
    dict_preds.pop("Stand", None)
    dict_preds.pop("Soil", None)
    dict_preds.pop("Carrying Capacity", None)

    dict_preds["Tree Size"] = [
        "htot_final",
        "c13_1",
    ]

    dict_preds["Light Competition"] = [
        "c13_rel",
        "social_status",
        "competition_larger",
        "competition_larger_rel",
    ]

    dict_preds["Species Competition"] = [
        "competition_same_species",
        "competition_same_species_rel",
        "competition_other_species",
        "competition_other_species_rel",
        "belongs_to_dom_spec",
        "num_species",
        "simpson_species",
        "shannon_species",
    ]

    dict_preds["Stand Structure"] = [
        "num_trees",
        "gini_ba_1",
        "mean_dbh",
        "carrying_capacity",
        "competition_total",
    ]
    # dict_preds["stand_diversity"] = ["num_species", "simpson_species", "shannon_species"]
    dict_preds["Soil Fertility"] = [
        # "waterlogging_temp",
        # "waterlogging_perm",
        # "swhc",
        "CN",
        "pH",
    ]

    dict_preds["Soil Water Conditions"] = [
        "waterlogging_temp",
        "waterlogging_perm",
        "swhc",
    ]

    # dict_preds["soil_conditions"] = []

    # Save dictionary to file
    with open(f"{current_dir}/dict_preds.json", "w") as f:
        json.dump(dict_preds, f)

    # ! DATA PREPARATION --------------------------------------------------------------------------------

    #### Removing NAs
    df_before = df_preds.copy()

    # Get Na percentages per columns
    l_nas = round(
        df_before.drop(columns=["idp", "tree_id"])
        .isna()
        .sum()
        .sort_values(ascending=False)
        / len(df_before)
        * 100,
        2,
    )

    # print("Percentages of NA values in the predictor dataset:")
    # for i in range(0, len(l_nas)):
    # print(f"- {l_nas.index[i]:25}: {l_nas[i]}%")
    # display()

    # Drop where more than 10% of values are missing
    dropped_vars = []
    for var in l_nas[l_nas > 10].index:
        # print(f"Dropping '{var}'")
        dropped_vars = dropped_vars.append(var)
        df_before = df_before.drop(columns=var)
    # print(f"\nDropping variables with more than 10% missing values: {dropped_vars}")

    # Impute mean values where less than 10% of values are missing
    l_imp = l_nas.copy()
    l_imp = l_imp[l_imp > 00]
    l_imp = l_imp[l_imp <= 10]
    imp_vars = l_imp.index
    # display()
    # print(
    #     f"\nImputing mean values for variables with less than 10% missing values: {imp_vars}"
    # )
    for var in l_imp.index:
        # print(f"Imputing mean for '{var}'")
        df_before[var] = df_before[var].fillna(df_before[var].mean())

    df_nonas = df_before.copy()
    #### OHE

    # Get temporary df
    df_ohe = df_nonas.copy()

    # Get all variables names before one-hot encoding
    all_var_names_before_ohe = sorted(df_ohe.columns.to_list())

    # Set variables to not ohe:
    my_vars_not_to_ohe = ["test_train_strata", "target", "idp", "tree_id"]

    # Do the OHE
    df_ohe = do_ohe(df_ohe, my_vars_not_to_ohe, verbose=False)

    # Get all variables names after one-hot encoding
    all_var_names_after_ohe = sorted(df_ohe.columns.to_list())

    # Get variable dictionary
    var_ohe_dict = {}
    for var in all_var_names_before_ohe:
        sub_vars = []

        if var in all_var_names_after_ohe:
            # If the variable was not ohe, it stays the same
            var_ohe_dict[var] = [var]
            continue
        else:
            # If the variable was ohe, search for pattern and add it
            pattern = r"^" + var + r"_.*"
            for sub_var in all_var_names_after_ohe:
                # print(pattern, sub_var, re.match(pattern, sub_var))
                if re.match(pattern, sub_var):
                    sub_vars.append(sub_var)
        var_ohe_dict[var] = sub_vars

    # Print which variable has how many sub-variables
    # print("\n---- Variable OHE Count ----")
    # for k, v in var_ohe_dict.items():
    #     if len(v) > 1:
    # print(f" - {k}: {len(v)}")
    ### Final Dataset

    # Get final dataset from above
    df_predictors_final = df_ohe.copy()

    # Raise error if target and predictor df have not same number of rows
    if df_target.shape[0] != df_predictors_final.shape[0]:
        raise ValueError(
            f"Target and predictor datasets have different number of rows: {df_target.shape[0]} vs {df_predictors_final.shape[0]}"
        )

    # Merge to get correct order
    df_target_pred_final = pd.merge(
        df_target, df_predictors_final, on=["idp", "tree_id"], how="left"
    )

    df_target_pred_final = df_target_pred_final.drop(
        columns=["idp", "tree_id", "first_year"], errors="ignore"
    )
    # df_target_pred_final.to_csv("df_final_target_predictors.csv", index=False)
    #### Test/Train Split

    # Get df
    df_for_splitting = df_target_pred_final.copy()
    print(f" - Shape of df before splitting: \t {df_for_splitting.shape}")

    X = df_for_splitting.drop("target", axis=1)
    y = df_for_splitting["target"]

    # Split the data into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=user_input["test_split"],
        random_state=user_input["seed_nr"],
        stratify=y,
    )

    Xy_train = pd.concat([y_train, X_train], axis=1).reset_index(drop=True)
    Xy_test = pd.concat([y_test, X_test], axis=1).reset_index(drop=True)

    # print(f"Shape of Xy_train:\t\t {Xy_train.shape}")
    # print(f"Shape of Xy_test:\t\t {Xy_test.shape}")
    Xy_train.target.value_counts()
    Xy_test.target.value_counts()
    #### Debug for removing vars

    # print(f"Shape of Xy_train before removal: \t {Xy_train.shape}")
    # print(f"Shape of Xy_test before removal: \t {Xy_test.shape}")

    # SPEI
    if "SPEI" in dict_preds:

        spei_train = Xy_train[
            [var for var in dict_preds["SPEI"] if var in Xy_train.columns]
        ]
        spei_test = Xy_test[
            [var for var in dict_preds["SPEI"] if var in Xy_test.columns]
        ]

        # Xy_train = Xy_train.drop(columns=dict_preds["SPEI"], errors="ignore")
        # Xy_test = Xy_test.drop(columns=dict_preds["SPEI"], errors="ignore")

        # dict_preds.pop("SPEI", None)

    # TEMPERATURE
    # if "Temperature" in dict_preds:
    #     Xy_train = Xy_train.drop(columns=dict_preds["Temperature"], errors="ignore")
    #     Xy_test = Xy_test.drop(columns=dict_preds["Temperature"], errors="ignore")
    #     dict_preds.pop("Temperature", None)

    # print(f"Shape of Xy_train:\t\t\t {Xy_train.shape}")
    # print(f"Shape of Xy_test:\t\t\t {Xy_test.shape}")

    print(f" - Shape of Xy_train:\t\t\t {Xy_train.shape}")
    print(f" - Shape of Xy_test:\t\t\t {Xy_test.shape}")

    # Keep original dfs for saving tree ID further below
    df_target_for_treeid = df_target.copy()
    df_predictors_final_for_treeid = df_predictors_final.copy()

    # ! RFE ------------------------------------------------------------------------------
    display(" --- FEATURE ELIMINATION ---")
    rfecv_params = {
        "n_estimators": 100,
        "max_depth": 8,
        "max_features": 0.01,
        "bootstrap": True,
        "criterion": "gini",
    }

    df_cvmetrics_per_nfeatures = run_rfecv_treemort(
        dict_categories=dict_preds.copy(),
        var_ohe_dict=var_ohe_dict.copy(),
        Xy_train_for_rfe=Xy_train.copy(),
        user_input=user_input,
        rfecv_params=rfecv_params,
        debug_stop=False,
        debug_stop_after_n_iterations=10,
        verbose=False,
    )

    #! Report best variables ----------------------------------------------------------------
    display(" --- BEST FEATURES ---")
    if user_input["method_validation"] == "oob":
        user_input["best_model_metric"] = "oob"

    if user_input["best_model_decision"] == "best_metric":

        ohed_variables_in_final_model = (
            df_cvmetrics_per_nfeatures.sort_values(
                by=user_input["best_model_metric"], ascending=False
            )
            .head(1)["ohe_vars_in_model"]
            .values[0]
        )

        non_ohed_variables_in_final_model = (
            df_cvmetrics_per_nfeatures.sort_values(
                by=user_input["best_model_metric"], ascending=False
            )
            .head(1)["non_ohe_vars_in_model"]
            .values[0]
        )

        best_score = (
            df_cvmetrics_per_nfeatures.sort_values(
                by=user_input["best_model_metric"], ascending=False
            )
            .head(1)[user_input["best_model_metric"]]
            .values[0]
        )

    elif user_input["best_model_decision"] == "best_per_category":
        dict_len = len(dict_preds)

        ohed_variables_in_final_model = df_cvmetrics_per_nfeatures.query(
            "n_features == @dict_len"
        )["ohe_vars_in_model"].values[0]

        non_ohed_variables_in_final_model = df_cvmetrics_per_nfeatures.query(
            "n_features == @dict_len"
        )["non_ohe_vars_in_model"].values[0]

        best_score = df_cvmetrics_per_nfeatures.query("n_features == @dict_len")[
            user_input["best_model_metric"]
        ].values[0]

    elif user_input["best_model_decision"] == "best_metric_max1":
        dict_len = len(dict_preds)

        max1cat = df_cvmetrics_per_nfeatures.query("n_features <= @dict_len")

        non_ohed_variables_in_final_model = (
            max1cat.sort_values(by=user_input["best_model_metric"], ascending=False)
            .head(1)["non_ohe_vars_in_model"]
            .values[0]
        )

        best_score = (
            max1cat.sort_values(by=user_input["best_model_metric"], ascending=False)
            .head(1)[user_input["best_model_metric"]]
            .values[0]
        )

    else:
        raise ValueError(
            f"Invalid selection for final model decision!: {user_input['best_model_decision']}"
        )

    txt_best_var = f"""
    - Best score: {user_input['best_model_metric']} = {round(best_score,3)} based on model selecting by '{user_input['best_model_decision']}
    
    - Variables in best model (ohe):\t{ohed_variables_in_final_model}
    
    - Variables in best model (non-ohe):\t{sorted(non_ohed_variables_in_final_model)}
        """

    # print(txt_best_var)
    with open(f"{current_dir}/final_model_variables.txt", "w") as f:
        f.write(txt_best_var)

    # ! Select variables of best model
    Xy_train_best_preds = Xy_train.copy()[["target"] + ohed_variables_in_final_model]

    #! Plot ----------------------------------------------------------------
    # display()
    # Get max number of features
    x_max = df_cvmetrics_per_nfeatures["n_features"].max()

    # Get list of variables to plot
    all_metrics = [
        # "balanced_accuracy",
        "roc_auc",
        "accuracy",
        "f1",
        "recall",
        "precision",
    ]

    if user_input["method_validation"] == "oob":
        all_metrics = ["oob"] + all_metrics

    # Start figure
    fig, axs = plt.subplots(len(all_metrics), 1, figsize=(7, 10))

    # Loop over every variable
    for i, metric in enumerate(all_metrics):

        # Get max score
        max_score = df_cvmetrics_per_nfeatures.sort_values(by=metric, ascending=False)[
            metric
        ].iloc[0]

        # Get n_features at max accuracy
        max_score_features = df_cvmetrics_per_nfeatures.sort_values(
            by=metric, ascending=False
        ).iloc[0, 0]

        # Add lines
        axs[i].plot(
            df_cvmetrics_per_nfeatures["n_features"],
            df_cvmetrics_per_nfeatures[metric],
        )

        # Add error bands
        axs[i].fill_between(
            df_cvmetrics_per_nfeatures["n_features"],
            df_cvmetrics_per_nfeatures[metric]
            - df_cvmetrics_per_nfeatures[f"{metric}_sd"],
            df_cvmetrics_per_nfeatures[metric]
            + df_cvmetrics_per_nfeatures[f"{metric}_sd"],
            alpha=0.3,
        )

        if i == len(all_metrics) - 1:
            axs[i].set_xlabel("Number of Features")
        else:
            axs[i].set_xlabel("")

        # axs[i].set_ylabel("Accuracy Score")
        axs[i].set_ylabel(metric)
        axs[i].set_title("")
        axs[i].set_xlim(x_max, 0)
        axs[i].set_ylim((max_score * 0.75), (max_score * 1.15))

        # Add red vertical line for highest accuracy score
        axs[i].axvline(x=max_score_features, color="red")

        # Add text of max_score_features to axs in red
        axs[i].text(
            x_max - 5,
            max_score * 0.875,
            f"Optimal Nr. of Features: {int(max_score_features)} at {metric} = {round(max_score,3)}",
            color="red",
        )

    # LAYOUT
    plt.tight_layout()
    plt.savefig(f"{current_dir}/fig_refcv_results.png")
    plt.close()
    # plt.show()

    # ! Correlation Removal ----------------------------------------------------------------
    # First get feature importance of the best model
    if user_input["method_validation"] == "cv":
        rf, sco, rf_vi = SMOTE_cv(
            Xy_all=Xy_train_best_preds,
            var_ohe_dict=var_ohe_dict,
            rf_params=rfecv_params,
            method_importance=user_input["method_importance"],
            smote_on_test=user_input["do_smote_test_validation"],
            rnd_seed=user_input["seed_nr"],
            verbose=False,
            save_directory=None,
        )
    elif user_input["method_validation"] == "oob":
        rf, sco, rf_vi = SMOTE_oob(
            Xy_all=Xy_train_best_preds,
            var_ohe_dict=var_ohe_dict,
            rf_params=rfecv_params,
            method_importance=user_input["method_importance"],
            smote_on_test=user_input["do_smote_test_validation"],
            rnd_seed=user_input["seed_nr"],
            verbose=False,
            save_directory=None,
        )
    else:
        raise ValueError(
            f"Failed during RFE - Invalid method_validation! Got: {user_input['method_validation']}"
        )

    # Get order of features (note that they are NOT ohe'd, so I have to first decode the dataframe, before selection. As done below.)
    order_of_features = rf_vi.Feature.to_list()
    final_vars = remove_correlation_based_on_vi(
        Xy_train_best_preds,
        var_ohe_dict,
        rf_vi,
        threshold=user_input["correlation_threshold"],
        make_heatmaps=False,
        return_only_top_n=user_input["n_features_in_final_model"],
        save_directory=current_dir,
    )

    # ! SET FINAL FEATURES ----------------------------------------------------------------
    Xy_train_final = Xy_train_best_preds.copy()[["target"] + final_vars]
    Xy_test_final = Xy_test.copy()[["target"] + final_vars]

    # ! TUNING -----------------------------------------------------------------------
    # ! Prescribed Gridsearch
    display(" --- GRID SEARCH ---")
    # Get dataframe
    Xy_train_for_tuning = Xy_train_final.copy()

    # Split into response and predictors
    Xy = Xy_train_for_tuning.copy()
    X = Xy.drop(
        columns=["target", "test_train_strata", "tree_id", "idp"], errors="ignore"
    )
    y = Xy["target"]

    # Build model
    oversample = SMOTE(random_state=user_input["seed_nr"])
    model = RandomForestClassifier(random_state=user_input["seed_nr"], n_jobs=-1)

    # Apply oversampling to train set
    X_train_over, y_train_over = oversample.fit_resample(X, y)

    # Create Stratified K-fold cross validation
    cv = RepeatedStratifiedKFold(
        n_splits=3, n_repeats=1, random_state=user_input["seed_nr"]
    )

    # Get parameter grid
    # param_grid = get_tune_grid_classification()
    param_grid = {
        "n_estimators": [100, 300],  # Has minor influence
        "max_depth": [1, 3, 12, 18],  # [1, 5, 8, 12, 15, 18]
        # 'min_samples_split': [2, 5, 10],
        # 'min_samples_leaf': [1, 2, 4],
        "max_features": [0.01, 0.1, "sqrt"],  # Minor influence
        # 'bootstrap': [True],
        "criterion": ["gini"],  # 'gini',  entropy worked better on test
    }

    # Set the grid search model
    grid_search = GridSearchCV(
        estimator=model,
        param_grid=param_grid,
        cv=cv,
        n_jobs=-1,
        verbose=0,
        return_train_score=True,
        scoring=user_input["gsc_metric"],
    )

    # Fit the grid search to the data
    grid_search.fit(
        X,
        y,
    )

    # Print results
    display("")
    print("--- FINAL RESULTS ---")
    print("Parameter grid:")
    for key, value in param_grid.items():
        print(f" - {key}: {value}")

    print("\nBest parameters:")
    for key, value in grid_search.best_params_.items():
        print(f" - {key}: {value}")
    print(
        f"\nBest {user_input['best_model_metric']}: {round(grid_search.best_score_, 2)}"
    )

    # Get best parameters
    best_params = grid_search.best_params_
    # Visualize tuning
    plot_grid_search_results(
        grid_search, "prescribed", save_directory=current_dir, show=False
    )

    # ! Final Model ------------------------------------------------------------------------
    display(" --- FINAL MODEL RUN ---")

    # Setup model
    rf_model = RandomForestClassifier(
        random_state=user_input["seed_nr"],
        n_jobs=-1,
        **best_params,
    )

    # Split response and predictors
    X_train_final = Xy_train_final.drop(columns=["target"], errors="ignore")
    y_train_final = Xy_train_final["target"]

    X_test_final = Xy_test_final.drop(columns=["target"], errors="ignore")
    y_test_final = Xy_test_final["target"]

    # Apply SMOTE to train data
    sm = SMOTE(random_state=user_input["seed_nr"])
    X_train_final, y_train_final = sm.fit_resample(X_train_final, y_train_final)
    if user_input["do_smote_test_final"]:
        X_test_final, y_test_final = sm.fit_resample(X_test_final, y_test_final)

    # Fit model
    rf_model.fit(X_train_final, y_train_final)

    # Feature importance
    # * 2024-10-20: Disabled permutation
    # rf_vi = assessing_top_predictors(
    #     vi_method="permutation",
    #     rf_in=rf_model,
    #     X_train_in=X_train_final,
    #     X_test_in=X_test_final,
    #     y_test_in=y_test_final,
    #     dict_ohe_in=var_ohe_dict,
    #     with_aggregation=True,
    #     n_predictors=20,
    #     random_state=user_input["seed_nr"],
    #     verbose=False,
    #     save_directory=user_input["current_dir"],
    # )

    # Feature importance
    rf_vi = assessing_top_predictors(
        vi_method="impurity",
        rf_in=rf_model,
        X_train_in=X_train_final,
        X_test_in=X_test_final,
        y_test_in=y_test_final,
        dict_ohe_in=var_ohe_dict,
        with_aggregation=True,
        n_predictors=20,
        random_state=user_input["seed_nr"],
        verbose=False,
        save_directory=user_input["current_dir"],
    )

    # Evaluate model
    model_evaluation_classification(
        rf_model=rf_model,
        X_train=X_train_final,
        y_train=y_train_final,
        X_test=X_test_final,
        y_test=y_test_final,
        prob_threshold=user_input["prob_threshold"],
        save_directory=user_input["current_dir"],
        metric="f1-score",
        verbose=False,
    )

    # y_pred = rf_model.predict(X_test_final)
    # y_pred = pd.Series(y_pred, index=y_test_final.index)
    # rf_score = bootstrap_classification_metric(
    #     y_test_final,
    #     y_pred,
    #     metrics=["accuracy", "precision", "recall", "roc_auc"],
    #     n_bootstraps=100,
    # )

    # rf_score.to_csv(
    #     f"{current_dir}/final_model_scores_from_binary_values.csv", index=False
    # )

    # display(rf_score)

    # ! DEBUG
    # print("\n\n\n xxxxxxxxxxxx QUICK EXIT: NOT RUNNING SHAP! xxxxxxxxxxxx")
    # chime.warning()
    # return None

    # ! DEBUG TREE_ID: JUST TO SAVE LINK BETWEEN TREE_ID AND PREDICTORS
    # * 2024-10-20: Moved tree_id debugging down here, after the final dataset is saved in the evaluation_classification function

    # display(" --- DEBUG: SAVING TREE_ID SEPARATELY ---")
    # Get final predictors and save a reduced X_test for safety check
    final_predictors = (
        pd.read_csv(f"{current_dir}/final_model/X_test.csv")
        .drop(columns=["Unnamed: 0"])
        .columns.to_list()
    )

    final_predictors = []

    df_targted_treeid = pd.merge(
        df_target_for_treeid,
        df_predictors_final_for_treeid,
        on=["idp", "tree_id"],
        how="left",
    )

    print(f" - Shape of df_targted_treeid: {df_targted_treeid.shape}")
    print(f" - Shape of df_targted_treeid target: {df_targted_treeid['target'].shape}")

    X_train_treeid, X_test_treeid, y_train_treeid, y_test_treeid = train_test_split(
        df_targted_treeid,
        df_targted_treeid["target"],
        test_size=user_input["test_split"],
        random_state=user_input["seed_nr"],
        stratify=df_targted_treeid["target"],
    )

    dir_treeid = f"{current_dir}/treeid"
    os.makedirs(dir_treeid, exist_ok=True)

    X_train_treeid[["tree_id"] + final_predictors].to_csv(
        f"{dir_treeid}/X_train_treeid.csv", index=True
    )
    X_test_treeid[["tree_id"] + final_predictors].to_csv(
        f"{dir_treeid}/X_test_treeid.csv", index=True
    )

    y_train_treeid.to_csv(f"{dir_treeid}/y_train_treeid.csv", index=True)
    y_test_treeid.to_csv(f"{dir_treeid}/y_test_treeid.csv", index=True)

    # display("Saved treeid files to:", dir_treeid)
    # return
    # ! DEBUG TREE_ID ---

    # * ----------------------------------------------------------------------------------------------------------------
    # * 2024-10-20: Uncommented creation of SHAP plots and PDP plots because they are not used in the final report

    # # ! SHAP ANALYSIS ------------------------------------------------------------------------
    # display(" --- SHAP ANALYSIS ---")

    # # Get list of features
    # X_train_shap = Xy_train_final.drop(columns="target")
    # X_test_shap = Xy_test_final.drop(columns="target")

    # features = list(X_train_shap)

    # # ! Pick which dataset to use
    # test_or_train = user_input["shap_on_test_or_train"]
    # print(f" - Using {test_or_train} set for SHAP analysis")

    # if test_or_train == "train":
    #     X_shap = X_train_shap.copy()
    # else:
    #     X_shap = X_test_shap.copy()

    # # ! Take at least 100 samples but max x% of the dataset
    # max_perc = 0.15
    # min_samples = 200
    # max_samples = 800

    # X_shap_len = X_shap.shape[0]
    # # min_samples = min(min_samples, X_shap_len)
    # # max_samples = max(max_samples, int(round(X_shap_len * max_perc)))
    # # n_shap_samples = max(min_samples, max_samples)

    # if int(round(X_shap_len * max_perc)) > max_samples:
    #     n_shap_samples = max_samples
    # else:
    #     n_shap_samples = min(min_samples, X_shap_len)

    # # > DEBUG option to use all data available! (takes much longer when SMOTE on test set is on!)
    # if user_input["use_all_shap_data"]:
    #     n_shap_samples = X_shap_len
    # else:
    #     print("游댮游댮游댮 NOT USING ALL DATA TO RUN SHAP 游댮游댮游댮")

    # print(
    #     f" - Using {n_shap_samples} samples for SHAP analysis ({max_perc*100}% of {X_shap_len} samples = {int(round(X_shap_len * max_perc))}, min = {min_samples}, max = {max_samples})"
    # )

    # # Run explainer
    # # Subset and save dataset used for SHAP
    # np.random.seed(user_input["seed_nr"])
    # X_shap = X_shap.sample(n_shap_samples, random_state=user_input["seed_nr"])
    # X_shap.to_feather(f"{current_dir}/X_shap_{test_or_train}.feather")

    # # Get one-way SHAP values and save them
    # print(" - Calculating SHAP values")
    # explainer = shap.TreeExplainer(rf_model, X_shap)
    # shap_values_extended = explainer(X_shap, check_additivity=False)

    # # Save values
    # with open(f"{current_dir}/shap_values_{test_or_train}.pkl", "wb") as file:
    #     pickle.dump(shap_values_extended, file)

    # # Get interaction values and save them for later
    # if user_input["run_shap_interaction"]:
    #     print(" - Calculating SHAP interaction values")
    #     interaction_values = shap.TreeExplainer(rf_model).shap_interaction_values(
    #         X_shap
    #     )
    #     with open(
    #         f"{current_dir}/shap_values_interaction_{test_or_train}.pkl", "wb"
    #     ) as file:
    #         pickle.dump(interaction_values, file)

    # print(" - Creating final plots")

    # # Keep values leading to mortality (1)
    # shap_values = shap_values_extended.values[:, :, 1]
    # # Model importances
    # feature_importances = rf_model.feature_importances_
    # importances = pd.DataFrame(index=features)
    # importances["importance"] = feature_importances
    # importances["rank"] = importances["importance"].rank(ascending=False).values
    # # display(importances.sort_values("rank").head())

    # # Calculate mean Shapley value for each feature in training set
    # importances["mean_shapley_values"] = np.mean(shap_values, axis=0)

    # # Calculate mean absolute Shapley value for each feature in training set
    # # This will give us the average importance of each feature
    # importances["mean_abs_shapley_values"] = np.mean(np.abs(shap_values), axis=0)

    # # Add Shapley values to coefficient table.
    # importances.sort_values(by="importance", ascending=False).head()

    # # Get top n features
    # top_n = len(final_vars)

    # importance_top_n = (
    #     importances.sort_values(by="importance", ascending=False).head(top_n).index
    # )
    # shapley_top_n = (
    #     importances.sort_values(by="mean_abs_shapley_values", ascending=False)
    #     .head(top_n)
    #     .index
    # )

    # # Add to DataFrame
    # top_n_features = pd.DataFrame()
    # top_n_features["importances"] = importance_top_n.values
    # top_n_features["Shapley"] = shapley_top_n.values

    # # ! Wrangle dataset for bar plots
    # Take mean of absolute shap values
    # xxx = pd.DataFrame(shap_values[0].tolist()).T
    # xxx.columns = X_shap.columns

    # for i in range(1, len(shap_values)):
    #     iii = pd.DataFrame(shap_values[i].tolist()).T
    #     iii.columns = X_shap.columns
    #     xxx = pd.concat([xxx, iii], axis=0, ignore_index=True)

    # # Take mean of all variables
    # rrr = xxx.abs().mean().sort_values(ascending=False)
    # rrr = rrr / rrr.sum()
    # rrr = pd.DataFrame(rrr)
    # rrr.columns = ["Importance"]
    # rrr.Importance = rrr.Importance * 100
    # rrr["Feature"] = rrr.index
    # rrr.reset_index(drop=True, inplace=True)
    # # Link feature variable to predictor dataset in new column
    # for f in rrr.Feature:
    #     for key, value in dict_preds.items():
    #         if f in value:
    #             rrr.loc[rrr.Feature == f, "dataset"] = key

    # # Sum up the VI for each dataset
    # rrr_of_dataset = (
    #     rrr[["Importance", "dataset"]]
    #     .groupby("dataset")
    #     .sum()
    #     .reset_index()
    #     .rename({"Importance": "dataset_imp"}, axis=1)
    # )

    # rrr_of_dataset.dataset_imp = (
    #     rrr_of_dataset.dataset_imp / rrr_of_dataset.dataset_imp.sum() * 100
    # )

    # # Attach dataset label with percentages
    # for i, row in rrr_of_dataset.iterrows():
    #     # rrr_of_dataset.loc[i, "dataset_label"] = (
    #     #     rrr_of_dataset.loc[i, "dataset"]
    #     #     + "  ("
    #     #     + str(round(rrr_of_dataset.loc[i, "dataset_imp"]))
    #     #     + "%)"
    #     # )
    #     rrr_of_dataset.loc[i, "dataset_label"] = (
    #         str(round(rrr_of_dataset.loc[i, "dataset_imp"]))
    #         + "%: "
    #         + rrr_of_dataset.loc[i, "dataset"]
    #     )

    # rrr = rrr.merge(rrr_of_dataset, on="dataset", how="left")

    # # Display
    # # display(importances, rrr)
    # # #! Figures --------------------------------------------------------

    # # > Beeswarm plot
    # fig = plt.figure(figsize=(6, 6))
    # shap.summary_plot(
    #     shap_values=shap_values,
    #     features=X_shap.values,
    #     feature_names=X_shap.columns.values,
    #     plot_type="violin",
    #     max_display=15,
    #     show=False,
    # )
    # plt.tight_layout()
    # plt.savefig(f"{current_dir}/fig_shap_beeswarm.png")
    # plt.close()
    # # plt.show()

    # # ! PDP plot ------------------------------------------------------------
    # feat_to_show = shapley_top_n[0:]
    # num_features = len(feat_to_show)
    # num_cols = 7  # Number of subplots per row
    # num_rows = 2  # Calculate number of rows
    # if n_shap_samples > 5000:
    #     point_density = 0.25
    # elif n_shap_samples > 2500:
    #     point_density = 0.5
    # elif n_shap_samples > 2500:
    #     point_density = 0.5
    # else:
    #     point_density = 0.75

    # fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(25, 8))

    # for i, feat in enumerate(feat_to_show):
    #     row = i // num_cols  # Calculate row index
    #     col = i % num_cols  # Calculate column index

    #     shap.plots.scatter(
    #         shap_values_extended[:, feat][:, 1],
    #         x_jitter=0,
    #         alpha=point_density,
    #         ax=axes[row, col],
    #         show=False,
    #     )
    #     axes[row, col].set_xlabel(f"{feat}")
    #     if col == 0:  # For the first column
    #         axes[row, col].set_ylabel("SHAP Value")
    #     else:
    #         axes[row, col].set_ylabel("SHAP VAlue")
    #     # axes[row, col].set_ylim(-0.2, 0.2)

    # # Remove extra subplots
    # for i in range(num_features, num_rows * num_cols):
    #     fig.delaxes(axes.flatten()[i])

    # fig.set_tight_layout(True)
    # fig.savefig(f"{current_dir}/fig_shap_scatter.png")
    # plt.close()

    # # ! My own bar plot --------------------------------------------------------
    # # Make a barplot of feature against importance, color by dataset, add percentage of total importance of dataset in the legend
    # # Get palette first
    # palette = sns.color_palette("tab20", len(rrr_of_dataset))
    # dict_hue = dict(zip(rrr_of_dataset.dataset, palette))
    # rrr_of_dataset["color"] = rrr_of_dataset.dataset.map(dict_hue)

    # # * By Feature Alone --------------------------------------------------------
    # plt.figure(figsize=(8, max(math.ceil(rrr.shape[0] / 4), 4)))
    # sns.barplot(
    #     x="Importance",
    #     y="Feature",
    #     data=rrr.sort_values(by="Importance", ascending=False),
    #     hue="dataset_label",
    #     dodge=False,
    #     palette=palette,
    # )
    # plt.xlabel("Relative Importance")
    # plt.tight_layout()
    # plt.legend(loc="upper left", bbox_to_anchor=(1, 1), frameon=False)
    # plt.title(
    #     f"Feature Importance for {user_input['subset_group'][0]}\n",
    #     fontsize=12,
    #     fontdict={"fontweight": "bold"},
    # )
    # plt.savefig(user_input["current_dir"] + "/fig-vip-shap-by_feature.png")
    # plt.close()
    # # plt.show()

    # # * By Dataset Alone --------------------------------------------------------
    # plt.figure(figsize=(8, max(math.ceil(rrr.shape[0] / 4), 4)))
    # sns.barplot(
    #     x="dataset_imp",
    #     y="dataset_label",
    #     data=rrr_of_dataset.sort_values(by="dataset_imp", ascending=False),
    #     hue="dataset_label",
    #     dodge=False,
    #     palette=palette,
    # )

    # plt.xlabel("Relative Importance")
    # plt.tight_layout()
    # plt.legend(loc="upper left", bbox_to_anchor=(1, 1), frameon=False)
    # plt.title(
    #     f"Dataset Importance for {user_input['subset_group'][0]}\n",
    #     fontsize=12,
    #     fontdict={"fontweight": "bold"},
    # )
    # plt.savefig(user_input["current_dir"] + "/fig-vip-shap-by_dataset.png")
    # plt.close()
    # # plt.show()

    # # * Both --------------------------------------------------------
    # # Create a figure with two side-by-side subplots
    # fig, axs = plt.subplots(1, 2, figsize=(16, max(math.ceil(rrr.shape[0] / 4), 4)))
    # palette = sns.color_palette("tab20", len(rrr["dataset_label"].unique()))

    # # Plot the first barplot
    # sns.barplot(
    #     x="Importance",
    #     y="Feature",
    #     data=rrr.sort_values(by="Importance", ascending=False),
    #     hue="dataset_label",
    #     dodge=False,
    #     palette=palette,
    #     ax=axs[0],  # Plot on the first subplot
    # )
    # axs[0].set_xlabel("Relative Importance")
    # axs[0].set_ylabel("")
    # # axs[0].legend(loc="upper left", bbox_to_anchor=(1, 1))
    # axs[0].legend([], [], frameon=False)
    # axs[0].set_title(
    #     f"Final Predictors",
    #     fontsize=12,
    #     # fontweight="bold",
    # )

    # # Plot the second barplot
    # sns.barplot(
    #     x="dataset_imp",
    #     y="dataset_label",
    #     data=rrr_of_dataset.sort_values(by="dataset_imp", ascending=False),
    #     hue="color",
    #     dodge=False,
    #     palette=palette,
    #     ax=axs[1],  # Plot on the second subplot
    # )

    # axs[1].set_xlabel("Relative Importance")
    # axs[1].set_ylabel("")
    # axs[1].legend([], [], frameon=False)
    # axs[1].set_title(
    #     f"Datasets",
    #     fontsize=12,
    #     # fontweight="bold",
    # )

    # # Adjust layout and display the figure
    # fig.suptitle(f"{user_input['subset_group'][0]}", fontsize=14, fontweight="bold")
    # plt.tight_layout()
    # plt.savefig(user_input["current_dir"] + "/fig-vip-shap-both.png")
    # plt.close()
    # # plt.show()
    # * ----------------------------------------------------------------------------------------------------------------

    # ! SAVE IT ALL --------------------------------------------------------
    # General information
    df_save = pd.DataFrame(
        {
            "subset": [user_input["subset"][0]],
            "subset_group": [user_input["subset_group"][0]],
            "created": [datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")],
            "best_model_decision": [user_input["best_model_decision"]],
            "N_died": [df_target.target.sum()],
            "N_surv": [df_target.shape[0] - df_target.target.sum()],
            "dir": [user_input["current_dir"]],
            "oversampled_cv": [user_input["do_smote_test_validation"]],
            "oversampled_test": [user_input["do_smote_test_final"]],
        }
    )

    # Final model metrics
    df_save = pd.concat(
        [df_save, pd.read_csv(f"{current_dir}/classification_metrics.csv")], axis=1
    )

    # Feature information
    # for dataset in rrr.dataset.unique():
    #     df_save[f"{dataset} - Importance"] = rrr.loc[
    #         rrr.dataset == dataset, "dataset_imp"
    #     ].values[0]
    #     df_save[f"{dataset} - Metrics"] = [
    #         rrr.loc[rrr.dataset == dataset, "Feature"].values
    #     ]
    #     df_save[f"{dataset} - Values"] = [
    #         rrr.loc[rrr.dataset == dataset, "Importance"].values
    #     ]

    df_save.to_csv(f"{current_dir}/final_model_performance.csv", index=False)

## Loop


In [None]:
user_input = {}

# ! TRAINING -----------------------------------------------------------------------
# General
user_input["seed_nr"] = 42
user_input["test_split"] = 0.2
user_input["current_dir"] = None
# Feature Elimination
user_input["do_ref"] = False
user_input["cv_folds"] = 5
user_input["method_validation"] = "oob"  # none | cv | oob
user_input["method_importance"] = "impurity"  # permutation | impurity
user_input["do_tuning"] = False  # Tune during rfe validation?
user_input["correlation_threshold"] = 0.8  # 游뚿 CURRENTLY AT 1 == NO REMOVAL!
user_input["n_features_in_final_model"] = 15
# SMOTE
user_input["do_smote_test_validation"] = False
user_input["do_smote_test_final"] = False
# Tuning
user_input["do_prescribed_search"] = True
user_input["do_random_search"] = False
user_input["gsc_metric"] = "roc_auc"  # grid search metric
# ! Final model ---------------------------------------------------------------------------------
# best_per_category | best_metric | best_metric_max1
user_input["best_model_decision"] = "best_per_category"
user_input["best_model_metric"] = "roc_auc"
user_input["prob_threshold"] = 0.4  # Classification threshold
# ! SHAP ---------------------------------------------------------------------------------
user_input["shap_on_test_or_train"] = "test"  # test | train
user_input["use_all_shap_data"] = False  # Whether to use all shap data
user_input["run_shap_interaction"] = True  # Whether to run shap interactions
# ! Output ---------------------------------------------------------------------------------
user_input["dir_suffix"] = None  # None or string
user_input["description_file"] = "ADD_DESCRIPTION"  # None or string
user_input["subset"] = ["species_lat2"]

In [None]:
# # # Code to find where error occured in loop
# for ii, ss in enumerate(species.index):
#     if ss in [
#         # "Fagus sylvatica",
#         "Alnus incana",
#         # "Buxus sempervirens",
#         # "Pinus sylvestris",
#         # "Fraxinus excelsior",
#         # "Betula pendula",
#         # "Quercus robur",
#         # "Quercus pubescens",
#     ]:
#         display(f"游리游리游리 Running {ii}: {ss} 游리游리游리")
#         run_all(ss, user_input)
#         # chime.success()
#         break
#     # clear_output(wait=True)
# # # chime.success()
# # # chime.success()
# # # chime.success()

In [None]:
# # Settings
# run_name = "oob + impurity + seed 24"

# user_input["seed_nr"] = 24
# user_input["method_importance"] = "impurity"  # permutation | impurity
# user_input["shap_on_test_or_train"] = "test"  # test | train
# user_input["use_all_shap_data"] = False

# # Run loop
# base_dir = create_new_run_folder_treemort_fullrun(run_name)
# st = start_time()
# for i, s in enumerate(species.index):
#     display(f"游리游리游리 {i}/{len(species)}: {s} 游리游리游리")
#     if i > 20:
#         print(f" - Skipping {i}")
#         continue
#     ist = start_time()
#     run_all(s, user_input, base_dir=base_dir)
#     clear_output(wait=True)
#     end_time(ist, None)

# end_time(st, user_input["current_dir"])

In [None]:
# Get all runs
all_seeds = pd.read_csv("allOriginalSeeds.csv").seed.tolist()
all_species = species.index.tolist()
all_runs = pd.DataFrame(
    list(itertools.product(all_seeds, all_species)),
    columns=["seed", "species"],
)
all_runs["dir"] = ""
all_runs["done"] = False


# Loop over all runs and check if that run has been completed
for i, row in all_runs.iterrows():
    # Get folder matching the seed
    seed = row.seed
    base_dir = glob.glob(f"./model_runs/_fullruns/* {seed} +*")

    if len(base_dir) == 0:
        print(f" - No folder found for seed: {seed}")
        continue
    else:
        base_dir = base_dir[0]
        all_runs.loc[i, "dir"] = base_dir

    if os.path.isfile(f"./{base_dir}/{row.species}/final_model_performance.csv"):
        all_runs.loc[i, "done"] = True
    elif os.path.isfile(f"{base_dir}/{row.species}/丘멆잺 too few dead trees.txt"):
        all_runs.loc[i, "done"] = True
    else:
        all_runs.loc[i, "done"] = False

# Get missing runs
runs_to_run = (
    all_runs.query("done == False")
    .sort_values(["species", "seed"])
    .reset_index(drop=True)
)

# ! Debug for running on multiple notebooks
l_runs_to_run = split_df_into_list_of_group_or_ns(runs_to_run, "seed", 5)
runs_to_run = l_runs_to_run[0].reset_index(drop=True)

# Loop over all runs and check if that run has been completed
for i, row in runs_to_run.iterrows():

    iseed = row.seed
    ispecies = row.species
    idir = row.dir

    if idir == "":
        # Create folder for run
        run_name = f"impurity + slow selection + {iseed} + correlation removal"
        idir = create_new_run_folder_treemort_fullrun(run_name)
        all_runs.loc[i, "dir"] = idir

    # Start run
    user_input["seed_nr"] = iseed
    display("")
    print(
        f"""
        --------------------------------------------------------------------------------
        Run {i}/{runs_to_run.shape[0]}
        Seed: {iseed}
        Species: {ispecies}
        Dir: {idir}
        Started: {datetime.datetime.now().strftime('%Y-%m-%d @ %H:%M:%S')}
        --------------------------------------------------------------------------------
        """
    )
    ist = start_time(False)
    run_all(ispecies, user_input, base_dir=idir)
    clear_output(wait=True)
    end_time(ist, None, ring=False)

In [None]:
# ! osascript -e 'tell app "System Events" to shut down'

In [None]:
# # ! Code snippet for parallel processing. Reduce allOriginals to non-completed seeds
# # ! --------------------------------------------
# # Remove completed seeds
# completed_seeds = [
#     1220,
#     1221,
#     1222,
#     1223,
#     1224,
#     1225,
#     1226,
#     1227,
#     1228,
#     19991,
#     19992,
#     19993,
#     19994,
#     542,
#     569,
#     61,
#     610000,
#     612,
#     642,
#     669,
#     91,
#     910000,
#     912,
#     91996,
#     942,
#     969,
#     9991,
#     9992,
#     9993,
#     9994,
#     9995,
#     9996,
#     9997,
#     9998,
# ]
# allOriginalSeeds = [x for x in allOriginalSeeds if x not in completed_seeds]
# allOriginalSeeds = allOriginalSeeds[0:3]
# # ! --------------------------------------------

---


## Calculate SHAP Values


In [None]:
import sys

sys.path.insert(0, "../../src")
from imports import *

init_notebook()

In [None]:
# Load the models and species
final_species = get_species_with_models("list")

top9 = final_species[:9]

base_dir = "./model_runs/_fullruns/"
models_dir = os.listdir(base_dir)
models_dir = [m for m in models_dir if "impurity" in m]
models_dir = sorted(models_dir)

# Merge species and model lists into one df
models_species = list(itertools.product(models_dir, final_species))
df_in = pd.DataFrame(models_species, columns=["model", "species"])
df_in["seed"] = df_in["model"].apply(lambda x: x.split(" + ")[2].split(" +")[0])
display(df_in)

print(f"Number of models: {df_in.model.nunique()}")
print(f"Number of species: {df_in.species.nunique()}")
print(f"Number of seeds: {df_in.seed.nunique()}")

In [None]:
# Check which runs are actually missing
missing_runs = []
for i, row in df_in.iterrows():
    idir = f"./model_runs/_fullruns/{row['model']}/{row['species']}/final_model_performance.csv"
    if not os.path.isfile(idir):
        # print(idir)
        missing_runs.append(row)

df_missing = pd.DataFrame(missing_runs)
print(f"{df_missing.shape[0]} missing runs")
df_missing

In [None]:
shap_run_new_loop_mp(
    df_in,
    run_interaction=False,
    approximate=True,
    test_or_train="test",
    force_run=True,
    verbose=False,
    num_cores=9,
)

In [None]:
! osascript -e 'tell app "System Events" to shut down'

## SHAP - Variable Importance


In [None]:
# Todos
# - ADD SKIP IF MODEL PERFORMANCE FILE IS NOT AVAILABLE! SIMPLE CHECK IF MODEL HAS BEEN RUN
# - Could add possibility to use train data instead of test but we are generally focusing on test anyways...

# Imports
import json
import shutil

# Load the models and species
final_species = get_species_with_models("list")

top9 = final_species[:9]

base_dir = "./model_runs/_fullruns/"
models_dir = os.listdir(base_dir)
models_dir = [m for m in models_dir if "impurity" in m]
models_dir = sorted(models_dir)

# Merge species and model lists into one df
models_species = list(itertools.product(models_dir, final_species))
df_in = pd.DataFrame(models_species, columns=["model", "species"])

df_in

In [None]:
# Loop over runs and species and calculate mean absolute SHAP values
for i, row in tqdm(df_in.iterrows(), total=df_in.shape[0]):
    # Get predictor data
    ipreds = f"./model_runs/_fullruns/{row.model}/{row.species}/final_model/X_test.csv"
    ipreds = pd.read_csv(ipreds, index_col=[0])

    # Get SHAP data
    ishap = f"./model_runs/_fullruns/{row.model}/{row.species}/new_shap/approximated/shap_values_test.pkl"
    if not os.path.exists(ishap):
        raise ValueError(
            f" 游뚿 Skipping {row.model}/{row.species} because no SHAP values calculated yet!"
        )
    ishap = load_shap(ishap)

    # Extract SHAP values per prediction (saved in third dimension)
    ishap = ishap.values[:, :, 1]

    # Get the row of SHAP values to have a basis to add to
    ishapAll = pd.DataFrame(ishap[0].tolist()).T

    # Give the df the correct predictor names
    ishapAll.columns = ipreds.columns

    # Loop over all SHAP predictions and concatenate
    for j in range(1, len(ishap)):
        iii = pd.DataFrame(ishap[j].tolist()).T
        iii.columns = ipreds.columns
        ishapAll = pd.concat([ishapAll, iii], axis=0, ignore_index=True)

    # Safety check: Shape of predictors should be the same as for SHAP values
    if ipreds.shape != ishapAll.shape:
        print(
            f" - Issue: The shape of the predictor data should equal the shape of the concatenated SHAP values!"
        )

    # Take mean of SHAP values across all variables
    ishapMean_org = ishapAll.abs().mean().sort_values(ascending=False)
    ishapMean = ishapMean_org / ishapMean_org.sum()
    ishapMean = pd.DataFrame(ishapMean)
    ishapMean.columns = ["Importance"]
    ishapMean.Importance = ishapMean.Importance * 100
    ishapMean["Feature"] = ishapMean.index
    ishapMean.reset_index(drop=True, inplace=True)

    # Link feature variable to predictor dataset in new column
    # Load predictor dictionary
    dict_preds = json.load(open(f"./model_analysis/dict_preds.json"))
    for f in ishapMean.Feature:
        for key, value in dict_preds.items():
            if f in value:
                ishapMean.loc[ishapMean.Feature == f, "dataset"] = key

    # Sum up the VI for each dataset
    ishapMean_of_dataset = (
        ishapMean[["Importance", "dataset"]]
        .groupby("dataset")
        .sum()
        .reset_index()
        .rename({"Importance": "dataset_imp"}, axis=1)
    )

    ishapMean_of_dataset.dataset_imp = (
        ishapMean_of_dataset.dataset_imp / ishapMean_of_dataset.dataset_imp.sum() * 100
    )

    ishapMean["mean_abs_shap_org"] = ishapMean_org.values

    # Attach dataset label with percentages
    for j, jrow in ishapMean_of_dataset.iterrows():
        # ishapMean_of_dataset.loc[j, "dataset_label"] = (
        #     ishapMean_of_dataset.loc[j, "dataset"]
        #     + "  ("
        #     + str(round(ishapMean_of_dataset.loc[j, "dataset_imp"]))
        #     + "%)"
        # )
        ishapMean_of_dataset.loc[j, "dataset_label"] = (
            str(round(ishapMean_of_dataset.loc[j, "dataset_imp"]))
            + "%: "
            + ishapMean_of_dataset.loc[j, "dataset"]
        )

    ishapMean = ishapMean.merge(ishapMean_of_dataset, on="dataset", how="left")

    # Save SHAP data
    ishapMean.to_csv(
        f"./model_runs/_fullruns/{row.model}/{row.species}/shap_variable_importance.csv"
    )

    # Load final model performance
    ifinalOrg = f"./model_runs/_fullruns/{row.model}/{row.species}/final_model_performance_org.csv"
    ifinalNew = (
        f"./model_runs/_fullruns/{row.model}/{row.species}/final_model_performance.csv"
    )

    # If the original file has not yet been backuped, save it!
    if not os.path.exists(ifinalOrg):
        shutil.copy2(ifinalNew, ifinalOrg)

    # Load model performance file, attach SHAP information and save it again
    ifinalNewDf = pd.read_csv(ifinalNew)

    for dataset in ishapMean.dataset.unique():
        ifinalNewDf[f"{dataset} - Importance"] = ishapMean.loc[
            ishapMean.dataset == dataset, "dataset_imp"
        ].values[0]
        ifinalNewDf[f"{dataset} - Metrics"] = [
            ishapMean.loc[ishapMean.dataset == dataset, "Feature"].values
        ]
        ifinalNewDf[f"{dataset} - Values"] = [
            ishapMean.loc[ishapMean.dataset == dataset, "Importance"].values
        ]

    ifinalNewDf.to_csv(ifinalNew, index=False)

In [None]:
! osascript -e 'tell app "System Events" to shut down'

# EOS
