# Quality Control

This notebook is for validating the drought indices dataset produced using the `scripts/process.py` script.

In [1]:
from pathlib import Path
import numpy as np
import xarray as xr
import pandas as pd
from config import INDICES_DIR, DOWNLOAD_DIR, CLIM_DIR
import luts
from scripts.snap.process_calibration_params import estimate_params
from xclim.indices.stats import dist_method

## Index validation

For each drought index, re-compute a value manually and compare with the indices dataset.

We will be working with the climatologies, the downloaded ERA5 data, and of course the computed indices data. Set up connections to these datasets.

Indices dataset:

In [2]:
intervals = pd.Index([1, 7, 30, 60, 90, 180, 365], name="interval")
fps = [INDICES_DIR.joinpath(f"nws_drought_indices_{i}day.nc") for i in intervals]
indices_ds = xr.open_mfdataset(fps, combine="nested", concat_dim=[intervals])

# get the refernce date
ref_date = pd.to_datetime(indices_ds.attrs["reference_date"])

Define a function to help with extracting data from grid cells in ERA5 downloads, since we will be doing this for every index:

In [28]:
def extract_era5(index, time_slice, lat, lon):
    """Function to open the three ERA5 datasets for a given variable name and extract the data from a grid cell for a given point location"""
    varname_lu = {
        "tp": "tp",
        "pntp": "tp",
        "swe": "sd",
        "pnswe": "sd",
        "spi": "tp",
        "smd1": "swvl1",
        "smd2": "swvl2",
        "pev": "pev"
    }
    varname = varname_lu[index]
    da_list = []
    latlon_sel_di = {"latitude": lat, "longitude": lon}
    for fp in DOWNLOAD_DIR.glob(f"{luts.varname_prefix_lu[varname]}*.nc"):
        with xr.open_dataset(fp) as ds:
            if "expver" in ds.dims:
                # if expver is present, combine from both into a single dataset and drop it
                da = xr.merge([
                    ds[varname].sel(
                        latlon_sel_di, method="nearest"
                    ).sel(expver=1).drop("expver"),
                    ds[varname].sel(
                        latlon_sel_di, method="nearest"
                    ).sel(expver=5).drop("expver")
                ])[varname].sel(time=time_slice)

            else:
                da = ds[varname].sel(
                    latlon_sel_di, method="nearest"
                ).sel(time=time_slice)

            da_list.append(da)
            
    out_da = xr.concat(da_list, dim="time").sortby("time")
    return out_da


def extract_clim(varname, doy_slice, lat, lon):
    clim_lu = {
        "tp": "era5_daily_tp_climatology_1981_2020_leap.nc",
        "swe": "era5_swe_climo_81-20.nc",
        "swvl": "era5_daily_swvl_1981_2020.nc"
    }
    with xr.open_dataset(CLIM_DIR.joinpath(clim_lu[varname])) as clim_ds:
        
        # need to re-index longitude in some cases, not consistent across clim datasets
        if clim_ds.longitude.values[0] == 180:
            clim_ds = clim_ds.assign_coords(
                longitude=(clim_ds.longitude.values) - 360
            )
        # same reason as above, inconsistency between clims... reindex coords with dayofyear
        try:
            if clim_ds.time.dt.year.values[0] == 1980:
                clim_ds = clim_ds.assign_coords(
                    time=clim_ds.time.dt.dayofyear
                )
        except TypeError:
            pass

        clim_da = clim_ds[varname].sel(time=doy_slice).sel(
            latitude=lat, longitude=lon, method="nearest"
        )

    return clim_da

    
def get_time_slice(ref_date, interval):
    start_date = ref_date - pd.to_timedelta(f"{interval - 1} day")
    return slice(start_date.strftime("%Y-%m-%d"), ref_date.strftime("%Y-%m-%d"))


def get_doy_slice(ref_date, interval):
    ref_doy = ref_date.timetuple().tm_yday
    start_doy = ref_doy - (interval - 1)
    return slice(start_doy, ref_doy)

Now work through each index and test the existing values against newly processed ones for the following intervals and locations:

In [4]:
interval = 30
lat, lon = 65, -148
latlon_sel_di = {"latitude": lat, "longitude": lon, "method": "nearest"}
ref_doy = ref_date.timetuple().tm_yday

#### Total precip

Total precip should be the sum of the precip values over the specified interval.

In [20]:
index = "tp"
test = indices_ds[index].sel(interval=interval).sel(
    latitude=lat, longitude=lon, method="nearest"
).compute()
time_slice = get_time_slice(ref_date, interval)
raw = extract_era5("tp", time_slice, lat, lon)
# convert m to cm
check = np.round((raw.sum() * 100).astype("float32"), 1)
# save value for pntp check
tp_check = check
assert check == test

#### Total precip % of normal

In [24]:
index = "pntp"
test = indices_ds[index].sel(interval=interval).sel(
    latitude=lat, longitude=lon, method="nearest"
).compute()
doy_slice = get_doy_slice(ref_date, interval)
clim = extract_clim("tp", doy_slice, lat, lon)
check = np.round((tp_check / clim.sum()).astype("float32") * 100)
assert check == test

#### Snow water equivalent

In [26]:
index = "swe"
test = indices_ds[index].sel(interval=interval).sel(
    latitude=lat, longitude=lon, method="nearest"
).compute()
time_slice = get_time_slice(ref_date, interval)
raw = extract_era5(index, time_slice, lat, lon)
check = np.round((raw.mean() * 100).astype("float32"), 1)
swe_check = check
assert check == test

#### SWE % of normal

In [32]:
index = "pnswe"
test = indices_ds[index].sel(interval=interval).sel(
    latitude=lat, longitude=lon, method="nearest"
).compute()
doy_slice = get_doy_slice(ref_date, interval)
clim = extract_clim("swe", doy_slice, lat, lon)
check = np.round((swe_check / clim.mean()), 1).astype("float32")
assert check == test

#### Soil moisture deficit

In [71]:
index = "smd"
test = indices_ds[index].sel(interval=interval).sel(
    latitude=lat, longitude=lon, method="nearest"
).compute()
time_slice = get_time_slice(ref_date, interval)
# two different extractions since we are working with two different 
#  levels here
raw1 = extract_era5(index + "1", time_slice, lat, lon)
raw2 = extract_era5(index + "2", time_slice, lat, lon)
#combine the levels
raw = (raw1 * 0.25) + (raw2 * 0.75)
doy_slice = get_doy_slice(ref_date, interval)
clim = extract_clim("swvl", doy_slice, lat, lon)
check = np.round((((clim.mean() - raw.mean()) / clim.mean()) * 100).astype("float32"), 1)
assert check == test

#### SPI & SPEI gamma parameters

Okay for these two indices we need to do little bit more checking. We have the daily precip and pev files at `/workspace/Shared/Tech_Projects/NWS_Drought_Indicators/project_data/calibration/`. We will first validate the estimated parameters of the gamma distributions fit to these daily data by recomputing them from the daily data.

##### SPI

In [10]:
# copied from /workspace/Shared/Tech_Projects/NWS_Drought_Indicators/project_data/calibration
cal_dir = Path("/atlas_scratch/kmredilla/nws_drought_indicators/calibration/")

tp_daily_fp = cal_dir.joinpath("era5_daily_tp_1981_2020.nc")
daily_tp_ds = xr.open_dataset(tp_daily_fp)
daily_tp = daily_tp_ds["tp"].sel(latitude=lat, longitude=lon, method="nearest")
spi_params = estimate_params(daily_tp, interval)
check = spi_params.sel(dayofyear=ref_doy).astype("float32")

with xr.open_dataset(CLIM_DIR.joinpath("spi_gamma_parameters.nc")) as spi_ds:
    test = spi_ds["params"].sel(
        interval=interval, latitude=lat, longitude=lon, dayofyear=ref_doy
    )

assert np.all(check == test)

##### SPEI

In [11]:
pev_daily_fp = cal_dir.joinpath("era5_daily_pev_1981_2020.nc")
daily_pev_ds = xr.open_dataset(pev_daily_fp)
daily_pev = daily_pev_ds["pev"].sel(latitude=lat, longitude=lon, method="nearest")
daily_wb = daily_tp + daily_pev
daily_wb += 0.002
spei_params = estimate_params(daily_wb, interval)
check = spei_params.sel(dayofyear=ref_doy).astype("float32")

with xr.open_dataset(CLIM_DIR.joinpath("spei_gamma_parameters.nc")) as spei_ds:
    test = spei_ds["params"].sel(
        interval=interval, latitude=lat, longitude=lon, dayofyear=ref_doy
    )

assert np.all(check == test)

CPU times: user 768 ms, sys: 1.26 s, total: 2.03 s
Wall time: 6.97 s


#### SPI & SPEI

Now we can use the gamma parameters we just validated and compute the indices.

##### SPI

In [30]:
index = "spi"
test = indices_ds[index].sel(interval=interval).sel(
    **latlon_sel_di
).compute()
time_slice = get_time_slice(ref_date, interval)
raw_tp = extract_era5("tp", time_slice, lat, lon)
pr = raw_tp.resample(time="1D").sum().mean()
spi_params.attrs["scipy_dist"] = "gamma"

# do the statistical parts
prob = dist_method(
    "cdf",
    spi_params.sel(dayofyear=ref_doy).astype("float32"),
    pr.where(pr > 0)
)
params_norm = xr.DataArray(
    [0, 1],
    dims=["dparams"],
    coords=dict(dparams=(["loc", "scale"])),
    attrs=dict(scipy_dist="norm"),
)
check = np.round(dist_method("ppf", params_norm, prob), 2)

assert check == test

##### SPEI

In [33]:
index = "spei"
test = indices_ds[index].sel(interval=interval).sel(
    **latlon_sel_di
).compute()
time_slice = get_time_slice(ref_date, interval)
raw_pev = extract_era5("pev", time_slice, lat, lon)
wb = (raw_tp + raw_pev).resample(time="1D").sum().mean()
wb += 0.002
spei_params.attrs["scipy_dist"] = "gamma"

# do the statistical parts
prob = dist_method(
    "cdf",
    spei_params.sel(dayofyear=ref_doy).astype("float32"),
    wb.where(wb > 0)
)
params_norm = xr.DataArray(
    [0, 1],
    dims=["dparams"],
    coords=dict(dparams=(["loc", "scale"])),
    attrs=dict(scipy_dist="norm"),
)
check = np.round(dist_method("ppf", params_norm, prob), 2)

assert check == test