# Analyze Events (Prototype)

This example shows how to evaluate Salient's forecasts with an event-and-decision framework. It demonstrates [validation best practices](https://salientpredictions.notion.site/Validation-0220c48b9460429fa86f577914ea5248).

Status: under development. Not yet ready for external use.


In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from scipy.stats import spearmanr
from sklearn.metrics import average_precision_score, roc_auc_score

try:
    import salientsdk as sk
except ModuleNotFoundError as e:
    if os.path.exists("../salientsdk"):
        sys.path.append(os.path.abspath(".."))
        import salientsdk as sk
    else:
        raise ModuleNotFoundError("Install salient SDK with: pip install salientsdk")

# Need Salient SDK v0.3.22 or later to use GEM:
assert "gem" in sk.forecast_timeseries_api.MODELS

# Prevent wrapping on tables for readability
pd.set_option("display.width", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", 10)
pd.set_option("display.expand_frame_repr", False)

sk.set_file_destination("analyze_events")
sk.login("username", "password")

<requests.sessions.Session at 0x7f226f44e590>

## Customize The Validation

This notebook is written flexibly so you have the option of validating Salient and other forecasts multiple ways. These variables will control what, when, and how the validation proceeds.


In [None]:
# 1. The meteorological variable that we'll be evaluating:
var = "tmax"
# var = "tmin"

# 2. Number of standard deviations to be considered "extreme"
ext_std = -2 if var == "tmin" else 2

# 3. Number of historical forecast dates to download.
if "beta" in sk.constants.URL:  # NOSHIP
    date_freq = 30  # Get about 1 forecast per month.  Fast and indicative.
else:
    date_freq = 1  # Get every available historical forecast.  Comprehensive.

# 4. Strategy for optimizing extreme thresholds.
# groupby = ["lead", "location"]
groupby = ["lead"]

# 5. Cost-loss framework payoff matrix
beta = 1
# beta = 2 # F2 weights recall higher than precision

# fmt: off
payoff = {
    # Payoff coefficients for a 2x2 confusion matrix:
    "np":  -10,  "pp":  100, # Acting on a FP costs a little, Correctly calling a TP is worth a lot
    "nn":    0,  "pn":    0, # take no action, gain/lose nothing
    # Joint 4x2 confusion matrix strategy
    "npp": -10, "ppp":  100, # both agree. Same payout as 2x2 case.
    "npn": -20, "ppn":  200, # GEM extreme, GEFS normal.  Double position size.
    "nnp":  40, "pnp": -400, # GEFS extreme, GEM normal.  Short the market.  Lose big if an extreme happens.
    "nnn":   0, "pnn":   0,  # neither extreme, take no action.
    # For setting axis labels
    "units": "$M",
}
# fmt: off


# The quantity to optimize
objective = "payoff"
# objective = "f_score"

# Number of days to analyze
lead_days = 14

#### Constants and Settings

Not recommended to change these.

In [None]:
# Print diagnostic information
verbose = False

# Caching strategy:
force = False  # Cache data to save on repeat API calls
# force = True  # Repeat API calls, even if data exists

# The name of the primary forecast model to test
gem_model = "gem"

# The reference model to compare gem_model to
ref_model = "noaa_gefs"

# Set the forecast date range to test over.
(start_date, end_date) = ("2020-10-01", "2024-12-31")
date_range = pd.date_range(start=start_date, end=end_date, freq="D")
date_range = date_range[::date_freq]

if var == "tmax":
    summer_months = [5, 6, 7, 8]
    date_range = date_range[date_range.month.isin(summer_months)]
elif var == "tmin":
    winter_months = [11, 12, 1, 2]
    date_range = date_range[date_range.month.isin(winter_months)]

date_range = date_range.strftime("%Y-%m-%d").tolist()

# Temporal resolution.  "hourly" not yet supported
freq = "daily"

# We're going to conduct this investigation in absolute (not anomaly) space
field = "vals"
field_ens = f"{field}_ens"

###  Set the Area of Interest

The Salient SDK uses a "Location" object to specify the geographic bounds of a request. In this case, we will be testing 11 cities representing ERCOT's 8 [weather zones](https://www.ercot.com/gridmktinfo/dashboards/weatherforecast).  In the following code, only `lats`, `lons`, and `names` are required.  The other inputs are for reference purpose only.




In [None]:
# fmt: off
loc = sk.Location(location_file=sk.upload_location_file(
    lats    =[  31.9686,   33.5779,  33.9137,  32.4487,    32.8972,  32.3513,    30.2672,    29.4241,  29.7604,  27.8006,  26.2034],
    lons    =[-102.0779, -101.8552, -98.4934, -99.7331,   -97.0403, -95.3011,   -97.7431,   -98.4936, -95.3698, -97.3964, -98.2300],
    names   =[    "MAF",     "LBB",    "SPS",    "ABI",      "DFW",    "TYR",      "AUS",      "SAT",    "IAH",    "CRP",    "MFE"],
    region  =["FarWest",   "North",  "North",   "West", "NCentral",   "East", "SCentral", "SCentral",  "Coast",  "South",  "South"],
    weight  =[  117.631,   264.662,  104.683,  124.407,   7637.387,  108.505,   2227.083,   1598.964, 7122.240,  326.586,  871.377],
    city    =["Midland", "Lubbock", "Wichita Falls", "Abilene", "Dallas/Fort Worth", "Tyler", "Austin", "San Antonio", "Houston", "Corpus Christi", "McAllen"],
    geoname ="ercot_weather_zones",
    force   =force
))
# fmt: on
print(loc.load_location_file())
# loc.plot_locations()

        lat       lon name    region    weight               city                   geometry
0   31.9686 -102.0779  MAF   FarWest   117.631            Midland  POINT (-102.0779 31.9686)
1   33.5779 -101.8552  LBB     North   264.662            Lubbock  POINT (-101.8552 33.5779)
2   33.9137  -98.4934  SPS     North   104.683      Wichita Falls   POINT (-98.4934 33.9137)
3   32.4487  -99.7331  ABI      West   124.407            Abilene   POINT (-99.7331 32.4487)
4   32.8972  -97.0403  DFW  NCentral  7637.387  Dallas/Fort Worth   POINT (-97.0403 32.8972)
..      ...       ...  ...       ...       ...                ...                        ...
6   30.2672  -97.7431  AUS  SCentral  2227.083             Austin   POINT (-97.7431 30.2672)
7   29.4241  -98.4936  SAT  SCentral  1598.964        San Antonio   POINT (-98.4936 29.4241)
8   29.7604  -95.3698  IAH     Coast  7122.240            Houston   POINT (-95.3698 29.7604)
9   27.8006  -97.3964  CRP     South   326.586     Corpus Christi   PO

## Get Forecasts & Observed

Use the Salient API to get the Salient GEM forecast, then compare it to historical observed as well as a reference model.


In [None]:
# Historical observed data - what really happened
obs_start = np.datetime64(start_date) - np.timedelta64(5, "D")
obs_end = np.datetime64(end_date) + np.timedelta64(lead_days + 1, "D")
obs_src = sk.data_timeseries(
    loc=loc,
    variable=var,
    field=field,
    start=obs_start,
    end=obs_end,
    frequency=freq,
    verbose=verbose,
    force=force,
)

# Climatology: historical mean & standard deviation on this day
clim = sk.data_timeseries_api.extrapolate_trend(
    loc=loc, variable=var, start=obs_start, end=obs_end, verbose=verbose, force=force
)
stdv = sk.data_timeseries_api.extrapolate_trend(
    loc=loc, variable=var, start=obs_start, end=obs_end, verbose=verbose, force=force, stdv_mult=1
)
stdv = stdv - clim

obs_ts = xr.merge(
    [xr.load_dataset(obs_src), clim.rename({var: "clim"}), stdv.rename({var: "stdv"})]
)

print(obs_ts)
print((((obs_ts.vals - obs_ts.clim) / obs_ts.stdv) > ext_std).sum())

<xarray.Dataset> Size: 428kB
Dimensions:   (time: 1573, location: 11)
Coordinates:
  * time      (time) datetime64[ns] 13kB 2020-09-26 2020-09-27 ... 2025-01-15
    lat       (location) float64 88B 31.97 33.58 33.91 32.45 ... 29.76 27.8 26.2
    lon       (location) float64 88B -102.1 -101.9 -98.49 ... -97.4 -98.23
  * location  (location) <U3 132B 'MAF' 'LBB' 'SPS' 'ABI' ... 'IAH' 'CRP' 'MFE'
Data variables:
    vals      (time, location) float64 138kB 36.24 36.18 32.55 ... 12.5 13.49
    clim      (time, location) float64 138kB 29.87 27.83 30.73 ... 18.95 22.51
    stdv      (time, location) float64 138kB 3.775 3.927 4.094 ... 4.292 5.463
Attributes:
    long_name:   2 metre temperature
    units:       degC
    clim_start:  1990-01-01
    clim_end:    2019-12-31
<xarray.DataArray ()> Size: 8B
array(256)


In [None]:
# Get the Salient GEM hindcasts
gem_src = sk.forecast_timeseries(
    loc=loc,
    variable=var,
    date=date_range,
    model=gem_model,
    field=field_ens,
    timescale=freq,
    strict=False,
    force=force,
    verbose=verbose,
)

In [None]:
# Reference model: the forecast we are comparing Salient GEM to
ref_src = sk.forecast_timeseries(
    loc=loc,
    variable=var,
    date=date_range,
    field=field_ens,
    model=ref_model,
    timescale=freq,
    strict=False,
    force=force,
    verbose=verbose,
)

### Package data into identical forecast_date & lead coordintes



In [None]:
ref = sk.forecast_timeseries_api.stack_forecast(ref_src)
gem = sk.forecast_timeseries_api.stack_forecast(gem_src)

# Forecasts have different time horizons.  Align them.
ref = ref.isel(lead=slice(0, lead_days))
gem = gem.isel(lead=slice(0, lead_days))

# Use common dates where both ref and gem have no NaNs
valid_dims = ["lead", "location", "ensemble"]
valid_dates = ~gem[field_ens].isnull().any(dim=valid_dims) & ~ref[field_ens].isnull().any(
    dim=valid_dims
)
gem = gem.sel(forecast_date=valid_dates)
ref = ref.sel(forecast_date=valid_dates)

# Let's make sure the forecasts are equivalent
xr.testing.assert_equal(gem.time, ref.time)
xr.testing.assert_equal(gem.lead, ref.lead)
xr.testing.assert_equal(gem.location, ref.location)

print(f"{gem_model} {gem.data_vars}")
print(f"{ref_model} {ref.data_vars}")

In [None]:
# Reshape historical obs match forecast dimensions
obs = sk.data_timeseries_api.stack_history(obs_ts, gem.forecast_date, gem.lead)

# Add population as a weighting factor
# obs = sk.merge_location_data(obs, loc, False)

xr.testing.assert_equal(gem.time, obs.time)
xr.testing.assert_equal(gem.lead, obs.lead)
xr.testing.assert_equal(gem.location, obs.location)

print("ERA5: " + str(obs.data_vars))

In [None]:
# Set some attributes for later plotting purposes

# Preserve model description for later plotting purposes
obs[field].attrs["model_name"] = "ERA5"
gem[field_ens].attrs["model_name"] = gem_model.split("_")[-1].upper()
ref[field_ens].attrs["model_name"] = ref_model.split("_")[-1].upper()

# Assign each model a distinct color
obs[field].attrs["color"] = "black"
gem[field_ens].attrs["color"] = "dodgerblue"
ref[field_ens].attrs["color"] = "#FF8C1E"

# plotting can get confused by timedelta.  Have a convenient numeric version:
lead_days = [td.astype("timedelta64[D]").astype(int) for td in gem.lead.values]
obs = obs.assign_coords(lead_days=("lead", lead_days))
gem = gem.assign_coords(lead_days=("lead", lead_days))
ref = ref.assign_coords(lead_days=("lead", lead_days))
obs.lead_days.attrs.update({"long_name": "Lead time", "units": "days"})
gem.lead_days.attrs.update({"long_name": "Lead time", "units": "days"})
ref.lead_days.attrs.update({"long_name": "Lead time", "units": "days"})

## Define Extremes

Define `obs.extreme` as true if the observed variable exceeded the extreme threshold.
Define `ref` and `gem.extreme_pct` extreme as a sigmoid function of the normalized extreme as it relates to the threshold.

This sets us up to create a classifier that will determine the number of extreme ensembles required for us to say that an extreme is likely.




In [None]:
ENS = "anom_ens_stdv"
with xr.set_options(keep_attrs=True):
    # Normalize absolute values into standard deviation anom space
    obs["anom_stdv"] = ((obs[field] - obs["clim"]) / obs["stdv"]).assign_attrs(units="\u03c3")
    gem[ENS] = ((gem[field_ens] - obs["clim"]) / obs["stdv"]).assign_attrs(units="\u03c3")
    ref[ENS] = ((ref[field_ens] - obs["clim"]) / obs["stdv"]).assign_attrs(units="\u03c3")

obs["extreme"] = sk.event.classify_event(obs["anom_stdv"], ext_std, width=0)
# ref["extreme_pct"] = sk.event.classify_event(ref[ENS], ext_std, dim="ensemble")
# gem["extreme_pct"] = sk.event.classify_event(gem[ENS], ext_std, dim="ensemble")
gem["extreme_pct"] = sk.event.calibrate_event(obs["extreme"], gem[ENS], groupby)
ref["extreme_pct"] = sk.event.calibrate_event(obs["extreme"], ref[ENS], groupby)

print(f"{obs['extreme'].attrs['model_name']} {obs[['extreme']].data_vars}")
print(f"{ref['extreme_pct'].attrs['model_name']} {ref[['extreme_pct']].data_vars}")
print(f"{gem['extreme_pct'].attrs['model_name']} {gem[['extreme_pct']].data_vars}")

## Optimize Extremes Detection Threshold



In [None]:
gem["extreme_threshold"] = sk.event.optimize_threshold(
    observed=obs.extreme,
    forecast=gem.extreme_pct,
    payoff=payoff,
    beta=beta,
    groupby=groupby,
)

ref["extreme_threshold"] = sk.event.optimize_threshold(
    observed=obs.extreme,
    forecast=ref.extreme_pct,
    payoff=payoff,
    beta=beta,
    groupby=groupby,
)

with xr.set_options(keep_attrs=True):
    gem["extreme"] = gem["extreme_pct"] >= gem["extreme_threshold"]
    ref["extreme"] = ref["extreme_pct"] >= ref["extreme_threshold"]

print_vars = ["extreme_threshold", "extreme"]
print(f"{gem.extreme.attrs['model_name']} {gem[print_vars].data_vars}")
print(f"{ref.extreme.attrs['model_name']} {ref[print_vars].data_vars}")

## View Confusion Matrices

A confusion matrix compares predictions against reality in a simple grid format. For any prediction system, it tallies four possible outcomes: correct predictions of both events and non-events (true positives and true negatives), false alarms (false positives), and missed events (false negatives). The name comes from its ability to reveal when the system gets "confused" - either crying wolf or missing actual events.


In [None]:
# groupby location & lead
display(sk.event.style_confusion_matrix(obs.extreme, gem.extreme, payoff=payoff, beta=beta))
display(sk.event.style_confusion_matrix(obs.extreme, ref.extreme, payoff=payoff, beta=beta))

In [None]:
payoffs = [
    sk.event.calc_f_score(obs.extreme, ref.extreme, groupby="lead", payoff=payoff)["payoff"],
    sk.event.calc_f_score(obs.extreme, gem.extreme, groupby="lead", payoff=payoff)["payoff"],
]

pcolors = [payoff.attrs["color"] for payoff in payoffs]
payoffs = {payoff.attrs["model_name"]: payoff.mean().item() for payoff in payoffs}
plt.figure(figsize=(8, 4))
bars = plt.bar(payoffs.keys(), payoffs.values(), color=pcolors)
plt.axhline(y=payoffs["GEFS"], color="lightgrey", linestyle="--", alpha=0.5)

lifts = {"GEM": (payoffs["GEM"] / payoffs["GEFS"] - 1) * 100}
for i, key in enumerate(["GEM"]):
    lift = lifts[key]
    bar = bars[i + 1]
    plt.text(
        bar.get_x() + bar.get_width() / 2.0,
        (bar.get_height() + payoffs["GEFS"]) / 2,
        f"+{lift:.0f}%",
        ha="center",
        va="center",
        color="white",
        fontweight="bold",
    )

plt.ylabel("Expected Payoff [$M]");

### Code - analysis and visualization

This section contains code that will be used later in the notebook for visualization and analysis.

In [None]:
def compare_crps(
    observations: xr.DataArray, forecast: xr.DataArray, reference: xr.DataArray, groupby: str
):
    """Calculate CRPS for forecast and reference models, then visualize comparison grouped by a dimension.

    Args:
        observations: DataArray containing observed values
        forecast: DataArray containing forecast ensemble values
        reference: DataArray containing reference ensemble values
        groupby: Dimension to group by (string only)

    Returns:
        Decisionmaking metrics from forecast and reference.

    """
    fcst_crps = sk.skill._crps_ensemble_core(observations=observations, forecasts=forecast)
    ref_crps = sk.skill._crps_ensemble_core(observations=observations, forecasts=reference)

    # Calculate mean CRPS across all dimensions except the groupby dimension
    dims_to_mean = [dim for dim in fcst_crps.dims if dim != groupby]
    fcst_crps_mean = fcst_crps.mean(dim=dims_to_mean)
    ref_crps_mean = ref_crps.mean(dim=dims_to_mean)

    # Add model names
    fcst_model_name = forecast.attrs.get("model_name", "Forecast")
    ref_model_name = reference.attrs.get("model_name", "Reference")

    fcst_color = forecast.attrs.get("color", "dodgerblue")
    ref_color = reference.attrs.get("color", "#FF8C1E")

    # Determine if the groupby dimension is numerical or categorical
    coord_values = forecast[groupby].values

    # Check if values are numerical (dates, times, or numbers)
    is_numerical = (
        np.issubdtype(coord_values.dtype, np.number)
        or np.issubdtype(coord_values.dtype, np.datetime64)
        or isinstance(coord_values[0], (pd.Timestamp, np.timedelta64))
    )

    # Create figure
    fig, ax = plt.subplots(figsize=(9, 4))

    if is_numerical:
        # For timedelta, convert to days for better readability
        if np.issubdtype(coord_values.dtype, np.timedelta64):
            x_values = np.array([td.astype("timedelta64[D]").astype(int) for td in coord_values])
            x_label = f"{groupby} (days)"
            marker = "."
        else:
            x_values = coord_values
            x_label = groupby
            marker = "None"

        ax.plot(
            x_values,
            fcst_crps_mean.values,
            linestyle="-",
            marker=marker,
            color=fcst_color,
            linewidth=2,
            label=fcst_model_name,
        )
        ax.plot(
            x_values,
            ref_crps_mean.values,
            linestyle="-",
            marker=marker,
            color=ref_color,
            linewidth=2,
            label=ref_model_name,
        )

        # Format x-axis for dates
        if np.issubdtype(coord_values.dtype, np.datetime64):
            fig.autofmt_xdate()

    else:
        # Convert to pandas DataFrame for easier manipulation
        df = pd.DataFrame(
            {
                fcst_model_name: fcst_crps_mean.values,
                ref_model_name: ref_crps_mean.values,
                "difference": ref_crps_mean.values - fcst_crps_mean.values,
                groupby: fcst_crps_mean[groupby].values,
            }
        )

        # Sort by difference between reference and forecast (note: for CRPS, lower is better)
        df = df.sort_values("difference", ascending=False)

        x = np.arange(len(df))
        width = 0.35

        ax.bar(x - width / 2, df[fcst_model_name], width, label=fcst_model_name, color=fcst_color)
        ax.bar(x + width / 2, df[ref_model_name], width, label=ref_model_name, color=ref_color)

        ax.set_xticks(x)
        ax.set_xticklabels(df[groupby], rotation=90 if len(df) > 10 else 0)

    ax.set_ylabel("CRPS")
    ax.set_xlabel(x_label if is_numerical else groupby)
    ax.grid(True, alpha=0.3)
    ax.legend(loc="best")

    plt.tight_layout()

    results = xr.Dataset({"forecast": fcst_crps_mean, "reference": ref_crps_mean})

    # Add model names as attributes
    results.forecast.attrs["model_name"] = fcst_model_name
    results.reference.attrs["model_name"] = ref_model_name

    return results


if False:
    crps_date = compare_crps(obs[field], gem[field_ens], ref[field_ens], "lead")
    # crps_date = compare_crps(obs[field], gem[field_ens], ref[field_ens], "forecast_date")

In [None]:
def plot_extremes_timeseries(
    observed: xr.Dataset, forecast: xr.Dataset | None = None, reference: xr.Dataset | None = None
):
    """Plot a time series of observed data with optional forecast and reference data.

    Args:
        observed: xr.Dataset containing observed data
        forecast: Optional xr.Dataset containing forecast data
        reference: Optional xr.Dataset containing reference data

    Returns:
        fig, ax: The figure and axis objects
    """
    if forecast is None and reference is None:
        raise ValueError("At least one of forecast or reference must be provided")

    lead_days = obs.lead_days.values
    target = obs.extreme.attrs.get("target", 1)
    sign = 1 if target > 0 else -1

    fig, ax = plt.subplots(figsize=(9, 4))

    ax.plot(lead_days, observed.anom_stdv.values, color="black", linewidth=3, label="Observed")
    ax.axhline(y=target, color="grey", linestyle="--", label=f"Extreme Threshold ({target}\u03c3)")
    ax.set_xlabel("Lead Time [days]")
    ax.set_ylabel(f"Anomaly [\u03c3]")

    def plot_forecast_data(data, color, label):
        """Plot fcst/ref data."""
        if data is None:
            return None
        label = data.anom_ens_stdv.attrs.get("model_name", label)
        color = data.anom_ens_stdv.attrs.get("color", color)

        # Calculate 10-90th percentile band using xarray's native functionality
        lower_bound = data.anom_ens_stdv.quantile(0.1, dim="ensemble").values
        upper_bound = data.anom_ens_stdv.quantile(0.9, dim="ensemble").values

        # Plot the 10-90th percentile band
        ax.fill_between(
            lead_days, lower_bound, upper_bound, color=color, alpha=0.3, label=f"{label} p10-p90"
        )

        extreme_condition = sign * data.anom_ens_stdv >= sign * target

        # Create DataArrays with lead day values repeated for each ensemble member
        lead_days_array = xr.DataArray(lead_days, dims=["lead"], coords={"lead": data.lead})
        lead_days_ens = lead_days_array.expand_dims(ensemble=data.ensemble.size)

        # Extract coordinates and values for true positives & false alarms
        true_positive_mask = (data.extreme == True) & (observed.extreme == True)
        false_alarm_mask = (data.extreme == True) & (observed.extreme == False)
        tp_x = lead_days_ens.where(
            extreme_condition & true_positive_mask.expand_dims(ensemble=data.ensemble.size)
        )
        tp_y = data.anom_ens_stdv.where(
            extreme_condition & true_positive_mask.expand_dims(ensemble=data.ensemble.size)
        )
        fa_x = lead_days_ens.where(
            extreme_condition & false_alarm_mask.expand_dims(ensemble=data.ensemble.size)
        )
        fa_y = data.anom_ens_stdv.where(
            extreme_condition & false_alarm_mask.expand_dims(ensemble=data.ensemble.size)
        )

        # Plot true positives
        if not tp_x.isnull().all():
            # Extract valid points
            valid_x = tp_x.values.flatten()[~np.isnan(tp_x.values.flatten())]
            valid_y = tp_y.values.flatten()[~np.isnan(tp_y.values.flatten())]

            # Add small random jitter in x-direction (±0.2)
            jittered_x = valid_x + np.random.uniform(-0.2, 0.2, size=len(valid_x))

            ax.scatter(
                jittered_x,
                valid_y,
                s=36,
                marker="o",
                color=color,
                alpha=0.7,
                label=f"{label} True Positive",
            )

        # Plot false alarms
        if not fa_x.isnull().all():
            # Extract valid points
            valid_x = fa_x.values.flatten()[~np.isnan(fa_x.values.flatten())]
            valid_y = fa_y.values.flatten()[~np.isnan(fa_y.values.flatten())]

            # Add small random jitter in x-direction (±0.2)
            jittered_x = valid_x + np.random.uniform(-0.2, 0.2, size=len(valid_x))

            ax.scatter(
                jittered_x,
                valid_y,
                s=36,
                marker="o",
                facecolors="white",
                edgecolors=color,
                linewidths=1.5,
                alpha=0.7,
                label=f"{label} False Alarm",
            )

    plot_forecast_data(forecast, "dodgerblue", "Forecast")
    plot_forecast_data(reference, "#FF8C1E", "Reference")

    location = observed.location.values
    forecast_date = pd.to_datetime(observed.forecast_date.values).strftime("%Y-%m-%d")
    ax.set_title(f"{location}: {forecast_date}")

    ax.legend(prop={"size": 8})

    return fig, ax


if False:
    dex = "2023-06-18"
    lex = "MAF"
    plot_extremes_timeseries(
        obs.sel(location=lex, forecast_date=dex),
        gem.sel(location=lex, forecast_date=dex),
        ref.sel(location=lex, forecast_date=dex),
    )


if False:
    cex = find_confusion_examples(
        obs.extreme, gem.extreme, ref.extreme, ["location", "forecast_date"]
    )["ppp"]
    dex = cex["forecast_date"].values[0]
    lex = cex["location"].values[0]
    plot_extremes_timeseries(
        obs.sel(location=lex, forecast_date=dex),
        gem.sel(location=lex, forecast_date=dex),
        ref.sel(location=lex, forecast_date=dex),
    )

In [None]:
def find_confusion_examples(
    observations: xr.DataArray,
    forecast: xr.DataArray,
    reference: xr.DataArray | None = None,
    groupby: str | list[str] | None = None,
) -> dict[str, pd.DataFrame]:
    """Find most compelling examples for each confusion matrix category.

    Args:
        observations: Binary array of observed events
        forecast: Binary array of forecast events
        reference: Optional binary array of reference forecast events.
            If provided, finds examples for joint (4x2) confusion matrix.
            If omitted, finds examples over the simple (2x2) confusion matrix.
        groupby: dimensions to include in the return dictionary coordinates

    Returns:
        Dictionary mapping category names to DataFrames containing top 10 examples
        with their coordinates and scores
    """
    cm = sk.event.build_confusion_matrix(observations, forecast, reference, groupby=groupby)

    if reference is None:
        # 2x2 confusion matrix scoring
        scores = {
            "pp": cm.pp - 0.5 * cm.np - 0.5 * cm.pn,
            "nn": cm.nn - 0.5 * cm.pp - 0.5 * cm.pn,
            "np": cm.np - 0.5 * cm.pn - 0.2 * cm.pp,
            "pn": cm.pn - 0.5 * cm.np - 0.2 * cm.pp,
        }
    else:
        fp = cm.npp + cm.npn + cm.nnp  # all 3 false positive types, cry wolf
        fn = cm.pnn + cm.pnp + cm.ppn  # all 3 false negative types, failure to detect

        # 4x2 joint confusion matrix scoring
        scores = {
            # both right
            "ppp": cm.ppp + 0.1 * (cm.pnp + cm.ppn) - 0.5 * (cm.pnn + fp),
            "nnn": cm.nnn + 0.1 * (cm.npn + cm.nnp) - 0.5 * (cm.npp + fn),
            # both wrong
            "npp": cm.npp + 0.1 * (cm.npn + cm.nnp) - 0.5 * fn,
            "pnn": cm.pnn + 0.1 * (cm.pnp + cm.ppn) - 0.5 * fp,
            # Mixed, false negative
            "ppn": cm.ppn + 0.2 * cm.nnp - 0.5 * (cm.npn + cm.ppp + cm.npp + cm.pnn) - 2 * cm.pnp,
            "pnp": cm.pnp + 0.2 * cm.npn - 0.5 * (cm.nnp + cm.ppp + cm.npp + cm.pnn) - 2 * cm.ppn,
            # Mixed, false positive
            "npn": cm.npn + 0.2 * cm.pnp - 0.5 * (cm.ppn + cm.ppp + cm.pnn + cm.npp) - 2 * cm.nnp,
            "nnp": cm.nnp + 0.2 * cm.ppn - 0.5 * (cm.pnp + cm.ppp + cm.pnn + cm.npp) - 2 * cm.npn,
        }

    # Convert each score to a DataFrame with coordinates and get top 10
    results = {}
    for category, score in scores.items():
        dims = list(score.dims)
        df = (
            score.to_dataframe(name="score").reset_index()[dims + ["score"]].reset_index(drop=True)
        )
        df = df[df["score"] > 0].sort_values("score", ascending=False).head(10)
        if not df.empty:
            df.attrs["long_name"] = cm[category].attrs.get("long_name", category)
            results[category] = df

    return results


if False:
    cex_2x2 = find_confusion_examples(
        obs.extreme, gem.extreme, groupby=["location", "forecast_date"]
    )
    print(cex_2x2)

    cex_4x2 = find_confusion_examples(
        obs.extreme, gem.extreme, ref.extreme, groupby=["location", "forecast_date"]
    )
    print(cex_4x2)

In [None]:
def viz_confusion_examples(
    observations: xr.Dataset,
    forecast: xr.Dataset,
    reference: xr.Dataset | None = None,
    groupby: str | list[str] | None = None,
    type: str | None = None,
    count: int | None = 1,
):
    """Visualize most compelling examples for each confusion matrix category.

    Args:
        observations: Binary array of observed events
        forecast: Binary array of forecast events
        reference: Optional binary array of reference forecast events.
                  If provided, shows examples for 4x2 confusion matrix.
        groupby: Optional dimension(s) to group by
        type: Optional confusion matrix category to show (e.g. 'np', 'ppp').
              If None, show all categories.
        count: Number of examples to show per category. If None, show all examples.
    """
    examples = find_confusion_examples(
        observations.extreme,
        forecast.extreme,
        reference=None if reference is None else reference.extreme,
        groupby=groupby,
    )

    # Filter to specific type if requested
    if type is not None:
        if type not in examples:
            print(f"No examples found for type {type}")
            return
        examples = {type: examples[type]}

    for category, df in examples.items():
        rows = df.head(count) if count is not None else df
        for _, row in rows.iterrows():
            coords = {dim: row[dim] for dim in row.index if dim != "score"}
            fig, ax = plot_extremes_timeseries(
                observations.sel(**coords),
                forecast.sel(**coords),
                reference.sel(**coords) if reference is not None else None,
            )
            old_title = ax.get_title()
            long_name = df.attrs.get("long_name", category)
            ax.set_title(f"{long_name}\n{old_title}")


if False:
    viz_confusion_examples(
        obsd, gemd, reference=refd, groupby=["forecast_date", "location"], type="nnp", count=5
    )

In [None]:
def viz_search_threshold(
    observed: xr.DataArray,
    forecast: xr.DataArray,
    reference: xr.DataArray,
    objective: str = "payoff",
    beta: float = 1.0,
    payoff: dict = sk.event.PAYOFF,
) -> tuple[plt.Figure, list]:
    """Visualize decisionmaking metrics as a function of threshold.

    Args:
        observed: Boolean DataArray indicating observed extreme events
        forecast: Continuous DataArray with forecast values
        reference: Continuous DataArray with reference values
        objective: Metric to optimize ("f_score", "payoff", "precision", "recall")
        beta: Weight of recall in F-score calculation (default: 1.0)
        payoff: Payoff matrix for cost-loss calculation

    Returns:
        matplotlib Figure and list of axes
    """
    # Get optimization results
    fcst = sk.event.search_threshold(
        observed, forecast, beta=beta, payoff=payoff, objective=objective
    )
    ref = sk.event.search_threshold(
        observed, reference, beta=beta, payoff=payoff, objective=objective
    )

    # Determine which variables to plot (those that have thresholds dimension)
    plot_vars = [
        var
        for var in fcst.data_vars
        if "thresholds" in fcst[var].dims and var != "threshold" and var != "index"
    ]

    # Move objective to the front of the list if it exists
    if objective in plot_vars:
        plot_vars.remove(objective)
        plot_vars.insert(0, objective)

    # Create figure with subplots (one per variable)
    fig, axes = plt.subplots(len(plot_vars), 1, figsize=(8, 2 * len(plot_vars)), sharex=True)
    if len(plot_vars) == 1:
        axes = [axes]  # Ensure axes is always a list

    # Create dictionary mapping variable names to axes
    ax_dict = dict(zip(plot_vars, axes))

    # Helper function to process and plot model results
    def process_and_plot_model(results, ax_dict):
        name = results.attrs.get("model_name", "Forecast")
        color = results.attrs.get("color", "grey")

        # Get the optimal threshold value and index
        optimal_idx = results.index.item()
        optimal_threshold = results.threshold.item()

        # Get thresholds from coordinates
        thresholds = results.thresholds.values

        # Plot each variable
        for var_name, ax in ax_dict.items():
            values = results[var_name].values
            ax.plot(
                thresholds,
                values,
                "-",
                color=color,
                linewidth=2,
                label=f"{name}"
                + (f" (Best = {optimal_threshold:.4f})" if var_name == objective else ""),
            )

            # Add a star at the optimal point determined by the objective
            # (same position for all metrics)
            opt_val = values[optimal_idx]
            ax.plot(optimal_threshold, opt_val, "*", color=color, markersize=10)

            # Vertical line at optimal threshold
            ax.axvline(x=optimal_threshold, color=color, linestyle="--", alpha=0.4)

    # Plot the forecast and reference models
    process_and_plot_model(fcst, ax_dict)
    process_and_plot_model(ref, ax_dict)

    # Set labels and formatting for each axis
    for var_name, ax in ax_dict.items():
        # Use attributes for labels where available
        long_name = fcst[var_name].attrs.get("long_name", var_name)
        ax.set_ylabel(long_name)
        ax.legend()
        ax.set_ylim(bottom=0)  # Set minimum y-value to zero

    # Set xlabel only on bottom axis
    axes[-1].set_xlabel("Threshold")

    # If lead time information is available, use it in the title
    if "lead" in fcst.coords:
        lead_val = fcst.lead.values
        axes[0].set_title(f"lead {lead_val}")

    plt.tight_layout()
    return fig, axes  # viz_confusion_examples(obsd, gemd, refd)


if False:
    dat = np.timedelta64(13, "D")
    # loc = row["location"]

    obss = obs.sel(lead=dat)
    refs = ref.sel(lead=dat)
    gems = gem.sel(lead=dat)

    fig, ax = viz_search_threshold(
        obss.extreme,
        gems.extreme_pct,
        refs.extreme_pct,
        objective=objective,
        payoff=payoff,
    )
    ax[0].set_title(f"lead {dat}")
    ref_search = sk.event.search_threshold(
        obss.extreme,
        refs.extreme_pct,
        objective=objective,
        payoff=payoff,
    )
    print(ref_search.payoff.isel(thresholds=ref_search.index.values))

In [None]:
def compare_f_score(
    observed: xr.DataArray,
    forecast: xr.DataArray,
    reference: xr.DataArray,
    groupby: str,
    payoff: dict = sk.event.PAYOFF,
    beta: float = 1.0,
    figsize: tuple = (10, 10),
) -> tuple[plt.Figure, list]:
    """Compare F-score and related metrics between two forecasts along a dimension.

    Args:
        observed: Boolean DataArray indicating observed extreme events
        forecast: Boolean DataArray indicating forecasted extreme events
        reference: Boolean DataArray indicating reference forecasted extreme events
        groupby: Dimension to group by for comparison
        payoff: Dictionary specifying the value for each confusion matrix element
        beta: Weight of recall in F-score calculation (default: 1.0 for F1 score)
        figsize: Figure size (width, height)

    Returns:
        Figure and axes objects showing the comparison
    """
    # Calculate metrics for both forecast and reference
    fcst_metrics = sk.event.calc_f_score(
        observed, forecast, groupby=groupby, payoff=payoff, beta=beta
    )
    ref_metrics = sk.event.calc_f_score(
        observed, reference, groupby=groupby, payoff=payoff, beta=beta
    )
    dif_metrics = fcst_metrics - ref_metrics

    # Create unified dataset with model dimension
    fcst_name = getattr(forecast, "model_name", "Forecast")
    ref_name = getattr(reference, "model_name", "Reference")
    models = [fcst_name, ref_name]

    # Get model colors from attributes
    fcst_color = forecast.attrs.get("color", "dodgerblue")
    ref_color = reference.attrs.get("color", "#ff7f0e")
    colors = {fcst_name: fcst_color, ref_name: ref_color}

    # Staple the two metrics arrays together
    metrics = []
    for model, data in zip(models, [fcst_metrics, ref_metrics]):
        # Add model dimension to each dataset
        model_ds = data.expand_dims(dim={"model": [model]})
        # expand_dims drops attrs: preserve them.
        for var in model_ds.data_vars:
            model_ds[var].attrs = data[var].attrs
        metrics.append(model_ds)
    combined_metrics = xr.concat(metrics, dim="model")

    if isinstance(combined_metrics[groupby].values[0], str):  # only sort if categorical
        objective = "f_score" if payoff is None else "payoff"
        sort_order = dif_metrics[objective].values.argsort()[::-1]
        combined_metrics = combined_metrics.isel({groupby: sort_order})
        dif_metrics = dif_metrics.isel({groupby: sort_order})

    # Create figure with subplots
    plot_vars = list(combined_metrics.data_vars)
    num_plots = len(plot_vars)
    fig_height = max(5, 2 * num_plots)
    fig, axes = plt.subplots(num_plots, 1, figsize=(figsize[0], fig_height))

    # Plot each metric using xarray's built-in plotting
    for i, var in enumerate(plot_vars):
        plot = combined_metrics[var].plot(
            ax=axes[i],
            x=groupby,
            hue="model",
            marker="o",
            markersize=4,
        )
        for line, model in zip(axes[i].lines, colors.keys()):
            if model in colors:
                line.set_color(colors[model])

    plt.tight_layout()

    dif_metrics = dif_metrics.assign_coords(model=["DIFF"])
    combined_metrics = xr.concat([combined_metrics, dif_metrics], dim="model")

    return combined_metrics


if False:
    cmp = compare_f_score(obs.extreme, gem.extreme, ref.extreme, groupby="location", payoff=payoff)
# print(cmp.payoff.sel(model="DIFF"))

## Decompose Scores by Location and Lead

**Precision**: "When we predict an event will happen, how often are we right?" High precision means few false alarms. `pp/(pp + np)`: correct event predictions divided by total event predictions.

**Recall**: "Of all the events that actually happened, how many did we catch?"
High recall means we detect most events. `pp/(pp + pn)`: caught events divided by total actual events.

**False Positive Rate**: "Of all the non-events, what fraction did the forecast falsely flag as events?"  In other words, when nothing extreme is happening, how often does the forecast cry wolf?  `np/(np + nn)`: false alarms divided by total non-events


- `pp` = forecast predicted it and it happened
- `np` = forecast predicted it but it didn't happen (false alarm)
- `pn` = it happened but the forecast missed it

In [None]:
compare_f_score(obs.extreme, gem.extreme, ref.extreme, "lead_days", payoff, beta)
compare_f_score(obs.extreme, gem.extreme, ref.extreme, "location", payoff, beta);

## When Forecasts Differ: Joint Strategy

A paired confusion matrix extends the standard format to compare two prediction systems against reality simultaneously. Instead of just "event" or "no event" for a single system, it shows all possible combinations: both systems correct, both wrong, and cases where one system gets it right while the other fails. This format is particularly powerful for understanding not just how often each system succeeds or fails, but whether they tend to make the same mistakes or complement each other's strengths.

In [None]:
display(sk.event.style_confusion_matrix(obs.extreme, gem.extreme, ref.extreme, beta, payoff))

# display(sk.event.style_confusion_matrix(obs.extreme, gem.extreme, ref.extreme, beta=beta, payoff=None))

In [None]:
payoffs = sk.event.calc_f_score(obs.extreme, gem.extreme, ref.extreme, "lead", beta, payoff)[
    "payoff"
]
fig = payoffs.plot(x="lead_days", hue="forecast")
print(fig);

In [None]:
payoffs = sk.event.calc_f_score(obs.extreme, gem.extreme, ref.extreme, None, beta, payoff)[
    "payoff"
]
payoffs = payoffs.sel(forecast=["reference", "forecast", "paired"])

bars = plt.bar(
    [ref.extreme.attrs["model_name"], gem.extreme.attrs["model_name"], "Paired"], payoffs.values
)

bars[0].set_color(ref.extreme.attrs["color"])
bars[1].set_color(gem.extreme.attrs["color"])
bars[2].set_color("#9467bd")
plt.axhline(y=payoffs.sel(forecast="reference"), color="lightgrey", linestyle="--", alpha=0.5)

for idx in [1, 2]:
    plt.text(
        bars[idx].get_x() + bars[idx].get_width() / 2,
        (bars[idx].get_height() + bars[0].get_height()) / 2,
        f"+{(bars[idx].get_height() / bars[0].get_height() - 1) * 100:.0f}%",
        ha="center",
        va="center",
        fontweight="bold",
        color="white",
    )
plt.ylabel(f"{payoffs.attrs['long_name']} [{payoffs.attrs['units']}]");

## Visualize Threshold Optimization

The choosing a threshold is a balance between precision and recall.  View the tradeoff here.

In [None]:
# Let's find a compelling example for comparing strategy.
gem_f = sk.event.calc_f_score(obs.extreme, gem.extreme, groupby=groupby, payoff=payoff, beta=beta)
ref_f = sk.event.calc_f_score(obs.extreme, ref.extreme, groupby=groupby, payoff=payoff, beta=beta)
obs_counts = obs.extreme.sum(dim=["forecast_date", "location"])

# Create DataFrame with f-scores and add observation counts
df = pd.DataFrame(
    {
        "gem": gem_f.f_score.to_dataframe()["f_score"],
        "ref": ref_f.f_score.to_dataframe()["f_score"],
        "obs_count": obs_counts.to_dataframe()["extreme"],
    }
).reset_index()
df["dif"] = df["gem"] - df["ref"]

result = (
    df[(df["gem"] > 0) & (df["ref"] > 0) & (df["obs_count"] > 2)]
    .sort_values("dif", ascending=False)
    .head()
)
idx = 1
row = result.iloc[idx]
dat = row["lead"]
# loc = row["location"]
# dat = np.timedelta64(13,"D")


obss = obs.sel(lead=dat)
refs = ref.sel(lead=dat)
gems = gem.sel(lead=dat)

fig, ax = viz_search_threshold(
    obss.extreme,
    gems.extreme_pct,
    refs.extreme_pct,
    objective=objective,
    payoff=payoff,
)
ax[0].set_title(f"lead {dat}")

## Appendix - Statistical Skill / CRPS

In [None]:
compare_crps(obs[field], gem[field_ens], ref[field_ens], "lead")
compare_crps(obs[field], gem[field_ens], ref[field_ens], "location")
crps_date = compare_crps(obs[field], gem[field_ens], ref[field_ens], "forecast_date")

## Appendix - Classifier Diagnostics


In [None]:
# Diagnostic - forecast calibration
if False:
    pct = xr.Dataset(
        {
            "obs": sk.event.classify_event(
                obs["anom_stdv"], ext_std, width=0, dim=["forecast_date", "location"]
            ),
            "gem": sk.event.classify_event(
                gem["anom_ens_stdv"],
                ext_std,
                width=0,
                dim=["forecast_date", "location", "ensemble"],
            ),
            "ref": sk.event.classify_event(
                ref["anom_ens_stdv"],
                ext_std,
                width=0,
                dim=["forecast_date", "location", "ensemble"],
            ),
        }
    )

    fig, ax = plt.subplots(figsize=(8, 4))
    pct.obs.plot.line(x="lead_days", label=pct.obs.attrs["model_name"], color="black")
    pct.gem.plot.line(
        x="lead_days", label=pct.gem.attrs["model_name"], color=pct.gem.attrs["color"]
    )
    pct.ref.plot.line(
        x="lead_days", label=pct.ref.attrs["model_name"], color=pct.ref.attrs["color"]
    )
    ax.legend()
    plt.show()

In [None]:
def plot_extreme_scatter(dataset, obs, ax, ext_std):
    """Plot extreme probability vs mean standardized anomaly for a dataset."""
    dataset["anom_mean"] = dataset.anom_ens_stdv.mean(dim="ensemble", keep_attrs=True)
    obs["extreme_pct"] = sk.event.classify_event(obs["anom_stdv"], ext_std)
    obs.plot.scatter(
        x="anom_stdv", y="extreme_pct", alpha=0.05, s=5, color="black", rasterized=True, ax=ax
    )
    dataset.plot.scatter(
        x="anom_mean",
        y="extreme_pct",
        alpha=0.05,
        s=5,
        color=dataset["anom_ens_stdv"].attrs["color"],
        rasterized=True,
        ax=ax,
    )
    ax.axvline(x=ext_std, color="gray", linestyle="--", label=f"Target threshold: {ext_std}")
    ax.set_title(dataset.extreme_pct.attrs.get("model_name", "Model"))
    del dataset["anom_mean"]


if True:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 5), sharey=True)
    plot_extreme_scatter(gem, obs, ax1, ext_std)
    plot_extreme_scatter(ref, obs, ax2, ext_std)
    plt.show()

In [None]:
def plot_classifier_performance(forecast, observed, ax, color):
    """Plot classifier performance: extreme_pct (x) vs observed standardized anomaly (y)."""
    # Temporarily add the observed standardized anomaly to the forecast dataset
    forecast["obs_anom"] = observed.anom_stdv
    forecast["obs_anom"].attrs["long_name"] = "Observed " + forecast["obs_anom"].attrs["long_name"]

    # Create the scatter plot
    forecast.plot.scatter(
        x="extreme_pct",
        y="obs_anom",
        alpha=np.clip(0.3 - (0.25 * (forecast.extreme_pct.size - 1000) / 9000), 0.05, 0.3),
        s=5,
        color=color,
        rasterized=True,
        ax=ax,
    )
    ax.axhline(y=ext_std, color="gray", linestyle=":", label=f"Extreme threshold: {ext_std}")
    # ax.set_ylim(bottom=0)

    # Flatten arrays for metric calculation
    y_true_flat = observed.extreme.values.flatten()
    y_pred_flat = forecast.extreme_pct.values.flatten()
    mask = ~np.isnan(y_true_flat) & ~np.isnan(y_pred_flat)
    y_true_flat = y_true_flat[mask]
    y_pred_flat = y_pred_flat[mask]

    # Calculate AUC-ROC (Area Under the Receiver Operating Characteristic curve)
    auc_roc = roc_auc_score(y_true_flat, y_pred_flat)

    # Calculate AUC-PR (Area Under the Precision-Recall curve)
    auc_pr = average_precision_score(y_true_flat, y_pred_flat)

    # Calculate rank correlation (Spearman)
    correlation, _ = spearmanr(y_true_flat, y_pred_flat)

    model_name = forecast.extreme_pct.attrs.get("model_name", "Model")
    ax.set_title(
        f"{model_name}\nAUC-ROC: {auc_roc:.3f}, AUC-PR: {auc_pr:.3f}\nRank Correlation: {correlation:.3f}"
    )

    del forecast["obs_anom"]


if True:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 5), sharey=True)
    plot_classifier_performance(gem, obs, ax1, "dodgerblue")
    plot_classifier_performance(ref, obs, ax2, "#FF8C1E")
    plt.tight_layout()
    plt.show()


if False:
    location = "DFW"
    lead = np.timedelta64(15, "D")
    gem_subset = gem.sel(location=location, lead=lead)
    ref_subset = ref.sel(location=location, lead=lead)
    obs_subset = obs.sel(location=location, lead=lead)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6), sharey=True)
    plot_classifier_performance(gem_subset, obs_subset, ax1, "dodgerblue")
    plot_classifier_performance(ref_subset, obs_subset, ax2, "#FF8C1E")

    # Add location and lead time to the overall figure title
    fig.suptitle(f"Classifier Performance for {location} at {lead}-day Lead Time", fontsize=16)

    # sk.event._viz_search_f_score(obs_subset.extreme, gem_subset.extreme_pct, ref_subset.extreme_pct)

    plt.tight_layout()
    plt.show()

## Appendix - Plot Extreme Thresholds

In [None]:
def plot_extreme_thresholds(model, ax, obs):
    """Plot extreme thresholds for a model on a given axis.

    Args:
        model: The model data containing extreme thresholds
        ax: The matplotlib axis to plot on
        highlight_color: Color for the average threshold line
        obs: Observations data to calculate climatology
        ext_std: Standard deviation threshold for extreme classification
    """
    color = model.extreme_threshold.attrs["color"]
    ext_std = obs.extreme.attrs["target"]

    # Plot individual location thresholds
    model.extreme_threshold.plot.line(
        ax=ax,
        x="lead_days",
        hue="location",
        add_legend=False,
        color=color,
        alpha=0.8,
        linewidth=1,
    )
    # Plot climatology reference
    # clim = sk.event.classify_event(obs["anom_stdv"], ext_std, dim=["forecast_date", "location"])
    # clim.plot.line(ax=ax, x="lead_days", color="grey", linestyle="dashed")

    ax.set_ylabel("Extreme Threshold")
    ax.grid(True, alpha=0.2)

    # Set title using model name, extracting post-underscore part and uppercasing
    ax.set_title(f"{model.extreme.attrs['model_name']} Thresholds")


if True:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
    plot_extreme_thresholds(gem, ax1, obs)
    plot_extreme_thresholds(ref, ax2, obs)
    plt.tight_layout()

## Appendix: Example Timeseries - False Positives

In [None]:
cex = viz_confusion_examples(
    obs, gem, ref, groupby=["location", "forecast_date"], type="nnp", count=10
)

### Appendix - Unique Alpha

Highlight instances where GEM finds extremes that GEFS doesn't.

In [None]:
cex = viz_confusion_examples(
    obs, gem, ref, groupby=["location", "forecast_date"], type="ppn", count=10
)