# Wrangle SPEI Data - extracted in R


In [None]:
import sys

sys.path.insert(0, "../../src")
from imports import *

init_notebook()

In [None]:
# Load both R files into one df
file_1 = "/Volumes/SAMSUNG 1TB/IFNA/digitalis_v3/processed/1km/all/merged/spei/R-spei_split_1.csv"
file_2 = "/Volumes/SAMSUNG 1TB/IFNA/digitalis_v3/processed/1km/all/merged/spei/R-spei_split_2.csv"

file_1 = pd.read_csv(file_1)
file_2 = pd.read_csv(file_2)

df_spei = pd.concat([file_1, file_2])

In [None]:
# Takes ~10 Minutes
df_spei_trend = calc_spei_trend_loop_mp(df_spei.dropna())
df_spei_trend

In [None]:
# Takes ~10 Minutes
df_spei_minmean = calc_spei_min_mean_loop_mp(
    df_spei.dropna(), years_before_second_visit=7
)
df_spei_minmean

In [None]:
# Save dfs as predictors
df_spei_trend.to_feather("../../data/final/predictor_datasets/spei_trend.feather")
df_spei_minmean.to_feather("../../data/final/predictor_datasets/spei_anom.feather")

In [None]:
# Load dfs
df_spei_trend = pd.read_feather(
    "../../data/final/predictor_datasets/spei_trend.feather"
)
df_spei_minmean = pd.read_feather(
    "../../data/final/predictor_datasets/spei_anom.feather"
)

display(df_spei_trend)
display(df_spei_minmean)

## Heatmap Example


In [None]:
# Pick metric to plot
metric_to_plot = "min"

# Select one site, split values into long format based on spei_-month_metric pattern
df_spei_minmean_long = df_spei_minmean.merge(df_spei_trend, how="left", on="idp").query(
    "idp == 500008"
)

df_spei_minmean_long = df_spei_minmean_long.melt(
    id_vars=["idp"], var_name="spei_metric", value_name="spei_value"
)
df_spei_minmean_long["spei_#"] = (
    df_spei_minmean_long["spei_metric"].str.split("spei").str[1].str.split("_").str[0]
)

df_spei_minmean_long["spei_metric"] = (
    df_spei_minmean_long["spei_metric"].str.split("spei").str[1].str.split("_").str[-1]
)

df_spei_minmean_long["spei_month"] = (
    df_spei_minmean_long["spei_#"].str.split("-").str[1].astype("int")
)

df_spei_minmean_long["spei_#"] = (
    df_spei_minmean_long["spei_#"].str.split("-").str[0].astype("int")
)

# Filter metric
df_spei_minmean_long = df_spei_minmean_long.query(f"spei_metric == '{metric_to_plot}'")

display(df_spei_minmean_long)
# Make matrix of spei_# and spei_month with spei_value as values
df_spei_minmean_long = df_spei_minmean_long.pivot_table(
    index=["idp", "spei_#"], columns="spei_month", values="spei_value"
).reset_index()
display(df_spei_minmean_long)
# Plot matrix
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
sns.heatmap(
    df_spei_minmean_long.drop("idp", axis=1).set_index("spei_#"),
    ax=ax,
    cmap="RdBu",
    center=0,
)
# Set title
ax.set_title(
    f"Metric: {metric_to_plot} | Site: {df_spei_minmean_long['idp'].values[0]}"
)
# Flip coordinates
ax.invert_yaxis()

- spei\_#: Duration over which the SPEI metric was aggregated (1 = preceeding 1 month, 10 = preceeding 10 months, etc.)
- spei_month: Month of the year (13 means that all months were considered when taking the minimum value, not just all Januaries as for spei_month = 1)
- color: Minimum SPEI value for given spei\_#_month pair between the two tree visits.


## Visualized SPEI Timeseries


In [None]:
import sys

sys.path.insert(0, "../../src")
from imports import *

init_notebook()

In [None]:
# Set seed for random site
iseed = 1

# Set SPEI timescale
spei_timescale = 6

# Set season
myseason = "JJA"

# Data source
source = "all/merged"

In [None]:
# Load precept
df_precetp = pd.read_csv(
    f"/Volumes/WD - ExFAT/IFNA/digitalis_v3/processed/1km/{source}/data_to_calculate_spei.csv",
)

# Clean data (was scaled by 10x)
df_precetp["prec"] = df_precetp["prec"] / 10
df_precetp["etp"] = df_precetp["etp"] / 10
display(df_precetp.head())

# Load SPEI from top of this notebook
file_1 = "/Volumes/WD - ExFAT/IFNA/digitalis_v3/processed/1km/all/merged/spei/R-spei_split_1.csv"
file_2 = "/Volumes/WD - ExFAT/IFNA/digitalis_v3/processed/1km/all/merged/spei/R-spei_split_2.csv"

file_1 = pd.read_csv(file_1)
file_2 = pd.read_csv(file_2)

df_spei = pd.concat([file_1, file_2])
df = df_spei.copy()
display(df.head())

In [None]:
import matplotlib.dates as mdates

# Plot some of the timeseries for SPEI1, SPEI3, SPEI6, SPEI12, SPEI24
site = pd.Series(df_spei.idp.unique()).sample(1, random_state=11).values[0]
df_x = df_spei[df_spei["idp"] == site]
df_x = df_x.query("date >= '2000-01-01'")
# df_x = df_x.query("month == 1")

fig, ax = plt.subplots(5, 1, figsize=(10, 12.5))
for i, col in enumerate(["spei1", "spei3", "spei6", "spei12", "spei24"]):
    ax[i].axhline(0, color="black", linestyle="-")
    ax[i].axhline(-1.6, color="red", linestyle="dotted")
    ax[i].plot(df_x["date"], df_x[col], label=col)
    ax[i].set_title(col)
    ax[i].legend()
    # Show only every 12th date to avoid overlapping
    ax[i].set_xticks(ax[i].get_xticks()[::12])
    # Rotate the x-axis labels
    ax[i].tick_params(axis="x", rotation=90)
    # Show only year on x-axis

for a in ax.flat:
    a.label_outer()

# Show x axis only at bottom
fig.suptitle(f"SPEI timeseries of site: {site}\n", fontsize=16, fontweight="bold")
plt.tight_layout()
plt.show()

In [None]:
df_spei.head(3)

In [None]:
# Plot some timeseries for SPEI1, SPEI3, SPEI6, SPEI12, SPEI24
# Plot them for the mean across all sites
df_mean_org = (
    df_spei.groupby(["date"])
    .mean()
    .reset_index()
    .drop(columns=["idp", "first_year", "year", "month"])
)

df_sd_org = (
    df_spei.groupby(["date"])
    .std()
    .reset_index()
    .drop(columns=["idp", "first_year", "year", "month"])
)

# Attach season to df_mean_org
df_mean_org["date"] = pd.to_datetime(df_mean_org["date"])
df_mean_org["month"] = (
    df_mean_org["date"].astype(str).str.split("-").str[1].astype("int")
)

# Change month to season string
df_mean_org["month"] = df_mean_org["month"].map(
    {
        1: "DJF",
        2: "DJF",
        3: "MAM",
        4: "MAM",
        5: "MAM",
        6: "JJA",
        7: "JJA",
        8: "JJA",
        9: "SON",
        10: "SON",
        11: "SON",
        12: "DJF",
    }
)

df_mean_org = move_vars_to_front(df_mean_org, ["date", "month"])

In [None]:
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import dates as mdates
import matplotlib.ticker as ticker


# --- Functions ---
# Show labels only every 2nd year
def label_every_second_year(x, pos):
    year = mdates.num2date(x).year
    return str(year) if year % 2 == 0 else ""


# --- User input ---
spei_timescales = [1, 3, 6, 12, 24]
first_year = 2000

# --- Load data ---
df_mean = df_mean_org.query(f"date >= '{first_year}-01-01'").copy()
df_sd = df_sd_org.query(f"date >= '{first_year}-01-01'").copy()

# Ensure datetime format
df_mean["date"] = pd.to_datetime(df_mean["date"])
df_sd["date"] = pd.to_datetime(df_sd["date"])

# Merge mean and SD
df_plot = pd.merge(df_mean, df_sd, on="date", suffixes=("", "_sd"))

# --- Define season colors ---
season_colors = {
    "DJF": "blue",
    "MAM": "green",
    "JJA": "orange",
    "SON": "brown",
}

# --- Plot setup ---
fig, axarr = plt.subplots(len(spei_timescales), 1, figsize=(10, 7), sharex=True)

for i, scale in enumerate(spei_timescales):
    ax = axarr[i]
    col = f"spei{scale}"
    col_sd = f"{col}_sd"

    # --- Add seasonal shading by transitions ---
    current_season = df_plot["month"].iloc[0]
    start_date = df_plot["date"].iloc[0]

    for j in range(1, len(df_plot)):

        # Make 2018 more visible
        if df_plot["date"].iloc[j].year in [2003, 2011, 2015, 2016, 2018]:
            season_alpha = 0.25
        else:
            season_alpha = 0.05
        if df_plot["month"].iloc[j] != current_season:
            end_date = df_plot["date"].iloc[j - 1]
            ax.axvspan(
                start_date
                - pd.Timedelta(days=11),  # Start shading one day after start_date
                end_date
                + pd.Timedelta(days=11),  # Extend by one day for better visibility
                color=season_colors.get(current_season, "gray"),
                alpha=season_alpha,
            )
            current_season = df_plot["month"].iloc[j]
            start_date = df_plot["date"].iloc[j]

    # Final span
    ax.axvspan(
        start_date,
        df_plot["date"].iloc[-1],
        color=season_colors.get(current_season, "gray"),
        alpha=0.25,
    )

    # --- Plot SPEI line and ribbon ---
    ax.axhline(0, color="black", linestyle="-", linewidth=0.8)
    ax.axhline(-1, color="black", linestyle="dotted", linewidth=0.8)
    ax.axhline(1, color="black", linestyle="dotted", linewidth=0.8)

    ax.plot(df_plot["date"], df_plot[col], color="black", label=f"SPEI{scale}")
    ax.fill_between(
        df_plot["date"],
        df_plot[col] - df_plot[col_sd],
        df_plot[col] + df_plot[col_sd],
        color="gray",
        alpha=0.5,
        label="±1 SD",
    )

    # ax.set_title(f"SPEI{scale}", fontweight="bold")
    ax.set_ylabel(f"SPEI {scale}", fontweight="bold")
    # ax.legend(loc="upper right")
    ax.set_xlim(
        df_plot["date"].min(),
        df_plot["date"].max(),
    )

# --- X-axis formatting ---
axarr[-1].xaxis.set_major_locator(mdates.YearLocator(2))
axarr[-1].xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
# for label in axarr[-1].get_xticklabels():
#     label.set_horizontalalignment("left")

# --- X-axis formatting: move ticks/labels to topmost plot ---
top_ax = axarr[0]

# Activate top x-axis
top_ax.xaxis.set_ticks_position("top")
top_ax.xaxis.set_label_position("top")

# Apply tick formatting to top axis
top_ax.xaxis.set_major_locator(mdates.YearLocator(2))
top_ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))

# Left-align tick labels on top axis
# for label in top_ax.get_xticklabels():
#     label.set_horizontalalignment("left")

# Set ticks every year
top_ax.xaxis.set_major_locator(mdates.YearLocator(1))
top_ax.xaxis.set_major_formatter(ticker.FuncFormatter(label_every_second_year))

plt.tight_layout()
plt.savefig(
    "./climate_evolution_figures/S_spei_timeseries.png", dpi=300, bbox_inches="tight"
)
plt.savefig(
    "./climate_evolution_figures/S_spei_timeseries.pdf", dpi=300, bbox_inches="tight"
)
plt.show()

In [None]:
# Same Figure but with linear regression lines per season
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import dates as mdates
import matplotlib.ticker as ticker
from sklearn.linear_model import LinearRegression


# --- Functions ---
def label_every_second_year(x, pos):
    year = mdates.num2date(x).year
    return str(year) if year % 2 == 0 else ""


# --- User input ---
spei_timescales = [1, 3, 6, 12, 24]
first_year = 2000

# --- Load data ---
df_mean = df_mean_org.query(f"date >= '{first_year}-01-01'").copy()
df_sd = df_sd_org.query(f"date >= '{first_year}-01-01'").copy()

# Ensure datetime format
df_mean["date"] = pd.to_datetime(df_mean["date"])
df_sd["date"] = pd.to_datetime(df_sd["date"])

# Merge mean and SD
df_plot = pd.merge(df_mean, df_sd, on="date", suffixes=("", "_sd"))

# --- Define season colors ---
season_colors = {
    "DJF": "blue",
    "MAM": "green",
    "JJA": "orange",
    "SON": "brown",
}

# # --- Plot setup ---
# fig, axarr = plt.subplots(len(spei_timescales), 1, figsize=(10, 7), sharex=True)

# for i, scale in enumerate(spei_timescales):
#     ax = axarr[i]
#     col = f"spei{scale}"
#     col_sd = f"{col}_sd"

#     # --- Add seasonal shading by transitions ---
#     current_season = df_plot["month"].iloc[0]
#     start_date = df_plot["date"].iloc[0]

#     for j in range(1, len(df_plot)):
#         if df_plot["date"].iloc[j].year in [2003, 2011, 2015, 2016, 2018]:
#             season_alpha = 0.25
#         else:
#             season_alpha = 0.05

#         if df_plot["month"].iloc[j] != current_season:
#             end_date = df_plot["date"].iloc[j - 1]
#             ax.axvspan(
#                 start_date - pd.Timedelta(days=11),
#                 end_date + pd.Timedelta(days=11),
#                 color=season_colors.get(current_season, "gray"),
#                 alpha=season_alpha,
#             )
#             current_season = df_plot["month"].iloc[j]
#             start_date = df_plot["date"].iloc[j]

#     # Final shading
#     ax.axvspan(
#         start_date,
#         df_plot["date"].iloc[-1],
#         color=season_colors.get(current_season, "gray"),
#         alpha=0.25,
#     )

#     # --- Plot SPEI line and ribbon ---
#     ax.axhline(0, color="black", linestyle="-", linewidth=0.8)
#     ax.axhline(-1, color="black", linestyle="dotted", linewidth=0.8)
#     ax.axhline(1, color="black", linestyle="dotted", linewidth=0.8)

#     ax.plot(df_plot["date"], df_plot[col], color="black", label=f"SPEI{scale}")
#     ax.fill_between(
#         df_plot["date"],
#         df_plot[col] - df_plot[col_sd],
#         df_plot[col] + df_plot[col_sd],
#         color="gray",
#         alpha=0.5,
#         label="±1 SD",
#     )

#     # --- Add linear regression lines per season ---
#     for season, color in season_colors.items():
#         df_season = df_plot[df_plot["month"] == season]
#         x = mdates.date2num(df_season["date"]).reshape(-1, 1)
#         y = df_season[col].values

#         mask = ~np.isnan(y)
#         if np.sum(mask) < 2:
#             continue  # Skip if not enough valid data

#         x_clean = x[mask]
#         y_clean = y[mask]

#         model = LinearRegression()
#         model.fit(x_clean, y_clean)
#         y_pred = model.predict(x_clean)

#         ax.plot(
#             mdates.num2date(x_clean),
#             y_pred,
#             color=color,
#             linewidth=1.5,
#             linestyle="--",
#             alpha=0.9,
#             label=f"{season} trend" if i == 0 else None,
#         )

#     ax.set_ylabel(f"SPEI {scale}", fontweight="bold")
#     ax.set_xlim(df_plot["date"].min(), df_plot["date"].max())

# # --- X-axis formatting ---
# axarr[-1].xaxis.set_major_locator(mdates.YearLocator(2))
# axarr[-1].xaxis.set_major_formatter(mdates.DateFormatter("%Y"))

# # --- Move ticks/labels to topmost plot ---
# top_ax = axarr[0]
# top_ax.xaxis.set_ticks_position("top")
# top_ax.xaxis.set_label_position("top")
# top_ax.xaxis.set_major_locator(mdates.YearLocator(1))
# top_ax.xaxis.set_major_formatter(ticker.FuncFormatter(label_every_second_year))

# # --- Save and show ---
# plt.tight_layout()
# plt.savefig(
#     "./climate_evolution_figures/S_spei_timeseries.png", dpi=300, bbox_inches="tight"
# )
# plt.savefig(
#     "./climate_evolution_figures/S_spei_timeseries.pdf", dpi=300, bbox_inches="tight"
# )
# plt.show()

## Example for SPEI Features


In [None]:
# Pick example site
unique_sites = df["idp"].unique()
example_site = pd.Series(unique_sites).sample(1, random_state=iseed).values[0]

# ! Filter for example site
ex_precetp = df_precetp.query(f"idp == {example_site}").copy()
ex_spei = df.query(f"idp == {example_site}").copy()

# ! OR Take the mean across all sites
# ex_precetp = df_precetp.groupby(["date"]).mean().reset_index()
# ex_spei = df.groupby(["date", "year", "month"]).mean().reset_index()

# Merge data
ex = pd.merge(ex_precetp, ex_spei, on=["date", "idp"], how="left")
ex["date"] = pd.to_datetime(ex["date"])

# Change month to season string
ex["month"] = ex["month"].map(
    {
        1: "DJF",
        2: "DJF",
        3: "MAM",
        4: "MAM",
        5: "MAM",
        6: "JJA",
        7: "JJA",
        8: "JJA",
        9: "SON",
        10: "SON",
        11: "SON",
        12: "DJF",
    }
)

# Get zoomed in data
ex_zoomed = ex.query("date >= '2013-01-01' & date <= '2020-12-31'")
ex_zoomed_season = ex_zoomed.query(f"month == '{myseason}'")
# ex_zoomed.sort_values("date").head(10)

# Reduce to last month of season
# ex_zoomed_season = ex_zoomed_season.query("month == @myseason")
# ex_zoomed_season = ex_zoomed_season.groupby("year").tail(1)
ex_zoomed_season

In [None]:
# Set season colors
season_colors = {
    "DJF": "skyblue",
    "MAM": "green",
    "JJA": "yellow",
    "SON": "orange",
}

# 1. Precipitation and ETP
# fig, ax = plt.subplots(4, 1, figsize=(10, 10))
fig, ax = plt.subplots(2, 2, figsize=(15, 5))
ax = ax.flatten()

# Primary axis for Precipitation
ax[0].plot(ex["date"], ex["prec"], label="Precipitation", color="blue")
ax[0].set_ylabel("Precipitation (mm)", color="blue")
ax[0].tick_params(axis="y", labelcolor="blue")

# Secondary axis for ETP
ax0_twin = ax[0].twinx()  # Create a second y-axis
ax0_twin.plot(ex["date"], ex["etp"], label="ETP", color="red")
ax0_twin.set_ylabel("ETP (mm)", color="red")
ax0_twin.tick_params(axis="y", labelcolor="red")

# Set the title and legend for the first plot
ax[0].set_title(
    "1.) Full time series of precipitation and ETP", fontweight="bold", loc="left"
)
# ax[0].legend(loc="lower left")
# ax0_twin.legend(loc="lower right")

# 2. SPEI
ax[1].axhline(0, color="black", linestyle="-")
ax[1].plot(
    ex["date"],
    ex[f"spei{spei_timescale}"],
    label=f"SPEI{spei_timescale}",
    color="black",
)
ax[1].set_title(
    f"2.) Get full time series of SPEI{spei_timescale}", fontweight="bold", loc="left"
)
ax[1].set_ylabel(f"SPEI{spei_timescale}")
# ax[1].legend(loc="lower left")

# 3. SPEI zoomed in and background colored by season
ax[2].axhline(0, color="black", linestyle="-")

# Iterate through ex_zoomed DataFrame to apply background color by season
current_season = ex_zoomed["month"].iloc[0]
start_date = ex_zoomed["date"].iloc[0]

for i in range(1, len(ex_zoomed)):
    if ex_zoomed["month"].iloc[i] != current_season:
        # End of the current season
        end_date = ex_zoomed["date"].iloc[i - 1]
        ax[2].axvspan(
            start_date, end_date, color=season_colors[current_season], alpha=0.25
        )
        # Start a new season span
        current_season = ex_zoomed["month"].iloc[i]
        start_date = ex_zoomed["date"].iloc[i]

# Add the last season span after loop
ax[2].axvspan(
    start_date,
    ex_zoomed["date"].iloc[-1],
    color=season_colors[current_season],
    alpha=0.5,
)

# Add line and scatter plots
ax[2].plot(
    ex_zoomed["date"],
    ex_zoomed[f"spei{spei_timescale}"],
    label=f"SPEI{spei_timescale}",
    color="black",
)
ax[2].scatter(
    ex_zoomed["date"],
    ex_zoomed[f"spei{spei_timescale}"],
    edgecolor="black",
    color="white",
)

# 4. SPEI zoomed and all seasons colored
ax[3].axhline(0, color="black", linestyle="-")

# Iterate through ex_zoomed DataFrame to apply background color by season
current_season = ex_zoomed["month"].iloc[0]
start_date = ex_zoomed["date"].iloc[0]

for i in range(1, len(ex_zoomed)):
    if ex_zoomed["month"].iloc[i] != current_season:
        # End of the current season
        end_date = ex_zoomed["date"].iloc[i - 1]
        ax[3].axvspan(
            start_date, end_date, color=season_colors[current_season], alpha=0.25
        )
        # Start a new season span
        current_season = ex_zoomed["month"].iloc[i]
        start_date = ex_zoomed["date"].iloc[i]

# Add the last season span after loop
ax[3].axvspan(
    start_date,
    ex_zoomed["date"].iloc[-1],
    color=season_colors[current_season],
    alpha=0.5,
)

# Add scatter plot
ax[3].scatter(
    ex_zoomed_season["date"],
    ex_zoomed_season[f"spei{spei_timescale}"],
    label=f"SPEI{spei_timescale}",
    color="white",
    edgecolor="black",
)

# Mark min, max and mean values
min_spei = ex_zoomed_season[f"spei{spei_timescale}"].min()
max_spei = ex_zoomed_season[f"spei{spei_timescale}"].max()
mean_spei = ex_zoomed_season[f"spei{spei_timescale}"].mean()

# Add points for min and max
ax[3].scatter(
    ex_zoomed_season.loc[ex_zoomed_season[f"spei{spei_timescale}"] == min_spei, "date"],
    min_spei,
    color="brown",
    edgecolor="black",
    s=100,
    marker="o",
    label=f"Min: {min_spei:.2f}",
)

ax[3].scatter(
    ex_zoomed_season.loc[ex_zoomed_season[f"spei{spei_timescale}"] == max_spei, "date"],
    max_spei,
    color="green",
    edgecolor="black",
    s=100,
    marker="o",
    label=f"Max: {max_spei:.2f}",
)

ax[3].axhline(
    mean_spei,
    color="brown",
    linestyle="--",
    label=f"Mean: {mean_spei:.2f}",
    linewidth=2,
)

# Add Text Labels
ax[2].set_title(
    f"3.) Get SPEI{spei_timescale} time series within 7-year window",
    fontweight="bold",
    loc="left",
)
ax[2].set_ylabel(f"SPEI{spei_timescale}")
# ax[2].legend(loc="lower left")
ax[2].set_xlabel("Date")

ax[3].legend(loc="lower left")
ax[3].set_title(
    f"4.) Extract SPEI{spei_timescale} anomaly metrics for season {myseason}",
    fontweight="bold",
    loc="left",
)
ax[3].set_ylabel(f"SPEI{spei_timescale}")
ax[3].set_xlabel("Date")

plt.tight_layout()
plt.savefig("../../notebooks/02_collect_features/example_spei_features.png")
plt.show()

In [None]:
# Get coordinates of the site
nfi = get_latest_nfi_raw_data()
site_coords = nfi.query(f"idp == {example_site}")[
    ["idp", "lat_fr", "lon_fr"]
].drop_duplicates()
gdf = gpd.GeoDataFrame(
    site_coords,
    geometry=gpd.points_from_xy(site_coords["lon_fr"], site_coords["lat_fr"]),
)
gdf

In [None]:
# Get the shapefile of the country
fr = get_shp_of_region("cty")

# Plot the site on the map
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
fr.boundary.plot(ax=ax, color="black", linewidth=1)
gdf.plot(ax=ax, color="red", markersize=100)
plt.title(f"Site: {example_site}")
plt.show()

## SPEI Trends across all sites


In [None]:
import sys

sys.path.insert(0, "../../src")
from imports import *

init_notebook()

In [None]:
# Set seed for random site
iseed = 1

# Set SPEI timescale
spei_timescale = 3

# Set season
myseason = "JJA"

In [None]:
# Load SPEI and PRECETP data from above

In [None]:
# Pick example site
unique_sites = df["idp"].unique()
example_site = pd.Series(unique_sites).sample(1, random_state=iseed).values[0]

# ! Filter for example site
ex_precetp = df_precetp.query(f"idp == {example_site}").copy()
ex_spei = df.query(f"idp == {example_site}").copy()

# ! OR Take the mean across all sites
# ex_precetp = df_precetp.groupby(["date"]).mean().reset_index()
# ex_spei = df.groupby(["date", "year", "month"]).mean().reset_index()

# Merge data
ex = pd.merge(ex_precetp, ex_spei, on=["date", "idp"], how="left")
ex["date"] = pd.to_datetime(ex["date"])
ex["date_num"] = mdates.date2num(ex["date"])

# Change month to season string
ex["month"] = ex["month"].map(
    {
        1: "DJF",
        2: "DJF",
        3: "MAM",
        4: "MAM",
        5: "MAM",
        6: "JJA",
        7: "JJA",
        8: "JJA",
        9: "SON",
        10: "SON",
        11: "SON",
        12: "DJF",
    }
)

ex

In [None]:
from sklearn.linear_model import LinearRegression

# SPEI and Seasons to plot
spei_to_plot = ["spei3", "spei6", "spei12", "spei24"]
seasons_to_plot = ["DJF", "MAM", "JJA", "SON"]

# Calculate the slope of the SPEI values for each season
df_lms = []
df_values = []
for spei in spei_to_plot:
    for season in seasons_to_plot:
        idf = ex.query(f"month == '{season}'").dropna()
        lm = LinearRegression()
        lm.fit(idf[["date_num"]], idf[spei])
        df_lms.append(
            pd.DataFrame(
                {
                    "spei": spei,
                    "season": season,
                    "slope": lm.coef_[0],
                    "intercept": lm.intercept_,
                },
                index=[0],
            )
        )
        # Calculate the values for the regression line
        df_values.append(
            pd.DataFrame(
                {
                    "date_num": idf["date_num"],
                    "spei": spei,
                    "season": season,
                    "value": lm.predict(idf[["date_num"]]),
                }
            )
        )

df_lms = pd.concat(df_lms)
df_values = pd.concat(df_values)
df_values

In [None]:
df_values.date_num.min(), df_values.date_num.max()

In [None]:
from matplotlib.lines import Line2D
import matplotlib.lines as mlines


# Plot the regression lines, spei is linetype, season is color shade
spei_linetypes = {
    "spei3": "-",
    "spei6": "--",
    "spei12": "-.",
    "spei24": ":",
}

season_colors = {
    "DJF": "darkblue",
    "MAM": "green",
    "JJA": "gold",
    "SON": "darkorange",
}

# Start plot
# fig, ax = plt.subplots(1, 1, figsize=(12, 3))
fig, ax = plt.subplots(1, 1, figsize=(3.3, 3))

# Add zero line
ax.axhline(0, color="black", linestyle="-")

# Add lines
for spei in spei_to_plot:
    for season in seasons_to_plot:
        idf = df_values.query(f"spei == '{spei}' & season == '{season}'")
        ax.plot(
            idf["date_num"],
            idf["value"],
            linestyle=spei_linetypes[spei],
            color=season_colors[season],
        )

# Set the title and legend
# ax.set_title("Trends for different SPEI timescales and seasons", fontweight="bold")
ax.set_ylabel("SPEI Value")
ax.set_xlabel("Date")

# Use equal ylim
ax.set_ylim(-0.4, 0.4)

# Set x-axis to date format
ax.set_xlim(-3500.0, 19000)
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax.xaxis.set_major_locator(mdates.YearLocator())

# Show every 10th date to avoid overlapping
ax.set_xticks(ax.get_xticks()[::20])

# Create legend handles for spei linetypes
spei_handles = [
    mlines.Line2D([], [], color="black", linestyle=linetype, label=spei)
    for spei, linetype in spei_linetypes.items()
]

# Create legend handles for season colors
season_handles = [
    mlines.Line2D([], [], color=color, linestyle="-", label=season)
    for season, color in season_colors.items()
]

# Combine both sets of handles
fig.legend(
    handles=season_handles + spei_handles,  # Combine both sets of handles
    loc="upper right",  # Position at the center top
    bbox_to_anchor=(1.65, 0.95),  # Move the legend below the plot
    ncol=2,  # Arrange all items in one row
    title="Season and SPEI Timescale",
    frameon=False,
)

# Show the plot (optional, as this is just for the legend)
plt.tight_layout()
plt.show()