# Skill profiling

This notebook will do some general skill profiling of the analog forecast.  

In [1]:
from itertools import product
import numpy as np
import pandas as pd
import xarray as xr
import dask
from dask.distributed import Client
import tqdm
# local
from analog_forecast import find_analogs, read_subset_era5, make_forecast
from config import data_dir
import luts

In [2]:
# start a dask cluster
client = Client(n_workers=16, dashboard_address="localhost:33338")

In [21]:
# choose sample of random dates for simulating
sub_da = read_subset_era5("alaska", data_dir, "t2m", use_anom=False)
ref_dates = np.random.choice(sub_da.time.values, 50, replace=False)
ref_dates = [pd.to_datetime(d).strftime("%Y-%m-%d") for d in ref_dates]

varnames = list(luts.varnames_lu.keys())
spatial_domains = list(luts.spatial_domains.keys())
use_anom = [True, False]

args = list(product(varnames, spatial_domains, ref_dates, use_anom))

In [5]:
def get_possible_naive_timestamps(times, ref_date):
    before_forecast = pd.date_range(
        sub_da.time.values[0], 
        pd.to_datetime(ref_date + " 12:00:00") - pd.to_timedelta(1, unit="d")
    )
    after_forecast = pd.date_range( 
        pd.to_datetime(ref_date + " 12:00:00") + pd.to_timedelta(15, unit="d"),
        times[-15]
    )
    times = np.concatenate([before_forecast, after_forecast])
    
    return times


@dask.delayed
def forecast_and_error(sub_da, times, ref_date):
    naive_forecast = make_forecast(sub_da, times, ref_date)
    err = sub_da.sel(time=naive_forecast.time.values) - naive_forecast
    return (err ** 2).mean(axis=(1, 2))


def run_profile(varname, spatial_domain, ref_date, use_anom):
    analogs = find_analogs(varname, ref_date, spatial_domain, data_dir, 16, use_anom)
    sub_da = read_subset_era5(spatial_domain, data_dir, varname, use_anom=False)
    forecast = make_forecast(sub_da, analogs.time.values, ref_date)
    
    possible_naive_times = get_possible_naive_timestamps(sub_da.time.values, ref_date)
    # use dask delayed ot simulate 1000 naive forecasts
    results = []
    n = 500
    for i in range(n):
        times = np.random.choice(possible_naive_times, 5, replace=False)
        sim_rmse = forecast_and_error(sub_da, times, ref_date)
        results.append(sim_rmse)

    sim_rmse = xr.concat(dask.compute(*results), pd.Index(range(n), name="sim"))
    analog_err = sub_da.sel(time=forecast.time.values) - forecast
    
    err_df = pd.DataFrame({
        "variable": varname,
        "spatial_domain": spatial_domain,
        "anomaly_search": use_anom,
        "reference_date": ref_date,
        "forecast_day_number": np.arange(14) + 1,
        "analog": (analog_err ** 2).mean(axis=(1, 2)).values,
        "naive_2.5": sim_rmse.reduce(np.percentile, dim="sim", q=2.5),
        "naive_97.5": sim_rmse.reduce(np.percentile, dim="sim", q=97.5),
    })
    
    return err_df

Run the simulation!

In [None]:
results = []
for arg in tqdm.tqdm(args):
    results.append(run_profile(*arg))


  0%|                                                                                                                                          | 0/1600 [00:00<?, ?it/s][A