In [1]:
import os
import zarr
from glob import glob

import numpy as np
import xarray as xr

In [2]:
import time

In [3]:
import pandas as pd 

In [4]:
import matplotlib.pyplot as plt
%matplotlib inline

In [5]:
def time_to_lead_and_stack(ds_list, new_dim="member", labels=None, ref="first"):
    """
    ds_list: list[xr.Dataset], each has a 'time' coord with daily values
    new_dim: name of the new dimension to stack on
    labels: optional labels for new_dim (len == len(ds_list))
    ref: "first" (per-dataset first time) or a numpy/pandas datetime-like scalar
    """
    out = []
    for i, ds in enumerate(ds_list):
        ds = ds.sortby("time")

        if ref == "first":
            t0 = ds["time"].isel(time=0)
        else:
            # global reference (same for all ds)
            t0 = xr.DataArray(ref)

        lead = (ds["time"] - t0).astype("timedelta64[D]")  # daily deltas
        ds2 = ds.assign_coords(lead_time=("time", lead.data)).swap_dims({"time": "lead_time"})
        ds2 = ds2.drop_vars("time")  # optional: remove original coordinate

        out.append(ds2)

    if labels is None:
        labels = list(range(len(out)))

    stacked = xr.concat(out, dim=xr.IndexVariable(new_dim, labels))
    return stacked

In [6]:
def detrend_linear(da, dim="time"):
    """
    Remove a best-fit linear trend along `dim` for each grid point.
    Uses an index-based time axis (0..N-1) to avoid datetime scaling issues.
    """
    t = xr.DataArray(np.arange(da.sizes[dim]), dims=dim, coords={dim: da[dim]})

    valid = np.isfinite(da)
    t_valid = t.where(valid)
    da_valid = da.where(valid)

    t_mean = t_valid.mean(dim, skipna=True)
    y_mean = da_valid.mean(dim, skipna=True)

    cov = ((t_valid - t_mean) * (da_valid - y_mean)).mean(dim, skipna=True)
    var = ((t_valid - t_mean) ** 2).mean(dim, skipna=True)

    slope = cov / var
    intercept = y_mean - slope * t_mean

    trend = slope * t + intercept
    return da - trend

## CESM metrics

In [7]:
station_names = ['Pituffik', 'Fairbanks', 'Guam', 'Yuma_PG' ,'Fort_Bragg'] # 

for stn in station_names:
    
    t0 = time.perf_counter()
    
    base_dir = f'/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/{stn}/'
    
    # ========================== #
    # get data
    list_ds = []
    for year in range(1958, 2020):
        fn = f'/glade/campaign/ral/hap/ksha/EPRI_data/CESM_SMYLE_daily/{stn}_{year}.zarr'
        ds = xr.open_zarr(fn)
        list_ds.append(ds)
    ds_all = time_to_lead_and_stack(list_ds, new_dim="init_time")
    ds_all = ds_all.assign_coords({'init_time': np.arange(1959, 2021)})
    lead_year = (ds_all["lead_time"] / np.timedelta64(365, "D")).astype(int)
    ds_all = ds_all.assign_coords(lead_year=("lead_time", lead_year.data))
    
    # ========================== #
    # get anomaly
    ds_all_anom = ds_all.copy()
    vars_ = list(ds_all.keys())
    for v in vars_:
        ds_all_anom[v] = ds_all_anom[v] - ds_all_anom[v].mean(["init_time"])
    ds_all_anom = ds_all_anom[vars_]
    
    # ========================== #
    # get detrend data
    ds_all_detrend = ds_all.copy()
    vars_ = list(ds_all.keys())
    for v in vars_:
        ds_all_detrend[v] = detrend_linear(ds_all[v], dim="init_time")
    ds_all_detrend = ds_all_detrend[vars_]
    
    # ======================= #
    # metrics
    ds_group = ds_all.groupby("lead_year")
    ds_max  = ds_group.max(dim="lead_time",  skipna=True)
    ds_min  = ds_group.min(dim="lead_time",  skipna=True)
    ds_mean  = ds_group.mean(dim="lead_time",  skipna=True)
    ds_30d = ds_group.map(
        lambda x: x.rolling(lead_time=30, min_periods=30).mean().max(dim="lead_time", skipna=True)
    )
    ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min'})[['TREFHTMN_min',]]
    ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max'})[['PRECT_max', 'TREFHTMX_max']]
    ds_30d = ds_30d.rename({'TREFHTMX': 'TREFHTMX_30d', 'PRECT': 'PRECT_30d'})[['TREFHTMX_30d', 'PRECT_30d']]
    ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
    ds_metrics = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
    ds_metrics = ds_metrics.rename({v: f"{v}_default" for v in ds_metrics.data_vars})
    
    # ========================== #
    # anomaly metrics
    ds_group = ds_all_anom.groupby("lead_year")
    ds_max  = ds_group.max(dim="lead_time",  skipna=True)
    ds_min  = ds_group.min(dim="lead_time",  skipna=True)
    ds_mean  = ds_group.mean(dim="lead_time",  skipna=True)
    ds_30d = ds_group.map(
        lambda x: x.rolling(lead_time=30, min_periods=30).mean().max(dim="lead_time", skipna=True)
    )
    ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min'})[['TREFHTMN_min',]]
    ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max'})[['PRECT_max', 'TREFHTMX_max']]
    ds_30d = ds_30d.rename({'TREFHTMX': 'TREFHTMX_30d', 'PRECT': 'PRECT_30d'})[['TREFHTMX_30d', 'PRECT_30d']]
    ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
    ds_metrics_anom = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
    ds_metrics_anom = ds_metrics_anom.rename({v: f"{v}_anom" for v in ds_metrics_anom.data_vars})
    
    # ========================== #
    # detrended metrics
    ds_group = ds_all_detrend.groupby("lead_year")
    ds_max  = ds_group.max(dim="lead_time",  skipna=True)
    ds_min  = ds_group.min(dim="lead_time",  skipna=True)
    ds_mean  = ds_group.mean(dim="lead_time",  skipna=True)
    ds_30d = ds_group.map(
        lambda x: x.rolling(lead_time=30, min_periods=30).mean().max(dim="lead_time", skipna=True)
    )
    ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min'})[['TREFHTMN_min',]]
    ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max'})[['PRECT_max', 'TREFHTMX_max']]
    ds_30d = ds_30d.rename({'TREFHTMX': 'TREFHTMX_30d', 'PRECT': 'PRECT_30d'})[['TREFHTMX_30d', 'PRECT_30d']]
    ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
    ds_metrics_detrend = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
    ds_metrics_detrend = ds_metrics_detrend.rename({v: f"{v}_detrend" for v in ds_metrics_detrend.data_vars})
    
    # ========================== #
    # save
    ds_final = xr.merge([ds_metrics, ds_metrics_anom, ds_metrics_detrend])
    save_name = base_dir + 'CESM_metrics.zarr'
    ds_final.to_zarr(save_name, mode='w')
    print(save_name)
    
    t1 = time.perf_counter()
    print(f"Elapsed: {t1 - t0:.6f} s")

/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Pituffik/CESM_metrics.zarr
Elapsed: 43.043062 s
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Fairbanks/CESM_metrics.zarr
Elapsed: 41.059032 s
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Guam/CESM_metrics.zarr
Elapsed: 40.405697 s
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Yuma_PG/CESM_metrics.zarr
Elapsed: 41.526583 s
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Fort_Bragg/CESM_metrics.zarr
Elapsed: 41.141597 s
