In [1]:
import os
import zarr
from glob import glob

import numpy as np
import xarray as xr
import pandas as pd 

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
# def detrend_linear(da, dim="time"):
#     """
#     Remove a best-fit linear trend along `dim` for each grid point.
#     Uses an index-based time axis (0..N-1) to avoid datetime scaling issues.
#     """
#     t = xr.DataArray(np.arange(da.sizes[dim]), dims=dim, coords={dim: da[dim]})

#     valid = np.isfinite(da)
#     t_valid = t.where(valid)
#     da_valid = da.where(valid)

#     t_mean = t_valid.mean(dim, skipna=True)
#     y_mean = da_valid.mean(dim, skipna=True)

#     cov = ((t_valid - t_mean) * (da_valid - y_mean)).mean(dim, skipna=True)
#     var = ((t_valid - t_mean) ** 2).mean(dim, skipna=True)

#     slope = cov / var
#     intercept = y_mean - slope * t_mean

#     trend = slope * t + intercept
#     return da - trend

def detrend_linear_doy(da: xr.DataArray, dim: str = "time", keep_mean: bool = True) -> xr.DataArray:
    t = da[dim]
    doy = t.dt.dayofyear
    x = t.dt.year.astype("float64")  # predictor

    # If you have missing values, this keeps x/y stats consistent
    mask = da.notnull()

    # Groupwise means (per DOY)
    x_mean = x.where(mask).groupby(doy).mean(dim, skipna=True)       # dims: dayofyear
    y_mean = da.groupby(doy).mean(dim, skipna=True)                  # dims: dayofyear + (other dims)

    # Broadcast group means back onto the time axis
    x_anom = x - x_mean.sel(dayofyear=doy)                           # dims: time
    y_anom = da - y_mean.sel(dayofyear=doy)                          # dims: time + (other dims)

    # Least-squares slope per DOY: slope = cov(x,y)/var(x)
    numer = (x_anom * y_anom).where(mask).groupby(doy).mean(dim, skipna=True)
    denom = (x_anom ** 2).where(mask).groupby(doy).mean(dim, skipna=True)
    slope = numer / denom                                            # dims: dayofyear + (other dims)

    if keep_mean:
        # Remove slope only; preserves the DOY mean exactly
        return da - slope.sel(dayofyear=doy) * x_anom
    else:
        # Remove full fitted line (slope + intercept)
        intercept = y_mean - slope * x_mean
        fit = slope.sel(dayofyear=doy) * x + intercept.sel(dayofyear=doy)
        return da - fit

## ERA5 hourly to daily

In [4]:
# station_names = ['Pituffik', 'Fairbanks', 'Guam', 'Yuma_PG' ,'Fort_Bragg'] # 

# for stn in station_names:
#     base_dir = f'/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/{stn}/'
    
#     # ========================== #
#     # get data
#     list_ds = []
#     for year in range(1958, 2025):
#         fn = f'/glade/campaign/ral/hap/ksha/EPRI_data/ERA5_daily/{stn}_{year}.zarr'
#         ds = xr.open_zarr(fn)
#         list_ds.append(ds)
        
#     ds_all = xr.concat(list_ds, dim='time')
#     ds_all['PRECT'] = ds_all['PRECT'] * 1000 # mm per day
    
#     # ========================== #
#     # get anomaly
#     ds_all_anom = ds_all.copy()
#     vars_ = list(ds_all.keys())
#     for v in vars_:
#         clim = ds_all[v].groupby("time.dayofyear").mean("time", keep_attrs=True)
#         ds_all_anom[v] = ds_all[v].groupby("time.dayofyear") - clim
#     ds_all_anom = ds_all_anom[vars_]
    
#     # ========================== #
#     # get detrend data
#     ds_all_detrend = ds_all.copy()
#     vars_ = list(ds_all.keys())
#     for v in vars_:
#         ds_all_detrend[v] = detrend_linear_doy(ds_all[v], dim="time", keep_mean=False)
#         # detrend_linear(ds_all[v], dim="time")
#     ds_all_detrend = ds_all_detrend[vars_]
    
#     # ======================= #
#     # metrics
#     ds_group = ds_all.groupby("time.year")
#     ds_max  = ds_group.max(dim="time",  skipna=True)
#     ds_min  = ds_group.min(dim="time",  skipna=True)
#     ds_mean  = ds_group.mean(dim="time",  skipna=True)
#     ds_30d = ds_group.map(
#         lambda x: x.rolling(time=30, min_periods=30).mean().max(dim="time", skipna=True)
#     )
#     # ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min'})[['TREFHTMN_min',]]
#     # ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max'})[['PRECT_max', 'TREFHTMX_max']]
#     # ds_30d = ds_30d.rename({'TREFHTMX': 'TREFHTMX_30d', 'PRECT': 'PRECT_30d'})[['TREFHTMX_30d', 'PRECT_30d']]
#     # ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
#     ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min', 'TREFHT': 'TREFHT_min'})[['TREFHTMN_min', 'TREFHT_min']]
#     ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max', 'TREFHT': 'TREFHT_max'})[['PRECT_max', 'TREFHTMX_max', 'TREFHT_max']]
#     ds_30d = ds_30d.rename({'TREFHT': 'TREFHT_30d', 'PRECT': 'PRECT_30d'})[['TREFHT_30d', 'PRECT_30d']]
#     ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
#     ds_metrics = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
#     ds_metrics = ds_metrics.rename({v: f"{v}_default" for v in ds_metrics.data_vars})

#     # ========================== #
#     # anomaly metrics
#     ds_group = ds_all_anom.groupby("time.year")
#     ds_max  = ds_group.max(dim="time",  skipna=True)
#     ds_min  = ds_group.min(dim="time",  skipna=True)
#     ds_mean  = ds_group.mean(dim="time",  skipna=True)
#     ds_30d = ds_group.map(
#         lambda x: x.rolling(time=30, min_periods=30).mean().max(dim="time", skipna=True)
#     )
#     ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min', 'TREFHT': 'TREFHT_min'})[['TREFHTMN_min', 'TREFHT_min']]
#     ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max', 'TREFHT': 'TREFHT_max'})[['PRECT_max', 'TREFHTMX_max', 'TREFHT_max']]
#     ds_30d = ds_30d.rename({'TREFHT': 'TREFHT_30d', 'PRECT': 'PRECT_30d'})[['TREFHT_30d', 'PRECT_30d']]
#     ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
#     ds_metrics_anom = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
#     ds_metrics_anom = ds_metrics_anom.rename({v: f"{v}_anom" for v in ds_metrics_anom.data_vars})

#     # ========================== #
#     # detrended metrics
#     ds_group = ds_all_detrend.groupby("time.year")
#     ds_max  = ds_group.max(dim="time",  skipna=True)
#     ds_min  = ds_group.min(dim="time",  skipna=True)
#     ds_mean  = ds_group.mean(dim="time",  skipna=True)
#     ds_30d = ds_group.map(
#         lambda x: x.rolling(time=30, min_periods=30).mean().max(dim="time", skipna=True)
#     )
#     ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min', 'TREFHT': 'TREFHT_min'})[['TREFHTMN_min', 'TREFHT_min']]
#     ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max', 'TREFHT': 'TREFHT_max'})[['PRECT_max', 'TREFHTMX_max', 'TREFHT_max']]
#     ds_30d = ds_30d.rename({'TREFHT': 'TREFHT_30d', 'PRECT': 'PRECT_30d'})[['TREFHT_30d', 'PRECT_30d']]
#     ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]
#     ds_metrics_detrend = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
#     ds_metrics_detrend = ds_metrics_detrend.rename({v: f"{v}_detrend" for v in ds_metrics_detrend.data_vars})

#     # ========================== #
#     # save
#     ds_final = xr.merge([ds_metrics, ds_metrics_anom, ds_metrics_detrend])
#     save_name = base_dir + 'metrics.zarr'
#     ds_final.to_zarr(save_name, mode='w')
#     print(save_name)

In [6]:
station_names = ['Pituffik', 'Fairbanks', 'Guam', 'Yuma_PG' ,'Fort_Bragg']

def annual_metrics(ds_in, suffix):
    """
    Compute yearly min/max/mean and 30-day rolling-mean max, then rename + suffix.
    Minimal change: same variables as your original.
    """
    g = ds_in.groupby("time.year")
    ds_max  = g.max("time", skipna=True)
    ds_min  = g.min("time", skipna=True)
    ds_mean = g.mean("time", skipna=True)
    ds_30d = (
        ds_in.rolling(time=30, min_periods=30).mean()
            .groupby("time.year").max("time", skipna=True)
    )

    ds_min = ds_min.rename({'TREFHTMN': 'TREFHTMN_min', 'TREFHT': 'TREFHT_min'})[['TREFHTMN_min', 'TREFHT_min']]
    ds_max = ds_max.rename({'PRECT': 'PRECT_max', 'TREFHTMX': 'TREFHTMX_max', 'TREFHT': 'TREFHT_max'})[['PRECT_max', 'TREFHTMX_max', 'TREFHT_max']]
    ds_30d = ds_30d.rename({'TREFHT': 'TREFHT_30d', 'PRECT': 'PRECT_30d'})[['TREFHT_30d', 'PRECT_30d']]
    ds_mean = ds_mean.rename({'PRECT': 'PRECT_mean', 'TREFHT': 'TREFHT_mean'})[['PRECT_mean', 'TREFHT_mean']]

    ds_out = xr.merge([ds_min, ds_max, ds_30d, ds_mean])
    return ds_out.rename({v: f"{v}_{suffix}" for v in ds_out.data_vars})

for stn in station_names:
    base_dir = f'/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/{stn}/'
    
    # get data
    list_ds = []
    for year in range(1958, 2025):
        fn = f'/glade/campaign/ral/hap/ksha/EPRI_data/ERA5_daily/{stn}_{year}.zarr'
        ds = xr.open_zarr(fn)
        list_ds.append(ds)
        
    ds_all = xr.concat(list_ds, dim='time')
    ds_all = ds_all[['PRECT', 'TREFHT', 'TREFHTMX', 'TREFHTMN']]

    ds_all['PRECT'] = ds_all['PRECT'] * 1000  # mm per day
    ds_all = ds_all.chunk({"time": -1})

    # ======================= #
    # metrics (compute via helper; avoids repeating logic)
    ds_metrics_default = annual_metrics(ds_all, "default")

    # ========================== #
    # save
    ds_final = ds_metrics_default 
    #xr.merge([ds_metrics_default, ds_metrics_anom, ds_metrics_detrend])
    
    save_name = base_dir + 'metrics.zarr'
    ds_final.to_zarr(save_name, mode='w')
    print(save_name)

/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Pituffik/metrics.zarr
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Fairbanks/metrics.zarr
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Guam/metrics.zarr
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Yuma_PG/metrics.zarr
/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Fort_Bragg/metrics.zarr


In [7]:
save_name

'/glade/derecho/scratch/ksha/EPRI_data/METRICS_STN/Fort_Bragg/metrics.zarr'