In [1]:
import os
import zarr
from glob import glob

import numpy as np
import xarray as xr
import pandas as pd

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
# fn_area = '/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/static/C404_GP_grid_area.zarr'
# ds_area = xr.open_zarr(fn_area)
# area_da = ds_area['c404_area']

### CESM-SSP

In [4]:
varname_verif = ['WRF_T2',]

In [8]:
ds_collection = []

for year in range(2070, 2075):

    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/prog_outputs/opt_CESM_SSP_{year}.zarr'
    ds = xr.open_zarr(fn)[varname_verif]

    ds_collection.append(ds)

ds_cesm = xr.concat(ds_collection, dim='time')

ds_cesm_tmean = ds_cesm.mean(('south_north', 'west_east'))
ds_cesm_tmean = ds_cesm_tmean.chunk({'time': -1})
ds_cesm_smean = ds_cesm.mean(('time',))
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/LAM_SSP_T2_tmean.zarr'
ds_cesm_tmean.to_zarr(save_name)

np_cesm_smean = ds_cesm_smean['WRF_T2'].values
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/LAM_SSP_T2_smean.npy'
np.save(save_name, np_cesm_smean)

### CESM-HIST

In [9]:
ds_collection = []

for year in range(1980, 1985):

    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/prog_outputs/opt_CESM_HIST_{year}.zarr'
    ds = xr.open_zarr(fn)[varname_verif]

    ds_collection.append(ds)

ds_cesm = xr.concat(ds_collection, dim='time')

ds_cesm_tmean = ds_cesm.mean(('south_north', 'west_east'))
ds_cesm_tmean = ds_cesm_tmean.chunk({'time': -1})
ds_cesm_smean = ds_cesm.mean(('time',))
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/LAM_HIST_T2_tmean.zarr'
ds_cesm_tmean.to_zarr(save_name)

np_cesm_smean = ds_cesm_smean['WRF_T2'].values
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/LAM_HIST_T2_smean.npy'
np.save(save_name, np_cesm_smean)

### CESM-LENS-SSP

In [11]:
varname_verif = ['VAR_2T',]

In [19]:
fn_fmt = '/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_SSP/CESM_GP_{}.zarr'

ds_collection = []
for year in range(2070, 2100):

    fn = fn_fmt.format(year)
    ds_year = xr.open_zarr(fn)[varname_verif]
    ds_collection.append(ds_year)

ds_cesm = xr.concat(ds_collection, dim='time')

ds_cesm_tmean = ds_cesm.mean(('south_north', 'west_east'))
ds_cesm_tmean = ds_cesm_tmean.chunk({'time': -1})
ds_cesm_smean = ds_cesm.mean(('time',))
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/CESM_SSP_T2_tmean.zarr'
ds_cesm_tmean.to_zarr(save_name, mode='w')

np_cesm_smean = ds_cesm_smean['VAR_2T'].values
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/CESM_SSP_T2_smean.npy'
np.save(save_name, np_cesm_smean)

### CESM-LENS-HIST

In [20]:
fn_fmt = '/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_HIST/CESM_GP_{}.zarr'

ds_collection = []
for year in range(1980, 2010):

    fn = fn_fmt.format(year)
    ds_year = xr.open_zarr(fn)[varname_verif]
    ds_collection.append(ds_year)

ds_cesm = xr.concat(ds_collection, dim='time')

ds_cesm_tmean = ds_cesm.mean(('south_north', 'west_east'))
ds_cesm_tmean = ds_cesm_tmean.chunk({'time': -1})
ds_cesm_smean = ds_cesm.mean(('time',))
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/CESM_HIST_T2_tmean.zarr'
ds_cesm_tmean.to_zarr(save_name, mode='w')

np_cesm_smean = ds_cesm_smean['VAR_2T'].values
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/CESM_HIST_T2_smean.npy'
np.save(save_name, np_cesm_smean)

## Diurnal cycle

In [4]:
def diurnal_cycle_stats(
    rmse_ds,
    time_dim,
    tz = None,
    q = (0.10, 0.90),
) -> xr.Dataset:
    """
    Compute diurnal (hour-of-day) stats for each numeric var in rmse_ds that has `time_dim`.
    Returns a Dataset with dims:
      - hour: 0..23
      - stat: ["mean", "std", "pXX", ...] (one entry per requested quantile)
    Other, non-time dims are preserved (e.g., vertical levels).

    Parameters
    ----------
    rmse_ds : xr.Dataset
        Hourly dataset with a time dimension (default name "Time").
    time_dim : str
        Name of the time dimension.
    tz : str | None
        If provided, convert times to this timezone before grouping by hour.
        If None, use the dataset's time as-is (often UTC).
    q : float or sequence of floats in [0, 1]
        Quantiles to compute (e.g., 0.10, 0.90). Order is preserved.
    """
    # Keep only numeric variables that depend on time
    ds = rmse_ds[[v for v, da in rmse_ds.data_vars.items()
                  if np.issubdtype(da.dtype, np.number) and (time_dim in da.dims)]]

    if ds.sizes.get(time_dim, 0) == 0:
        raise ValueError(f"No data along time dim '{time_dim}' to compute diurnal stats.")

    # Prepare time coordinate (optionally convert to local time)
    t = ds[time_dim]
    if tz is not None:
        # If naive, localize to UTC then convert; if tz-aware, just convert
        try:
            t_local = t.dt.tz_localize("UTC").dt.tz_convert(tz)
        except Exception:
            t_local = t.dt.tz_convert(tz)
    else:
        t_local = t
    
    # Group by hour-of-day
    gb = ds.groupby(t_local.dt.hour)
    
    # Mean and std over time
    mean_hour = gb.mean(dim=time_dim, skipna=True)
    std_hour  = gb.std(dim=time_dim, skipna=True)
    
    # Normalize/validate quantiles
    q_arr = np.atleast_1d(q).astype(float)
    if np.any((q_arr < 0) | (q_arr > 1)):
        raise ValueError("All quantiles in `q` must be within [0, 1].")

    # Compute quantiles (Dataset with 'quantile' dim)
    try:
        q_hour = gb.quantile(q=xr.DataArray(q_arr, dims="quantile"),
                             dim=time_dim, skipna=True, method="linear")
    except TypeError:  # xarray<2023.08 uses 'interpolation'
        q_hour = gb.quantile(q=xr.DataArray(q_arr, dims="quantile"),
                             dim=time_dim, skipna=True, interpolation="linear")

    # Build readable labels for the quantiles: e.g., 0.10 -> "p10", 0.025 -> "p2.5"
    def _q_label(qq: float) -> str:
        val = qq * 100.0
        if np.isclose(val, round(val)):      # integer percent
            return f"p{int(round(val))}"
        # otherwise keep one or two decimals, trimming trailing zeros
        s = f"{val:.2f}".rstrip("0").rstrip(".")
        return f"p{s}"

    q_labels = [_q_label(qq) for qq in q_arr]

    # Rename 'quantile' -> 'stat' and attach labels so we can concat with mean/std
    q_part = q_hour.rename({"quantile": "stat"}).assign_coords(stat=q_labels)

    # Put mean/std into a 'stat' dimension too
    mean_std = xr.concat(
        [mean_hour, std_hour],
        dim=xr.IndexVariable("stat", ["mean", "std"])
    )

    # Concatenate along 'stat'
    out = xr.concat([mean_std, q_part], dim="stat")
    out = out.chunk({"stat": -1, "hour": -1})  # ensure uniform chunks
    
    # Ensure hours 0..23 are present/in order even if some are missing
    out = out.reindex(hour=np.arange(24))

    return out

### CESM-SSP

In [5]:
varname_diurnal = ['WRF_T2',]

In [12]:
ds_collection = []
for year in range(2070, 2075):
    fn = fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/prog_outputs/opt_CESM_SSP_{year}.zarr'
    ds_year = xr.open_zarr(fn)[varname_diurnal].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)

ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_diurnal = diurnal_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_SSP_cycle.zarr'
ds_diurnal.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_SSP_cycle.zarr


### CESM-HIST

In [14]:
ds_collection = []
for year in range(1980, 1981):
    fn = fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/prog_outputs/opt_CESM_HIST_{year}.zarr'
    ds_year = xr.open_zarr(fn)[varname_diurnal].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)

ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_diurnal = diurnal_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_HIST_cycle.zarr'
ds_diurnal.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_HIST_cycle.zarr


### CESM-LENS-SSP

In [19]:
varname_diurnal = ['VAR_2T',]

In [20]:
ds_collection = []
for year in range(2080, 2100):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_SSP/CESM_GP_{year}.zarr'
    ds_year = xr.open_zarr(fn)[varname_diurnal].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)
    
ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_diurnal = diurnal_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_SSP_cycle.zarr'
ds_diurnal.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_SSP_cycle.zarr


### CESM-LENS-HIST

In [21]:
ds_collection = []
for year in range(1980, 2010):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_HIST/CESM_GP_{year}.zarr'
    ds_year = xr.open_zarr(fn)[varname_diurnal].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)

ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_diurnal = diurnal_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_HIST_cycle.zarr'
ds_diurnal.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_HIST_cycle.zarr


## Domain-averaged TMAX/TMIN

### CESM-SSP

In [6]:
varname_pick = ['WRF_T2_min', 'WRF_T2_max']

In [10]:
ds_collection = []
for year in range(2070, 2075):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/prog_outputs/opt_CESM_SSP_{year}_daily.zarr'
    ds_year = xr.open_zarr(fn)[varname_pick].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)

ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_SSP_daily_clean.zarr'
ds_CLIM.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

### CESM-HIST

In [None]:
varname_pick = ['WRF_T2_min', 'WRF_T2_max']

In [11]:
ds_collection = []
for year in range(1980, 1985):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/prog_outputs/opt_CESM_HIST_{year}_daily.zarr'
    ds_year = xr.open_zarr(fn)[varname_pick].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)

ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_HIST_daily_clean.zarr'
ds_CLIM.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_HIST_daily_clean.zarr


### CESM-LENS2-SSP

In [12]:
varname_pick = ["VAR_2T_min", "VAR_2T_max"]

In [13]:
ds_collection = []
for year in range(2070, 2075):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/prog_outputs/CESM_LENS2_SSP_{year}_daily.zarr'
    ds_year = xr.open_zarr(fn)[varname_pick].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)

ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_SSP_daily_clean.zarr'
ds_CLIM.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_SSP_daily_clean.zarr


### CESM-LENS2-HIST

In [14]:
varname_pick = ["VAR_2T_min", "VAR_2T_max"]

In [15]:
ds_collection = []
for year in range(1980, 1985):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/prog_outputs/CESM_LENS2_HIST_{year}_daily.zarr'
    ds_year = xr.open_zarr(fn)[varname_pick].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)

ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_HIST_daily_clean.zarr'
ds_CLIM.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_HIST_daily_clean.zarr
