# Calculate Growth and Mortality


## Setup


In [None]:
import sys

sys.path.insert(0, "../../src")
from imports import *

init_notebook()

# Load NFI data
nfi_raw = get_final_nfi_data_for_analysis()
nfi_raw.shape
nfi_raw["species_lat"] = nfi_raw["species_lat2"]

from IPython.display import clear_output

# Set figure font
plt.rcParams["font.family"] = "Arial"
from matplotlib.ticker import FormatStrFormatter

In [None]:
# Get all analyzed species
species_in_final_anlysis = get_species_with_models("list")

# Get species occurerence data for species in final analysis
top_species_all = nfi_raw.query("tree_state_change in ['alive_alive', 'alive_dead']")

# Get normalized and non normalized counts
top_species_all["species"] = top_species_all["species_lat2"]
tmp_abs = top_species_all[["species", "genus_lat"]].value_counts()
tmp_norm = top_species_all[["species", "genus_lat"]].value_counts(normalize=True)

# Concat data
top_species_all = pd.concat(
    [tmp_abs, tmp_norm], axis=1, keys=["count", "percent"]
).reset_index()


# Reduce to 52 and redo the relative count
top52_species = top_species_all.head(52).copy()
top52_species["percent"] = top52_species["count"] / top52_species["count"].sum()

top52_species["percent"] = top52_species["percent"] * 100
top52_species["percent"] = top52_species["percent"].round(1)

# Attach species title with percentage
top52_species["title"] = (
    top52_species["species"].astype(str)
    + " ("
    + top52_species["percent"].astype(str)
    + "%)"
)

# Reduce to top 9
top9_species = top52_species.head(9).copy()
top9 = top9_species.copy()

# Show
display(top9)
display(top52_species)

# Print N trees per group
nfi_raw.query("tree_state_change in ['alive_alive', 'alive_dead', 'alive_cut']")

## Settings


In [None]:
# Settings
kwargs = {
    # ! General
    "file_suffix": None,
    # ! Metric
    "my_metric": "mort_nat_stems_prc_yr",  # "mort_nat_stems_prc_yr" mort_nat_vol_yr, mort_nat_vol_prc_yr
    # ! Data Wrangling
    "df": nfi_raw.copy(),
    # "my_grouping": ["gre", "species_lat"],
    "my_grouping": ["gre"],
    "my_method": "direct_bs",
    "load_from_file": True,
    "top_n_groups": None,
    "n_bootstraps_samples": 100,
    # ! Data Filter
    "min_trees_per_site": 0,
    "min_sites_per_group_year": 0,
    "reduce_to_dominant_sites": False,
    "weigh_by_sites_or_trees": "none",  # none, sites, trees
    # ! Plotting
    "plot_type": "facet",  # all or facet
    "save_plot": True,
    "center_to_first_year": False,
    "normalize_to_first_year": False,
    "ylim": None,  # [-200, 400],  # None or [min, max]
    "facet_label_trees_or_sites": "trees",  # trees or sites
    "uncertainty_representation": "band",  # band, bar, none
    "uncertainty_variable": "std",
    "aggregation_variable": "mean",
    "top_n_metric": "most_observations",  # most_observations or highest_final_mortality
    "top_n_groups_plot": 15,
    # "genus_filter": None,  # None or list of genus_lat values to filter for # ! Not implemented yet, needs a species-genus dictionary!
}


def custom_sort_key(element):
    priority_order = ["gre", "ser", "hex", "reg", "dep"]
    if element in priority_order:
        return (priority_order.index(element), element)
    else:
        return (len(priority_order), element)

In [None]:
# Input
iregion = "gre"  # gre, ser, hex, reg, dep | Region to plot
what_species = "top9"  # top9, all, final_analysis | Relevant whether map trend is based on all or top9 species
by_species_or_direct = "by_species"  # by_species or directly | Relevant whether map trend is weighted by species or calculated directly
fig1_folder = "../../data/final/mortality_trends"
os.makedirs(fig1_folder, exist_ok=True)

## Data for temporal trends


In [None]:
file_temp = f"{fig1_folder}/df_top9_temporal_data.feather"

if os.path.exists(file_temp):
    df_top9 = pd.read_feather(file_temp)
    print(f"Loaded data from file: {file_temp}")
    chime.info()
else:
    # Only need the top9 species for the facet grid
    kwargs["suffix"] = (
        f"temporal_trend-region_{iregion}-species_{what_species}-calculation_{by_species_or_direct}"
    )
    kwargs["df"] = nfi_raw.copy().query("species_lat in @top9_species['species']")
    kwargs["my_grouping"] = ["species_lat"]
    kwargs["load_from_file"] = False

    print(kwargs["suffix"])

    # Run function
    kwargs["my_grouping"] = sorted(kwargs["my_grouping"], key=custom_sort_key)
    df_counts, df_gm, df_grouped = start_to_finish_temporal_plot(
        kwargs,
        return_before_plotting=True,
    )
    chime.success()

    df_top9_raw = df_grouped.copy()

    # ! Wrangling
    # Get copy of df_grouped
    df_top9 = df_top9_raw.copy()

    # Extract group and year
    df_top9["group"] = df_top9["group_year"].str.split("_", expand=True)[0]
    df_top9["year"] = df_top9["group_year"].str.split("_", expand=True)[1].astype(int)

    # Extract mean value
    df_top9["mean"] = df_top9[f"{kwargs['my_metric']}_mean"]
    df_top9["std"] = df_top9[f"{kwargs['my_metric']}_std"]

    # Reduce df size
    df_top9 = df_top9[["group", "year", "group_year", "mean", "std"]].copy()
    df_top9["region"] = df_top9["group"].str.split("&", expand=True)[0]

    # Save variable
    df_top9["target"] = kwargs["my_metric"]

    # Save it
    df_top9.to_feather(f"{fig1_folder}/df_top9_temporal_data.feather")

df_top9.head()

## Data for spatial trends


In [None]:
#
# ! Input
by_species_or_direct = "directly"  # by_species or directly
what_species = "all"  # top9 or all (= all 52 species in final analysis)

# Setup
if what_species == "all":
    kwargs["df"] = nfi_raw.copy().query("species_lat in @species_in_final_anlysis")
elif what_species == "top9":
    kwargs["df"] = nfi_raw.copy().query("species_lat in @top9_species['species']")
else:
    chime.error()
    raise ValueError("what_species must be 'all' or 'top9'")

if by_species_or_direct == "by_species":
    kwargs["my_grouping"] = [iregion, "species_lat"]
elif by_species_or_direct == "directly":
    kwargs["my_grouping"] = [iregion]
else:
    chime.error()
    raise ValueError("by_species_or_direct must be 'by_species' or 'directly'")

kwargs["suffix"] = (
    f"map_trend-region_{iregion}-species_{what_species}-calculation_{by_species_or_direct}"
)
kwargs["load_from_file"] = True
# filename_df_map = f"{fig1_folder}/df_map_trend-region_{iregion}-species_{what_species}-calculation_{by_species_or_direct}.feather"
filename_df_map = f"../../data/final/mortality_trends/df_map_trend-region_gre-species_{what_species}-calculation_{by_species_or_direct}.feather"
print(filename_df_map)

In [None]:
if os.path.exists(filename_df_map):
    # Load and show the data
    print(f"Loading: {filename_df_map}")
    df_map = (
        pd.read_feather(filename_df_map)
        .drop(columns=["gre_num", "gre_name", "geometry"])
        .rename(columns={kwargs["my_grouping"][0]: "region"})
    )

    # ! Turn into gdf
    # Attach geometry
    shp_region = get_shp_of_region(
        kwargs["my_grouping"][0], make_per_year=None, make_per_group=None
    ).rename(columns={kwargs["my_grouping"][0]: "region"})
    df_map = shp_region.merge(df_map, on="region")

    chime.info()

else:
    # ! Run function -----------------------------------
    kwargs["my_grouping"] = sorted(kwargs["my_grouping"], key=custom_sort_key)
    kwargs["load_from_file"] = False
    df_counts, df_gm, df_grouped = start_to_finish_temporal_plot(
        kwargs,
        return_before_plotting=True,
    )
    chime.success()

    # Save raw data
    df_map_raw = df_grouped.copy()
    df_map_raw.to_feather(filename_df_map.replace("df_map_trend", "df_map_trend_raw"))

    # ! Wrangling -----------------------------------
    from sklearn.linear_model import LinearRegression

    if by_species_or_direct == "directly":
        # Get copy of df_grouped
        df_map = df_map_raw.copy()

        # Extract group and year
        df_map["group"] = df_map["group_year"].str.split("_", expand=True)[0]
        df_map["year"] = df_map["group_year"].str.split("_", expand=True)[1].astype(int)

        # Extract mean value
        df_map["mean"] = df_map[f"{kwargs['my_metric']}_mean"]
        df_map["std"] = df_map[f"{kwargs['my_metric']}_std"]

        # Reduce df size
        df_map = df_map[["group", "year", "group_year", "mean", "std"]].copy()
        df_map["region"] = df_map["group"].str.split("&", expand=True)[0]

        # For each group, calculate the slope of the regression of the mean value over year
        slopes = df_map.groupby("group").apply(
            lambda x: np.polyfit(x["year"], x["mean"], 1)[0], include_groups=False
        )

        slopes = slopes.reset_index().rename({0: "slope"}, axis=1)

        # Load the shapefile for the regions
        shp = get_shp_of_region(
            kwargs["my_grouping"][0], make_per_year=None, make_per_group=None
        )

        # Merge the slopes with the shapefile
        shp = shp.merge(slopes, left_on=kwargs["my_grouping"][0], right_on="group")

    elif by_species_or_direct == "by_species":

        # Get copy of df_grouped
        df_map = df_map_raw.copy()

        # Extract group and year
        df_map["group"] = df_map["group_year"].str.split("_", expand=True)[0]
        df_map["year"] = df_map["group_year"].str.split("_", expand=True)[1].astype(int)

        # Extract mean value
        df_map["mean"] = df_map[f"{kwargs['my_metric']}_mean"]
        df_map["std"] = df_map[f"{kwargs['my_metric']}_std"]

        # Reduce df size
        df_map = df_map[["group", "year", "group_year", "mean", "std"]].copy()
        df_map["region"] = df_map["group"].str.split("&", expand=True)[0]

        # Loop over each group and regress mean mortality against year
        grouped = df_map.groupby("group")
        results = []

        for name, group in grouped:
            # Get group data
            X = group["year"].values
            y = group["mean"].values

            # Fit linear model
            model = LinearRegression().fit(X.reshape(-1, 1), y)

            # Get slope and intercept
            slope = model.coef_[0]
            intercept = model.intercept_

            # Get p-value
            p_value = stats.linregress(X, y).pvalue

            # Append to results
            results.append(
                {
                    "group": name,
                    "slope": slope,
                    "intercept": intercept,
                    "p_value": p_value,
                }
            )
        results = pd.DataFrame(results)

        # ! Get weights (= percentage of trees belonging to a species in a region)
        # Load data
        dfw = nfi_raw.copy()
        # Attach group-region column
        dfw["group"] = (
            dfw[kwargs["my_grouping"][0]].astype(str)
            + "&"
            + dfw["species_lat"].astype(str)
        )
        # Calculate number of trees per group and region
        dfw = (
            dfw.query("tree_state_change in ['alive_alive', 'alive_dead']")
            .groupby("group", observed=True)
            .agg(weights=("tree_id", "count"))
            .reset_index()
        )

        # Attach region
        dfw["region"] = dfw["group"].str.split("&", expand=True)[0]
        # Normalize per region
        dfw["weights"] = dfw["weights"] / dfw.groupby("region")["weights"].transform(
            "sum"
        )

        # ! Merge with results
        results = results.merge(dfw, on="group")

        # Extract region
        results["region"] = results["group"].str.split("&", expand=True)[0]
        results
        # Calculate the mean slope, weighted by number of trees
        results["weighted_slope"] = results["slope"] * results["weights"]

        # Remove insignificant results
        # results = results.query("p_value < 0.05")

        # Get the mean slope per region
        results = results.groupby("region").agg(
            # slope=("slope", "mean"),
            slope=("weighted_slope", "mean"),
            # n_trees=("n_trees", "sum"),
        )

        # ! Turn into gdf
        # Get region shapefile
        shp_region = get_shp_of_region(
            kwargs["my_grouping"][0], make_per_year=None, make_per_group=None
        ).rename(columns={kwargs["my_grouping"][0]: "region"})

        # Merge with results
        shp = shp_region.merge(results, on="region")
    else:
        chime.error()
        raise ValueError("by_species_or_direct must be 'by_species' or 'directly'")

    shp.to_feather(filename_df_map)
    df_map = shp.copy()

df_map

## Plotting functions


In [None]:
from matplotlib.ticker import FormatStrFormatter


def plot_species_trend(species, df, ax=None, font_scaler=1):
    # Get species data
    df_species = df.query("group == @species").sort_values("year")

    # Plot
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))

    # Plot
    df_species["year"] = df_species["year"] + 5  # ! Bugfix for the year
    ax.plot(df_species["year"], df_species["mean"], color="black")
    ax.fill_between(
        df_species["year"],
        df_species["mean"] - df_species["std"],
        df_species["mean"] + df_species["std"],
        alpha=0.3,
        color="black",
        edgecolor=None,
    )

    # Labels
    # ax.set_title(f"{species}")
    ax.set_xlabel("Year", fontsize=12 * font_scaler)
    ax.set_ylabel("Mortality Rate (%-stems yr$^{-1}$)", fontsize=12 * font_scaler)

    # Reduce y-ticks
    ax.yaxis.set_major_locator(plt.MaxNLocator(3))

    # Reduce x-ticks
    if df.year.nunique() == 9:
        # Includes 2023 census
        nxticks = 5
    else:
        nxticks = 4

    # Format ticks
    ax.xaxis.set_major_locator(plt.MaxNLocator(nxticks))
    ax.yaxis.set_major_formatter(FormatStrFormatter("%.1f"))

    # Set font size for ticks
    ax.tick_params(axis="both", which="major", labelsize=12 * font_scaler)

    # Remove upper and right spines
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

    return ax


# plot_species_trend("Abies alba", df_top9, font_scaler=1)
# plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import box, Polygon


def get_species_dist(
    species,
    nfi_data,
    jitter_amount=1000,
    dot_size=5,
    ax=None,
    bold_names=False,
    font_scaler=1,
    with_title=True,
):

    # Get data
    qr_alive = nfi_data.copy().query(
        "species_lat2 == @species and tree_state_change == 'alive_alive'"
    )

    qr_dead = nfi_data.copy().query(
        "species_lat2 == @species and tree_state_change == 'alive_dead'"
    )

    # Make map of France and plot alive as green and dead as red
    shp_france = get_shp_of_region("cty", make_per_year=None, make_per_group=None)

    # Add minor jitter to the coordinates to avoid overplotting
    qr_alive["lat_fr"] = qr_alive["lat_fr"] + np.random.uniform(
        -jitter_amount, jitter_amount, len(qr_alive)
    )
    qr_alive["lon_fr"] = qr_alive["lon_fr"] + np.random.uniform(
        -jitter_amount, jitter_amount, len(qr_alive)
    )

    qr_dead["lat_fr"] = qr_dead["lat_fr"] + np.random.uniform(
        -jitter_amount, jitter_amount, len(qr_dead)
    )
    qr_dead["lon_fr"] = qr_dead["lon_fr"] + np.random.uniform(
        -jitter_amount, jitter_amount, len(qr_dead)
    )

    # Turn df into geodataframe
    qr_alive = gpd.GeoDataFrame(
        qr_alive, geometry=gpd.points_from_xy(qr_alive.lon_fr, qr_alive.lat_fr)
    )

    qr_dead = gpd.GeoDataFrame(
        qr_dead, geometry=gpd.points_from_xy(qr_dead.lon_fr, qr_dead.lat_fr)
    )

    # If ax is None, create a new figure and axis
    if ax is None:
        fig, ax = plt.subplots(figsize=(4, 4))
        show_plot = True
    else:
        show_plot = False

    # Plot the France boundary
    shp_france.boundary.plot(ax=ax, color="black", linewidth=0.5)

    # Create a "mask" that covers everything outside France to avoid plotting there
    xmin, ymin, xmax, ymax = shp_france.total_bounds
    bounding_box = box(xmin - 2, ymin - 2, xmax + 2, ymax + 2)

    mask = gpd.GeoDataFrame(
        geometry=[bounding_box.difference(shp_france.unary_union)], crs=shp_france.crs
    )

    # Plot the mask on top of everything, with white color
    mask.plot(ax=ax, color="white", edgecolor="none")

    # Add white background to France
    shp_france.plot(ax=ax, color="white", edgecolor="none")

    # Plot alive trees
    qr_alive.plot(
        ax=ax,
        color="lightgrey",
        markersize=dot_size,
        alpha=1,
        label="Alive",
        legend=True,
    )

    # Plot dead trees
    qr_dead.plot(
        ax=ax,
        color="red",
        markersize=dot_size,
        alpha=0.5,
        label="Dead",
        legend=True,
    )

    # Re-plot France boundary to ensure it’s visible on top
    shp_france.boundary.plot(
        ax=ax,
        edgecolor="black",
        linewidth=0.5,
    )

    # Remove axis spines
    ax.axis("off")

    # Add title
    bold_names = "bold" if bold_names else "normal"
    if with_title:
        ax.set_title(
            f"{species}",
            fontdict={"weight": bold_names, "fontsize": 12 * font_scaler},
            loc="left",
            pad=-10,
        )

    # Show the plot only if we created the figure
    if show_plot:
        plt.show()


# get_species_dist("Quercus robur", nfi_raw, jitter_amount=1000, font_scaler=2)

In [None]:
def plot_trend_with_inset(
    species,
    df_trend,
    df_nfi,
    mytitle=None,
    jitter_amount=1000,
    dot_size=5,
    ax=None,
    bold_names=False,
    font_scaler_graph=1,
    font_scaler_inset=1,
):
    # Plot the trend onto the axis and add an inset map
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))

    # Plot the trend
    plot_species_trend(species, df_trend, ax=ax, font_scaler=font_scaler_graph)

    # Add inset map in top left corner
    inset_ax = ax.inset_axes([0.02, 0.55, 0.4, 0.4])
    get_species_dist(
        species,
        df_nfi,
        jitter_amount=jitter_amount,
        dot_size=dot_size,
        ax=inset_ax,
        bold_names=bold_names,
        font_scaler=font_scaler_inset,
        with_title=False,
    )

    # Add title
    bold_names = "bold" if bold_names else "normal"
    if mytitle is None:
        mytitle = f"{species}"

    ax.set_title(
        mytitle,
        fontdict={"weight": bold_names, "fontsize": 18 * font_scaler_graph},
        loc="left",
        pad=-10,
        position=(0.025, 0),
    )

    return ax


# plot_trend_with_inset(
#     "Abies alba",
#     df_top9,
#     nfi_raw,
#     dot_size=2,
#     font_scaler_inset=1.2,
#     font_scaler_graph=1.2,
# )

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors


def make_map(shp, filepath=None, ts_legend=18, ts_ticks=14, ax=None, add_letters=False):
    no_input = False
    if ax is None:
        no_input = True
        fig, ax = plt.subplots(1, 1, figsize=(8, 8))

    # Colorbar settings (symmetric around 0)
    vmin = min(shp["slope"].min(), -shp["slope"].max())
    vmax = max(shp["slope"].max(), -shp["slope"].min())
    ticks = np.linspace(vmin, vmax, 5)

    # Plot regions
    shp.plot(
        column="slope",
        legend=True,
        cmap="coolwarm",
        edgecolor="black",
        linewidth=0.5,
        legend_kwds={
            "label": "Change in Mortality Rate (%-stems yr$^{-2}$)",
            "orientation": "horizontal",
            "shrink": 0.8,
            "pad": -0.08,
            "fraction": 0.03,
            "ticks": ticks,
            "format": "%.2f",
            "extend": "both",
        },
        ax=ax,
        vmin=vmin,
        vmax=vmax,
    )

    ax.axis("off")

    # Add region letters at representative points (better than centroids for complex shapes)
    if add_letters:
        for _, row in shp.iterrows():
            point = row["geometry"].representative_point()
            ax.text(
                point.x,
                point.y,
                row["region"],
                ha="center",
                va="center",
                fontsize=14,
                fontweight="bold",
                color="black",
                path_effects=[
                    plt.matplotlib.patheffects.withStroke(
                        linewidth=2, foreground="white"
                    )
                ],  # outline for visibility
            )

    # Set legend font size
    cbar = ax.get_figure().get_axes()[-1]
    cbar.set_xlabel(cbar.get_xlabel(), fontsize=ts_legend)
    cbar.tick_params(labelsize=ts_ticks)

    # Optional: adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    if no_input:
        if filepath is not None:
            plt.savefig(filepath, dpi=300, bbox_inches="tight")
            print(f"Saved map to {filepath}")
        plt.show()
    else:
        return ax


# ifolder = "specific_runs/fig_1"
# if not os.path.exists(ifolder):
#     os.makedirs(ifolder)
# filepath = f"{ifolder}/{kwargs['suffix']}.png"
# make_map(df_map, filepath=filepath, ts_legend=16, ts_ticks=14)

In [None]:
def climate_evolution_plot(
    ax, dataset, season="13", first_year=1980, contour_levels=7, final_file=False
):

    df_lm = produce_dfs_for_climate_evolution(
        agg_factor_km=1,
        dataset=dataset,
        first_year=first_year,
        last_year=2020,
        season=season,
        load_file=True,
        final_file=final_file,
    )

    ax = make_map_for_temp_prec_cover(
        df_lm,
        dataset,
        season=season,
        pixel_res=500j,
        textsize=16,
        contour_levels=contour_levels,
        cbar_pad=-0.02,
        cbar_fraction=0.2,
        cbar_shrink=0.6,
        cbar_aspect=20,
        # filepath=f"/Volumes/WD - ExFat/IFNA/digitalis_v3/processed/aggregated-to-{1}km/{dataset}/trend_per_pixel-months_{season}-from_{first_year}_to_{2020}-res_{500j}.png",
        ax=ax,
        final_file=final_file,
    )

    return ax


def forest_cover_plot(ax, final_file=False):
    # Forest Cover data
    df_fc = aggregate_raster_to_csv(
        input_raster_path="../../data/final/forest_cover/hansen2013highresolution-forestcover.tif",
        output_csv_path=None,
        # input_raster_path = "/Volumes/WD - ExFat/IFNA/hansen2013/treecover2000_merged_in_python-clipped_mask_france.tif",
        # output_csv_path = "/Volumes/WD - ExFat/IFNA/hansen2013/treecover2000_merged_in_python-clipped_mask_france.csv",
        agg_factor_m=2 * 1000,
        save_file=False,
        verbose=False,
    )

    ax = make_map_for_temp_prec_cover(
        df_fc,
        "treecover",
        pixel_res=500j,
        textsize=16,
        contour_levels=11,
        # filepath="../02_collect_features/forest_cover/forest_cover_map.png",
        tick_interval=1,
        ax=ax,
        final_file=final_file,
    )

    return ax


def add_text_to_plot():
    # ! Add letters to the subplots
    plt.text(
        0.022,
        0.91,
        "A",
        transform=fig.transFigure,
        fontsize=24,
        fontweight="bold",
    )
    plt.text(
        0.022,
        0.61,
        "B",
        transform=fig.transFigure,
        fontsize=24,
        fontweight="bold",
    )
    plt.text(
        0.022,
        0.335,
        "C",
        transform=fig.transFigure,
        fontsize=24,
        fontweight="bold",
    )
    plt.text(
        0.2,
        0.91,
        "D",
        transform=fig.transFigure,
        fontsize=24,
        fontweight="bold",
    )
    plt.text(
        0.58,
        0.91,
        "E",
        transform=fig.transFigure,
        fontsize=24,
        fontweight="bold",
    )

    plt.subplots_adjust(
        wspace=0.15,  # Space between columns
        hspace=0.1,  # Space between rows
    )

## Figure Examples


In [None]:
# # Note: Cell takes up to xxx minutes to run
# fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# climate_evolution_plot(
#     ax=ax[0], dataset="prec", season="13", first_year=1980, final_file=True
# )
# climate_evolution_plot(
#     ax=ax[1], dataset="tmoy", season="13", first_year=1980, final_file=True
# )
# plt.show()

In [None]:
# # 3x3 trend plot
# fig, ax = plt.subplots(3, 3, figsize=(12, 12))
# axs = ax.flatten()
# iax = 0
# # Plot the trends for the top 9 species

# for i in top9.species:
#     plot_species_trend(
#         i,
#         df_top9.query("group == @i"),
#         ax=axs[iax],
#         font_scaler=1,
#     )
#     iax += 1
#     # Re

In [None]:
# # Load map data
# filename_df_map = "../../data/final/mortality_trends/df_map_trend-region_gre-species_top9-calculation_directly.feather"
# df_map = (
#     pd.read_feather(filename_df_map)
#     .drop(columns=["gre_num", "gre_name", "geometry"])
#     .rename(columns={kwargs["my_grouping"][0]: "region"})
# )
# # Attach geometry
# shp_region = get_shp_of_region(
#     kwargs["my_grouping"][0], make_per_year=None, make_per_group=None
# ).rename(columns={kwargs["my_grouping"][0]: "region"})
# df_map = shp_region.merge(df_map, on="region")

# # Make map
# make_map(df_map, ax=None, ts_legend=20, ts_ticks=14)

## Figure 1


In [None]:
#
# ! Load map data
# filename_df_map = "../../data/final/mortality_trends/df_map_trend-region_gre-species_top9-calculation_by_species.feather"
filename_df_map = "../../data/final/mortality_trends/df_map_trend-region_gre-species_top9-calculation_directly.feather"
print(f"Loading: {filename_df_map}")
df_map = (
    pd.read_feather(filename_df_map)
    .drop(columns=["gre_num", "gre_name", "geometry"])
    .rename(columns={kwargs["my_grouping"][0]: "region"})
)

# Attach geometry
shp_region = get_shp_of_region(
    kwargs["my_grouping"][0], make_per_year=None, make_per_group=None
).rename(columns={kwargs["my_grouping"][0]: "region"})
df_map = shp_region.merge(df_map, on="region")

# ! Run the figure
# first_years = [1960, 1980, 2000]
first_years = [1980]

for first_year in first_years:
    # Create a figure and a grid layout
    fig = plt.figure(
        figsize=(26, 10)
    )  # Increase width to make space for the new column
    gs = gridspec.GridSpec(
        3,
        8,  # 9,
        figure=fig,
        height_ratios=[1, 1, 1],  # Equal height for rows
        width_ratios=[
            1,
            # 0.001,
            1.2,
            1.2,
            1.2,
            0.05,
            1,
            1,
            1,
        ],
    )

    # Get the left column axes
    ax_left1 = fig.add_subplot(gs[0, 0])  # First column, first row
    ax_left2 = fig.add_subplot(gs[1, 0])  # First column, second row
    ax_left3 = fig.add_subplot(gs[2, 0])  # First column, third row

    # Get the 3x3 grid for the map
    ax_map = fig.add_subplot(gs[:, 1:4])

    # Get the 3x3 grid for trend plots
    ax_1 = fig.add_subplot(gs[0, 5])
    ax_2 = fig.add_subplot(gs[0, 6])
    ax_3 = fig.add_subplot(gs[0, 7])
    ax_4 = fig.add_subplot(gs[1, 5])
    ax_5 = fig.add_subplot(gs[1, 6])
    ax_6 = fig.add_subplot(gs[1, 7])
    ax_7 = fig.add_subplot(gs[2, 5])
    ax_8 = fig.add_subplot(gs[2, 6])
    ax_9 = fig.add_subplot(gs[2, 7])

    # Define the axes that will have x and y axis labels
    ax_central_y = ax_4  # Middle plot for y-axis
    ax_central_x = ax_8  # Middle plot for x-axis

    # ! Add plots (outcomment for quick edits) -----------------------
    # ! Left Column
    # Forest Cover
    ax_left1 = forest_cover_plot(ax_left1, final_file=True)
    # Temperature Change
    ax_left2 = climate_evolution_plot(
        ax_left2, "tmoy", season="13", first_year=first_year, final_file=True
    )
    # Precipitation Change
    ax_left3 = climate_evolution_plot(
        ax_left3, "prec", season="13", first_year=first_year, final_file=True
    )

    # ! Mortality Map
    ax_map = make_map(df_map, ax=ax_map, ts_legend=20, ts_ticks=14)

    # ! Trends
    for i, ax in enumerate(
        [ax_1, ax_2, ax_3, ax_4, ax_5, ax_6, ax_7, ax_8, ax_9], start=1
    ):
        plot_trend_with_inset(
            top9_species["species"].iloc[i - 1],
            df_top9,
            nfi_raw,
            mytitle=top9_species["title"].iloc[i - 1],
            jitter_amount=1000,
            dot_size=1,
            ax=ax,
            font_scaler_graph=1,
        )
        ax.tick_params(axis="both", which="major", labelsize=14)

        # Remove y-axis labels for all except the middle left (ax_central_y)
        if ax != ax_central_y:
            ax.set_ylabel("")

        # Remove x-axis labels for all except the middle bottom (ax_central_x)
        if ax != ax_central_x:
            ax.set_xlabel("")

        # Remove ticks for all except bottom
        if i not in [7, 8, 9]:
            ax.set_xticklabels([])
            ax.set_xticks([])
            ax.spines["bottom"].set_visible(False)
    # Trends (loop end)

    # Increase the font size and move axis labels for the central axes
    ax_central_y.set_ylabel(
        "Mortality Rate (%-stems yr$^{-1}$)",
        # fontweight="bold",
        fontsize=22,
        labelpad=15,
    )
    ax_central_x.set_xlabel(
        "Year",
        # fontweight="bold",
        fontsize=22,
        labelpad=15,
    )

    # Optional: Additional adjustments to make the central labels more prominent
    ax_central_y.yaxis.set_label_position("left")
    ax_central_x.xaxis.set_label_position("bottom")

    # ! Add letters to the subplots
    add_text_to_plot()

    plt.savefig(
        f"overview-climate_change-mortality.png",
        dpi=300,
    )
    # plt.savefig(
    #     f"./specific_runs/fig_1/fig1-v4-complete-{first_year}_period-{what_species}_for_map.png",
    #     dpi=300,
    # )
    plt.show()
    plt.close()

chime.success()

## Figure S1: Species Trends


In [None]:
# Settings
kwargs = {
    # ! General
    "file_suffix": None,
    # ! Metric
    "my_metric": "mort_nat_stems_prc_yr",  # "mort_nat_stems_prc_yr" mort_nat_vol_yr, mort_nat_vol_prc_yr
    # ! Data Wrangling
    "df": nfi_raw.copy(),
    # "my_grouping": ["gre", "species_lat"],
    "my_grouping": ["gre"],
    "my_method": "direct_bs",
    "load_from_file": True,
    "top_n_groups": None,
    "n_bootstraps_samples": 100,
    # ! Data Filter
    "min_trees_per_site": 0,
    "min_sites_per_group_year": 0,
    "reduce_to_dominant_sites": False,
    "weigh_by_sites_or_trees": "none",  # none, sites, trees
    # ! Plotting
    "plot_type": "facet",  # all or facet
    "save_plot": True,
    "center_to_first_year": False,
    "normalize_to_first_year": False,
    "ylim": None,  # [-200, 400],  # None or [min, max]
    "facet_label_trees_or_sites": "trees",  # trees or sites
    "uncertainty_representation": "band",  # band, bar, none
    "uncertainty_variable": "std",
    "aggregation_variable": "mean",
    "top_n_metric": "most_observations",  # most_observations or highest_final_mortality
    "top_n_groups_plot": 15,
    # "genus_filter": None,  # None or list of genus_lat values to filter for # ! Not implemented yet, needs a species-genus dictionary!
}


def custom_sort_key(element):
    priority_order = ["gre", "ser", "hex", "reg", "dep"]
    if element in priority_order:
        return (priority_order.index(element), element)
    else:
        return (len(priority_order), element)

In [None]:
# # ! Settings
species52 = get_species_with_models("list")
# | Get all species of final analysis
iregion = "gre"
# | gre, ser, hex, reg, dep | Region to plot
what_species = "final_analysis"
# | top9, all, final_analysis | Relevant whether map trend is based on all or top9 species
by_species_or_direct = "by_species"
# | by_species or directly | Relevant whether map trend is averaged over species or calculated directly
# fig1_folder = "./specific_runs/fig_si-52_species_temporal/data"
fig1_folder = "../../data/final/mortality_trends/"
# | Folder to save the data for the figure
os.makedirs(fig1_folder, exist_ok=True)

In [None]:
file_temp = f"{fig1_folder}/df_species52_temporal_data.feather"

if os.path.exists(file_temp):
    df_species52 = pd.read_feather(file_temp)
    print(f"Loaded data from file: {file_temp}")
    chime.info()
else:
    # Only need the species52 species for the facet grid
    kwargs["suffix"] = (
        f"temporal_trend-region_{iregion}-species_{what_species}-calculation_{by_species_or_direct}"
    )
    kwargs["df"] = nfi_raw.copy().query("species_lat2 in @species52")
    kwargs["my_grouping"] = ["species_lat2"]
    kwargs["load_from_file"] = False

    print(kwargs["suffix"])

    # Run function
    kwargs["my_grouping"] = sorted(kwargs["my_grouping"], key=custom_sort_key)
    df_counts, df_gm, df_grouped = start_to_finish_temporal_plot(
        kwargs,
        return_before_plotting=True,
    )
    chime.success()

    df_species52_raw = df_grouped.copy()

    # ! Wrangling
    # Get copy of df_grouped
    df_species52 = df_species52_raw.copy()

    # Extract group and year
    df_species52["group"] = df_species52["group_year"].str.split("_", expand=True)[0]
    df_species52["year"] = (
        df_species52["group_year"].str.split("_", expand=True)[1].astype(int)
    )

    # Extract mean value
    df_species52["mean"] = df_species52[f"{kwargs['my_metric']}_mean"]
    df_species52["std"] = df_species52[f"{kwargs['my_metric']}_std"]

    # Reduce df size
    df_species52 = df_species52[["group", "year", "group_year", "mean", "std"]].copy()
    df_species52["region"] = df_species52["group"].str.split("&", expand=True)[0]

    # Save it
    df_species52.to_feather(f"{fig1_folder}/df_species52_temporal_data.feather")

df_species52.head(20)

In [None]:
# Note: Cell takes up to 5 minutes to run

# ! Separate species for two plots
myorder = sorted(species_in_final_anlysis.copy())  # By alphabet
myorder = species_in_final_anlysis.copy()  # By occurrence

part1 = myorder[:30]
part2 = myorder[30:]

# ! First Part  -----------------------------------
# Make a grid of 4x6 plots
fig, axs = plt.subplots(6, 5, figsize=(20, 20))

# Flatten the axs array
axs = axs.flatten()

# Loop over each species and plot the trend with inset
for i, species in enumerate(part1):
    plot_trend_with_inset(
        species,
        df_species52,
        nfi_raw,
        ax=axs[i],
        font_scaler_graph=0.8,
        font_scaler_inset=0.5,
    )
    # Remove x-axis labels
    axs[i].set_xlabel("")
    axs[i].set_ylabel("")

# Add universal y-axis label
fig.text(
    0.085,
    0.5,
    "Mortality Rate (%-stems yr$^{-1}$)",
    va="center",
    rotation="vertical",
    fontsize=20,
)
# Add universal x-axis label
fig.text(
    0.5,
    0.07,
    "Year",
    ha="center",
    fontsize=20,
)

# Save the plot
# plt.savefig(f"{fig1_folder}/fig-52_species_temporal_trend-1.png")
plt.savefig(f"./mortality_trends-all_species-1.png")
plt.show()

# ! Second Part  -----------------------------------
# Make a grid of 4x6 plots
fig, axs = plt.subplots(5, 5, figsize=(20, 20))

# Flatten the axs array
axs = axs.flatten()

# Loop over each species and plot the trend with inset
for i, species in enumerate(part2):
    plot_trend_with_inset(
        species,
        df_species52,
        nfi_raw,
        ax=axs[i],
        font_scaler_graph=0.8,
        font_scaler_inset=0.5,
    )
    # Remove x-axis labels
    axs[i].set_xlabel("")
    axs[i].set_ylabel("")

# Remove empty subplots
for i in range(len(part2), len(axs)):
    axs[i].axis("off")

# Add universal y-axis label
fig.text(
    0.085,
    0.5,
    "Mortality Rate (%-stems yr$^{-1}$)",
    va="center",
    rotation="vertical",
    fontsize=20,
)
# Add universal x-axis label
fig.text(
    0.5,
    0.07,
    "Year",
    ha="center",
    fontsize=20,
)

# Save the plot
# plt.savefig(f"{fig1_folder}/fig-52_species_temporal_trend-2.png")
plt.savefig(f"./mortality_trends-all_species-2.png")
plt.show()

## Figure S2: Regional Trends


In [None]:
# Get regional dictionary
gre_dictionary = get_shp_of_region("gre")
gre_dictionary["gre_name"] = gre_dictionary["gre_name"].str.replace("-", " ")
gre_dictionary["gre_name"] = gre_dictionary["gre_name"].apply(
    lambda x: " ".join(x.split()[:2])
)
gre_dictionary = gre_dictionary.set_index("gre")["gre_name"].to_dict()
gre_dictionary

In [None]:
# User input
all_species_path = "../../data/final/mortality_trends/df_map_trend_raw-region_gre-species_all-calculation_directly.feather"
top9_species_path = "../../data/final/mortality_trends/df_map_trend_raw-region_gre-species_top9-calculation_directly.feather"

# Load data
df_all = pd.read_feather(all_species_path)
df_top9 = pd.read_feather(top9_species_path)


def preprocess(df):
    df["group"] = df["group_year"].str.split("_", expand=True)[0]
    df["year"] = df["group_year"].str.split("_", expand=True)[1].astype(int)
    df["mean"] = df[f"{kwargs['my_metric']}_mean"]
    df["std"] = df[f"{kwargs['my_metric']}_std"]
    return df[["group", "year", "group_year", "mean", "std"]].copy()


df_all = preprocess(df_all)
df_top9 = preprocess(df_top9)

In [None]:
def plot_regional_mortality_trends_grid(ax, df_all, df_top9, gre_dictionary=None):
    """
    Plot regional mortality trends from all species and top 9 species onto provided axes.

    Parameters
    ----------
    ax : np.ndarray of Axes
        Array of axes to plot into (e.g. from GridSpec or subplots).
    df_all : pd.DataFrame
        DataFrame of all-species mortality trends (preprocessed).
    df_top9 : pd.DataFrame
        DataFrame of top-9-species mortality trends (preprocessed).
    gre_dictionary : dict
        Dictionary mapping region codes to human-readable names.
    """

    regions = sorted(set(df_all["group"]) | set(df_top9["group"]))

    for i, region in enumerate(regions):
        df_all_region = df_all.query("group == @region").sort_values("year")
        df_top9_region = df_top9.query("group == @region").sort_values("year")

        # All species
        ax[i].plot(
            df_all_region["year"],
            df_all_region["mean"],
            label="All species",
            color="black",
            linestyle="-",
        )
        ax[i].fill_between(
            df_all_region["year"],
            df_all_region["mean"] - df_all_region["std"],
            df_all_region["mean"] + df_all_region["std"],
            alpha=0.2,
            color="black",
        )

        # Top 9 species
        ax[i].plot(
            df_top9_region["year"],
            df_top9_region["mean"],
            label="Top 9 species",
            color="C1",
            linestyle="-",
        )
        ax[i].fill_between(
            df_top9_region["year"],
            df_top9_region["mean"] - df_top9_region["std"],
            df_top9_region["mean"] + df_top9_region["std"],
            alpha=0.2,
            color="C1",
        )

        if gre_dictionary is not None:
            # Set title with region name from dictionary
            ax[i].set_title(
                f"{region}. {gre_dictionary.get(region, region)}",
                fontsize=16,
                loc="left",
            )
        else:
            # Set title with region code
            ax[i].set_title(f"{region}", fontsize=16, loc="left")

        # Title no vjustification
        ax[i].set_xlabel("Year")
        ax[i].set_ylabel("Mortality Rate (%-stems yr$^{-1}$)")
        ax[i].yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
        ax[i].spines["right"].set_visible(False)
        ax[i].spines["top"].set_visible(False)

    # Turn off unused subplots
    for j in range(i + 1, len(ax)):
        ax[j].axis("off")

    # Remove all subplot legends
    for a in ax[: len(regions)]:
        if a.get_legend() is not None:
            a.get_legend().remove()

    # Return region labels if needed later
    return regions

In [None]:
from matplotlib import gridspec

# Create figure with adjusted width ratios: map gets more horizontal space
fig = plt.figure(figsize=(22, 9))
gs = gridspec.GridSpec(nrows=1, ncols=2, width_ratios=[3.0, 3.0], figure=fig)

# ! Map
ax_map = fig.add_subplot(gs[0, 0])

# User input
# filename_df_map = "../../data/final/mortality_trends/df_map_trend-region_gre-species_all-calculation_directly.feather"
filename_df_map = "../../data/final/mortality_trends/df_map_trend-region_gre-species_all-calculation_directly.feather"
print(f"Loading: {filename_df_map}")

# Load and prepare map
df_map = (
    pd.read_feather(filename_df_map)
    .drop(columns=["gre_num", "gre_name", "geometry"])
    .rename(columns={"gre": "region"})
)
shp_region = get_shp_of_region("gre", make_per_year=None, make_per_group=None).rename(
    columns={"gre": "region"}
)
df_map = shp_region.merge(df_map, on="region")

# Plot the map
make_map(df_map, ax=ax_map, ts_legend=16, ts_ticks=14, add_letters=True)

# ! Trend grid
gs_right = gridspec.GridSpecFromSubplotSpec(
    3,
    4,
    subplot_spec=gs[0, 1],
    hspace=0.35,  # tighter vertical space
    wspace=0.3,  # tighter horizontal space
)

axes_trends = np.array(
    [fig.add_subplot(gs_right[i, j]) for i in range(3) for j in range(4)]
)

# Plot trends
df_all["year"] = df_all["year"].astype(int) + 5
df_top9["year"] = df_top9["year"].astype(int) + 5
plot_regional_mortality_trends_grid(axes_trends, df_all, df_top9, gre_dictionary)

# Remove all x-axis labels
for ax in axes_trends:
    ax.set_xlabel("")
    ax.set_ylabel("")

# Add universal y-axis label
fig.text(
    0.45,
    0.5,
    "Mortality Rate (%-stems yr$^{-1}$)",
    va="center",
    rotation="vertical",
    fontsize=16,
)

# Add universal x-axis label
fig.text(
    0.725,
    0.02,
    "Year",
    ha="center",
    fontsize=16,
)

# ! Global legend
handles, labels = axes_trends[0].get_legend_handles_labels()
fig.legend(
    handles,
    labels,
    title="Regional \nMortality Trends\n",
    title_fontsize=16,
    ncol=1,
    fontsize=14,
    # loc="center right",
    bbox_to_anchor=(0.98, 0.25),
    frameon=False,
)

# ! Controlled layout (do NOT use plt.tight_layout())
# fig.subplots_adjust(left=0.03, right=0.9, top=0.98, bottom=0.05)
plt.tight_layout()
fig.savefig("./regional_mortality_trends.png", dpi=300, bbox_inches="tight")

## Figure S3: Harvest Bias


In [None]:
# Settings
kwargs = {
    # ! General
    "file_suffix": None,
    # ! Metric
    "my_metric": "mort_nat_stems_prc_yr",  # "mort_nat_stems_prc_yr" mort_nat_vol_yr, mort_nat_vol_prc_yr
    # ! Data Wrangling
    "df": nfi_raw.copy(),
    # "my_grouping": ["gre", "species_lat"],
    "my_grouping": ["gre"],
    "my_method": "direct_bs",
    "load_from_file": True,
    "top_n_groups": None,
    "n_bootstraps_samples": 100,
    # ! Data Filter
    "min_trees_per_site": 0,
    "min_sites_per_group_year": 0,
    "reduce_to_dominant_sites": False,
    "weigh_by_sites_or_trees": "none",  # none, sites, trees
    # ! Plotting
    "plot_type": "facet",  # all or facet
    "save_plot": True,
    "center_to_first_year": False,
    "normalize_to_first_year": False,
    "ylim": None,  # [-200, 400],  # None or [min, max]
    "facet_label_trees_or_sites": "trees",  # trees or sites
    "uncertainty_representation": "band",  # band, bar, none
    "uncertainty_variable": "std",
    "aggregation_variable": "mean",
    "top_n_metric": "most_observations",  # most_observations or highest_final_mortality
    "top_n_groups_plot": 15,
    # "genus_filter": None,  # None or list of genus_lat values to filter for # ! Not implemented yet, needs a species-genus dictionary!
}


def custom_sort_key(element):
    priority_order = ["gre", "ser", "hex", "reg", "dep"]
    if element in priority_order:
        return (priority_order.index(element), element)
    else:
        return (len(priority_order), element)

In [None]:
# Input
iregion = "gre"  # gre, ser, hex, reg, dep | Region to plot
what_species = "52"  # top9, 52, final_analysis | Relevant whether map trend is based on 52 or top9 species
by_species_or_direct = "by_species"  # by_species or directly | Relevant whether map trend is weighted by species or calculated directly
fig1_folder = "../../data/final/mortality_trends/harvest_comparison/"
os.makedirs(fig1_folder, exist_ok=True)

In [None]:
for harvest_method in ["as_survivor", "as_mortality", "excluded"]:

    file_temp = f"{fig1_folder}/df_{what_species}_temporal_data-harvest_{harvest_method}.feather"

    print(f" --- Processing file: {file_temp} ---")

    if os.path.exists(file_temp):
        df_top9 = pd.read_feather(file_temp)
        print(f"Loaded data from file: {file_temp}")
        chime.info()
    else:
        # Only need the top9 species for the facet grid
        kwargs["suffix"] = (
            f"temporal_trend-region_{iregion}-species_{what_species}-calculation_{by_species_or_direct}"
        )

        # Filter for required species
        if what_species == "top9":
            df_temp = nfi_raw.copy().query("species_lat in @top9_species['species']")
        elif what_species == "52":
            df_temp = nfi_raw.copy().query("species_lat in @species_in_final_anlysis")

        # Filter with / without harvest
        if harvest_method == "as_survivor":
            # Counting harvest as survivor (default assumption)
            keep_these = ["alive_alive", "alive_dead", "new_alive", "alive_cut"]
        elif harvest_method == "excluded":
            # Exclude harvest from the analysis by removing "alive_cut"
            keep_these = ["alive_alive", "alive_dead", "new_alive"]
        elif harvest_method == "as_mortality":
            # Counting harvest as mortality
            keep_these = ["alive_alive", "alive_dead", "new_alive"]
            df_temp["tree_state_change"] = df_temp["tree_state_change"].replace(
                {
                    "alive_cut": "alive_dead",
                }
            )

        df_temp = df_temp.query("tree_state_change in @keep_these").copy()

        kwargs["df"] = df_temp.copy()
        kwargs["my_grouping"] = ["species_lat"]
        kwargs["load_from_file"] = False

        print(kwargs["suffix"])

        # Run function
        kwargs["my_grouping"] = sorted(kwargs["my_grouping"], key=custom_sort_key)
        df_counts, df_gm, df_grouped = start_to_finish_temporal_plot(
            kwargs,
            return_before_plotting=True,
        )
        chime.success()

        df_top9_raw = df_grouped.copy()

        # ! Wrangling
        # Get copy of df_grouped
        df_top9 = df_top9_raw.copy()

        # Extract group and year
        df_top9["group"] = df_top9["group_year"].str.split("_", expand=True)[0]
        df_top9["year"] = (
            df_top9["group_year"].str.split("_", expand=True)[1].astype(int)
        )

        # Extract mean value
        df_top9["mean"] = df_top9[f"{kwargs['my_metric']}_mean"]
        df_top9["std"] = df_top9[f"{kwargs['my_metric']}_std"]

        # Reduce df size
        df_top9 = df_top9[["group", "year", "group_year", "mean", "std"]].copy()
        df_top9["region"] = df_top9["group"].str.split("&", expand=True)[0]

        # Save variable
        df_top9["target"] = kwargs["my_metric"]

        # Save it
        df_top9.to_feather(file_temp)

In [None]:
# Get numbers on distribution among survivors, mortality and harvest
x = (
    nfi_raw.query("tree_state_change in ['alive_alive', 'alive_dead', 'alive_cut']")
    # .query("species_lat in @top9_species['species']")
    .query("species_lat in @species_in_final_anlysis")
    .groupby(["species_lat", "tree_state_change"], observed=False)
    .size()
    .reset_index(name="count")
    .query("count > 0")
    .reset_index(drop=True)
    .copy()
)

# Calculate percentages per species
x["total"] = x.groupby("species_lat", observed=False)["count"].transform("sum")
x["percentage"] = x["count"] / x["total"] * 100
x["percentage"] = x["percentage"].round(0).astype(int).astype(str)

# Compile dictionary for plotting
harvest_dictionary = {}
for species in x["species_lat"].unique():
    xtmp = x.query("species_lat == @species").copy()
    aa = xtmp.query("tree_state_change == 'alive_alive'")["percentage"].values[0]
    ac = xtmp.query("tree_state_change == 'alive_cut'")["percentage"].values[0]
    ad = xtmp.query("tree_state_change == 'alive_dead'")["percentage"].values[0]

    # string = f"{species} ({aa}, {ac}, {ad})"
    string = f"{species}\n({aa}% surv., {ac}% harv., {ad}% mort.)"

    harvest_dictionary[species] = string

harvest_dictionary

In [None]:
df_harvest_as_survivor = pd.read_feather(
    f"{fig1_folder}/df_{what_species}_temporal_data-harvest_as_survivor.feather"
)

df_harvest_as_mortality = pd.read_feather(
    f"{fig1_folder}/df_{what_species}_temporal_data-harvest_as_mortality.feather"
)

df_harvest_excluded = pd.read_feather(
    f"{fig1_folder}/df_{what_species}_temporal_data-harvest_excluded.feather"
)

# Map species names to harvest dictionary
df_harvest_as_survivor["group"] = df_harvest_as_survivor["group"].map(
    harvest_dictionary
)
df_harvest_as_mortality["group"] = df_harvest_as_mortality["group"].map(
    harvest_dictionary
)
df_harvest_excluded["group"] = df_harvest_excluded["group"].map(harvest_dictionary)

In [None]:
from matplotlib.ticker import FormatStrFormatter
from matplotlib.lines import Line2D


def plot_harvest_comparison(ax, df1, df2, df3, keep_legend_in_first_subplot=False):

    # Fix years
    df1["year"] = df1["year"].astype(int) + 5
    df2["year"] = df2["year"].astype(int) + 5
    df3["year"] = df3["year"].astype(int) + 5

    # Define label and color per dataset
    datasets = [
        ("As survivor", df1, "black"),
        ("As mortality", df2, "C1"),
        ("Excluded", df3, "C2"),
    ]

    # All regions present in any dataset
    # regions = sorted(set(df1["group"]) | set(df2["group"]) | set(df3["group"]))
    regions = df1["group"].unique().tolist()

    for i, region in enumerate(regions):
        for label, df, color in datasets:
            df_region = df.query("group == @region").sort_values("year")

            ax[i].plot(
                df_region["year"],
                df_region["mean"],
                label=label,
                color=color,
                linestyle="-",
            )
            ax[i].fill_between(
                df_region["year"],
                df_region["mean"] - df_region["std"],
                df_region["mean"] + df_region["std"],
                alpha=0.2,
                color=color,
            )

        # Title setup
        ax[i].set_title(f"{region}", fontsize=16, loc="left")
        # ax[i].set_xlabel("Year")
        # ax[i].set_ylabel("Mortality Rate (%-stems yr$^{-1}$)")
        ax[i].yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
        ax[i].spines["right"].set_visible(False)
        ax[i].spines["top"].set_visible(False)

    # Turn off unused subplots
    for j in range(i + 1, len(ax)):
        ax[j].axis("off")

    if keep_legend_in_first_subplot:
        # Create custom legend handles
        custom_handles = [
            Line2D([0], [0], color="black", lw=2, label="As survivor"),
            Line2D([0], [0], color="C1", lw=2, label="As mortality"),
            Line2D([0], [0], color="C2", lw=2, label="Excluded"),
        ]
        ax[0].legend(
            handles=custom_handles,
            loc="upper left",
            fontsize=10,
            title="Harvest Method",
            title_fontsize=12,
        )

    # Remove legends from each subplot
    # for a in ax[: len(regions)]:
    #     if a.get_legend() is not None:
    #         if keep_legend_in_first_subplot and a == ax[0]:
    #             continue
    #         a.get_legend().remove()

    return regions

In [None]:
from matplotlib.ticker import FormatStrFormatter

# # ! Load plot_regional_mortality_trends_grid() from above

fig, axes_trends = plt.subplots(2, 5, figsize=(16, 6))
axes_trends = axes_trends.flatten()

plot_harvest_comparison(
    axes_trends,
    df_harvest_as_survivor.query("region in @top9_species['species']").copy(),
    df_harvest_as_mortality.query("region in @top9_species['species']").copy(),
    df_harvest_excluded.query("region in @top9_species['species']").copy(),
    keep_legend_in_first_subplot=False,
)

# Remove all x-axis labels
for ax in axes_trends:
    ax.set_xlabel("")
    ax.set_ylabel("")

# Add universal y-axis label
fig.text(
    -0.01,
    0.5,
    "Mortality Rate (%-stems yr$^{-1}$)",
    va="center",
    rotation="vertical",
    fontsize=16,
)
# Add universal x-axis label
fig.text(
    0.5,
    -0.02,
    "Year",
    ha="center",
    fontsize=16,
)

# Add legend
legend_linewidth = 10

# Define custom legend handles without affecting actual plot lines
custom_lines = [
    Line2D([0], [0], color="black", label="As survivor"),  # , lw=legend_linewidth),
    Line2D([0], [0], color="orange", label="As mortality"),  # , lw=legend_linewidth),
    Line2D([0], [0], color="green", label="Excluded"),  # , lw=legend_linewidth),
]

plt.legend(
    handles=custom_lines,
    loc="upper left",
    title="Harvest Calculation:\n",
    title_fontsize=16,
    ncol=1,
    fontsize=14,
    # loc="center right",
    bbox_to_anchor=(0.125, 0.75),
    frameon=False,
    # handlelength=0.1,
)


plt.tight_layout()

In [None]:
from matplotlib.ticker import FormatStrFormatter

# Order dfs
myorder = sorted(species_in_final_anlysis.copy())  # By alphabet
myorder = species_in_final_anlysis.copy()  # By occurrence


def order_df_by(df, myorder, variable):
    df[variable] = pd.Categorical(df[variable], categories=myorder, ordered=True)

    # Sort the DataFrame according to the custom group order
    return df.sort_values(variable).reset_index(drop=True)


# Order dfs by occurrence
df_harvest_as_survivor = order_df_by(df_harvest_as_survivor, myorder, "region")
df_harvest_as_mortality = order_df_by(df_harvest_as_mortality, myorder, "region")
df_harvest_excluded = order_df_by(df_harvest_excluded, myorder, "region")

# # ! Load plot_regional_mortality_trends_grid() from above
part1 = myorder[:30]
part2 = myorder[30:]

fig, axes_trends = plt.subplots(6, 5, figsize=(20, 20))
axes_trends = axes_trends.flatten()

plot_harvest_comparison(
    axes_trends,
    df_harvest_as_survivor.query("region in @part1").copy(),
    df_harvest_as_mortality.query("region in @part1").copy(),
    df_harvest_excluded.query("region in @part1").copy(),
    keep_legend_in_first_subplot=False,
)

# Remove all x-axis labels
for ax in axes_trends:
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.tick_params(axis="both", which="major", labelsize=14)

# Add universal y-axis label
fig.text(
    -0.01,
    0.5,
    "Mortality Rate (%-stems yr$^{-1}$)",
    va="center",
    rotation="vertical",
    fontsize=20,
)
# Add universal x-axis label
fig.text(
    0.5,
    -0.02,
    "Year",
    ha="center",
    fontsize=20,
)

# Add legend
handles, labels = axes_trends[0].get_legend_handles_labels()
labels = ["Harvest as survival", "Harvest as mortality", "Excluding Harvest"]
# axes_trends[0].get_legend().remove()

# Add legend below the figure
fig.legend(
    handles,
    labels,
    loc="center",
    bbox_to_anchor=(0.5, -0.035),
    fontsize=16,
    title="",
    title_fontsize=16,
    ncol=3,
    frameon=False,
    # orientation="horizontal",
)

# Remove legend from all other subplots
# for ax in axes_trends[0:]:
# ax.get_legend().remove()

plt.tight_layout()
plt.savefig(f"./trends_harvest_comparison-1.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
fig, axes_trends = plt.subplots(6, 5, figsize=(20, 20))
axes_trends = axes_trends.flatten()

plot_harvest_comparison(
    axes_trends,
    df_harvest_as_survivor.query("region in @part2").copy(),
    df_harvest_as_mortality.query("region in @part2").copy(),
    df_harvest_excluded.query("region in @part2").copy(),
    keep_legend_in_first_subplot=True,
)

# Remove all x-axis labels
for ax in axes_trends:
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.tick_params(axis="both", which="major", labelsize=14)

# Add universal y-axis label
fig.text(
    -0.01,
    0.5,
    "Mortality Rate (%-stems yr$^{-1}$)",
    va="center",
    rotation="vertical",
    fontsize=20,
)

# Add universal x-axis label
fig.text(
    0.5,
    0.135,
    "Year",
    ha="center",
    fontsize=20,
)

# Add legend
handles, labels = axes_trends[0].get_legend_handles_labels()
labels = ["Harvest as survival", "Harvest as mortality", "Excluding Harvest"]
axes_trends[0].get_legend().remove()

# Add legend below the figure
fig.legend(
    handles,
    labels,
    loc="center",
    bbox_to_anchor=(0.5, 0.12),
    fontsize=16,
    title="",
    title_fontsize=16,
    ncol=3,
    frameon=False,
    # orientation="horizontal",
)

# Remove legend from all other subplots
# for ax in axes_trends[0:]:
# ax.get_legend().remove()

plt.tight_layout()
plt.savefig(f"./trends_harvest_comparison-2.png", dpi=300, bbox_inches="tight")
plt.show()

## Figure S4: Basal Area Trends


In [None]:
# Settings
kwargs = {
    # ! General
    "file_suffix": None,
    # ! Metric
    "my_metric": "mort_nat_stems_prc_yr",  # "mort_nat_stems_prc_yr" mort_nat_vol_yr, mort_nat_vol_prc_yr
    # ! Data Wrangling
    "df": nfi_raw.copy(),
    # "my_grouping": ["gre", "species_lat"],
    "my_grouping": ["gre"],
    "my_method": "direct_bs",
    "load_from_file": True,
    "top_n_groups": None,
    "n_bootstraps_samples": 100,
    # ! Data Filter
    "min_trees_per_site": 0,
    "min_sites_per_group_year": 0,
    "reduce_to_dominant_sites": False,
    "weigh_by_sites_or_trees": "none",  # none, sites, trees
    # ! Plotting
    "plot_type": "facet",  # all or facet
    "save_plot": False,
    "center_to_first_year": False,
    "normalize_to_first_year": False,
    "ylim": None,  # [-200, 400],  # None or [min, max]
    "facet_label_trees_or_sites": "trees",  # trees or sites
    "uncertainty_representation": "band",  # band, bar, none
    "uncertainty_variable": "std",
    "aggregation_variable": "mean",
    "top_n_metric": "most_observations",  # most_observations or highest_final_mortality
    "top_n_groups_plot": 15,
    # "genus_filter": None,  # None or list of genus_lat values to filter for # ! Not implemented yet, needs a species-genus dictionary!
}


def custom_sort_key(element):
    priority_order = ["gre", "ser", "hex", "reg", "dep"]
    if element in priority_order:
        return (priority_order.index(element), element)
    else:
        return (len(priority_order), element)

In [None]:
# Input
iregion = "gre"  # gre, ser, hex, reg, dep | Region to plot
what_species = (
    "52"  # top9 or 52 | Relevant whether map trend is based on all or top9 species
)
by_species_or_direct = "by_species"  # by_species or directly | Relevant whether map trend is weighted by species or calculated directly

# Setup
fig1_folder = f"../../data/final/ba_trends"
os.makedirs(fig1_folder, exist_ok=True)

In [None]:
# Calculate data
load_from_file = True  # Set to False to recalculate data

# for target_type in ["relative"]:
for target_type in ["absolute", "relative"]:
    if target_type == "absolute":
        targets = [
            "mort_nat_ba_yr",
            "mort_cut_ba_yr",
            "grwt_tot_ba_yr",
            "change_tot_ba_yr",
        ]
    elif target_type == "relative":
        targets = [
            "mort_nat_ba_prc_yr",
            "mort_cut_ba_prc_yr",
            "grwt_tot_ba_prc_yr",
            "change_tot_ba_prc_yr",
        ]
    else:
        raise ValueError("target_type must be 'absolute' or 'relative'")

    # Set file name for the temporary file
    file_temp = f"{fig1_folder}/df_{what_species}_temporal_data_{target_type}.feather"

    if load_from_file and os.path.exists(file_temp):
        df_top9 = pd.read_feather(file_temp)
        print(f"Loaded data from file: {file_temp}")
        chime.info()
    else:
        # Save all target into one list
        df_all_targets = []
        for target in targets:
            print(" --- Processing target:", target)
            kwargs["my_metric"] = target
            # Only need the top9 species for the facet grid
            kwargs["suffix"] = (
                f"temporal_trend-region_{iregion}-species_{what_species}-calculation_{by_species_or_direct}"
            )
            if what_species == "top9":
                kwargs["df"] = nfi_raw.copy().query(
                    "species_lat in @top9_species['species']"
                )
            elif what_species == "52":
                kwargs["df"] = nfi_raw.copy().query(
                    "species_lat2 in @species_in_final_anlysis"
                )
            kwargs["my_grouping"] = ["species_lat"]
            kwargs["load_from_file"] = False

            print(kwargs["suffix"])

            # Run function
            kwargs["my_grouping"] = sorted(kwargs["my_grouping"], key=custom_sort_key)
            df_counts, df_gm, df_grouped = start_to_finish_temporal_plot(
                kwargs,
                return_before_plotting=True,
            )
            chime.success()
            df_top9_raw = df_grouped.copy()

            # ! Wrangling
            # Get copy of df_grouped
            df_top9 = df_top9_raw.copy()

            # Extract group and year
            df_top9["group"] = df_top9["group_year"].str.split("_", expand=True)[0]
            df_top9["year"] = (
                df_top9["group_year"].str.split("_", expand=True)[1].astype(int)
            )

            # Extract mean value
            df_top9["mean"] = df_top9[f"{kwargs['my_metric']}_mean"]
            df_top9["std"] = df_top9[f"{kwargs['my_metric']}_std"]

            # Reduce df size
            df_top9 = df_top9[["group", "year", "group_year", "mean", "std"]].copy()
            df_top9["region"] = df_top9["group"].str.split("&", expand=True)[0]

            # Save variable
            df_top9["target"] = kwargs["my_metric"]

            df_all_targets.append(df_top9)

        # Concatenate all targets into one DataFrame
        df_top9 = pd.concat(df_all_targets, ignore_index=True)

        # Save it
        df_top9.to_feather(file_temp)

    df_top9.head()

In [None]:
# Define plotting function
def plot_trend_ba(
    species,
    target_type="absolute",
    df=None,
    ax=None,
    font_scaler=1,
):
    """
    Plot the trend of a specific species based on the provided DataFrame.

    Parameters:
        species (str): The species to plot.
        df (pd.DataFrame): The DataFrame containing the data.
        target_type (str): The type of target variable ('absolute' or 'relative').
        ax (matplotlib.axes.Axes, optional): The axes to plot on. If None, a new figure is created.
        font_scaler (float): Scaling factor for font sizes in the plot.
    """
    # --- Checks and validations ---
    # Check if the species is in the DataFrame
    if species not in df["group"].unique():
        raise ValueError(f"Species '{species}' not found in the DataFrame.")

    # Ensure the target variable is valid
    valid_targets = [
        "mort_nat_ba_yr",
        "mort_cut_ba_yr",
        "grwt_tot_ba_yr",
        "change_tot_ba_yr",
        "mort_nat_ba_prc_yr",
        "mort_cut_ba_prc_yr",
        "grwt_tot_ba_prc_yr",
        "change_tot_ba_prc_yr",
    ]
    if target_type == "absolute":
        targets = [
            "mort_nat_ba_yr",
            "mort_cut_ba_yr",
            "grwt_tot_ba_yr",
            "change_tot_ba_yr",
        ]
    elif target_type == "relative":
        targets = [
            "mort_nat_ba_prc_yr",
            "mort_cut_ba_prc_yr",
            "grwt_tot_ba_prc_yr",
            "change_tot_ba_prc_yr",
        ]
    else:
        raise ValueError("target_type must be 'absolute' or 'relative'")
    if not all(target in valid_targets for target in targets):
        raise ValueError(
            f"Invalid target variable(s) for {target_type}: {targets}. "
            "Valid targets are: " + ", ".join(valid_targets)
        )
    if not isinstance(df, pd.DataFrame):
        raise TypeError("df must be a pandas DataFrame.")
    if not isinstance(font_scaler, (int, float)):
        raise TypeError("font_scaler must be a number (int or float).")
    if not isinstance(species, str):
        raise TypeError("species must be a string.")
    if ax is not None and not hasattr(ax, "plot"):
        raise TypeError("ax must be a matplotlib Axes object or None.")

    # ------------------------------------------------------------------

    # Get species data
    df_species = df.query("group == @species").sort_values("year")

    # Inverse values for mortality and harvest because they should be negative
    # if target_type == "absolute":
    #     df_species["mean"] = df_species["mean"].where(
    #         ~df_species["target"].isin(["mort_nat_ba_yr", "mort_cut_ba_yr"]),
    #         -df_species["mean"],
    #     )
    # elif target_type == "relative":
    #     df_species["mean"] = df_species["mean"].where(
    #         ~df_species["target"].isin(["mort_nat_ba_prc_yr", "mort_cut_ba_prc_yr"]),
    #         -df_species["mean"],
    #     )

    # Plot
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))

    # Fix years
    df_species["year"] = df_species["year"] + 5

    # Define colors for each target variable
    if target_type == "absolute":
        colors = {
            "mort_nat_ba_yr": "orange",
            "mort_cut_ba_yr": "brown",
            "grwt_tot_ba_yr": "green",
            "change_tot_ba_yr": "black",
        }
    elif target_type == "relative":
        colors = {
            "mort_nat_ba_prc_yr": "orange",
            "mort_cut_ba_prc_yr": "brown",
            "grwt_tot_ba_prc_yr": "green",
            "change_tot_ba_prc_yr": "black",
        }

    # Add dotted 0 line
    ax.axhline(0, color="black", linestyle="dotted", linewidth=0.5)

    # Repeat line making for each target variable
    for target in df_species["target"].unique():

        ialpha_line = 1  # if "nat" in target else 0.4
        ialpha_fill = 0.2  # if "nat" in target else 0.075

        df_target = df_species.query("target == @target")
        ax.plot(
            df_target["year"],
            df_target["mean"],
            label=target,
            alpha=ialpha_line,
            color=colors[target],
        )

        # Fill between mean +/- std
        ax.fill_between(
            df_target["year"],
            df_target["mean"] - df_target["std"],
            df_target["mean"] + df_target["std"],
            alpha=ialpha_fill,
            edgecolor=None,
            color=colors[target],
        )

    # Set y-axis limits so that the plot is centered around 0 -> if mortality is shown as negative flux
    # ymax = df_species["mean"].max() + df_species["std"].max()
    # ymin = df_species["mean"].min() - df_species["std"].max()
    # yabs = max(abs(ymax), abs(ymin))
    # ax.set_ylim(-yabs, yabs)

    # Labels
    # ax.set_title(f"{species}")
    ax.set_xlabel("Year", fontsize=12 * font_scaler)
    # ax.set_ylabel("Mortality Rate (%-stems yr$^{-1}$)", fontsize=12 * font_scaler)

    # Add legend
    for line in ax.get_lines():
        label = line.get_label()
        if "nat" in label:
            line.set_label("Mortality")
        elif "cut" in label:
            line.set_label("Harvest")
        elif "grwt" in label:
            line.set_label("Growth")
        elif "change" in label:
            line.set_label("Net Change")

    ax.legend(
        loc="upper left",
        fontsize=10 * font_scaler,
        title="",
        title_fontsize=12 * font_scaler,
    )

    # Reduce y-ticks
    ax.yaxis.set_major_locator(plt.MaxNLocator(3))

    # Reduce x-ticks
    if df.year.nunique() == 9:
        # Includes 2023 census
        nxticks = 5
    else:
        nxticks = 4

    # Format ticks
    ax.xaxis.set_major_locator(plt.MaxNLocator(nxticks))
    ax.yaxis.set_major_formatter(FormatStrFormatter("%.1f"))

    # Set font size for ticks
    ax.tick_params(axis="both", which="major", labelsize=12 * font_scaler)

    # Remove upper and right spines
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

    # Set y-axis label
    ylabel = "Basal Area Rate"
    if target_type == "absolute":
        ylabel += " (m$^2$ yr$^{-1}$)"
    else:
        ylabel += " (%-m$^2$ yr$^{-1}$)"

    ax.set_ylabel(
        ylabel,
        fontsize=12 * font_scaler,
        # labelpad=10,
    )

    return ax

In [None]:
# Make Figures for top 9 species

# Set target type
target_type = "absolute"  # relative or absolute

# Get target variables
if target_type == "absolute":
    targets = [
        "mort_nat_ba_yr",
        "mort_cut_ba_yr",
        "grwt_tot_ba_yr",
        # "change_tot_ba_yr",
    ]
elif target_type == "relative":
    targets = [
        "mort_nat_ba_prc_yr",
        "mort_cut_ba_prc_yr",
        "grwt_tot_ba_prc_yr",
        # "change_tot_ba_prc_yr",
    ]

# Load data and split into two parts
df_species52 = pd.read_feather(
    f"{fig1_folder}/df_52_temporal_data_{target_type}.feather"
)

df_species52 = df_species52.query("target in @targets").copy()
df_species52 = df_species52.query("group in @top9_species['species']").copy()

# ! Start Plot
# Make grid of 2x5 plots
fig, axs = plt.subplots(2, 5, figsize=(16, 6))
axs = axs.flatten()

# Loop over each species and plot the trendt
for i, species in enumerate(top9_species["species"]):
    if species not in df_species52["group"].unique():
        print(f"Species '{species}' not found in the DataFrame. Skipping.")
        continue

    ax = plot_trend_ba(
        species,
        target_type=target_type,
        df=df_species52,
        ax=axs[i],
        font_scaler=0.8,
    )

    # ax.set_title(top9_species["title"].iloc[i], fontsize=14, loc="left") # with occurrence percentage
    ax.set_title(species, fontsize=14, loc="left")  # without occurrence percentage

    # Save labels and handles for the legend
    if i == 0:
        handles, labels = ax.get_legend_handles_labels()

    # Remove legend from all but the first subplot
    ax.get_legend().remove()

    # Remove all x-axis labels and legends
    ax.set_xlabel("")
    ax.set_ylabel("")

    # Remove digit from the y-axis label
    ax.yaxis.set_major_formatter(FormatStrFormatter("%.0f"))
    ax.xaxis.set_major_formatter(FormatStrFormatter("%.0f"))


# Add universal y-axis label
ylabel = "Change in Basal Area"
if target_type == "absolute":
    ylabel += " (m$^2$ yr$^{-1}$)"
else:
    ylabel += " (%-m$^2$ yr$^{-1}$)"

fig.text(
    -0.01,
    0.5,
    ylabel,
    va="center",
    rotation="vertical",
    fontsize=16,
)

# Add universal x-axis label
fig.text(
    0.5,
    -0.01,
    "Year",
    ha="center",
    fontsize=16,
)


# Remove empty subplots
for i in range(9, len(axs)):
    axs[i].axis("off")

# Inverse order for nicer legend
handles = handles[::-1]
labels = labels[::-1]

# Add legend
plt.legend(
    handles=handles,
    labels=labels,
    title="Basal Area Trends:\n",
    title_fontsize=16,
    ncol=1,
    fontsize=14,
    # loc="center right",
    bbox_to_anchor=(0.9, 0.75),
    frameon=False,
    # handlelength=0.1,
)


# Save the plot
plt.tight_layout()
plt.savefig(f"./trends_ba_{target_type}-top9_species.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
# Set target type
target_type = "absolute"  # relative or absolute

# Get target variables
if target_type == "absolute":
    targets = [
        "mort_nat_ba_yr",
        "mort_cut_ba_yr",
        "grwt_tot_ba_yr",
        # "change_tot_ba_yr",
    ]
elif target_type == "relative":
    targets = [
        "mort_nat_ba_prc_yr",
        "mort_cut_ba_prc_yr",
        "grwt_tot_ba_prc_yr",
        # "change_tot_ba_prc_yr",
    ]

# Load data and split into two parts
df_species52 = df_species52.query("target in @targets").copy()

df_species52 = pd.read_feather(
    f"{fig1_folder}/df_52_temporal_data_{target_type}.feather"
)

# Ensure 'group' is a categorical variable with the specified order
myorder = sorted(species_in_final_anlysis.copy())  # By alphabet
myorder = species_in_final_anlysis.copy()  # By occurrence

df_species52["group"] = pd.Categorical(
    df_species52["group"], categories=myorder, ordered=True
)

# Sort the DataFrame according to the custom group order
df_species52 = df_species52.sort_values("group").reset_index(drop=True)

part1 = df_species52.group.unique()[:30]
part2 = df_species52.group.unique()[30:]

# Set ylims
if target_type == "absolute":
    ylims = [0, 600]  # m^2 yr^-1
elif target_type == "relative":
    ylims = [-50, 50]  # %-m^2 yr^-1
else:
    raise ValueError("target_type must be 'absolute' or 'relative'")

# ! First Part  -----------------------------------
# Make a grid of 4x6 plots
fig, axs = plt.subplots(6, 5, figsize=(20, 20))

# Flatten the axs array
axs = axs.flatten()

# Loop over each species and plot the trend with inset
for i, species in enumerate(part1):
    plot_trend_ba(
        species,
        target_type=target_type,
        df=df_species52,
        ax=axs[i],
        font_scaler=1,
    )
    # Remove x-axis labels
    axs[i].set_xlabel("")
    axs[i].set_ylabel("")

    # Add title for each subplot
    if species in top9_species["species"].values:
        fontw = "normal"
    else:
        fontw = "normal"

    axs[i].set_title(
        " " + species.replace("_", " ").title(),
        fontsize=16,
        fontweight=fontw,
        pad=10,
        loc="left",
        y=0.975,
        # fontweight="bold",
    )

    # Remove trailing zero from the y-axis label
    axs[i].yaxis.set_major_formatter(FormatStrFormatter("%.0f"))
    axs[i].tick_params(axis="both", which="major", labelsize=14)

    # Set y-axis limits
    # axs[i].set_ylim(ylims)

# Add universal y-axis label
ylabel = "Change in Basal Area"
if target_type == "absolute":
    ylabel += " (m$^2$ yr$^{-1}$)"
else:
    ylabel += " (%-m$^2$ yr$^{-1}$)"

fig.text(
    -0.01,
    0.5,
    ylabel,
    va="center",
    rotation="vertical",
    fontsize=20,
)

# Add universal x-axis label
fig.text(
    0.5,
    -0.01,
    "Year",
    ha="center",
    fontsize=20,
)

# Take care of the legend
handles, labels = axs[0].get_legend_handles_labels()
# Add legend only to the first subplot
# axs[0].legend(
#     handles,
#     labels,
#     loc="upper left",
#     fontsize=12,
#     title="",
#     title_fontsize=12,
# )

# Add legend below the figure
fig.legend(
    handles,
    labels,
    loc="center",
    bbox_to_anchor=(0.5, -0.03),
    fontsize=16,
    title="",
    title_fontsize=16,
    ncol=4,
    frameon=False,
)

# Remove legend from all other subplots
for ax in axs[0:]:
    ax.get_legend().remove()

# Remove empty subplots
for i in range(len(part1), len(axs)):
    axs[i].axis("off")

# Tighten layout
plt.tight_layout()

# Save the plot
# plt.savefig(f"{fig1_folder}/fig-52_species_temporal_trend-1.png")
plt.savefig(
    f"./trends_ba_{target_type}-all_species-1.png", dpi=300, bbox_inches="tight"
)
plt.show()

In [None]:
# # ! Second Part  -----------------------------------
# Make a grid of 4x6 plots
fig, axs = plt.subplots(5, 5, figsize=(20, 20))

# Flatten the axs array
axs = axs.flatten()

# Loop over each species and plot the trend with inset
for i, species in enumerate(part2):
    plot_trend_ba(
        species,
        target_type=target_type,
        df=df_species52,
        ax=axs[i],
        font_scaler=1,
    )
    # Remove x-axis labels
    axs[i].set_xlabel("")
    axs[i].set_ylabel("")

    # Add title for each subplot
    if species in top9_species["species"].values:
        fontw = "normal"
    else:
        fontw = "normal"

    axs[i].set_title(
        " " + species.replace("_", " ").title(),
        fontsize=16,
        fontweight=fontw,
        pad=10,
        loc="left",
        y=0.975,
        # fontweight="bold",
    )

    # Remove trailing zero from the y-axis label
    axs[i].yaxis.set_major_formatter(FormatStrFormatter("%.0f"))
    axs[i].tick_params(axis="both", which="major", labelsize=14)
    # Set y-axis limits
    # axs[i].set_ylim(ylims)

    # Remove legend for all but the first subplot
    # if i > 0:
    axs[i].get_legend().remove()

# Add universal y-axis label
ylabel = "Change in Basal Area"
if target_type == "absolute":
    ylabel += " (m$^2$ yr$^{-1}$)"
else:
    ylabel += " (%-m$^2$ yr$^{-1}$)"

fig.text(
    -0.01,
    0.5,
    ylabel,
    va="center",
    rotation="vertical",
    fontsize=20,
)

# Add universal x-axis label
fig.text(
    0.5,
    -0.01,
    "Year",
    ha="center",
    fontsize=20,
)

# Take care of the legend
handles, labels = axs[0].get_legend_handles_labels()
# Add legend only to the first subplot
# axs[0].legend(
#     handles,
#     labels,
#     loc="upper left",
#     fontsize=12,
#     title="",
#     title_fontsize=12,
# )

# Add legend below the figure
fig.legend(
    handles,
    labels,
    loc="center",
    bbox_to_anchor=(0.5, -0.03),
    fontsize=16,
    title="",
    title_fontsize=16,
    ncol=4,
    frameon=False,
)

# Remove empty subplots
for i in range(len(part2), len(axs)):
    axs[i].axis("off")

# Tighten layout
plt.tight_layout()

# Save the plot
# plt.savefig(f"{fig1_folder}/fig-52_species_temporal_trend-1.png")
plt.savefig(
    f"./trends_ba_{target_type}-all_species-2.png", dpi=300, bbox_inches="tight"
)
plt.show()

## Figure S5: Trends per Size Class


In [None]:
# Get new df
nfi_new = nfi_raw.copy()

# Redefine size classes based on sampling strategy:  7.5 / 22.5 / 37.5 cm
nfi_new["tree_height_class"] = pd.cut(
    nfi_new["dbh_1"] * 100,
    bins=[7.5, 22.5, 37.5, float("inf")],
    labels=["7.5-22.5 cm", "22.5-37.5 cm", "≥ 37.5 cm"],
    right=False,
)

# Add "Missing" to categories before filling NA
nfi_new["tree_height_class"] = nfi_new["tree_height_class"].cat.add_categories(
    ["Missing"]
)
nfi_new["tree_height_class"] = nfi_new["tree_height_class"].fillna("Missing")

In [None]:
# Settings
kwargs = {
    # ! General
    "file_suffix": None,
    # ! Metric
    "my_metric": "mort_nat_stems_prc_yr",  # "mort_nat_stems_prc_yr" mort_nat_vol_yr, mort_nat_vol_prc_yr
    # ! Data Wrangling
    "df": nfi_new.copy(),
    # "my_grouping": ["tree_height_class", "species_lat"],
    "my_grouping": ["tree_height_class"],
    "my_method": "direct_bs",
    "load_from_file": False,
    "top_n_groups": None,
    "n_bootstraps_samples": 100,
    # ! Data Filter
    "min_trees_per_site": 0,
    "min_sites_per_group_year": 0,
    "reduce_to_dominant_sites": False,
    "weigh_by_sites_or_trees": "none",  # none, sites, trees
    # ! Plotting
    "plot_type": "facet",  # all or facet
    "save_plot": True,
    "center_to_first_year": False,
    "normalize_to_first_year": False,
    "ylim": None,  # [-200, 400],  # None or [min, max]
    "facet_label_trees_or_sites": "trees",  # trees or sites
    "uncertainty_representation": "band",  # band, bar, none
    "uncertainty_variable": "std",
    "aggregation_variable": "mean",
    "top_n_metric": "most_observations",  # most_observations or highest_final_mortality
    "top_n_groups_plot": 15,
    # "genus_filter": None,  # None or list of genus_lat values to filter for # ! Not implemented yet, needs a species-genus dictionary!
}


def custom_sort_key(element):
    priority_order = ["gre", "ser", "hex", "reg", "dep"]
    if element in priority_order:
        return (priority_order.index(element), element)
    else:
        return (len(priority_order), element)

In [None]:
# Input
iregion = "gre"  # gre, ser, hex, reg, dep | Region to plot
what_species = (
    "52"  # top9, all, 52 | Relevant whether map trend is based on 52 or top9 species
)
by_species_or_direct = "by_species"  # by_species or directly | Relevant whether map trend is weighted by species or calculated directly
fig1_folder = "../../data/final/mortality_trends/per_size_class/"
os.makedirs(fig1_folder, exist_ok=True)

In [None]:
load_existing = False

# for g in [["tree_height_class"], ["tree_height_class", "species_lat"]]:
for g in [["tree_height_class", "species_lat"]]:
    kwargs["my_grouping"] = g
    if g.__len__() == 1:
        suf = f"by_height_only"
    else:
        suf = f"{what_species}_by_height_and_species"

    file_temp = f"{fig1_folder}/df_{suf}.feather"

    print(f" --- Processing file: {file_temp} ---")

    if load_existing and os.path.exists(file_temp):
        df_out = pd.read_feather(file_temp)
        print(f"Loaded data from file: {file_temp}")
        chime.info()
    else:
        # Only need the top9 species for the facet grid
        kwargs["suffix"] = f"{suf}"

        # If calculated per species, filter for top 9 species
        if suf == "by_height_only":
            kwargs["df"] = nfi_new.copy().query(
                "species_lat in @species_in_final_anlysis"
            )
        elif suf == "top9_by_height_and_species":
            kwargs["df"] = nfi_new.copy().query(
                "species_lat in @top9_species['species']"
            )
        elif suf == "52_by_height_and_species":
            kwargs["df"] = nfi_new.copy().query(
                "species_lat in @species_in_final_anlysis"
            )

        # Run function
        kwargs["my_grouping"] = sorted(kwargs["my_grouping"], key=custom_sort_key)
        df_counts, df_gm, df_grouped = start_to_finish_temporal_plot(
            kwargs,
            return_before_plotting=True,
        )
        chime.success()

        df_out_raw = df_grouped.copy()

        # ! Wrangling
        # Get copy of df_grouped
        df_out = df_out_raw.copy()

        # Extract group and year
        df_out["group"] = df_out["group_year"].str.split("_", expand=True)[0]
        df_out["year"] = df_out["group_year"].str.split("_", expand=True)[1].astype(int)

        # Extract mean value
        df_out["mean"] = df_out[f"{kwargs['my_metric']}_mean"]
        df_out["std"] = df_out[f"{kwargs['my_metric']}_std"]

        # Reduce df size
        df_out = df_out[["group", "year", "group_year", "mean", "std"]].copy()
        df_out["region"] = df_out["group"].str.split("&", expand=True)[0]

        # Save variable
        df_out["target"] = kwargs["my_metric"]

        # Save it
        df_out.to_feather(file_temp)

In [None]:
# Get proportion of size classes
txt_heightpercentages = (
    nfi_new.copy()
    .query("species_lat in @species_in_final_anlysis")
    # .query("species_lat in @top9_species['species']")
    .query("tree_state_1 == 'alive'")
    .value_counts("tree_height_class", normalize=True)
    .sort_index()
    .to_dict()
)

for k, v in txt_heightpercentages.items():
    txt_heightpercentages[k] = f"{(v * 100):.1f}%"
txt_heightpercentages

In [None]:
# Load data
df_byspec = pd.read_feather(
    f"{fig1_folder}/df_{what_species}_by_height_and_species.feather"
)

df_byspec["group"] = df_byspec["group"].str.split("&", expand=True)[1]
df_byspec["group"] = df_byspec["group"].fillna("Missing")
df_byspec["group"] = df_byspec["group"].astype(str)
df_byspec["group"] = df_byspec["group"].replace("nan", "Missing")

group_order = ["7.5-22.5 cm", "22.5-37.5 cm", "≥ 37.5 cm", "Missing"]
df_byspec["group"] = pd.Categorical(
    df_byspec["group"], categories=group_order, ordered=True
)

# Fix years
df_byspec["year"] = df_byspec["year"].astype(int) + 5
df_byspec

In [None]:
# Top 9 species plot
df_plot = df_byspec.query("region in @top9_species['species']").copy()

# Set colors
colors = plt.cm.viridis(np.linspace(0, 1, df_plot["group"].nunique()))

# Plot
fig, axs = plt.subplots(2, 5, figsize=(16, 6))
axs = axs.flatten()

# Loop over species names saved in region variable
for i, r in enumerate(df_plot.region.unique()):
    idf = df_plot.query("region == @r").copy()

    # Loop over all groups
    # for g, group in enumerate(idf["group"].unique()):
    for g, group in enumerate(["7.5-22.5 cm", "22.5-37.5 cm", "≥ 37.5 cm"]):
        # Filter for current group
        idf_group = idf.query("group == @group").copy()

        # Set label and color
        label = group.replace("_", " ").title()
        label = group
        color = colors[g]

        # Plot the mean and fill between std
        axs[i].plot(
            idf_group["year"],
            idf_group["mean"],
            label=label,
            color=color,
            linestyle="-",
        )

        axs[i].fill_between(
            idf_group["year"],
            idf_group["mean"] - idf_group["std"],
            idf_group["mean"] + idf_group["std"],
            alpha=0.2,
            color=color,
        )

        # Remove top and right spines
        axs[i].spines["right"].set_visible(False)
        axs[i].spines["top"].set_visible(False)

        # Give title
        if r == "all":
            r = "All Species"
        axs[i].set_title(f"{r}", fontsize=14, loc="left")

# Turn off unused subplots
for j in range(i + 1, len(axs)):
    axs[j].axis("off")

# Add legend
handles, labels = axs[0].get_legend_handles_labels()

# Map labels to height percentages
# labels = [f"{label} m ({txt_heightpercentages.get(label, '')})" for label in labels] # Uncomment to display percentages

plt.legend(
    handles=handles,
    labels=labels,
    title="DBH Classes",
    title_fontsize=16,
    ncol=1,
    fontsize=14,
    # Shorten line length
    handlelength=0.8,
    # loc="upper left",
    # bbox_to_anchor=(0.5, 0.925),
    bbox_to_anchor=(0.9, 0.925),
    frameon=False,
)

fig.text(
    -0.01,
    0.5,
    "Mortality Rate (%-stems yr$^{-1}$)",
    va="center",
    rotation="vertical",
    fontsize=16,
)

# Add universal x-axis label
fig.text(
    0.5,
    -0.01,
    "Year",
    ha="center",
    fontsize=16,
)

# Save the plot
plt.tight_layout()
plt.savefig(f"./trends_by_height_class.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
from matplotlib.ticker import FormatStrFormatter

# Order dfs
myorder = sorted(species_in_final_anlysis.copy())  # By alphabet
myorder = species_in_final_anlysis.copy()  # By occurrence


def order_df_by(df, myorder, variable):
    df[variable] = pd.Categorical(df[variable], categories=myorder, ordered=True)

    # Sort the DataFrame according to the custom group order
    return df.sort_values(variable).reset_index(drop=True)


part1 = myorder[:30]
part2 = myorder[30:]

df_plot = df_byspec.copy()
df_plot = order_df_by(df_plot, myorder, "region")
df_plot = df_plot.query("region in @part1").copy()

# Set colors
colors = plt.cm.viridis(np.linspace(0, 1, df_plot["group"].nunique()))

# Plot
fig, axs = plt.subplots(6, 5, figsize=(20, 20))
axs = axs.flatten()

# Loop over species names saved in region variable
for i, r in enumerate(df_plot.region.unique()):
    idf = df_plot.query("region == @r").copy()

    # Loop over all groups
    # for g, group in enumerate(idf["group"].unique()):
    for g, group in enumerate(["7.5-22.5 cm", "22.5-37.5 cm", "≥ 37.5 cm"]):
        # Filter for current group
        idf_group = idf.query("group == @group").copy()

        # Sort by year
        idf_group = idf_group.sort_values("year")

        # Remove mean = 0 rows
        idf_group = idf_group.query("mean != 0").copy()

        # Have at least 4 years of data to show
        if idf_group.shape[0] < 4:
            continue

        # Set label and color
        label = group.replace("_", " ").title()
        label = group
        color = colors[g]

        # Plot the mean and fill between std
        axs[i].plot(
            idf_group["year"],
            idf_group["mean"],
            label=label,
            color=color,
            linestyle="-",
        )

        axs[i].fill_between(
            idf_group["year"],
            idf_group["mean"] - idf_group["std"],
            idf_group["mean"] + idf_group["std"],
            alpha=0.2,
            color=color,
        )

        # Increase axis ticks size
        axs[i].tick_params(axis="both", which="major", labelsize=14)

        # Remove top and right spines
        axs[i].spines["right"].set_visible(False)
        axs[i].spines["top"].set_visible(False)

        # Give title
        if r == "all":
            r = "All Species"
        axs[i].set_title(f"{r}", fontsize=16, loc="left")

# Turn off unused subplots
for j in range(i + 1, len(axs)):
    axs[j].axis("off")

# Add legend
handles, labels = axs[0].get_legend_handles_labels()

# Map labels to height percentages
# labels = [f"{label} m ({txt_heightpercentages.get(label, '')})" for label in labels] # Uncomment to display percentages

fig.legend(
    handles=handles,
    labels=labels,
    title="",
    loc="center",
    bbox_to_anchor=(0.5, -0.035),
    title_fontsize=16,
    fontsize=16,
    ncol=3,
    handlelength=0.8,
    frameon=False,
)

fig.text(
    -0.02,
    0.5,
    "Mortality Rate (%-stems yr$^{-1}$)",
    va="center",
    rotation="vertical",
    fontsize=20,
)

# Add universal x-axis label
fig.text(
    0.5,
    -0.01,
    "Year",
    ha="center",
    fontsize=20,
)

# Save the plot
plt.tight_layout()
plt.savefig(f"./trends_by_height_class-1.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
df_plot = df_byspec.copy()
df_plot = order_df_by(df_plot, myorder, "region")
df_plot = df_plot.query("region in @part2").copy()

# Set colors
colors = plt.cm.viridis(np.linspace(0, 1, df_plot["group"].nunique()))

# Plot
fig, axs = plt.subplots(6, 5, figsize=(20, 20))
axs = axs.flatten()

# Loop over species names saved in region variable
for i, r in enumerate(df_plot.region.unique()):
    idf = df_plot.query("region == @r").copy()

    # Loop over all groups
    # for g, group in enumerate(idf["group"].unique()):
    for g, group in enumerate(["7.5-22.5 cm", "22.5-37.5 cm", "≥ 37.5 cm"]):
        # Filter for current group
        idf_group = idf.query("group == @group").copy()

        # Sort by year
        idf_group = idf_group.sort_values("year")

        # Remove mean = 0 rows
        idf_group = idf_group.query("mean != 0").copy()

        # Have at least 4 years of data to show
        if idf_group.shape[0] < 4:
            continue

        # Set label and color
        label = group.replace("_", " ").title()
        label = group
        color = colors[g]

        # Plot the mean and fill between std
        axs[i].plot(
            idf_group["year"],
            idf_group["mean"],
            label=label,
            color=color,
            linestyle="-",
        )

        axs[i].fill_between(
            idf_group["year"],
            idf_group["mean"] - idf_group["std"],
            idf_group["mean"] + idf_group["std"],
            alpha=0.2,
            color=color,
        )

        # Increase axis ticks size
        axs[i].tick_params(axis="both", which="major", labelsize=14)

        # Remove top and right spines
        axs[i].spines["right"].set_visible(False)
        axs[i].spines["top"].set_visible(False)

        # Give title
        if r == "all":
            r = "All Species"
        axs[i].set_title(f"{r}", fontsize=16, loc="left")

# Turn off unused subplots
for j in range(i + 1, len(axs)):
    axs[j].axis("off")

# Add legend
handles, labels = axs[0].get_legend_handles_labels()

# Map labels to height percentages
# labels = [f"{label} m ({txt_heightpercentages.get(label, '')})" for label in labels] # Uncomment to display percentages

fig.legend(
    handles=handles,
    labels=labels,
    title="",
    loc="center",
    bbox_to_anchor=(0.5, 0.12),
    title_fontsize=16,
    fontsize=16,
    ncol=3,
    handlelength=0.8,
    frameon=False,
)

fig.text(
    -0.02,
    0.5,
    "Mortality Rate (%-stems yr$^{-1}$)",
    va="center",
    rotation="vertical",
    fontsize=20,
)

# Add universal x-axis label
fig.text(
    0.5,
    0.135,
    "Year",
    ha="center",
    fontsize=20,
)

# Save the plot
plt.tight_layout()
plt.savefig(f"./trends_by_height_class-2.png", dpi=300, bbox_inches="tight")
plt.show()