# Validate GEM Skill

This example shows how to evaluate Salient's native daily GEM forecasts and calculate meaningful metrics. It demonstrates [validation best practices](https://salientpredictions.notion.site/Validation-0220c48b9460429fa86f577914ea5248) such as:

- Proper scoring using the Ensemble Continuous Ranked Probability Score (CRPS)
- Considers the full forecast distribution to reward both accuracy and precision
- Less sensitive to climatology decisions than metrics like Anomaly Correlation


In [None]:
import os
import sys
from reprlib import repr as rrepr

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

try:
    import salientsdk as sk
except ModuleNotFoundError as e:
    if os.path.exists("../salientsdk"):
        sys.path.append(os.path.abspath(".."))
        import salientsdk as sk
    else:
        raise ModuleNotFoundError("Install salient SDK with: pip install salientsdk")

# Prevent wrapping on tables for readability
pd.set_option("display.width", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.expand_frame_repr", False)

sk.set_file_destination("validate_gem_example")
sk.login("SALIENT_USERNAME", "SALIENT_PASSWORD")

<requests.sessions.Session at 0x7fab7a6ff2d0>

## Customize The Validation

This notebook is written flexibly so you have the option of validating Salient and other forecasts multiple ways. These variables will control what, when, and how the validation proceeds.

More information on available hindcast dates is available in the [salient documentation](https://salientpredictions.notion.site/Hindcasts-18fc9d5a921b8073a781e599e6d46be3).

In [None]:
# 1. The meteorological variable that we'll be evaluating:
var = "tmax"
# var = "tmin"

# 2. Set the number of forecast samples to download:
(start_date, end_date) = ("2020-10-16", "2025-04-04")
count_date = 4  # Get a few date samples from the range for a fast proof-of-concept
# count_date = 128  # Get a healthy range of samples for a good but quick test
# count_date = None # get all available date samples (N=1632) from the date range for a comprehensive test

# 3. The reference model to compare Salient GEM to
# ref_model = None  # skip the reference model comparison
ref_model = "noaa_gefs"
# ref_model = "ecmwf_ens"
# ref_model = "gem"  # Good for conparing GEM debiased to GEM native

# 4. Specify the source of "truth"
# debias = False  # Native ERA5-based model predictions
debias = True  # Debias to station data

In [None]:
# ===== Additional shared variables ==========================
# Not recommended to change these.

force = False  # Cache data to save on repeat API calls
verbose = False  # Show diagnostic details
gem_model = "gem"
leads = {"noaa_gefs": 35, "ecmwf_ens": 46}.get(ref_model, 50)
freq = "daily"
figsize = (8, 5)  # Make all figures have a consistent size

poc_warn = f"INDICATIVE (N={count_date}) " if count_date < 128 else ""
gem_name = ("Debiased " if debias else "") + gem_model.replace("_", " ").upper()
ref_name = "" if ref_model is None else ref_model.replace("_", " ").upper()

# Determine which dates for which to request forecasts:
if ref_model == "ecmwf_ens":
    date_range = sk.get_hindcast_dates(
        start_date=start_date, end_date=end_date, timescale=ref_model, extend=True
    )
    date_range = pd.to_datetime(date_range)
else:
    date_range = pd.date_range(start=start_date, end=end_date, freq="D")

if count_date is not None and len(date_range) > count_date:
    date_range = date_range[np.linspace(0, len(date_range) - 1, count_date, dtype=int)]

date_range = date_range.strftime("%Y-%m-%d").tolist()

print(f"Forecast dates to test [{len(date_range)}]: {rrepr(date_range)}")

Forecast dates to test [4]: ['2020-10-16', '2022-04-12', '2023-10-08', '2025-04-04']


## Set the Area of Interest

The Salient SDK uses a "Location" object to specify the geographic bounds of a request. In this case, we will be validating against the vector of airport locations.  If you have a list of locations already defined in a separate CSV file, you can use [`upload_file`](https://sdk.salientpredictions.com/api/#salientsdk.upload_file) to upload the file directly without building it in code via `upload_location_file`.


In [None]:
# fmt: off
loc = sk.Location(location_file=sk.upload_location_file(
        lats =[40.77945     ,      36.07190,      29.98438,      38.50659,      37.61970,      25.78810,      40.77060,      21.32390],
        lons =[-73.88027    ,    -115.16343,     -95.36072,    -121.49604,    -122.36560,     -80.31690,    -111.96500,    -157.93940],
        names=["LGA"        ,         "LAS",         "IAH",         "SAC",         "SFO",         "MIA",         "SLC",         "HNL"],
        ghcnd=["USW00014732", "USW00023169", "USW00012960", "USW00023232", "USW00023234", "USW00012839", "USW00024127", "USW00022521"],
        geoname="validate_stations",
        force=True,
        description=["New York La Guardia", "Las Vegas McCarran", "Houston-George Bush", "Sacramento Exec", "San Francisco Intl", "Miami Intl", "Salt Lake City Intl", "Honolulu Intl"],
    ))
# fmt: on
stations = loc.load_location_file()
print(stations)

        lat        lon name        ghcnd          description                     geometry
0  40.77945  -73.88027  LGA  USW00014732  New York La Guardia   POINT (-73.88027 40.77945)
1  36.07190 -115.16343  LAS  USW00023169   Las Vegas McCarran   POINT (-115.16343 36.0719)
2  29.98438  -95.36072  IAH  USW00012960  Houston-George Bush   POINT (-95.36072 29.98438)
3  38.50659 -121.49604  SAC  USW00023232      Sacramento Exec  POINT (-121.49604 38.50659)
4  37.61970 -122.36560  SFO  USW00023234   San Francisco Intl    POINT (-122.3656 37.6197)
5  25.78810  -80.31690  MIA  USW00012839           Miami Intl     POINT (-80.3169 25.7881)
6  40.77060 -111.96500  SLC  USW00024127  Salt Lake City Intl     POINT (-111.965 40.7706)
7  21.32390 -157.93940  HNL  USW00022521        Honolulu Intl    POINT (-157.9394 21.3239)


### Inspect station locations

If we're testing debiasing, we want to validate that that debiasing will use the expected stations.
This is optional.  If you have modified the location list and not supplied a list of reference stations

In [None]:
if debias:
    found_stations = pd.read_csv(
        sk.met_stations(
            loc=loc,
            variables="tmax" if var in ["cdd", "hdd"] else var,
            force=force,
            max_distance=40,
        )
    )
    pd.testing.assert_series_equal(found_stations.Station, stations.ghcnd, check_names=False)
    print(found_stations)

       Name Source      Station                  Description  Latitude  Longitude       Start         End Location ID  Requested Latitude  Requested Longitude  Distance (km)
0  efa5deba  ghcnd  USW00014732                 LAGUARDIA AP   40.7794   -73.8803  1990-01-01  2025-08-12         LGA            40.77945            -73.88027       0.006103
1  3b734ff2  ghcnd  USW00023169             MCCARRAN INTL AP   36.0719  -115.1633  1990-01-01  2025-08-12         LAS            36.07190           -115.16343       0.011711
2  d192c788  ghcnd  USW00012960  HOUSTON INTERCONTINENTAL AP   29.9844   -95.3608  1990-01-01  2025-08-12         IAH            29.98438            -95.36072       0.008032
3  a6fe33bd  ghcnd  USW00023232           SACRAMENTO AP ASOS   38.5067  -121.4961  1990-01-01  2025-08-12         SAC            38.50659           -121.49604       0.013285
4  197b383e  ghcnd  USW00023234        SAN FRANCISCO INTL AP   37.6197  -122.3656  1990-01-01  2025-08-12         SFO            3

## Download the Forecasts

The [`forecast_timeseries`](https://api.salientpredictions.com/v2/documentation/api/#/Forecasts/forecast_timeseries) API endpoint and SDK function gets a forecast from a particular forecast date.

In [None]:
fcst_args = dict(
    loc=loc,
    variable=var,
    debias=debias,
    field="vals_ens",
    date=date_range,
    timescale=freq,
    leads=leads,
    verbose=verbose,
    force=force,
    strict=False,  # if one downscale call fails, proceed with others
)
gem_src = sk.forecast_timeseries(model=gem_model, **fcst_args)
gem = sk.stack_forecast(gem_src, compute=False)
print(gem)

<xarray.Dataset> Size: 899kB
Dimensions:        (forecast_date: 4, lead: 35, location: 8, ensemble: 200)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * ensemble       (ensemble) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199
    lat            (location) float64 64B dask.array<chunksize=(8,), meta=np.ndarray>
    lon            (location) float64 64B dask.array<chunksize=(8,), meta=np.ndarray>
  * location       (location) <U3 96B 'LGA' 'LAS' 'IAH' ... 'MIA' 'SLC' 'HNL'
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
    time           (forecast_date, lead) datetime64[ns] 1kB 2020-10-16 ... 20...
Data variables:
    vals_ens       (forecast_date, lead, location, ensemble) float32 896kB dask.array<chunksize=(1, 35, 8, 200), meta=np.ndarray>
Attributes:
    debias:   true


In [None]:
if ref_model is not None:
    if ref_model == "gem" or ref_model == "ecmwf_ens":
        # If ref_model is GEM, the point is to compare debiased and native forecasts
        # ECMWF_ENS does not support debiasing
        fcst_args["debias"] = False

    ref_src = sk.forecast_timeseries(model=ref_model, **fcst_args)
    ref = sk.stack_forecast(ref_src, compute=False)
    print(ref)
else:
    ref = None

<xarray.Dataset> Size: 4MB
Dimensions:        (forecast_date: 128, lead: 35, location: 8, ensemble: 31)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 1kB 2020-10-16 ... 2025-04-04
  * ensemble       (ensemble) int64 248B 0 1 2 3 4 5 6 ... 24 25 26 27 28 29 30
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
    lat            (location) float64 64B dask.array<chunksize=(8,), meta=np.ndarray>
    lon            (location) float64 64B dask.array<chunksize=(8,), meta=np.ndarray>
  * location       (location) <U3 96B 'LGA' 'LAS' 'IAH' ... 'MIA' 'SLC' 'HNL'
    time           (forecast_date, lead) datetime64[ns] 36kB 2020-10-16 ... 2...
Data variables:
    vals_ens       (forecast_date, lead, location, ensemble) float32 4MB dask.array<chunksize=(1, 35, 8, 31), meta=np.ndarray>


### Download Historical Truth

Download daily historical values from [`data_timeseries`](https://sdk.salientpredictions.com/api/#salientsdk.data_timeseries) and [`met_observations`](https://api.salientpredictions.com/v2/documentation/api/#/Meteorological%20Stations/met_observations). When `debias=True` we will use met station observations as ground truth.  When `debias=False`, ERA5 historicals are considered truth.


In [None]:
hist_args = {
    "loc": loc,
    "start": np.datetime64(start_date) - np.timedelta64(5, "D"),
    "end": np.datetime64(end_date) + np.timedelta64(leads + 1, "D"),
    "verbose": verbose,
    "force": force,
}
obs_src = sk.met_observations(variables=var, **hist_args)
era_src = sk.data_timeseries(variable=var, field="vals", **hist_args)


obs = (
    sk.stack_history(obs_src, forecast_date=gem.forecast_date, lead=gem.lead, compute=False)
    .rename({var: "vals_ens"})
    .reset_coords(drop=True)
)
era = (
    sk.stack_history(era_src, forecast_date=gem.forecast_date, lead=gem.lead, compute=False)
    .rename({"vals": "vals_ens"})
    .reset_coords(drop=True)
)
print(obs if debias else era)

<xarray.Dataset> Size: 5kB
Dimensions:        (forecast_date: 4, lead: 35, location: 8)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * location       (location) <U3 96B 'LGA' 'LAS' 'IAH' ... 'MIA' 'SLC' 'HNL'
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
Data variables:
    vals_ens       (forecast_date, lead, location) float32 4kB dask.array<chunksize=(1, 35, 8), meta=np.ndarray>


Some locations (like HNL and SAC) can have significant disagreements between ERA5 and station observations.  The more consistent bias, the more headroom the debiaser has to improve forecasts.

In [None]:
if debias:
    err = (era - obs)["vals_ens"].mean(dim=["forecast_date", "lead"])
    err = err.sortby(err, ascending=False)

    plt.figure(figsize=figsize)
    plt.bar(err.location.values, err.values)
    plt.title("Bias: ERA5 - Station Observations")
    plt.ylabel(f"{era.attrs['long_name']} [{era.attrs['units']}]")
    plt.show()

## Calculate Forecast Skill Metrics

Compare the forecast and ERA5 datasets to see how well they match. Here we will calculate the same "Continuous Ranked Probability Score" that resulted from the call to `hindcast_summary` earlier.


In [None]:
skill_gem = sk.skill.crps_ensemble(observations=obs if debias else era, forecasts=gem)
print(skill_gem)

<xarray.Dataset> Size: 7kB
Dimensions:        (forecast_date: 4, lead: 35, location: 8)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * location       (location) <U3 96B 'LGA' 'LAS' 'IAH' ... 'MIA' 'SLC' 'HNL'
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
    lat            (location) float64 64B 40.78 36.07 29.98 ... 40.77 21.32
    lon            (location) float64 64B -73.88 -115.2 -95.36 ... -112.0 -157.9
    time           (forecast_date, lead) datetime64[ns] 1kB 2020-10-16 ... 20...
Data variables:
    crps_ens_all   (forecast_date, lead, location) float32 4kB 0.7379 ... 1.757
    crps_ens       (lead, location) float32 1kB 1.408 1.804 1.929 ... 3.311 1.53
Attributes:
    debias:      true
    short_name:  crps
    long_name:   CRPS


In [None]:
if ref is None:
    print("No reference model, skipping relative skill calculation")
    skill_ref = None
else:
    # Skill of the reference model:
    skill_ref = sk.skill.crps_ensemble(observations=obs if debias else era, forecasts=ref)
    print(skill_ref)

<xarray.Dataset> Size: 7kB
Dimensions:        (forecast_date: 4, lead: 35, location: 8)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * location       (location) <U3 96B 'LGA' 'LAS' 'IAH' ... 'MIA' 'SLC' 'HNL'
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
    lat            (location) float64 64B 40.78 36.07 29.98 ... 40.77 21.32
    lon            (location) float64 64B -73.88 -115.2 -95.36 ... -112.0 -157.9
    time           (forecast_date, lead) datetime64[ns] 1kB 2020-10-16 ... 20...
Data variables:
    crps_ens_all   (forecast_date, lead, location) float32 4kB 0.03331 ... 4.795
    crps_ens       (lead, location) float32 1kB 1.257 2.78 2.136 ... 4.047 3.521
Attributes:
    short_name:  crps
    long_name:   CRPS


In [None]:
if ref is None:
    print("No reference model, skipping relative comparison")
    skill_rel = None
else:
    # Relative skills score of Salient downscale vs the reference model:
    skill_rel = sk.skill.crpss(forecast=skill_gem, reference=skill_ref)
    print(skill_rel)

<xarray.Dataset> Size: 7kB
Dimensions:        (forecast_date: 4, location: 8, lead: 35)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * location       (location) <U3 96B 'LGA' 'LAS' 'IAH' ... 'MIA' 'SLC' 'HNL'
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
    lat            (location) float64 64B 40.78 36.07 29.98 ... 40.77 21.32
    lon            (location) float64 64B -73.88 -115.2 -95.36 ... -112.0 -157.9
    time           (forecast_date, lead) datetime64[ns] 1kB 2020-10-16 ... 20...
Data variables:
    crpss_ens_all  (forecast_date, lead, location) float32 4kB -21.15 ... 0.6335
    crpss_ens      (lead, location) float32 1kB -0.1208 0.3511 ... 0.1817 0.5654
Attributes:
    short_name:  crpss
    long_name:   CRPSS


In [None]:
fig, ax = plt.subplots(figsize=figsize)

if skill_ref is not None:
    skill_ref[f"crps_ens"].mean("location", keep_attrs=True).plot(
        ax=ax,
        color="#FF7F00",
        linewidth=2,
        label=ref_name,
    )

skill_gem[f"crps_ens"].mean("location", keep_attrs=True).plot(
    ax=ax,
    color="dodgerblue",
    linewidth=2,
    label=gem_name,
)

ax.xaxis.set_major_formatter(lambda x, pos: f"{x/1e9/86400:.0f}")
ax.set_xlabel("Lead Time (days)")
ax.set_ylabel(f"CRPS {gem.vals_ens.attrs['long_name']} [{gem.vals_ens.attrs['units']}]")
ax.set_title(f"{poc_warn} All-locations Mean CRPS (lower is better)")
plt.legend()
plt.tight_layout()

In [None]:
if skill_rel is None:
    print("Skipping relative skill boxplot")
else:
    medians = skill_rel[f"crpss_ens"].median("lead")
    sorted_locations = medians.sortby(medians, ascending=False).location.values
    fig, ax = plt.subplots(figsize=figsize)
    df = skill_rel[f"crpss_ens"].to_pandas().melt(ignore_index=False)
    ax.boxplot(
        [df[df["location"] == loc]["value"] for loc in sorted_locations],
        tick_labels=sorted_locations,  # Updated parameter name
        patch_artist=True,
        showfliers=False,
        medianprops=dict(color="black"),
        boxprops=dict(facecolor="dodgerblue"),
    )
    ax.axhline(y=0, color="grey", linestyle=":", zorder=0)
    ax.set_xlabel("Location")
    ax.set_ylabel(f"CRPSS {gem.vals_ens.attrs['long_name']}")
    ax.set_title(f"{poc_warn} Relative Skill {gem_name} vs {ref_name} (higher is better)")
    ymin, ymax = ax.get_ylim()
    ax.set_ylim(max(-1, ymin), min(1, ymax))
    plt.tight_layout()