# Validate Forecast Skill vs Observations

This example shows how to evaluate Salient's probabilistic forecasts against observations and calculate meaningful metrics. It demonstrates [validation best practices](https://salientpredictions.notion.site/Validation-0220c48b9460429fa86f577914ea5248) such as:

- Proper scoring using the Continuous Ranked Probability Score (CRPS)
  - Considers the full forecast distribution to reward both accuracy and precision
  - Less sensitive to climatology decisions than metrics like Anomaly Correlation
- A long backtesting period (2015-2022)
  - Short evaluation periods are subject to noise


In [None]:
# Initialize the environment:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

try:
    import salientsdk as sk
except ModuleNotFoundError as e:
    if os.path.exists("../salientsdk"):
        sys.path.append(os.path.abspath(".."))
        import salientsdk as sk
    else:
        raise ModuleNotFoundError("Install salient SDK with: pip install salientsdk")

# Prevent wrapping on tables for readability
pd.set_option("display.width", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.expand_frame_repr", False)

sk.set_file_destination("validation_example")
sk.login("username", "password")

<requests.sessions.Session at 0x7f7fabf06fd0>

## Customize The Validation

This notebook is written flexibly so you have the option of validating Salient and other forecasts multiple ways. These variables will control what, when, and how the validation proceeds:


In [None]:
# 1. Set the meteorological variable that we'll be evaluating:
var = "temp"
# var = "precip"

# 2. Set the forecast look-ahead amount:
timescale = "sub-seasonal"  # weeks 1-5
# timescale = "seasonal"  # months 1-3
# timescale = "long-range" # quarters 1-4

# 3. Set the number of hindcasts to download for validation:
split_set = "sample"  # fast demonstration of mechanics
# split_set = "test"  # recommended to validate out-of-sample with hindcast_summary
# split_set = "all"  # download hindcasts since 2000

(start_date, end_date) = {
    "sample": ("2021-04-01", "2021-08-31"),
    "test": ("2015-01-01", "2022-12-31"),
    "all": ("2000-01-01", "2022-12-31"),
}[split_set]

# 4. Decide whether or not to validate vs station observations:
validate_obs = True  # Calculate skill of debiased forecast vs met stations
# validate_obs = False  # Calculate skill of undebiased forecast vs ERA5


fld = "vals"
model = "blend"  # Validate the primary Salient blend model
ref_model = "clim"  # Works across all timescale values.
force = False  # If "False", cache data calls.  Set to "True" to overwrite caches

## Set the Area of Interest

The Salient SDK uses a "Location" object to specify the geographic bounds of a request. In this case, we will be validating against a vector of airport locations that are used to settle the Chicago Mercantile Exchange's Cooling and Heating Degree Day contracts. The SDK contains an example file with all of the locations pre-defined, as well as additional useful information. With `load_location_file` we can see that the file contains:

- `lat` / `lon`: latitude and longitude of the met station, standard for a `location_file`
- `name`: the 3-letter IATA airport code of the location, also `location_file` standard
- `ghcnd`: the global climate network ID of the station
- `cme`: the CME code for the location used to create CDD/HDD strips
- `description`: full name of the airport


In [None]:
# fmt: off
loc = sk.Location(location_file=sk.upload_location_file(
    lats =[33.62972     ,      42.36057,      34.19966,      41.96017,      39.04443,      32.89744,      29.98438,      36.07190,      44.88523,      40.77945,      39.87326,      45.59578,      38.50659],
    lons =[-84.44224    ,     -71.00975,    -118.36543,     -87.93164,     -84.67241,     -97.02196,     -95.36072,    -115.16343,     -93.23133,     -73.88027,     -75.22681,    -122.60919,    -121.49604],
    names=["ATL"        ,         "BOS",         "BUR",         "ORD",         "CVG",         "DFW",         "IAH",         "LAS",         "MSP",         "LGA",         "PHL",         "PDX",         "SAC"],
    ghcnd=["USW00013874", "USW00014739", "USW00023152", "USW00094846", "USW00093814", "USW00003927", "USW00012960", "USW00023169", "USW00014922", "USW00014732", "USW00013739", "USW00024229", "USW00023232"],
    cme  =["1"          ,           "W",           "P",           "2",           "3",           "5",           "R",           "0",           "Q",           "4",           "6",           "7",           "S"],
    geoname="cmeus",
    force=force,
    description=["Atlanta Hartsfield", "Boston Logan", "Burbank-Glendale-Pasadena", "Chicago O'Hare", "Cincinnati (Covington)","Dallas-Fort Worth", "Houston-George Bush", "Las Vegas McCarran", "Minneapolis-StPaul", "New York La Guardia","Philadelphia", "Portland", "Sacramento Exec"],
))
# fmt: on
stations = loc.load_location_file()
print(stations)

         lat        lon name        ghcnd cme                description                     geometry
0   33.62972  -84.44224  ATL  USW00013874   1         Atlanta Hartsfield   POINT (-84.44224 33.62972)
1   42.36057  -71.00975  BOS  USW00014739   W               Boston Logan   POINT (-71.00975 42.36057)
2   34.19966 -118.36543  BUR  USW00023152   P  Burbank-Glendale-Pasadena  POINT (-118.36543 34.19966)
3   41.96017  -87.93164  ORD  USW00094846   2             Chicago O'Hare   POINT (-87.93164 41.96017)
4   39.04443  -84.67241  CVG  USW00093814   3     Cincinnati (Covington)   POINT (-84.67241 39.04443)
5   32.89744  -97.02196  DFW  USW00003927   5          Dallas-Fort Worth   POINT (-97.02196 32.89744)
6   29.98438  -95.36072  IAH  USW00012960   R        Houston-George Bush   POINT (-95.36072 29.98438)
7   36.07190 -115.16343  LAS  USW00023169   0         Las Vegas McCarran   POINT (-115.16343 36.0719)
8   44.88523  -93.23133  MSP  USW00014922   Q         Minneapolis-StPaul   POINT (

## Pull precomputed skill

Salient pre-calculates skill metrics as part of our internal validation and model improvement process. Use the `hindcast_summary` api endpoint to request pre-calculated skill scores. This is the "easy" way to validate Salient's forecasts: we've already done all the work for you.

The remainder of this notebook will show you how to reproduce this skill calculation by requesting historical forecasts, historical data, and calculating a skill score.


In [None]:
# hindcast_summary prefers single lat/lon values, so we'll iterate over each geo-pair.
skill_summ = pd.concat(
    [
        pd.read_csv(
            sk.hindcast_summary(
                loc=sk.Location(lat=row["lat"], lon=row["lon"]),
                interp_method="linear",
                metric="crps",
                variable=var,
                timescale=timescale,
                reference=ref_model,
                split_set="test" if split_set == "sample" else split_set,
                verbose=False,
                force=force,
            )
        )
        .assign(Location=row["name"])
        .drop(columns=["Reference Model"])
        .set_index(["Location", "Lead"])
        for _, row in stations.iterrows()
    ],
    ignore_index=False,
)
print(skill_summ)

                 Reference CRPS  Salient CRPS  Salient CRPS Skill Score (%)
Location Lead                                                              
ATL      Week 1            1.52          0.48                          68.7
         Week 2            1.52          1.06                          30.1
         Week 3            1.53          1.36                          11.0
         Week 4            1.52          1.41                           7.5
         Week 5            1.52          1.41                           7.5
...                         ...           ...                           ...
SAC      Week 1            1.22          0.49                          60.0
         Week 2            1.22          0.89                          27.4
         Week 3            1.23          1.09                          11.3
         Week 4            1.23          1.12                           8.9
         Week 5            1.23          1.12                           9.2

[65 rows x 

In [None]:
def plot_skill(
    df: pd.DataFrame, col: str = "Salient CRPS Skill Score (%)", title: str = "Skill"
) -> None:
    """Plot skill scores in a table."""
    df = df.reset_index()

    plt.figure(figsize=(12, 6))
    for location in df["Location"].unique():
        subset = df[df["Location"] == location]
        plt.plot(subset["Lead"], subset[col], label=location)

    plt.title(title)
    plt.xlabel("Lead")
    plt.ylabel(col)
    plt.legend(title="Location", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()


plot_skill(skill_summ, title=f"hindcast_summary crps {model} vs {ref_model}")

## Historical Data

To calculate forecast skill, we will want to compare forecasts made in the past with actuals. There are two flavors of actual data: (1) The ERA5 reanalysis dataset and (2) point weather station observations.

Salient's forecast natively predicts (1) ERA5, but contains a debiasing function to remove bias between ERA5 and (2) station observations.


### Historical ERA5 Data

Download daily historical values from [`data_timeseries`](https://sdk.salientpredictions.com/api/#salientsdk.data_timeseries) and then aggregate to match the forecasts, so that we can ensure that all forecasts use the same dates.

Also, get observed weather station data in the same format by downloading directly from NCEI.


In [None]:
# Get additional historical data beyond end_date to make sure we have enough
# observed days to compare with the final forecast.
duration = {"sub-seasonal": 8 * 5, "seasonal": 31 * 3, "long-range": 95 * 4}[timescale]
hist = sk.data_timeseries(
    loc=loc,
    variable=var,
    field=fld,
    start=np.datetime64(start_date) - np.timedelta64(5, "D"),
    end=np.datetime64(end_date) + np.timedelta64(duration, "D"),
    frequency="daily",
    # reference_clim="30_yr",  implicitly uses 30 yr climatology
    verbose=False,
    force=force,
)
print(xr.load_dataset(hist))

<xarray.Dataset> Size: 22kB
Dimensions:   (time: 198, location: 13)
Coordinates:
  * time      (time) datetime64[ns] 2kB 2021-03-27 2021-03-28 ... 2021-10-10
  * location  (location) object 104B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
    lat       (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon       (location) float64 104B -84.44 -71.01 -118.4 ... -122.6 -121.5
Data variables:
    vals      (time, location) float64 21kB 21.77 10.66 14.5 ... 12.31 16.97
Attributes:
    long_name:   2 metre temperature
    units:       degC
    clim_start:  1990-01-01
    clim_end:    2019-12-31


### Historical Observed Data

Get observed historical data from meteorological stations. In this case, we'll write a `get_ghcnd` function that downloads observed station data and returns a list of pandas `DataFrame`s. If you have observed data from a proprietary source, you can substitute a function here that returns a vector of `DataFrame`s. Make sure that the `DataFrame`s have a column that matches the `variable` of interest (such as `temp`).


#### Example Function: GHCNd Observed


In [None]:
from collections.abc import Iterable
import requests


def get_ghcnd(
    ghcnd_id: str | Iterable[str],
    start_date: str = "2000-01-01",
    xdd: float = (65 - 32) * 5 / 9,
    destination: str = "-default",
    force: bool = False,
) -> pd.DataFrame | list[pd.DataFrame]:
    """Get GHCNd observed data timeseries for a single station.

    Global Historical Climatology Network - Daily.

    Args:
        ghcnd_id (str | list[str]): GHCND station ID or list of IDs
        start_date (str): omit data before this date
        xdd (float): base temperature for heating/cooling degree days, in degC.
        destination (str): The directory to download the data to
        force (bool): if True (default False), force update of observed data

    Returns:
        pd.DataFrame | list[pd.DataFrame]: observed data timeseries with
        columns `time`, `precip`, `wspd`, `tmin`, `tmax`, `tavg`, `hdd` and `cdd`
        in units `degC`.  Also meta-data columns for `ghcn_id`, `lat`, `lon`, `elev`, and `name`.
        Will return a list of DataFrames if wban_id is iterable.


    Examples:
        >>> obs_single = get_ghcnd("USW00013874")
        >>> obs_vector = get_ghcnd(["USW00013874", "USW00014739"])
    """
    if isinstance(ghcnd_id, Iterable) and not isinstance(ghcnd_id, str):
        return [get_ghcnd(single_id, start_date, xdd, force) for single_id in ghcnd_id]

    file_name = os.path.join(sk.get_file_destination(), f"{ghcnd_id}.csv")
    if force or not os.path.exists(file_name):
        GHCND_ROOT = "https://www.ncei.noaa.gov"
        GHCND_DIR = "data/global-historical-climatology-network-daily/access"
        ghcnd_url = f"{GHCND_ROOT}/{GHCND_DIR}/{ghcnd_id}.csv"
        r = requests.get(ghcnd_url)
        r.raise_for_status()
        with open(file_name, "w") as f:
            f.write(r.text)

    # Gusts: WSF1, WSF2, WSF5 are fastest 1-min, 2-min, 5-sec wind speed
    # Humidity: RHAV, RHMN, RHMX
    keep_cols = {
        "STATION": "ghcnd_id",
        "DATE": "time",
        "LATITUDE": "lat",
        "LONGITUDE": "lon",
        "ELEVATION": "elev",  # meters
        "NAME": "name",
        "TMAX": "tmax",
        "TMIN": "tmin",
        "PRCP": "precip",
        "TAVG": "temp",
        "AWND": "wspd",
    }

    obs = pd.read_csv(file_name, usecols=keep_cols.keys())
    obs.rename(columns=keep_cols, inplace=True)

    # ncei uses 9999 for missing data
    INVALID_NUMBER = 9999
    obs.replace(INVALID_NUMBER, pd.NA, inplace=True)

    # ncei uses 10ths of values
    columns_to_decimalize = ["precip", "tmax", "tmin", "wspd", "temp"]
    for col in columns_to_decimalize:
        if col in obs.columns:
            obs[col] = obs[col] / 10.0

    obs = obs[obs["time"] >= start_date]
    obs["time"] = pd.to_datetime(obs["time"])

    # Only calculate degree days if both tmin and tmax are available
    if "tmin" in obs.columns and "tmax" in obs.columns:
        # Note that these long_names are identical to what data_timeseries returns:
        obs["tmax"].attrs["units"] = "degC"
        obs["tmax"].attrs["long_name"] = "2 metre daily temperature"

        obs["tmin"].attrs["units"] = "degC"
        obs["tmin"].attrs["long_name"] = "2 metre daily temperature"

        # HDD & CDD settle off the mean of tmin/tmax, not the daily mean
        temp = pd.concat([obs["tmin"], obs["tmax"]], axis=1).mean(axis=1)
        obs["cdd"] = (temp - xdd).clip(lower=0).round(1)
        obs["hdd"] = (xdd - temp).clip(lower=0).round(1)

        # Some stations such as USW00023152 don't record temp but do have tmin/tmax.
        # Replace missing temp with mean(tmin,tmax) as an approximation.
        temp_na_mask = obs["temp"].isna()
        obs.loc[temp_na_mask, "temp"] = temp[temp_na_mask]

        obs["hdd"].attrs["units"] = "HDD day-1 (degC)"
        obs["hdd"].attrs["long_name"] = "Heating Degree Days"

        obs["cdd"].attrs["units"] = "CDD day-1 (degC)"
        obs["cdd"].attrs["long_name"] = "Cooling Degree Days"

    obs["temp"].attrs["units"] = "degC"
    obs["temp"].attrs["long_name"] = "2 metre temperature"

    obs["precip"].attrs["units"] = "mm day-1"
    obs["precip"].attrs["long_name"] = "Total precipitation"

    obs["wspd"].attrs["units"] = "m s**-1"
    obs["wspd"].attrs["long_name"] = "Wind Speed"

    return obs

In [None]:
if validate_obs:
    obs_df = get_ghcnd(stations.ghcnd, start_date=start_date, force=force)
    # Output is a vector of DataFrames, one per station.  Let's inspect the first:
    print(obs_df[0])
else:
    print("skipped: not comparing to observed data")
    obs_df = None

          ghcnd_id       time       lat       lon   elev                                               name  precip  tmax  tmin  wspd  temp  cdd   hdd
33327  USW00013874 2021-04-01  33.62972 -84.44224  308.2  ATLANTA HARTSFIELD JACKSON INTERNATIONAL AIRPO...     0.0  11.7   4.4   8.8   8.4  0.0  10.3
33328  USW00013874 2021-04-02  33.62972 -84.44224  308.2  ATLANTA HARTSFIELD JACKSON INTERNATIONAL AIRPO...     0.0  13.9   0.6   4.0   6.7  0.0  11.1
33329  USW00013874 2021-04-03  33.62972 -84.44224  308.2  ATLANTA HARTSFIELD JACKSON INTERNATIONAL AIRPO...     0.0  17.8   1.7   1.3   9.1  0.0   8.6
33330  USW00013874 2021-04-04  33.62972 -84.44224  308.2  ATLANTA HARTSFIELD JACKSON INTERNATIONAL AIRPO...     0.0  22.8   4.4   2.5  13.8  0.0   4.7
33331  USW00013874 2021-04-05  33.62972 -84.44224  308.2  ATLANTA HARTSFIELD JACKSON INTERNATIONAL AIRPO...     0.0  25.6   9.4   2.5  17.5  0.0   0.8
...            ...        ...       ...       ...    ...                                      

Use the `make_observed_ds` function to reformat the `DataFrame`s of observed data into an `xarray.Dataset` with the same format as the historical timeseries returned by `data_timeseries`.


In [None]:
if validate_obs:
    obs = sk.observed.make_observed_ds(
        obs_df=obs_df,  # a DataFrame or vector of DataFrames
        name=stations.name,  # this will populate the location coordinate
        variable=var,  # make sure that the DataFrame(s) in obs_df contain this column name
        time_col="time",  # the name of the column in obs_df containing the date
    )
    print(obs)
else:
    print("skipped: not comparing to observed data")
    obs = None

<xarray.Dataset> Size: 146kB
Dimensions:       (time: 1303, location: 13)
Coordinates:
  * time          (time) datetime64[ns] 10kB 2021-04-01 ... 2024-10-24
  * location      (location) <U3 156B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
    lat_station   (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon_station   (location) float64 104B -84.44 -71.01 -118.4 ... -122.6 -121.5
    elev_station  (location) float64 104B 308.2 3.2 222.7 204.8 ... 2.1 6.7 5.9
Data variables:
    vals          (time, location) float64 136kB 8.4 10.3 22.5 ... 17.5 8.1 16.8
Attributes:
    short_name:  temp
    units:       degC
    long_name:   2 metre temperature


### Compare observed and ERA5 datasets

Via `make_observed_ds`, the observed station data (`obs`) is formatted the same as the ERA5 historical data (`hist`). This means we can easily compare one to the other and see the degree of bias that exists between the two.


In [None]:
if validate_obs:
    # Pull observed and ERA5 into a single dataset for easy comparison:
    merged = xr.merge(
        [
            obs.rename({"vals": "obs"}),
            xr.load_dataset(hist).rename({"vals": "hist"}),
        ],
        join="inner",
    )
    merged["delta_raw"] = merged["obs"] - merged["hist"]
    # Daily bias is noisy.  Smooth it out to make trends clearer.
    # This will induce nans at the beginning and end of the timeseries.
    merged["delta"] = (
        merged["delta_raw"]
        .rolling(time=max(1, int(len(merged["time"]) * 0.1)), center=True)
        .mean()
    )

    print(merged)
else:
    print("skipped: not comparing to observed data")
    merged = None

<xarray.Dataset> Size: 83kB
Dimensions:       (time: 193, location: 13)
Coordinates:
  * time          (time) datetime64[ns] 2kB 2021-04-01 2021-04-02 ... 2021-10-10
  * location      (location) <U3 156B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
    lat_station   (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon_station   (location) float64 104B -84.44 -71.01 -118.4 ... -122.6 -121.5
    elev_station  (location) float64 104B 308.2 3.2 222.7 204.8 ... 2.1 6.7 5.9
    lat           (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon           (location) float64 104B -84.44 -71.01 -118.4 ... -122.6 -121.5
Data variables:
    obs           (time, location) float64 20kB 8.4 10.3 22.5 ... 18.8 14.0 15.7
    hist          (time, location) float64 20kB 8.25 8.943 21.13 ... 12.31 16.97
    delta_raw     (time, location) float64 20kB 0.1502 1.357 ... 1.688 -1.273
    delta         (time, location) float64 20kB nan nan nan nan ... nan nan nan
Attributes:
    sh

In [None]:
if validate_obs:
    # Visualize obs-era5 bias over time at each station
    merged["delta"].plot.line(x="time", hue="location")
    plt.axhline(y=0, color="k", linestyle="--", alpha=0.5)
    plt.title("Delta of Observed to Historical Values")
    plt.ylabel(f'{merged.attrs["long_name"]} obs - hist [{merged.attrs["units"]}]')
    plt.xlabel("")
    plt.grid(True, alpha=0.3)
    plt.show()

## Forecast

The [`forecast_timeseries`](https://sdk.salientpredictions.com/api/#salientsdk.forecast_timeseries) API endpoint and SDK function returns Salient's native temporally granular weekly/monthly/quarterly forecasts.

This is the most heavyweight call in the notebook, since it's getting multiple historical forecasts. In the first cell, we set a `split_set` variable that controls the amount of data to requeest via the `start_date` and `end_date` variables:

- `sample` - a single year of data, good for quickly making sure that the mechanics of the process work.
- `test` - gets data from 2015-2022, which is guaranteed out-of-sample from the training process. This requests a medium amount of data and is recommended for most validation processes.
- `all` - gets data from 2000-2022, representing the full historical evaluation record. This will download quite a bit of data and is not recommended for most applications.


In [None]:
# get_hindcast_dates is a utility that returns all valid hindcast initializations.
date_range = sk.get_hindcast_dates(start_date=start_date, end_date=end_date, timescale=timescale)

fcst = sk.forecast_timeseries(
    loc=loc,
    variable=var,
    field=fld,
    date=date_range,
    timescale=timescale,
    model=[model, ref_model],
    reference_clim="30_yr",  # this is the same climatology used by data_timeseries
    debias=False,
    verbose=False,
    force=force,
    strict=False,  # There is missing data in 2020.  Work around it.
)
fcst["debias"] = False

In [None]:
# Get debiased hindcasts, if we are validating vs observations
if validate_obs:
    fcst_debias = sk.forecast_timeseries(
        loc=loc,
        variable=var,
        field=fld,
        date=date_range,
        timescale=timescale,
        model=model,  # debias only avilable for model blend
        reference_clim="30_yr",
        debias=True,  # Note debias
        verbose=False,
        force=force,
        strict=False,
    )
    fcst_debias["model"] = model
    fcst_debias["debias"] = True
    fcst = pd.concat([fcst, fcst_debias], axis=0)

print(fcst)

                                            file_name        date  model  debias
0   validation_example/forecast_timeseries_2080b77...  2021-04-01  blend   False
1   validation_example/forecast_timeseries_d2f085b...  2021-04-01   clim   False
2   validation_example/forecast_timeseries_3e8ae2d...  2021-04-04  blend   False
3   validation_example/forecast_timeseries_cbd0d20...  2021-04-04   clim   False
4   validation_example/forecast_timeseries_f8f2b29...  2021-04-07  blend   False
..                                                ...         ...    ...     ...
62  validation_example/forecast_timeseries_c53d5ea...  2021-08-21  blend    True
63  validation_example/forecast_timeseries_6edae86...  2021-08-24  blend    True
64  validation_example/forecast_timeseries_86981d1...  2021-08-25  blend    True
65  validation_example/forecast_timeseries_ea48660...  2021-08-28  blend    True
66  validation_example/forecast_timeseries_07800f3...  2021-08-31  blend    True

[201 rows x 4 columns]


In [None]:
# Check to see if there are any missing forecasts:
fcst_na = fcst[fcst["file_name"].isna()]
if not fcst_na.empty:
    print("Missing forecast dates:")
    print(fcst_na)

# Example forecast file is for a single model and a single forecast_date
print(xr.load_dataset(fcst["file_name"].values[0]))

<xarray.Dataset> Size: 13kB
Dimensions:                 (quantiles: 23, location: 13, lead_weekly: 5,
                             nbnds: 2)
Coordinates:
  * quantiles               (quantiles) float64 184B 0.01 0.025 ... 0.975 0.99
  * location                (location) object 104B 'ATL' 'BOS' ... 'PDX' 'SAC'
    lat                     (location) float64 104B 33.63 42.36 ... 45.6 38.51
    lon                     (location) float64 104B -84.44 -71.01 ... -121.5
    forecast_period_weekly  (lead_weekly, nbnds) datetime64[ns] 80B 2021-04-0...
  * lead_weekly             (lead_weekly) int32 20B 1 2 3 4 5
    forecast_date_weekly    datetime64[ns] 8B 2021-04-01
Dimensions without coordinates: nbnds
Data variables:
    vals_weekly             (lead_weekly, location, quantiles) float64 12kB 9...
Attributes:
    clim_period:  ['1990-01-01', '2019-12-31']
    region:       north-america
    short_name:   temp
    timescale:    sub-seasonal


## Calculate Skill Metrics

Compare the forecast and observed datasets to see how well they match.


In [None]:
skill_fcst = sk.skill.crps(
    observations=hist,
    forecasts=fcst[(fcst["model"] == model) & (fcst["debias"] == False)],
)

In [None]:
skill_ref = sk.skill.crps(
    observations=hist,
    forecasts=fcst[(fcst["model"] == ref_model) & (fcst["debias"] == False)],
)
skills = [
    skill_ref.assign_coords(source=ref_model),
    skill_fcst.assign_coords(source=model),
]

In [None]:
if validate_obs:
    skill_obs = sk.skill.crps(
        observations=obs,
        forecasts=fcst[(fcst["model"] == model) & (fcst["debias"] == True)],
    )
    skills.append(skill_obs.assign_coords(source="debiased"))

skills = xr.concat(skills, dim="source").round(2)
print(skills)

<xarray.Dataset> Size: 2kB
Dimensions:       (location: 13, lead_weekly: 5, source: 3)
Coordinates:
  * location      (location) object 104B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
    lat           (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon           (location) float64 104B -84.44 -71.01 -118.4 ... -122.6 -121.5
  * lead_weekly   (lead_weekly) int32 20B 1 2 3 4 5
  * source        (source) <U8 96B 'clim' 'blend' 'debiased'
    lat_station   (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon_station   (location) float64 104B -84.44 -71.01 -118.4 ... -122.6 -121.5
    elev_station  (location) float64 104B 308.2 3.2 222.7 204.8 ... 2.1 6.7 5.9
Data variables:
    crps          (source, lead_weekly, location) float64 2kB 0.79 1.2 ... 1.02
Attributes:
    clim_period:  ['1990-01-01', '2019-12-31']
    region:       north-america
    short_name:   crps
    timescale:    sub-seasonal
    long_name:    CRPS


In [None]:
skills["crps"].plot.line(
    x="lead_weekly", hue="source", col="location", col_wrap=3, figsize=(10, 10)
)
plt.suptitle(f"{var} CRPS", fontsize=16)
plt.subplots_adjust(top=0.9)
plt.show()

### Calculate Relative Skill

CRPS shows skill without context. A "skill score" will compare two different skills to generate relative value. In the example below, we will compare the Salient blend with climatology (historical averages).


In [None]:
skill_score = sk.skill.crpss(forecast=skill_fcst, reference=skill_ref)

# Represent the skill scores as a human-readable table of the same format as we generated
# for the hindcast_summary results.
skill_table = (
    xr.merge(
        [
            (skill_ref.rename({"crps": "Reference CRPS"})).round(2),
            skill_fcst.rename({"crps": "Salient CRPS"}).round(2),
            (skill_score * 100).rename({"crpss": "Salient CRPS Skill Score (%)"}).round(1),
        ]
    )
    .to_dataframe()
    .reset_index()
    .dropna(how="any")
    .drop(columns=["lat", "lon"])
    .rename(columns={"location": "Location", "lead_weekly": "Lead"})
)
skill_table["Lead"] = "Week " + skill_table["Lead"].astype(str)
skill_table.set_index(["Location", "Lead"], inplace=True)

print(skill_table)

                 Reference CRPS  Salient CRPS  Salient CRPS Skill Score (%)
Location Lead                                                              
ATL      Week 1            0.79          0.36                          54.6
         Week 2            0.73          0.69                           6.2
         Week 3            0.69          0.75                          -8.2
         Week 4            0.66          0.73                          -9.4
         Week 5            0.66          0.77                         -15.7
...                         ...           ...                           ...
SAC      Week 1            1.13          0.44                          60.7
         Week 2            1.19          0.91                          23.8
         Week 3            1.18          1.03                          12.6
         Week 4            1.17          0.97                          17.1
         Week 5            1.15          0.97                          15.8

[65 rows x 

In [None]:
plot_skill(skill_summ, title=f"manually-calculated crps {model} vs {ref_model}")

### Compare manually-calculated to pre-computed skill

Now that we have a CRPS calculated manually as well as downloaded from `hindcast_summary` we can evaluate how close the two values are.


In [None]:
skill_merge = pd.merge(
    skill_summ.add_prefix("Summary "),
    skill_table.add_prefix("Manual "),
    left_index=True,
    right_index=True,
)

print(skill_merge)

Now let's visualize the manually-calculated skill score with the precomputed skill scores published by `hindcast_summary`.

Note that when using `split_set = sample` the values won't match exactly. In this case we are plotting skill scores calculated from a single year of forecasts against the skill scores from the `test` set.


In [None]:
def compare_cols(col_name: str) -> None:
    """Plot manual and precalculated skill columns."""
    summary_col = f"Summary {col_name}"
    manual_col = f"Manual {col_name}"

    df = skill_merge.reset_index()

    plt.figure(figsize=(10, 6))

    for location in df["Location"].unique():
        subset = df[df["Location"] == location]
        plt.scatter(subset[summary_col], subset[manual_col], label=location, s=100)

    # Same limits for both axes
    min_limit = min(df[summary_col].min(), df[manual_col].min())
    max_limit = max(df[summary_col].max(), df[manual_col].max())
    plt.xlim(min_limit, max_limit)
    plt.ylim(min_limit, max_limit)
    plt.plot(
        [min_limit, max_limit], [min_limit, max_limit], color="gray", linestyle="--", linewidth=1
    )
    plt.gca().set_aspect("equal", adjustable="box")

    plt.title(f"Summary vs Manual {col_name}")
    plt.xlabel(summary_col)
    plt.ylabel(manual_col)
    plt.legend(title="Location", bbox_to_anchor=(1.05, 1), loc="upper left")

    if split_set == "sample":
        plt.text(
            min_limit,
            max_limit,
            "Results not expected to match for split_set = 'summary'.\nUse split_set = 'test' for a full comparison.",
            fontsize=10,
            verticalalignment="top",
            horizontalalignment="left",
            color="red",
        )

    plt.tight_layout()

    # Show the plot
    plt.show()


compare_cols("Salient CRPS Skill Score (%)")
# compare_cols("Salient CRPS")
# compare_cols("Reference CRPS")