In [1]:
import os
import yaml
import copy
import time
import numpy as np
import pandas as pd
import xarray as xr

In [2]:
from xclim.indices import (
    standardized_precipitation_evapotranspiration_index,
    water_budget,
)

from xclim.indices.stats import (
    standardized_index_fit_params, 
    standardized_index
)

In [3]:
import warnings
warnings.filterwarnings("ignore", message="Converting a CFTimeIndex.*noleap.*")

In [4]:
from scipy.interpolate import griddata
from scipy.spatial import Delaunay
from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator

In [5]:
import matplotlib.pyplot as plt
%matplotlib inline

In [6]:
def detrend_linear(da, dim="time"):
    """
    Remove a best-fit linear trend along `dim` for each grid point.
    Uses an index-based time axis (0..N-1) to avoid datetime scaling issues.
    """
    t = xr.DataArray(np.arange(da.sizes[dim]), dims=dim, coords={dim: da[dim]})

    valid = np.isfinite(da)
    t_valid = t.where(valid)
    da_valid = da.where(valid)

    t_mean = t_valid.mean(dim, skipna=True)
    y_mean = da_valid.mean(dim, skipna=True)

    cov = ((t_valid - t_mean) * (da_valid - y_mean)).mean(dim, skipna=True)
    var = ((t_valid - t_mean) ** 2).mean(dim, skipna=True)

    slope = cov / var
    intercept = y_mean - slope * t_mean

    trend = slope * t + intercept
    return da - trend

In [7]:
def process_vars(data):
    """
    Preprocess data variables for SPEI calculation.
    """
    
    data["time"] = (pd.to_datetime(data["time"]).to_numpy().astype("datetime64[D]"))
    first_date = data.time.min().dt.strftime("%Y-%m-%d").values.item()
    
    if not first_date.endswith("-01-01"):
        first_year = int(first_date[:4])
        data = data.sel(time=slice(str(first_year + 1), None))

    precip = data["precip"]
    tmin = data["tmin"]
    tmax = data["tmax"]
    
    return precip, tmin, tmax
    
def calc_spei_and_params(
    precip, tmin, tmax, 
    agg_freq, start_date, end_date, 
    lat=None, dist="fisk", method="ML"):
    
    # --- Daily water budget for full period
    wb = water_budget(pr=precip, tasmin=tmin, tasmax=tmax, method="HG85", lat=lat)
    wb.attrs["units"] = "kg m-2 s-1"

    # --- Fit params on the calibration period (monthly aggregation + rolling handled inside)
    params = standardized_index_fit_params(
        wb.sel(time=slice(start_date, end_date)),
        freq="MS",           # aggregate daily WB to monthly
        window=agg_freq,     # e.g., 12 for SPEI-12
        dist=dist,           # "fisk" = 3-parameter log-logistic
        method=method,       # "PWM" (L-moments) is robust; "ML" for MLE
    )

    # --- Apply those params to the full record to get SPEI (ensures consistency)
    spei = standardized_index(
        wb,
        freq="MS",
        window=agg_freq,
        dist=dist,
        method=method,
        params=params,       # << use fixed calibration parameters
        zero_inflated=False,    # for water balance, not zero-inflated
        fitkwargs=None,         # or {}
        cal_start=None,         # not needed when params are supplied
        cal_end=None,   
    )

    # spei = standardized_precipitation_evapotranspiration_index(
    #     wb=wb, freq="MS", window=agg_freq, dist=dist, method=method, params=params
    # )
    
    if np.nanmin(spei.values) < -3:
        spei = spei.where(spei >= -3, np.nan).interpolate_na("time")

    return params, spei

def noleap_to_gregorian_add_leap(ds: xr.Dataset, time_dim: str = "time") -> xr.Dataset:
    """
    Convert cftime.DatetimeNoLeap time coord to pandas DatetimeIndex (Gregorian)
    and add Feb 29 for leap years by reindexing to a complete daily index and
    linearly filling inserted dates.
    """
    # 1) CFTimeIndex -> pandas DatetimeIndex (drops Feb 29 by definition)
    cft = ds.indexes[time_dim]                 # xarray.coding.cftimeindex.CFTimeIndex
    pd_idx = cft.to_datetimeindex()           # pandas.DatetimeIndex
    ds = ds.assign_coords({time_dim: pd_idx})

    # 2) Build full daily Gregorian index (includes Feb 29 when applicable)
    full_idx = pd.date_range(pd_idx[0], pd_idx[-1], freq="D")

    # 3) Reindex to insert missing days (Feb 29 becomes NaN rows)
    ds2 = ds.reindex({time_dim: full_idx})

    # 4) Fill inserted NaNs by linear interpolation in time
    #    (works for numeric variables; keeps non-numeric as-is)
    num_vars = [v for v in ds2.data_vars if ds2[v].dtype.kind in "fiu"]
    ds2[num_vars] = ds2[num_vars].interpolate_na(time_dim, method="linear")

    return ds2

In [8]:
def time_to_lead_and_stack(ds_list, new_dim="member", labels=None, ref="first"):
    """
    ds_list: list[xr.Dataset], each has a 'time' coord with daily values
    new_dim: name of the new dimension to stack on
    labels: optional labels for new_dim (len == len(ds_list))
    ref: "first" (per-dataset first time) or a numpy/pandas datetime-like scalar
    """
    out = []
    for i, ds in enumerate(ds_list):
        ds = ds.sortby("time")

        if ref == "first":
            t0 = ds["time"].isel(time=0)
        else:
            # global reference (same for all ds)
            t0 = xr.DataArray(ref)

        lead = (ds["time"] - t0).astype("timedelta64[D]")  # daily deltas
        ds2 = ds.assign_coords(lead_time=("time", lead.data)).swap_dims({"time": "lead_time"})
        ds2 = ds2.drop_vars("time")  # optional: remove original coordinate

        out.append(ds2)

    if labels is None:
        labels = list(range(len(out)))

    stacked = xr.concat(out, dim=xr.IndexVariable(new_dim, labels))
    return stacked

In [9]:
def fill_nan_linear_2d(a):
    a = np.asarray(a, dtype=float)
    ny, nx = a.shape
    yy, xx = np.mgrid[0:ny, 0:nx]

    mask = np.isfinite(a)
    pts = np.column_stack((yy[mask], xx[mask]))   # (row, col) for valid points
    vals = a[mask]

    # Linear interpolation inside the convex hull of valid points
    filled = a.copy()
    filled[~mask] = griddata(pts, vals, (yy[~mask], xx[~mask]), method="linear")

    # Optional: fill any remaining NaNs (outside convex hull) with nearest
    if np.any(~np.isfinite(filled)):
        filled[~np.isfinite(filled)] = griddata(pts, vals, (yy[~np.isfinite(filled)], xx[~np.isfinite(filled)]),
                                                method="nearest")
    return filled

In [10]:
# fn_ERA5 = f'/glade/derecho/scratch/ksha/EPRI_data/ERA5_grid/ERA5_{year}.zarr'
# ds_ERA5 = xr.open_zarr(fn_ERA5)
# fn_CESM = f'/glade/derecho/scratch/ksha/EPRI_data/CESM2_SMYLE/SMYLE_{year-1}-11-01_daily_ensemble.zarr'
# ds_CESM = xr.open_zarr(fn_CESM)

### Preparing gridded yearly metrics

In [11]:
dict_loc = {
    'Pituffik': (76.4, -68.575),
    'Fairbanks': (64.75, -147.4),
    'Guam': (13.475, 144.75),
    'Yuma_PG': (33.125, -114.125),
    'Fort_Bragg': (35.05, -79.115),
}
keys = list(dict_loc.keys())

### Fort_Bragg: max total precip

In [12]:
key = 'Fort_Bragg'
dir_stn = f'/glade/derecho/scratch/ksha/EPRI_data/METRICS/{key}/'
base_dir = '/glade/derecho/scratch/ksha/EPRI_data/CESM2_SMYLE/'

In [13]:
ds_collection = []

for year in range(1958, 2020):
    
    # get data and variable
    fn_CESM = base_dir + f'/{key}/SMYLE_{key}_{year}.zarr'
    ds_CESM = xr.open_zarr(fn_CESM)[["PRECT"]]
    ds_CESM['PRECT'] = ds_CESM['PRECT'] * 60*60*24 * 1000
    ds_CESM = ds_CESM.sel(time=slice(f'{year+1}-01-01T00', f'{year+10}-12-31T00'))
    
    # ============ #
    # PRECT max and weekly max
    ds_TP_group = ds_CESM[["PRECT"]].groupby("time.year")
    ds_TP_max  = ds_TP_group.max(dim="time",  skipna=True)
    ds_TP_3d = ds_TP_group.map(
        lambda x: x.rolling(time=3, min_periods=3).mean().max(dim="time", skipna=True)
    )
    ds_TP_5d = ds_TP_group.map(
        lambda x: x.rolling(time=5, min_periods=5).mean().max(dim="time", skipna=True)
    )
    
    # rename and merge ds_mean, ds_min, ds_max, ds_30d_max
    ds_TP_max = ds_TP_max.rename({v: f"{v}_max" for v in ds_TP_max.data_vars})
    ds_TP_3d = ds_TP_3d.rename({v: f"{v}_3d_max"  for v in ds_TP_3d.data_vars})
    ds_TP_5d = ds_TP_5d.rename({v: f"{v}_5d_max"  for v in ds_TP_5d.data_vars})
    
    # ============ #
    ds_merge = xr.merge([ds_TP_max, ds_TP_3d, ds_TP_5d])
    ds_merge = ds_merge.assign_coords({'year': np.arange(year, year+10) - year})
    ds_collection.append(ds_merge)
    
ds_all = xr.concat(ds_collection, dim='init_time')
ds_all = ds_all.assign_coords({'init_time': np.arange(1958+1, 2020+1)})
ds_all = ds_all.chunk({'init_time': 62, 'year': 10, 'lat': 21, 'lon': 16})

In [14]:
save_name = dir_stn + 'CESM_TP_max.zarr'
# ds_all.to_zarr(save_name, mode='w')
print(save_name)

### Fort_Bragg: max detrend total precip

In [15]:
list_ds = []

for year in range(1958, 2020):
    
    # get data and variable
    fn_CESM = base_dir + f'/{key}/SMYLE_{key}_{year}.zarr'
    ds_CESM = xr.open_zarr(fn_CESM)[["PRECT"]]
    ds_CESM['PRECT'] = ds_CESM['PRECT'] * 60*60*24 * 1000
    ds_CESM = ds_CESM.sel(time=slice(f'{year+1}-01-01T00', f'{year+10}-12-31T00'))
    list_ds.append(ds_CESM )

ds_all = time_to_lead_and_stack(list_ds, new_dim="init_time")
ds_all = ds_all.assign_coords({'init_time': np.arange(1959, 2021)})
lead_year = (ds_all["lead_time"] / np.timedelta64(365, "D")).astype(int)
ds_all = ds_all.assign_coords(lead_year=("lead_time", lead_year.data))

# ========================== #
# get detrend data
ds_all_detrend = ds_all.copy()
vars_ = list(ds_all.keys())
for v in vars_:
    ds_all_detrend[v] = detrend_linear(ds_all[v], dim="init_time")
ds_all_detrend = ds_all_detrend[vars_]

ds_all_detrend = ds_all_detrend.rename({'lead_time': 'time'})
ds_TP_group = ds_all_detrend.groupby("lead_year")

ds_TP_max  = ds_TP_group.max(dim="time",  skipna=True)
ds_TP_3d = ds_TP_group.map(
    lambda x: x.rolling(time=3, min_periods=3).mean().max(dim="time", skipna=True)
)
ds_TP_5d = ds_TP_group.map(
    lambda x: x.rolling(time=5, min_periods=5).mean().max(dim="time", skipna=True)
)
# rename and merge ds_mean, ds_min, ds_max, ds_30d_max
ds_TP_max = ds_TP_max.rename({v: f"{v}_max" for v in ds_TP_max.data_vars})
ds_TP_3d = ds_TP_3d.rename({v: f"{v}_3d_max"  for v in ds_TP_3d.data_vars})
ds_TP_5d = ds_TP_5d.rename({v: f"{v}_5d_max"  for v in ds_TP_5d.data_vars})

# ============ #
ds_merge = xr.merge([ds_TP_max, ds_TP_3d, ds_TP_5d])
ds_merge = ds_merge.rename({'lead_year': 'year'})
ds_merge = ds_merge.chunk({'init_time': 62, 'year': 10, 'lat': 21, 'lon': 16})

In [None]:
save_name = dir_stn + 'CESM_TP_detrend_max.zarr'
# ds_merge.to_zarr(save_name, mode='w')
print(save_name)

In [None]:
print(save_name)

### Fort_Bragg: SPEI

In [20]:
ds_example = xr.open_zarr(base_dir+f'/{key}/SMYLE_{key}_2019.zarr')
lat = ds_example['lat'].values
lon = ds_example['lon'].values

In [21]:
# Nx = len(lat)
# Ny = len(lon)

# SPEI_09 = np.empty((10, 12*(2020-1959+1), Nx, Ny))
# SPEI_09[...] = np.nan

# SPEI_48 = np.empty((10, 12*(2020-1959+1), Nx, Ny))
# SPEI_48[...] = np.nan

# # 62 years x 12 month

# for ind_lat in range(Nx):
#     for ind_lon in range(Ny):

#         t0 = time.process_time()
#         print(f'({ind_lat}, {ind_lon})')
        
#         for lead_year in range(0, 10):
              
#             start_date = f"{1959+lead_year}-01-01T00"
#             end_date = f"{2020+lead_year}-12-31T23"
            
#             ds_collection = []
            
#             for year_init in range(1958, 2020, 1):
                
#                 year_start = year_init + 1 + lead_year
#                 time_start = f'{year_start}-01-01T00'
#                 time_end = f'{year_start}-12-31T00'
                
#                 fn_CESM = base_dir + f'/{key}/SMYLE_{key}_{year_init}.zarr'
#                 ds_CESM = xr.open_zarr(fn_CESM)[['TREFHTMN', 'TREFHTMX', 'PRECT']].isel(lon=ind_lon, lat=ind_lat)
                
#                 ds_CESM = ds_CESM.sel(time=slice(time_start, time_end))
#                 ds_collection.append(ds_CESM)
            
#             ds_all = xr.concat(ds_collection, dim='time')
#             ds_all = ds_all.load()
            
#             cft = ds_all.indexes['time']
#             pd_idx = cft.to_datetimeindex()
#             ds_all = ds_all.assign_coords({'time': pd_idx})
            
#             lat_ref = ds_all['lat'].values
#             lat_mid = lat_ref # lat_ref[ind_lat]
#             time_vals = ds_all['time']
            
#             tmin = ds_all['TREFHTMN'].values
#             tmax = ds_all['TREFHTMX'].values
#             precip = ds_all['PRECT'].values
            
#             ds = xr.Dataset(
#                 {
#                     "precip": (("time",), precip*1e3, {"units": "kg m-2 s-1"}),
#                     "tmin":   (("time",), tmin-273.15, {"units": "degC"}),
#                     "tmax":   (("time",), tmax-273.15, {"units": "degC"}),
#                 },
#                 coords={"time": time_vals, "lat": lat_mid}
#             )
            
#             for v in ("precip", "tmin", "tmax"):
                
#                 ds[v] = ds[v].assign_coords(lat=lat_mid)
                
#                 ds[v]["lat"].attrs = {
#                     "standard_name": "latitude",
#                     "units": "degrees_north", "axis": "Y"
#                 }
            
#             precip, tmin, tmax = process_vars(ds)

#             # ---------------------------------- #
#             # 24 month lagged SPEI
#             params, spei = calc_spei_and_params(
#                 precip, tmin, tmax, 
#                 agg_freq=9, 
#                 start_date=start_date,
#                 end_date=end_date,
#                 lat=precip["lat"],
#                 dist="fisk", method="ML"
#             )
            
#             SPEI_09[lead_year, :, ind_lat, ind_lon] = spei.values

#             # ---------------------------------- #
#             # 48 month lagged SPEI
#             params, spei = calc_spei_and_params(
#                 precip, tmin, tmax, 
#                 agg_freq=48, 
#                 start_date=start_date,
#                 end_date=end_date,
#                 lat=precip["lat"],
#                 dist="fisk", method="ML"
#             )
            
#             SPEI_48[lead_year, :, ind_lat, ind_lon] = spei.values
            
#         t1 = time.process_time()
#         print(f"time: {t1 - t0:.6f} s")

In [22]:
Nx = len(lat)
Ny = len(lon)

SPEI_09 = np.empty((10, 12*(2020-1959+1), Nx, Ny))
SPEI_09[...] = np.nan

SPEI_48 = np.empty((10, 12*(2020-1959+1), Nx, Ny))
SPEI_48[...] = np.nan

for ind_lat in range(21):
    fn = dir_stn + f'temp_np/SPEI_{ind_lat}.npy'
    data_ = np.load(fn)
    SPEI_09[:, :, ind_lat, :] = data_[..., 0]
    SPEI_48[:, :, ind_lat, :] = data_[..., 1]

In [23]:
fn

'/glade/derecho/scratch/ksha/EPRI_data/METRICS/Fort_Bragg/temp_np/SPEI_20.npy'

In [24]:
# ======================================= #
# fill nans using linear interp

for lead_year in range(10):
    for i_time in range(12*(2020-1959+1)):
        
        spei_temp_09 = SPEI_09[lead_year, i_time, :, :]
        flag_nan_09 = np.sum(np.isnan(spei_temp_09))
        
        spei_temp_48 = SPEI_48[lead_year, i_time, :, :]
        flag_nan_48 = np.sum(np.isnan(spei_temp_48))

        # need to have at least 4 values in the domain
        if (flag_nan_09 < (Nx*Ny-4)) and (flag_nan_09 > 0):
            SPEI_09[lead_year, i_time, :, :] = fill_nan_linear_2d(spei_temp_09)
            
        if (flag_nan_48 < (Nx*Ny-4)) and (flag_nan_48 > 0):
            SPEI_48[lead_year, i_time, :, :] = fill_nan_linear_2d(spei_temp_48)

# ======================================= #
# numpy to ds

n_lead = 10
n_init = 62
m_per_year = 12

# (10, 744, Nx, Ny) -> (10, 62, 12, Nx, Ny)
tmp_09 = SPEI_09.reshape(n_lead, n_init, m_per_year, *SPEI_09.shape[2:])
tmp_48 = SPEI_48.reshape(n_lead, n_init, m_per_year, *SPEI_48.shape[2:])

# (10, 62, 12, Nx, Ny) -> (62, 10, 12, Nx, Ny)
tmp_09 = tmp_09.transpose(1, 0, 2, 3, 4)
tmp_48 = tmp_48.transpose(1, 0, 2, 3, 4)

# (62, 10, 12, Nx, Ny) -> (62, 120, Nx, Ny)
SPEI_init_09 = tmp_09.reshape(n_init, n_lead * m_per_year, *SPEI_09.shape[2:])
SPEI_init_48 = tmp_48.reshape(n_init, n_lead * m_per_year, *SPEI_48.shape[2:])

ds_SPEI = xr.Dataset(
    data_vars={
        "SPEI_09": (("init_time", "lead_time_month", "lat", "lon"), SPEI_init_09),
        "SPEI_48": (("init_time", "lead_time_month", "lat", "lon"), SPEI_init_48)
    },
    coords={
        "init_time": np.arange(1958, 2020),
        "lead_time_month": np.arange(120),
        "lat": lat,
        "lon": lon,
    },
)

# optional metadata
ds_SPEI["SPEI_09"].attrs["long_name"] = "SPEI with 09 month lag"
ds_SPEI["SPEI_48"].attrs["long_name"] = "SPEI with 48 month lag"

In [25]:
save_name = dir_stn + 'CESM_SPEI.zarr'
# ds_SPEI.to_zarr(save_name)
print(save_name)

/glade/derecho/scratch/ksha/EPRI_data/METRICS/Fort_Bragg/CESM_SPEI.zarr


In [26]:
# plt.plot(SPEI_09[lead_year, :, 10, 11])
# plt.plot(SPEI_48[lead_year, :, 10, 11])