In [1]:
import os
import yaml
import copy

import numpy as np
import pandas as pd
import xarray as xr

In [2]:
import xesmf as xe

In [3]:
import matplotlib.pyplot as plt
%matplotlib inline

### Yuma_PG: max total precip

In [4]:
key = 'Yuma_PG'
base_dir = f'/glade/derecho/scratch/ksha/EPRI_data/METRICS/{key}/'

In [5]:
fn_CESM = base_dir + 'CESM_TP_max.zarr'
fn_ERA5 = base_dir + 'ERA5_TP_max.zarr'

ds_CESM = xr.open_zarr(fn_CESM)
ds_ERA5 = xr.open_zarr(fn_ERA5)

In [6]:
ds_CESM = ds_CESM.rename({'init_time': 'year_valid'})
ds_ERA5 = ds_ERA5.rename({'year': 'year_valid'})
varnames = list(ds_ERA5.keys())

ds_collection = []

for i_lead in range(10):
    ds_CESM_slice = ds_CESM.isel(year=i_lead)
    # convert ini_time to valid_time by adding lead times
    ds_CESM_slice['year_valid'] = ds_CESM_slice['year_valid'] + i_lead

    # align two datasets
    ds_ERA5_slice = ds_ERA5
    ds_CESM_slice, ds_ERA5_slice = xr.align(ds_CESM_slice, ds_ERA5_slice, join="inner")

    # compute ACC if there's enough time values
    if len(ds_CESM_slice['year_valid']) >= 10:
        da_collection = []
        for varname in varnames:
            da_collection.append(
                xr.corr(ds_CESM_slice[varname], ds_ERA5_slice[varname], dim="year_valid")
            )
            
        ds_ACC = xr.merge(da_collection)
        ds_collection.append(ds_ACC)

ds_ACC_all = xr.concat(ds_collection, dim='year_valid')

In [7]:
save_name = base_dir + 'ACC_TP_max.zarr'
# ds_ACC_all.to_zarr(save_name)

### Yuma_PG: extreme precip counts

In [12]:
def _sedi_from_hits_falsealarms(H, F, eps=1e-6):
    """
    H, F are DataArrays in [0,1]. Apply clipping to avoid log(0) and log(1-0).
    """
    Hc = H.clip(eps, 1 - eps)
    Fc = F.clip(eps, 1 - eps)

    num = (np.log(Fc) - np.log(Hc) - np.log(1 - Fc) + np.log(1 - Hc))
    den = (np.log(Fc) + np.log(Hc) + np.log(1 - Fc) + np.log(1 - Hc))
    return num / den

def _compute_sedi_binary(fcst, obs, dim, event_quantile=2/3):
    """
    Compute SEDI for deterministic forecasts, using an observation-based threshold.

    fcst, obs: DataArrays aligned on 'dim' and spatial dims.
    event_quantile: 2/3 = upper tercile (paper); 0.8 = 80th percentile (paper's ENSO-removed case).
    """
    # threshold from observations, per-gridcell
    thr = obs.quantile(event_quantile, dim=dim, skipna=True)

    obs_event  = obs  > thr
    fcst_event = fcst > thr

    # contingency components across time dim
    hits        = (fcst_event & obs_event).sum(dim=dim)
    misses      = (~fcst_event & obs_event).sum(dim=dim)
    falsealarms = (fcst_event & ~obs_event).sum(dim=dim)
    correctneg  = (~fcst_event & ~obs_event).sum(dim=dim)

    # rates
    H = hits / (hits + misses)
    F = falsealarms / (falsealarms + correctneg)

    sedi = _sedi_from_hits_falsealarms(H, F)
    return sedi

In [20]:
fn_CESM = base_dir + 'CESM_p999_even_counts.zarr'
fn_ERA5 = base_dir + 'ERA5_p999_even_counts.zarr'

ds_CESM = xr.open_zarr(fn_CESM)
ds_ERA5 = xr.open_zarr(fn_ERA5)

ds_CESM = ds_CESM.rename({'init_time': 'year_valid'})
ds_ERA5 = ds_ERA5.rename({'year': 'year_valid'})

ds_CESM = ds_CESM.chunk(dict(year_valid=-1))
ds_ERA5 = ds_ERA5.chunk(dict(year_valid=-1))

varnames = list(ds_ERA5.data_vars)

acc_list_all_leads = []
sedi_list_all_leads = []

for i_lead in range(10):
    ds_CESM_slice = ds_CESM.isel(year=i_lead).copy()

    # convert ini_time to valid_time by adding lead times
    ds_CESM_slice["year_valid"] = ds_CESM_slice["year_valid"] + i_lead

    ds_ERA5_slice = ds_ERA5

    # align (ensure year_valid is the shared dim)
    ds_CESM_slice, ds_ERA5_slice = xr.align(ds_CESM_slice, ds_ERA5_slice, join="inner")

    if ds_CESM_slice.dims.get("year_valid", 0) >= 10:
        acc_vars = []
        sedi_vars = []

        for v in varnames:
            fcst = ds_CESM_slice[v]
            obs  = ds_ERA5_slice[v]

            # ACC
            acc = xr.corr(fcst, obs, dim="year_valid")
            acc = acc.rename(f"{v}_ACC")
            acc_vars.append(acc)

            # SEDI (upper tercile events, as in the paperâ€™s categorical setup)
            sedi = _compute_sedi_binary(fcst, obs, dim="year_valid", event_quantile=2/3)
            sedi = sedi.rename(f"{v}_SEDI")
            sedi_vars.append(sedi)

        ds_metrics = xr.merge(acc_vars + sedi_vars)

        # keep a lead coordinate so concat is clean
        ds_metrics = ds_metrics.expand_dims(lead=[i_lead])
        acc_list_all_leads.append(ds_metrics)

ds_ACC_SEDI_all = xr.concat(acc_list_all_leads, dim="lead")

In [24]:
save_name = base_dir + 'ACC_p999_event.zarr'
# ds_ACC_SEDI_all.to_zarr(save_name)

<xarray.backends.zarr.ZarrStore at 0x152ccc65ba70>

### Yuma_PG: SPEI

In [10]:
fn_CESM = base_dir + 'CESM_SPEI.zarr'
fn_ERA5 = base_dir + 'ERA5_SPEI.zarr'

ds_CESM = xr.open_zarr(fn_CESM)
ds_ERA5 = xr.open_zarr(fn_ERA5)

In [11]:
# convert to annual min and mean
ds_group = ds_ERA5.groupby("time.year")
ds_mean = ds_group.mean(dim="time", skipna=False)
ds_min  = ds_group.min(dim="time",  skipna=False)
ds_mean = ds_mean.rename({v: f"{v}_mean" for v in ds_mean.data_vars})
ds_min  = ds_min.rename({v: f"{v}_min"  for v in ds_min.data_vars})
ds_ERA5_merge = xr.merge([ds_mean, ds_min])

In [12]:
bins = np.arange(0, 121, 12)

gb = ds_CESM.groupby_bins("lead_time_month", bins=bins, right=False)

ds_mean = (
    gb.mean("lead_time_month", skipna=False)
      .rename({"lead_time_month_bins": "year"})
      .assign_coords(year=np.arange(10))
)
ds_min = (
    gb.min("lead_time_month", skipna=False)
      .rename({"lead_time_month_bins": "year"})
      .assign_coords(year=np.arange(10))
)

ds_mean = ds_mean.rename({v: f"{v}_mean" for v in ds_mean.data_vars})
ds_min  = ds_min.rename({v: f"{v}_min"  for v in ds_min.data_vars})

ds_CESM_merge = xr.merge([ds_mean, ds_min], compat="override")

In [13]:
ds_ERA5_merge = ds_ERA5_merge.rename({'year': 'year_valid'})
ds_CESM_merge = ds_CESM_merge.rename({'init_time': 'year_valid'})

In [19]:
varnames = list(ds_ERA5_merge.keys())

ds_collection = []

for i_lead in range(10):
    ds_CESM_slice = ds_CESM_merge.isel(year=i_lead)
    # convert ini_time to valid_time by adding lead times
    ds_CESM_slice['year_valid'] = ds_CESM_slice['year_valid'] + i_lead
    
    # align two datasets
    ds_ERA5_slice = ds_ERA5_merge
    ds_CESM_slice, ds_ERA5_slice = xr.align(ds_CESM_slice, ds_ERA5_slice, join="inner")

    # compute ACC if there's enough time values
    if len(ds_CESM_slice['year_valid']) >= 10:
        da_collection = []
        for varname in varnames:
            da_collection.append(
                xr.corr(ds_CESM_slice[varname], ds_ERA5_slice[varname], dim="year_valid")
            )
            
        ds_ACC = xr.merge(da_collection)
        ds_collection.append(ds_ACC)

ds_ACC_all = xr.concat(ds_collection, dim='year_valid')

In [20]:
save_name = base_dir + 'ACC_SPEI.zarr'
# ds_ACC_all.to_zarr(save_name)

<xarray.backends.zarr.ZarrStore at 0x14ad6791d5b0>