In [1]:
import os
import zarr
import time
import numpy as np
import xarray as xr
import pandas as pd 
from glob import glob

In [2]:
from xclim.indices import (
    standardized_precipitation_evapotranspiration_index,
    water_budget,
)

from xclim.indices.stats import (
    standardized_index_fit_params, 
    standardized_index
)

In [3]:
import warnings
warnings.filterwarnings("ignore", message="Converting a CFTimeIndex.*noleap.*")

In [4]:
import matplotlib.pyplot as plt
%matplotlib inline

In [5]:
def time_to_lead_and_stack(ds_list, new_dim="member", labels=None, ref="first"):
    """
    ds_list: list[xr.Dataset], each has a 'time' coord with daily values
    new_dim: name of the new dimension to stack on
    labels: optional labels for new_dim (len == len(ds_list))
    ref: "first" (per-dataset first time) or a numpy/pandas datetime-like scalar
    """
    out = []
    for i, ds in enumerate(ds_list):
        ds = ds.sortby("time")

        if ref == "first":
            t0 = ds["time"].isel(time=0)
        else:
            # global reference (same for all ds)
            t0 = xr.DataArray(ref)

        lead = (ds["time"] - t0).astype("timedelta64[D]")  # daily deltas
        ds2 = ds.assign_coords(lead_time=("time", lead.data)).swap_dims({"time": "lead_time"})
        ds2 = ds2.drop_vars("time")  # optional: remove original coordinate

        out.append(ds2)

    if labels is None:
        labels = list(range(len(out)))

    stacked = xr.concat(out, dim=xr.IndexVariable(new_dim, labels))
    return stacked

In [6]:
def process_vars(data):
    """
    Preprocess data variables for SPEI calculation.
    """
    
    data["time"] = (pd.to_datetime(data["time"]).to_numpy().astype("datetime64[D]"))
    first_date = data.time.min().dt.strftime("%Y-%m-%d").values.item()
    
    if not first_date.endswith("-01-01"):
        first_year = int(first_date[:4])
        data = data.sel(time=slice(str(first_year + 1), None))

    precip = data["precip"]
    tmin = data["tmin"]
    tmax = data["tmax"]
    
    return precip, tmin, tmax
    
def calc_spei_and_params(
    precip, tmin, tmax, 
    agg_freq, start_date, end_date, 
    lat=None, dist="fisk", method="ML"):
    
    # --- Daily water budget for full period
    wb = water_budget(pr=precip, tasmin=tmin, tasmax=tmax, method="HG85", lat=lat)
    wb.attrs["units"] = "kg m-2 s-1"

    # --- Fit params on the calibration period (monthly aggregation + rolling handled inside)
    params = standardized_index_fit_params(
        wb.sel(time=slice(start_date, end_date)),
        freq="MS",           # aggregate daily WB to monthly
        window=agg_freq,     # e.g., 12 for SPEI-12
        dist=dist,           # "fisk" = 3-parameter log-logistic
        method=method,       # "PWM" (L-moments) is robust; "ML" for MLE
    )

    # --- Apply those params to the full record to get SPEI (ensures consistency)
    spei = standardized_index(
        wb,
        freq="MS",
        window=agg_freq,
        dist=dist,
        method=method,
        params=params,       # << use fixed calibration parameters
        zero_inflated=False,    # for water balance, not zero-inflated
        fitkwargs=None,         # or {}
        cal_start=None,         # not needed when params are supplied
        cal_end=None,   
    )

    if np.nanmin(spei.values) < -3:
        spei = spei.where(spei >= -3, np.nan).interpolate_na("time")

    return params, spei

def noleap_to_gregorian_add_leap(ds: xr.Dataset, time_dim: str = "time") -> xr.Dataset:
    """
    Convert cftime.DatetimeNoLeap time coord to pandas DatetimeIndex (Gregorian)
    and add Feb 29 for leap years by reindexing to a complete daily index and
    linearly filling inserted dates.
    """
    # 1) CFTimeIndex -> pandas DatetimeIndex (drops Feb 29 by definition)
    cft = ds.indexes[time_dim]                 # xarray.coding.cftimeindex.CFTimeIndex
    pd_idx = cft.to_datetimeindex()           # pandas.DatetimeIndex
    ds = ds.assign_coords({time_dim: pd_idx})

    # 2) Build full daily Gregorian index (includes Feb 29 when applicable)
    full_idx = pd.date_range(pd_idx[0], pd_idx[-1], freq="D")

    # 3) Reindex to insert missing days (Feb 29 becomes NaN rows)
    ds2 = ds.reindex({time_dim: full_idx})

    # 4) Fill inserted NaNs by linear interpolation in time
    #    (works for numeric variables; keeps non-numeric as-is)
    num_vars = [v for v in ds2.data_vars if ds2[v].dtype.kind in "fiu"]
    ds2[num_vars] = ds2[num_vars].interpolate_na(time_dim, method="linear")

    return ds2

In [9]:
station_names = ['Pituffik', 'Fairbanks', 'Guam', 'Yuma_PG' ,'Fort_Bragg'] # 
SPEI_48 = np.empty((len(station_names), 10, 12*(2020-1959+1)))

for i_stn, stn in enumerate(station_names):
    
    for lead_year in range(0, 10):
         
        start_date = f"{1959+lead_year}-01-01T00"
        end_date = f"{2020+lead_year}-12-31T23"
        
        ds_collection = []
        
        # ========================== #
        # get data
        list_ds = []
        for year_init in range(1958, 2020, 1):
            
            year_start = year_init + 1 + lead_year
            time_start = f'{year_start}-01-01T00'
            time_end = f'{year_start}-12-31T00'
            
            fn = f'/glade/campaign/ral/hap/ksha/EPRI_data/CESM_SMYLE_STN/{stn}_{year_init}.zarr'
            ds = xr.open_zarr(fn)[['TREFHTMN', 'TREFHTMX', 'PRECT']]
            ds = ds.sel(time=slice(time_start, time_end))
            list_ds.append(ds)
        
        ds_all = xr.concat(list_ds, dim='time')
        ds_all = ds_all.load()
        
        cft = ds_all.indexes['time']
        pd_idx = cft.to_datetimeindex()
        ds_all = ds_all.assign_coords({'time': pd_idx})
        
        lat_ref = ds_all['lat'].values
        lat_mid = lat_ref # lat_ref[ind_lat]
        time_vals = ds_all['time']
        
        tmin = ds_all['TREFHTMN'].values
        tmax = ds_all['TREFHTMX'].values
        precip = ds_all['PRECT'].values
        
        ds = xr.Dataset(
            {
                "precip": (("time",), precip*1e3, {"units": "kg m-2 s-1"}),
                "tmin":   (("time",), tmin-273.15, {"units": "degC"}),
                "tmax":   (("time",), tmax-273.15, {"units": "degC"}),
            },
            coords={"time": time_vals, "lat": lat_mid}
        )
        
        for v in ("precip", "tmin", "tmax"):
            
            ds[v] = ds[v].assign_coords(lat=lat_mid)
            
            ds[v]["lat"].attrs = {
                "standard_name": "latitude",
                "units": "degrees_north", "axis": "Y"
            }
        
        precip, tmin, tmax = process_vars(ds)
        # ---------------------------------- #
        # 48 month lagged SPEI
        params, spei = calc_spei_and_params(
            precip, tmin, tmax, 
            agg_freq=1, 
            start_date=start_date,
            end_date=end_date,
            lat=precip["lat"],
            dist="fisk", method="ML"
        )
        
        SPEI_48[i_stn, lead_year, :,] = spei.values

In [11]:
n_lead = 10
n_init = 62
m_per_year = 12

# (10, 744, Nx, Ny) -> (10, 62, 12, Nx, Ny)
tmp_48 = SPEI_48.reshape(n_lead, n_init, m_per_year)

# (10, 62, 12) -> (62, 10, 12)
tmp_48 = tmp_48.transpose(1, 0, 2)

# (62, 10, 12) -> (62, 120)
SPEI_init_48 = tmp_48.reshape(n_init, n_lead * m_per_year)

ds_SPEI = xr.Dataset(
    data_vars={
        "SPEI": (("init_time", "lead_time_month"), SPEI_init_48)
    },
    coords={
        "init_time": np.arange(1958, 2020),
        "lead_time_month": np.arange(120),
    },
)

ValueError: cannot reshape array of size 37200 into shape (10,62,12)

## CESM metrics

In [7]:
station_names = ['Pituffik', 'Fairbanks', 'Guam', 'Yuma_PG' ,'Fort_Bragg'] # 

for stn in station_names:
    
    t0 = time.perf_counter()
    
    base_dir = f'/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/{stn}/'
    
    # ========================== #
    # get data
    list_ds = []
    for year in range(1958, 2020):
        fn = f'/glade/campaign/ral/hap/ksha/EPRI_data/CESM_SMYLE_STN/{stn}_{year}.zarr'
        ds = xr.open_zarr(fn)
        list_ds.append(ds)
        
    ds_all = time_to_lead_and_stack(list_ds, new_dim="init_time")
    
    ds_all['PRECT'] = ds_all['PRECT'] * 60*60*24 * 1000 # mm per day
    
    ds_all = ds_all.assign_coords({'init_time': np.arange(1959, 2021)})
    lead_year = (ds_all["lead_time"] / np.timedelta64(365, "D")).astype(int)
    ds_all = ds_all.assign_coords(lead_year=("lead_time", lead_year.data))
    
    # ========================== #
    # get anomaly
    ds_all_anom = ds_all.copy()
    vars_ = list(ds_all.keys())
    for v in vars_:
        ds_all_anom[v] = ds_all_anom[v] - ds_all_anom[v].mean(["init_time"])
    ds_all_anom = ds_all_anom[vars_]
    
    # ========================== #
    # get detrend data
    ds_all_detrend = ds_all.copy()
    vars_ = list(ds_all.keys())
    for v in vars_:
        ds_all_detrend[v] = detrend_linear(ds_all[v], dim="init_time")
    ds_all_detrend = ds_all_detrend[vars_]
    
    # ======================= #
    # metrics
    ds_group = ds_all.groupby("lead_year")
    ds_max  = ds_group.max(dim="lead_time",  skipna=True)
    ds_min  = ds_group.min(dim="lead_time",  skipna=True)
    ds_mean  = ds_group.mean(dim="lead_time",  skipna=True)
    ds_30d = ds_group.map(
        lambda x: x.rolling(lead_time=30, min_periods=30).mean().max(dim="lead_time", skipna=True)
    )
    ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min', 'TREFHT': 'TREFHT_min'})[['TREFHTMN_min', 'TREFHT_min']]
    ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max', 'TREFHT': 'TREFHT_max'})[['PRECT_max', 'TREFHTMX_max', 'TREFHT_max']]
    ds_30d = ds_30d.rename({'TREFHT': 'TREFHT_30d', 'PRECT': 'PRECT_30d'})[['TREFHT_30d', 'PRECT_30d']]
    ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
    ds_metrics = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
    ds_metrics = ds_metrics.rename({v: f"{v}_default" for v in ds_metrics.data_vars})
    
    # ========================== #
    # anomaly metrics
    ds_group = ds_all_anom.groupby("lead_year")
    ds_max  = ds_group.max(dim="lead_time",  skipna=True)
    ds_min  = ds_group.min(dim="lead_time",  skipna=True)
    ds_mean  = ds_group.mean(dim="lead_time",  skipna=True)
    ds_30d = ds_group.map(
        lambda x: x.rolling(lead_time=30, min_periods=30).mean().max(dim="lead_time", skipna=True)
    )
    # ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min'})[['TREFHTMN_min',]]
    # ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max'})[['PRECT_max', 'TREFHTMX_max']]
    # ds_30d = ds_30d.rename({'TREFHTMX': 'TREFHTMX_30d', 'PRECT': 'PRECT_30d'})[['TREFHTMX_30d', 'PRECT_30d']]
    # ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
    ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min', 'TREFHT': 'TREFHT_min'})[['TREFHTMN_min', 'TREFHT_min']]
    ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max', 'TREFHT': 'TREFHT_max'})[['PRECT_max', 'TREFHTMX_max', 'TREFHT_max']]
    ds_30d = ds_30d.rename({'TREFHT': 'TREFHT_30d', 'PRECT': 'PRECT_30d'})[['TREFHT_30d', 'PRECT_30d']]
    ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
    ds_metrics_anom = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
    ds_metrics_anom = ds_metrics_anom.rename({v: f"{v}_anom" for v in ds_metrics_anom.data_vars})
    
    # ========================== #
    # detrended metrics
    ds_group = ds_all_detrend.groupby("lead_year")
    ds_max  = ds_group.max(dim="lead_time",  skipna=True)
    ds_min  = ds_group.min(dim="lead_time",  skipna=True)
    ds_mean  = ds_group.mean(dim="lead_time",  skipna=True)
    ds_30d = ds_group.map(
        lambda x: x.rolling(lead_time=30, min_periods=30).mean().max(dim="lead_time", skipna=True)
    )
    ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min', 'TREFHT': 'TREFHT_min'})[['TREFHTMN_min', 'TREFHT_min']]
    ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max', 'TREFHT': 'TREFHT_max'})[['PRECT_max', 'TREFHTMX_max', 'TREFHT_max']]
    ds_30d = ds_30d.rename({'TREFHT': 'TREFHT_30d', 'PRECT': 'PRECT_30d'})[['TREFHT_30d', 'PRECT_30d']]
    ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
    ds_metrics_detrend = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
    ds_metrics_detrend = ds_metrics_detrend.rename({v: f"{v}_detrend" for v in ds_metrics_detrend.data_vars})
    
    # ========================== #
    # save
    ds_final = xr.merge([ds_metrics, ds_metrics_anom, ds_metrics_detrend])
    save_name = base_dir + 'CESM_metrics.zarr'
    ds_final.to_zarr(save_name, mode='w')
    print(save_name)
    
    t1 = time.perf_counter()
    print(f"Elapsed: {t1 - t0:.6f} s")

/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Pituffik/CESM_metrics.zarr
Elapsed: 58.345111 s
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Fairbanks/CESM_metrics.zarr
Elapsed: 57.277154 s
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Guam/CESM_metrics.zarr
Elapsed: 59.746165 s
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Yuma_PG/CESM_metrics.zarr
Elapsed: 59.272855 s
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Fort_Bragg/CESM_metrics.zarr
Elapsed: 59.190298 s
