# Validate GEM Skill

This example shows how to evaluate Salient's native daily GEM forecasts and calculate meaningful metrics. It demonstrates [validation best practices](https://salientpredictions.notion.site/Validation-0220c48b9460429fa86f577914ea5248) such as:

- Proper scoring using the Ensemble Continuous Ranked Probability Score (CRPS)
- Considers the full forecast distribution to reward both accuracy and precision
- Less sensitive to climatology decisions than metrics like Anomaly Correlation


In [None]:
import os
import sys
from reprlib import repr as rrepr

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

try:
    import salientsdk as sk
except ModuleNotFoundError as e:
    if os.path.exists("../salientsdk"):
        sys.path.append(os.path.abspath(".."))
        import salientsdk as sk
    else:
        raise ModuleNotFoundError("Install salient SDK with: pip install salientsdk")

# Prevent wrapping on tables for readability
pd.set_option("display.width", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.expand_frame_repr", False)

sk.set_file_destination("validate_gem_example")
sk.login("SALIENT_USERNAME", "SALIENT_PASSWORD")

<requests.sessions.Session at 0x7fccdac77cd0>

## Customize The Validation

This notebook is written flexibly so you have the option of validating Salient and other forecasts multiple ways. These variables will control what, when, and how the validation proceeds.

More information on available hindcast dates is available in the [salient documentation](https://salientpredictions.notion.site/Hindcasts-18fc9d5a921b8073a781e599e6d46be3).

In [None]:
# 1. The meteorological variable that we'll be evaluating:
var = "tmax"

# 2. Set the number of forecast samples to download:
(start_date, end_date) = ("2020-10-16", "2025-04-04")
count_date = 4  # Get a few date samples from the range for a fast proof-of-concept
# count_date = 256 # Get a healthy range of samples for a good but still quick test
# count_date = None # get all available date samples (N=1632) from the range for a comprehensive test

# 3. The reference model to compare Salient GEM to
# ref_model = None  # skip the reference model comparison
ref_model = "noaa_gefs"
# ref_model = "ecmwf_ens"
# ref_model = "gem"

# ===== Additional shared variables ==========================
# Not recommended to change these.

debias = False
gem_model = "gem"
gem_name = gem_model.replace("_", " ").upper()
ref_name = ref_model.replace("_", " ").upper()
leads = {"noaa_gefs": 35, "ecmwf_ens": 42, "gem": 50, None: None}[ref_model]
freq = "daily"
force = False  # Cache data to save on repeat API calls
verbose = False  # Show diagnostic details
figsize = (8, 5)  # Make all figures have a consistent size
poc_warn = f"INDICATIVE (N={count_date}) " if count_date < 128 else ""


# Determine which dates for which to request forecasts:
if ref_model == "ecmwf_ens":
    date_range = sk.get_hindcast_dates(
        start_date=start_date, end_date=end_date, timescale=ref_model, extend=True
    )
    date_range = pd.to_datetime(date_range)
else:
    date_range = pd.date_range(start=start_date, end=end_date, freq="D")

if count_date is not None and len(date_range) > count_date:
    date_range = date_range[np.linspace(0, len(date_range) - 1, count_date, dtype=int)]

date_range = date_range.strftime("%Y-%m-%d").tolist()

print(f"Forecast dates to test [{len(date_range)}]: {rrepr(date_range)}")

Forecast dates to test [4]: ['2020-10-16', '2022-04-12', '2023-10-08', '2025-04-04']


## Set the Area of Interest

The Salient SDK uses a "Location" object to specify the geographic bounds of a request. In this case, we will be validating against the vector of airport locations that are used to settle the Chicago Mercantile Exchange's Cooling and Heating Degree Day contracts. With `load_location_file` we can see that the file contains:

- `lat` / `lon`: latitude and longitude of the met station, standard for a `location_file`
- `name`: the 3-letter IATA airport code of the location, also `location_file` standard
- `ghcnd`: the global climate network ID of the station, used to validate against observations. To customize this analysis for any set of observation stations, use the NCEI [stations list](https://www.ncei.noaa.gov/pub/data/ghcn/daily/ghcnd-stations.txt).
- `cme`: the CME code for the location used to create CDD/HDD strip codes.
- `description`: full name of the airport

If you have a list of locations already defined in a separate CSV file, you can use [`upload_file`](https://sdk.salientpredictions.com/api/#salientsdk.upload_file) to upload the file directly without building it in code via `upload_location_file`.


In [None]:
# fmt: off
loc = sk.Location(location_file=sk.upload_location_file(
    lats =[33.62972     ,      42.36057,      34.19966,      41.96017,      39.04443,      32.89744,      29.98438,      36.07190,      44.88523,      40.77945,      39.87326,      45.59578,      38.50659],
    lons =[-84.44224    ,     -71.00975,    -118.36543,     -87.93164,     -84.67241,     -97.02196,     -95.36072,    -115.16343,     -93.23133,     -73.88027,     -75.22681,    -122.60919,    -121.49604],
    names=["ATL"        ,         "BOS",         "BUR",         "ORD",         "CVG",         "DFW",         "IAH",         "LAS",         "MSP",         "LGA",         "PHL",         "PDX",         "SAC"],
    ghcnd=["USW00013874", "USW00014739", "USW00023152", "USW00094846", "USW00093814", "USW00003927", "USW00012960", "USW00023169", "USW00014922", "USW00014732", "USW00013739", "USW00024229", "USW00023232"],
    cme  =["1"          ,           "W",           "P",           "2",           "3",           "5",           "R",           "0",           "Q",           "4",           "6",           "7",           "S"],
    geoname="cmeus",
    force=force,
    description=["Atlanta Hartsfield", "Boston Logan", "Burbank-Glendale-Pasadena", "Chicago O'Hare", "Cincinnati (Covington)","Dallas-Fort Worth", "Houston-George Bush", "Las Vegas McCarran", "Minneapolis-StPaul", "New York La Guardia","Philadelphia", "Portland", "Sacramento Exec"],
))
# fmt: on
stations = loc.load_location_file()
print(stations)

         lat        lon name        ghcnd cme                description                     geometry
0   33.62972  -84.44224  ATL  USW00013874   1         Atlanta Hartsfield   POINT (-84.44224 33.62972)
1   42.36057  -71.00975  BOS  USW00014739   W               Boston Logan   POINT (-71.00975 42.36057)
2   34.19966 -118.36543  BUR  USW00023152   P  Burbank-Glendale-Pasadena  POINT (-118.36543 34.19966)
3   41.96017  -87.93164  ORD  USW00094846   2             Chicago O'Hare   POINT (-87.93164 41.96017)
4   39.04443  -84.67241  CVG  USW00093814   3     Cincinnati (Covington)   POINT (-84.67241 39.04443)
5   32.89744  -97.02196  DFW  USW00003927   5          Dallas-Fort Worth   POINT (-97.02196 32.89744)
6   29.98438  -95.36072  IAH  USW00012960   R        Houston-George Bush   POINT (-95.36072 29.98438)
7   36.07190 -115.16343  LAS  USW00023169   0         Las Vegas McCarran   POINT (-115.16343 36.0719)
8   44.88523  -93.23133  MSP  USW00014922   Q         Minneapolis-StPaul   POINT (

## Download the Forecasts

The [`forecast_timeseries`](https://api.salientpredictions.com/v2/documentation/api/#/Forecasts/forecast_timeseries) API endpoint and SDK function gets a forecast from a particular forecast date.

In [None]:
fcst_args = dict(
    loc=loc,
    variable=var,
    debias=debias,
    field="vals_ens",
    date=date_range,
    timescale=freq,
    leads=leads,
    verbose=verbose,
    force=force,
    strict=False,  # if one downscale call fails, proceed with others
)
gem_src = sk.forecast_timeseries(model=gem_model, **fcst_args)
gem = sk.stack_forecast(gem_src, compute=False)
print(gem)

<xarray.Dataset> Size: 1MB
Dimensions:        (forecast_date: 4, lead: 35, location: 13, ensemble: 200)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * ensemble       (ensemble) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
    lat            (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon            (location) float64 104B -84.44 -71.01 ... -122.6 -121.5
  * location       (location) <U3 156B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
    time           (forecast_date, lead) datetime64[ns] 1kB 2020-10-16 ... 20...
Data variables:
    vals_ens       (forecast_date, lead, location, ensemble) float32 1MB 21.2...


In [None]:
if ref_model is not None:
    fcst_args["debias"] = False  # GEFS & ENS don't support debiasing
    ref_src = sk.forecast_timeseries(model=ref_model, **fcst_args)
    ref = sk.stack_forecast(ref_src, compute=False)
    print(ref)
else:
    ref = None

<xarray.Dataset> Size: 228kB
Dimensions:        (forecast_date: 4, lead: 35, location: 13, ensemble: 31)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * ensemble       (ensemble) int64 248B 0 1 2 3 4 5 6 ... 24 25 26 27 28 29 30
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
    lat            (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon            (location) float64 104B -84.44 -71.01 ... -122.6 -121.5
  * location       (location) <U3 156B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
    time           (forecast_date, lead) datetime64[ns] 1kB 2020-10-16 ... 20...
Data variables:
    vals_ens       (forecast_date, lead, location, ensemble) float32 226kB 21...


### Download Historical Truth

Download daily historical values from [`data_timeseries`](https://sdk.salientpredictions.com/api/#salientsdk.data_timeseries) and [`met_observations`](https://api.salientpredictions.com/v2/documentation/api/#/Meteorological%20Stations/met_observations).


In [None]:
hist_args = {
    "loc": loc,
    "start": np.datetime64(start_date) - np.timedelta64(5, "D"),
    "end": np.datetime64(end_date) + np.timedelta64(leads + 1, "D"),
    "verbose": verbose,
    "force": force,
}
obs_src = sk.met_observations(variables=var, **hist_args)
era_src = sk.data_timeseries(variable=var, field="vals", **hist_args)


obs = (
    sk.stack_history(obs_src, forecast_date=gem.forecast_date, lead=gem.lead, compute=False)
    .rename({var: "vals_ens"})
    .reset_coords(drop=True)
)
era = (
    sk.stack_history(era_src, forecast_date=gem.forecast_date, lead=gem.lead, compute=False)
    .rename({"vals": "vals_ens"})
    .reset_coords(drop=True)
)
print(era)

<xarray.Dataset> Size: 15kB
Dimensions:        (forecast_date: 4, lead: 35, location: 13)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * location       (location) <U3 156B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
Data variables:
    vals_ens       (forecast_date, lead, location) float64 15kB dask.array<chunksize=(1, 35, 13), meta=np.ndarray>
Attributes:
    long_name:   2 metre temperature
    units:       degC
    clim_start:  1990-01-01
    clim_end:    2019-12-31


## Calculate Skill Metrics

Compare the forecast and ERA5 datasets to see how well they match. Here we will calculate the same "Continuous Ranked Probability Score" that resulted from the call to `hindcast_summary` earlier.


In [None]:
skill_gem = sk.skill.crps_ensemble(observations=era, forecasts=gem)
print(skill_gem)

<xarray.Dataset> Size: 20kB
Dimensions:        (forecast_date: 4, lead: 35, location: 13)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * location       (location) <U3 156B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
    lat            (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon            (location) float64 104B -84.44 -71.01 ... -122.6 -121.5
    time           (forecast_date, lead) datetime64[ns] 1kB 2020-10-16 ... 20...
Data variables:
    crps_ens_all   (forecast_date, lead, location) float64 15kB 0.473 ... 2.479
    crps_ens       (lead, location) float64 4kB 0.7765 1.101 ... 1.201 1.746
Attributes:
    short_name:  crps
    long_name:   CRPS


In [None]:
if ref is None:
    print("No reference model, skipping relative comparison")
    skill_ref = None
    skill_rel = None
else:
    # Skill of the reference model:
    skill_ref = sk.skill.crps_ensemble(observations=era, forecasts=ref)

    # Relative skills score of Salient downscale vs the reference model:
    skill_rel = sk.skill.crpss(forecast=skill_gem, reference=skill_ref)
    print(skill_rel)

<xarray.Dataset> Size: 20kB
Dimensions:        (forecast_date: 4, location: 13, lead: 35)
Coordinates:
  * forecast_date  (forecast_date) datetime64[ns] 32B 2020-10-16 ... 2025-04-04
  * location       (location) <U3 156B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
  * lead           (lead) timedelta64[ns] 280B 1 days 2 days ... 34 days 35 days
    lat            (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon            (location) float64 104B -84.44 -71.01 ... -122.6 -121.5
    time           (forecast_date, lead) datetime64[ns] 1kB 2020-10-16 ... 20...
Data variables:
    crpss_ens_all  (forecast_date, lead, location) float64 15kB -9.386 ... -1...
    crpss_ens      (lead, location) float64 4kB -1.907 -1.268 ... -0.3851
Attributes:
    short_name:  crpss
    long_name:   CRPSS


In [None]:
if skill_ref is None:
    print("Skipping relative skill plotting")
else:
    fig, ax = plt.subplots(figsize=figsize)

    skill_ref[f"crps_ens"].mean("location", keep_attrs=True).plot(
        ax=ax,
        color="#FF7F00",
        linewidth=2,
        label=ref_name,
    )
    skill_gem[f"crps_ens"].mean("location", keep_attrs=True).plot(
        ax=ax,
        color="dodgerblue",
        linewidth=2,
        label=gem_name,
    )

    ax.xaxis.set_major_formatter(lambda x, pos: f"{x/1e9/86400:.0f}")
    ax.set_xlabel("Lead Time (days)")
    ax.set_ylabel(f"CRPS {gem.vals_ens.attrs['long_name']} [{gem.vals_ens.attrs['units']}]")
    ax.set_title(f"{poc_warn} All-locations Mean CRPS (lower is better)")
    plt.legend()
    plt.tight_layout()

In [None]:
if skill_rel is None:
    print("Skipping relative skill boxplot")
else:
    medians = skill_rel[f"crpss_ens"].median("lead")
    sorted_locations = medians.sortby(medians, ascending=False).location.values
    fig, ax = plt.subplots(figsize=(8, 5))
    df = skill_rel[f"crpss_ens"].to_pandas().melt(ignore_index=False)
    ax.boxplot(
        [df[df["location"] == loc]["value"] for loc in sorted_locations],
        tick_labels=sorted_locations,  # Updated parameter name
        patch_artist=True,
        showfliers=False,
        medianprops=dict(color="black"),
        boxprops=dict(facecolor="dodgerblue"),
    )
    ax.axhline(y=0, color="grey", linestyle=":", zorder=0)
    ax.set_xlabel("Location")
    ax.set_ylabel(f"CRPSS {gem.vals_ens.attrs['long_name']}")
    ax.set_title(f"{poc_warn} Relative Skill {gem_name} vs {ref_name} (higher is better)")
    ymin, ymax = ax.get_ylim()
    ax.set_ylim(max(-1, ymin), min(1, ymax))
    plt.tight_layout()