# Validate Downscale Skill

This example shows how to evaluate Salient's daily downscaled forecasts and calculate meaningful metrics. It demonstrates [validation best practices](https://salientpredictions.notion.site/Validation-0220c48b9460429fa86f577914ea5248) such as:

- Proper scoring using the Ensemble Continuous Ranked Probability Score (CRPS)
  - Considers the full forecast distribution to reward both accuracy and precision
  - Less sensitive to climatology decisions than metrics like Anomaly Correlation
- A long backtesting period (2015-2022)
  - Short evaluation periods are subject to noise


In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

try:
    import salientsdk as sk
except ModuleNotFoundError as e:
    if os.path.exists("../salientsdk"):
        sys.path.append(os.path.abspath(".."))
        import salientsdk as sk
    else:
        raise ModuleNotFoundError("Install salient SDK with: pip install salientsdk")

# Prevent wrapping on tables for readability
pd.set_option("display.width", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.expand_frame_repr", False)

sk.set_file_destination("validate_ensemble_example")
sk.login("SALIENT_USERNAME", "SALIENT_PASSWORD")

<requests.sessions.Session at 0x7fc556a8a090>

## Customize The Validation

This notebook is written flexibly so you have the option of validating Salient and other forecasts multiple ways. These variables will control what, when, and how the validation proceeds.


In [None]:
# 1. The meteorological variable that we'll be evaluating:
vars = ["temp", "precip"]  # wspd, tsi

# 2. Number of days in the downscale:
length = 35  # fast 35-day evaluation vs gefs (also 35 days)
# length = 366  # comprehensive full-year evaluation

# 3. Debias temp, precip, and wind
debias = False  # Evaluate vs ERA5
# debias = True  # Evaluate vs observations from GHCNd

# 4. Set the date range to test over.
(start_date, end_date) = ("2021-04-01", "2021-07-31")  # fast "sample" dataset
# (start_date, end_date) = ("2015-01-01", "2022-12-31")  # out-of-sample "test" set
# (start_date, end_date) = ("2000-01-01", "2022-12-31")  # comprehensive "all-history" set

# 5. The reference model to compare Salient blend to
# ref_model = "none"  # skip the reference model comparison
ref_model = "noaa_gefs"  # good for daily-frequency 35-day comparisons
# ref_model = "noaa_gfs"  # hourly-frequency comparisons

# 6. Variable to focus on for plots
plot_var = "temp"
# plot_var = "precip"
assert plot_var in vars


# ===== Additional shared variables ==========================
# Not recommended to change these.

# For specialized use only, Salient can manually provide a "bulk downscale" zarr
# bulk = os.path.join(sk.get_file_destination(), "bulk_downscale")  # zarr directory
bulk = None  # don't use bulk downscale (default)

# Temporal resolution of the downscaled & historical timeseries:
freq = "daily"
# freq = "hourly"

# Caching strategy:
force = False  # Cache data to save on repeat API calls
# force = True  # Repeat API calls, even if data exists

# Make all figures have a consistent size:
figsize = (8, 5)

# Determine which dates for which to request forecasts:
date_range = pd.date_range(start=start_date, end=end_date, freq="D")

# GEFS is unavailable for 2020 January-September
if "gefs" in ref_model:
    date_range = date_range[~((date_range.year == 2020) & (date_range.month <= 9))]

# Find first Wednesday on or after 16th of each month
date_range = (
    date_range[(date_range.dayofweek == 2) & (date_range.day >= 16)]
    .to_series()
    .groupby([lambda x: x.year, lambda x: x.month])
    .first()
    .dt.strftime("%Y-%m-%d")
    .tolist()
)

print("Forecast dates to test:")
print(date_range)

Forecast dates to test:
['2021-04-21', '2021-05-19', '2021-06-16', '2021-07-21']


## Set the Area of Interest

The Salient SDK uses a "Location" object to specify the geographic bounds of a request. In this case, we will be validating against the vector of airport locations that are used to settle the Chicago Mercantile Exchange's Cooling and Heating Degree Day contracts. With `load_location_file` we can see that the file contains:

- `lat` / `lon`: latitude and longitude of the met station, standard for a `location_file`
- `name`: the 3-letter IATA airport code of the location, also `location_file` standard
- `ghcnd`: the global climate network ID of the station, used to validate against observations. To customize this analysis for any set of observation stations, use the NCEI [stations list](https://www.ncei.noaa.gov/pub/data/ghcn/daily/ghcnd-stations.txt).
- `cme`: the CME code for the location used to create CDD/HDD strip codes.
- `description`: full name of the airport

If you have a list of locations already defined in a separate CSV file, you can use [`upload_file`](https://sdk.salientpredictions.com/api/#salientsdk.upload_file) to upload the file directly without building it in code via `upload_location_file`.


In [None]:
# fmt: off
loc = sk.Location(location_file=sk.upload_location_file(
    lats =[33.62972     ,      42.36057,      34.19966,      41.96017,      39.04443,      32.89744,      29.98438,      36.07190,      44.88523,      40.77945,      39.87326,      45.59578,      38.50659],
    lons =[-84.44224    ,     -71.00975,    -118.36543,     -87.93164,     -84.67241,     -97.02196,     -95.36072,    -115.16343,     -93.23133,     -73.88027,     -75.22681,    -122.60919,    -121.49604],
    names=["ATL"        ,         "BOS",         "BUR",         "ORD",         "CVG",         "DFW",         "IAH",         "LAS",         "MSP",         "LGA",         "PHL",         "PDX",         "SAC"],
    ghcnd=["USW00013874", "USW00014739", "USW00023152", "USW00094846", "USW00093814", "USW00003927", "USW00012960", "USW00023169", "USW00014922", "USW00014732", "USW00013739", "USW00024229", "USW00023232"],
    cme  =["1"          ,           "W",           "P",           "2",           "3",           "5",           "R",           "0",           "Q",           "4",           "6",           "7",           "S"],
    geoname="cmeus",
    force=force,
    description=["Atlanta Hartsfield", "Boston Logan", "Burbank-Glendale-Pasadena", "Chicago O'Hare", "Cincinnati (Covington)","Dallas-Fort Worth", "Houston-George Bush", "Las Vegas McCarran", "Minneapolis-StPaul", "New York La Guardia","Philadelphia", "Portland", "Sacramento Exec"],
))
# fmt: on
stations = loc.load_location_file()
print(stations)

         lat        lon name        ghcnd cme                description                     geometry
0   33.62972  -84.44224  ATL  USW00013874   1         Atlanta Hartsfield   POINT (-84.44224 33.62972)
1   42.36057  -71.00975  BOS  USW00014739   W               Boston Logan   POINT (-71.00975 42.36057)
2   34.19966 -118.36543  BUR  USW00023152   P  Burbank-Glendale-Pasadena  POINT (-118.36543 34.19966)
3   41.96017  -87.93164  ORD  USW00094846   2             Chicago O'Hare   POINT (-87.93164 41.96017)
4   39.04443  -84.67241  CVG  USW00093814   3     Cincinnati (Covington)   POINT (-84.67241 39.04443)
5   32.89744  -97.02196  DFW  USW00003927   5          Dallas-Fort Worth   POINT (-97.02196 32.89744)
6   29.98438  -95.36072  IAH  USW00012960   R        Houston-George Bush   POINT (-95.36072 29.98438)
7   36.07190 -115.16343  LAS  USW00023169   0         Las Vegas McCarran   POINT (-115.16343 36.0719)
8   44.88523  -93.23133  MSP  USW00014922   Q         Minneapolis-StPaul   POINT (

## Historical Data

To calculate forecast skill, we will want to compare forecasts made in the past with actuals. There are two flavors of actual data: (1) The ERA5 reanalysis dataset and (2) point weather station observations.

Salient's forecast natively predicts (1) ERA5, but contains a debiasing function to remove bias between ERA5 and (2) station observations.


### Historical ERA5 or Observed Data

Download daily or hourly historical values from [`data_timeseries`](https://sdk.salientpredictions.com/api/#salientsdk.data_timeseries) or `get_ghcnd` and then aggregate to match the forecasts, so that we can ensure that all forecasts use the same dates.


#### Request Historical Data


In [None]:
if debias:
    # Use historical observations (not ERA5) as the truth.
    hist = xr.load_dataset(
        sk.met_observations(
            loc=loc,
            variables=vars,
            start=np.datetime64(start_date) - np.timedelta64(5, "D"),
            end=np.datetime64(end_date) + np.timedelta64(length + 1, "D"),
            verbose=False,
            force=force,
        )
    )
    # Make sure that we found the expected set of stations:
    assert hist.station.values.tolist() == stations.ghcnd.to_list()
    # Remove lat-lons to prevent merge conflicts later:
    hist = hist.reset_coords(drop=True)
elif bulk:
    # Use a Salient-provided bulk downscale zarr
    hist = xr.open_zarr(bulk)
    truth_vars = [var for var in hist.data_vars if "truth" in var]
    hist = hist[truth_vars].rename({var: var.replace("_truth", "") for var in truth_vars})
else:
    # Validate vs ERA5
    hist = sk.load_multihistory(
        sk.data_timeseries(
            loc=loc,
            variable=vars,
            field="vals",
            start=np.datetime64(start_date) - np.timedelta64(5, "D"),
            end=np.datetime64(end_date) + np.timedelta64(length + 1, "D"),
            frequency=freq,
            verbose=False,
            force=force,
        )
    )
print(hist)

<xarray.Dataset> Size: 36kB
Dimensions:   (time: 163, location: 13)
Coordinates:
  * time      (time) datetime64[ns] 1kB 2021-03-27 2021-03-28 ... 2021-09-05
    lat       (location) float64 104B 33.63 42.36 34.2 ... 39.87 45.6 38.51
    lon       (location) float64 104B -84.44 -71.01 -118.4 ... -122.6 -121.5
  * location  (location) <U3 156B 'ATL' 'BOS' 'BUR' 'ORD' ... 'PHL' 'PDX' 'SAC'
Data variables:
    temp      (time, location) float64 17kB 21.77 10.66 14.5 ... 21.44 24.16
    precip    (time, location) float64 17kB 4.033 0.0 0.0 ... 3.343 0.01152 0.0


## Downscale the Salient Forecast

The [`downscale`](https://sdk.salientpredictions.com/api/#salientsdk.downscale) API endpoint and SDK function converts Salient's native weekly/monthly/quarterly probabilistic forecasts into a daily or hourly ensemble timeseries.

This is the most heavyweight call in the notebook, since it's getting multiple historical forecasts.


In [None]:
if bulk:
    # Use a Salient-provided bulk downscale zarr
    fcst = xr.open_zarr(bulk)
    fcst_vars = [var for var in hist.data_vars if "truth" not in var]
    fcst = fcst[fcst_vars]
else:
    fcst = sk.downscale(
        loc=loc,
        variables=vars,
        debias=debias,
        date=date_range,
        frequency=freq,
        length=length,
        verbose=False,
        force=force,
        strict=False,  # if one downscale call fails, proceed with others
    )
    # Check to see if there are any missing forecasts:
    fcst_na = fcst[fcst["file_name"].isna()]
    if not fcst_na.empty:
        print("Missing forecast dates:")
        print(fcst_na)
print(fcst)

                                           file_name        date
0  validate_ensemble_example/downscale_d1f38fc0e6...  2021-04-21
1  validate_ensemble_example/downscale_47d128c2a7...  2021-05-19
2  validate_ensemble_example/downscale_3156cfdf5f...  2021-06-16
3  validate_ensemble_example/downscale_f83414a94d...  2021-07-21


## Reference Forecast


In [None]:
ref_source = (
    None
    if ref_model == "none"  # skip reference comparisons
    else sk.forecast_timeseries(
        loc=loc,
        variable=vars,
        date=date_range,
        field="vals_ens",
        model=ref_model,
        timescale=freq,
        strict=False,
        force=force,
        verbose=False,
    )
)

In [None]:
def format_as_downscale(ref: pd.DataFrame) -> pd.DataFrame:
    """Reformat single-variable forecast_timeseries to match downscale's format.

    Changes:
    - Renames 'lead' dimension to 'forecast_day'
    - Reorders dimensions to (ensemble, forecast_day, location)
    - Converts forecast_day values to datetime using forecast_date + lead
    """
    if ref is None:
        return None

    def process_date(date, group):
        out_file = os.path.join(sk.get_file_destination(), f"forecast_timeseries_ref_{date}.nc")

        # Load first dataset to get forecast_date
        first_ds = xr.load_dataset(group.iloc[0].file_name, decode_timedelta=True)
        forecast_date = first_ds.forecast_date

        # Create dataset with reordered dimensions
        ds = xr.Dataset(
            {
                row.variable: xr.load_dataset(row.file_name, decode_timedelta=True)
                .vals_ens.transpose("ensemble", "lead", "location")
                .assign_attrs(xr.load_dataset(row.file_name, decode_timedelta=True).attrs)
                for _, row in group.iterrows()
            }
        )

        # Convert lead to forecast_day
        ds = ds.rename({"lead": "forecast_day"})
        ds["forecast_day"] = forecast_date + ds.forecast_day

        ds.to_netcdf(out_file, encoding={"location": {"dtype": str}})
        return {"file_name": out_file, "date": date}

    result_files = [process_date(date, group) for date, group in ref.groupby("date")]
    return pd.DataFrame(result_files)


ref = format_as_downscale(ref_source)
print(ref)

                                           file_name        date
0  validate_ensemble_example/forecast_timeseries_...  2021-04-21
1  validate_ensemble_example/forecast_timeseries_...  2021-05-19
2  validate_ensemble_example/forecast_timeseries_...  2021-06-16
3  validate_ensemble_example/forecast_timeseries_...  2021-07-21


## Calculate Skill Metrics

Compare the forecast and ERA5 datasets to see how well they match. Here we will calculate the same "Continuous Ranked Probability Score" that resulted from the call to `hindcast_summary` earlier.


In [None]:
skill = sk.skill.crps_ensemble(observations=hist, forecasts=fcst)
print(skill)

<xarray.Dataset> Size: 37kB
Dimensions:          (forecast_date: 4, lead: 35, location: 13)
Coordinates:
  * location         (location) <U3 156B 'ATL' 'BOS' 'BUR' ... 'PHL' 'PDX' 'SAC'
    lat              (location) float64 104B 33.63 42.36 34.2 ... 45.6 38.51
    lon              (location) float64 104B -84.44 -71.01 ... -122.6 -121.5
  * lead             (lead) timedelta64[ns] 280B 0 days 1 days ... 34 days
  * forecast_date    (forecast_date) datetime64[ns] 32B 2021-04-21 ... 2021-0...
Data variables:
    crps_temp_all    (forecast_date, lead, location) float64 15kB 0.7206 ... ...
    crps_precip_all  (forecast_date, lead, location) float64 15kB 0.0002451 ....
    crps_temp        (lead, location) float64 4kB 0.5861 0.2641 ... 1.534 1.364
    crps_precip      (lead, location) float64 4kB 0.408 1.316 ... 1.646 0.007467
Attributes:
    short_name:  crps
    long_name:   CRPS


In [None]:
fig, ax = plt.subplots(figsize=figsize)
skill[f"crps_{plot_var}"].plot.line(x="lead", hue="location", add_legend=True, alpha=0.5, ax=ax)
skill[f"crps_{plot_var}"].mean("location").plot(ax=ax, color="black", linewidth=2, label="Mean")
ax.set_xlabel("Lead Time (days)")
plot_name = skill[f"crps_{plot_var}"].attrs.get("long_name", plot_var.title())
plot_units = (
    f'[{skill[f"crps_{plot_var}"].attrs["units"]}]'
    if "units" in skill[f"crps_{plot_var}"].attrs
    else ""
)
ax.set_ylabel(f"CRPS {plot_name} {plot_units}".strip())

plt.tight_layout()

### Calculate Skill Relative to Reference


In [None]:
if ref is None:
    print("No reference model, skipping relative comparison")
    skill_ref = None
    skill_rel = None
else:
    # Skill of the reference model:
    skill_ref = sk.skill.crps_ensemble(observations=hist, forecasts=ref)

    # Relative skills score of Salient downscale vs the reference model:
    skill_rel = sk.skill.crpss(forecast=skill, reference=skill_ref)
    print(skill_rel)

<xarray.Dataset> Size: 36kB
Dimensions:           (location: 13, lead: 34, forecast_date: 4)
Coordinates:
  * location          (location) <U3 156B 'ATL' 'BOS' 'BUR' ... 'PDX' 'SAC'
  * lead              (lead) timedelta64[ns] 272B 1 days 2 days ... 34 days
  * forecast_date     (forecast_date) datetime64[ns] 32B 2021-04-21 ... 2021-...
    lat               (location) float64 104B 33.63 42.36 34.2 ... 45.6 38.51
    lon               (location) float64 104B -84.44 -71.01 ... -122.6 -121.5
Data variables:
    crpss_temp_all    (forecast_date, lead, location) float64 14kB 0.8007 ......
    crpss_precip_all  (forecast_date, lead, location) float64 14kB 1.0 ... 0....
    crpss_temp        (lead, location) float64 4kB 0.5471 0.8113 ... 0.2946
    crpss_precip      (lead, location) float64 4kB 0.08583 0.9439 ... -0.09593
Attributes:
    short_name:  crpss
    long_name:   CRPSS


In [None]:
if skill_ref is None:
    print("Skipping relative skill plotting")
else:
    fig, ax = plt.subplots(figsize=figsize)

    skill_ref[f"crps_{plot_var}"].mean("location").plot(
        ax=ax,
        color="#FF7F00",
        linewidth=2,
        label=ref_model.replace("_", " ").upper(),
    )
    skill[f"crps_{plot_var}"].mean("location").plot(
        ax=ax,
        color="dodgerblue",
        linewidth=2,
        label="Salient downscale",
    )

    ax.xaxis.set_major_formatter(lambda x, pos: f"{x/1e9/86400:.0f}")
    ax.set_xlabel("Lead Time (days)")
    ax.set_ylabel(
        f'CRPS {skill[f"crps_{plot_var}"].attrs.get("long_name", plot_var.title())} [{skill[f"crps_{plot_var}"].attrs.get("units", "")}]'
    )
    ax.set_title("All-locations Mean Error (lower is better)")
    plt.legend()
    plt.tight_layout()

In [None]:
if skill_rel is None:
    print("Skipping relative skill timeseries plot")
else:
    fig, ax = plt.subplots(figsize=figsize)
    skill_rel[f"crpss_{plot_var}"].plot.line(
        x="lead", hue="location", add_legend=True, alpha=0.5, ax=ax
    )
    skill_rel[f"crpss_{plot_var}"].mean("location").plot(
        ax=ax, color="black", linewidth=2, label="Mean"
    )
    ax.xaxis.set_major_formatter(lambda x, pos: f"{x/1e9/86400:.0f}")
    ax.axhline(y=0, color="grey", linestyle=":", zorder=0)
    ax.set_title(f"Relative Skill Salient downscale vs {ref_model} (higher is better)")
    ax.set_xlabel("Lead Time (days)")
    ax.set_ylabel(
        f'CRPSS ({skill_rel[f"crpss_{plot_var}"].attrs.get("long_name", plot_var.title())})'
    )
    ymin, ymax = ax.get_ylim()
    ax.set_ylim(max(-0.7, ymin), min(0.7, ymax))
    plt.tight_layout()

In [None]:
if skill_rel is None:
    print("Skipping relative skill boxplot")
else:
    medians = skill_rel[f"crpss_{plot_var}"].median("lead")
    sorted_locations = medians.sortby(medians, ascending=False).location.values
    fig, ax = plt.subplots(figsize=(8, 5))
    df = skill_rel[f"crpss_{plot_var}"].to_pandas().melt(ignore_index=False)
    ax.boxplot(
        [df[df["location"] == loc]["value"] for loc in sorted_locations],
        tick_labels=sorted_locations,  # Updated parameter name
        patch_artist=True,
        showfliers=False,
        medianprops=dict(color="black"),
        boxprops=dict(facecolor="dodgerblue"),
    )
    ax.axhline(y=0, color="grey", linestyle=":", zorder=0)
    ax.set_xlabel("Location")
    ax.set_ylabel(
        f'CRPSS ({skill_rel[f"crpss_{plot_var}"].attrs.get("long_name", plot_var.title())})'
    )
    ax.set_title(f"Relative Skill Salient downscale vs {ref_model} (higher is better)")
    ymin, ymax = ax.get_ylim()
    ax.set_ylim(max(-1, ymin), min(1, ymax))
    plt.tight_layout()