In [1]:
import os
import zarr
from glob import glob

import numpy as np
import xarray as xr
import pandas as pd

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
# fn_area = '/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/static/C404_GP_grid_area.zarr'
# ds_area = xr.open_zarr(fn_area)
# area_da = ds_area['c404_area']

### CESM-SSP

In [4]:
varname_verif = ['WRF_PWAT_05',]

In [5]:
ds_collection = []

for year in range(2070, 2075):
    
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/full_output/opt_CESM_SSP_{year}_full.zarr'
    ds = xr.open_zarr(fn)[varname_verif]
    ds['WRF_PWAT'] = ds['WRF_PWAT_05']**2
    ds = ds.drop_vars(['WRF_PWAT_05'])
    ds_collection.append(ds)

ds_cesm = xr.concat(ds_collection, dim='time')
ds_cesm_tmean = ds_cesm.mean(('south_north', 'west_east'))
ds_cesm_tmean = ds_cesm_tmean.chunk({'time': -1})
ds_cesm_smean = ds_cesm.mean(('time',))
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/LAM_SSP_PWAT_tmean.zarr'
ds_cesm_tmean.to_zarr(save_name, mode='w')

np_cesm_smean = ds_cesm_smean['WRF_PWAT'].values
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/LAM_SSP_PWAT_smean.npy'
np.save(save_name, np_cesm_smean)

### CESM-HIST

In [None]:
varname_verif = ['WRF_PWAT_05',]

In [6]:
ds_collection = []

for year in range(1980, 1985):

    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/full_output/opt_CESM_HIST_{year}_full.zarr'
    ds = xr.open_zarr(fn)[varname_verif]
    ds['WRF_PWAT'] = ds['WRF_PWAT_05']**2
    ds = ds.drop_vars(['WRF_PWAT_05'])
    ds_collection.append(ds)

ds_cesm = xr.concat(ds_collection, dim='time')
ds_cesm_tmean = ds_cesm.mean(('south_north', 'west_east'))
ds_cesm_tmean = ds_cesm_tmean.chunk({'time': -1})
ds_cesm_smean = ds_cesm.mean(('time',))
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/LAM_HIST_PWAT_tmean.zarr'
ds_cesm_tmean.to_zarr(save_name, mode='w')

np_cesm_smean = ds_cesm_smean['WRF_PWAT'].values
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/LAM_HIST_PWAT_smean.npy'
np.save(save_name, np_cesm_smean)

In [7]:
#

### CESM-LENS-SSP

In [8]:
varname_verif = ['TMQ',]

In [9]:
fn_fmt = '/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_SSP_verif/CESM_GP_{}.zarr'

ds_collection = []
for year in range(2070, 2100):

    fn = fn_fmt.format(year)
    ds_year = xr.open_zarr(fn)[varname_verif]
    ds_collection.append(ds_year)

ds_cesm = xr.concat(ds_collection, dim='time')

ds_cesm_tmean = ds_cesm.mean(('south_north', 'west_east'))
ds_cesm_tmean = ds_cesm_tmean.chunk({'time': -1})
ds_cesm_smean = ds_cesm.mean(('time',))
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/CESM_SSP_PWAT_tmean.zarr'
ds_cesm_tmean.to_zarr(save_name, mode='w')

np_cesm_smean = ds_cesm_smean[varname_verif[0]].values
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/CESM_SSP_PWAT_smean.npy'
np.save(save_name, np_cesm_smean)

### CESM-LENS-HIST

In [None]:
varname_verif = ['TMQ',]

In [10]:
fn_fmt = '/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_HIST_verif/CESM_GP_{}.zarr'

ds_collection = []
for year in range(1980, 2010):

    fn = fn_fmt.format(year)
    ds_year = xr.open_zarr(fn)[varname_verif]
    ds_collection.append(ds_year)

ds_cesm = xr.concat(ds_collection, dim='time')

ds_cesm_tmean = ds_cesm.mean(('south_north', 'west_east'))
ds_cesm_tmean = ds_cesm_tmean.chunk({'time': -1})
ds_cesm_smean = ds_cesm.mean(('time',))
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/CESM_HIST_PWAT_tmean.zarr'
ds_cesm_tmean.to_zarr(save_name, mode='w')

np_cesm_smean = ds_cesm_smean[varname_verif[0]].values
save_name = '/glade/campaign/ral/hap/ksha/GWC_results/CESM_scores/CESM_HIST_PWAT_smean.npy'
np.save(save_name, np_cesm_smean)

In [11]:
#

## Diurnal cycle

In [12]:
def diurnal_cycle_stats(
    rmse_ds,
    time_dim,
    tz = None,
    q = (0.10, 0.90),
) -> xr.Dataset:
    """
    Compute diurnal (hour-of-day) stats for each numeric var in rmse_ds that has `time_dim`.
    Returns a Dataset with dims:
      - hour: 0..23
      - stat: ["mean", "std", "pXX", ...] (one entry per requested quantile)
    Other, non-time dims are preserved (e.g., vertical levels).

    Parameters
    ----------
    rmse_ds : xr.Dataset
        Hourly dataset with a time dimension (default name "Time").
    time_dim : str
        Name of the time dimension.
    tz : str | None
        If provided, convert times to this timezone before grouping by hour.
        If None, use the dataset's time as-is (often UTC).
    q : float or sequence of floats in [0, 1]
        Quantiles to compute (e.g., 0.10, 0.90). Order is preserved.
    """
    # Keep only numeric variables that depend on time
    ds = rmse_ds[[v for v, da in rmse_ds.data_vars.items()
                  if np.issubdtype(da.dtype, np.number) and (time_dim in da.dims)]]

    if ds.sizes.get(time_dim, 0) == 0:
        raise ValueError(f"No data along time dim '{time_dim}' to compute diurnal stats.")

    # Prepare time coordinate (optionally convert to local time)
    t = ds[time_dim]
    if tz is not None:
        # If naive, localize to UTC then convert; if tz-aware, just convert
        try:
            t_local = t.dt.tz_localize("UTC").dt.tz_convert(tz)
        except Exception:
            t_local = t.dt.tz_convert(tz)
    else:
        t_local = t
    
    # Group by hour-of-day
    gb = ds.groupby(t_local.dt.hour)
    
    # Mean and std over time
    mean_hour = gb.mean(dim=time_dim, skipna=True)
    std_hour  = gb.std(dim=time_dim, skipna=True)
    
    # Normalize/validate quantiles
    q_arr = np.atleast_1d(q).astype(float)
    if np.any((q_arr < 0) | (q_arr > 1)):
        raise ValueError("All quantiles in `q` must be within [0, 1].")

    # Compute quantiles (Dataset with 'quantile' dim)
    try:
        q_hour = gb.quantile(q=xr.DataArray(q_arr, dims="quantile"),
                             dim=time_dim, skipna=True, method="linear")
    except TypeError:  # xarray<2023.08 uses 'interpolation'
        q_hour = gb.quantile(q=xr.DataArray(q_arr, dims="quantile"),
                             dim=time_dim, skipna=True, interpolation="linear")

    # Build readable labels for the quantiles: e.g., 0.10 -> "p10", 0.025 -> "p2.5"
    def _q_label(qq: float) -> str:
        val = qq * 100.0
        if np.isclose(val, round(val)):      # integer percent
            return f"p{int(round(val))}"
        # otherwise keep one or two decimals, trimming trailing zeros
        s = f"{val:.2f}".rstrip("0").rstrip(".")
        return f"p{s}"

    q_labels = [_q_label(qq) for qq in q_arr]

    # Rename 'quantile' -> 'stat' and attach labels so we can concat with mean/std
    q_part = q_hour.rename({"quantile": "stat"}).assign_coords(stat=q_labels)

    # Put mean/std into a 'stat' dimension too
    mean_std = xr.concat(
        [mean_hour, std_hour],
        dim=xr.IndexVariable("stat", ["mean", "std"])
    )

    # Concatenate along 'stat'
    out = xr.concat([mean_std, q_part], dim="stat")
    out = out.chunk({"stat": -1, "hour": -1})  # ensure uniform chunks
    
    # Ensure hours 0..23 are present/in order even if some are missing
    out = out.reindex(hour=np.arange(24))

    return out

### CESM-SSP

In [21]:
varname_diurnal = ['WRF_PWAT_05',]

In [22]:
ds_collection = []
for year in range(2070, 2075):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/full_output/opt_CESM_SSP_{year}_full.zarr'
    ds = xr.open_zarr(fn)[varname_diurnal]
    ds['WRF_PWAT'] = ds['WRF_PWAT_05']**2
    ds = ds.drop_vars(['WRF_PWAT_05'])
    ds_year = ds.mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)
    
ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_diurnal = diurnal_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_SSP_PWAT_cycle.zarr'
ds_diurnal.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_SSP_PWAT_cycle.zarr


### CESM-HIST

In [23]:
varname_diurnal = ['WRF_PWAT_05',]

In [24]:
ds_collection = []
for year in range(1980, 1985):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/full_output/opt_CESM_HIST_{year}_full.zarr'
    ds = xr.open_zarr(fn)[varname_diurnal]
    ds['WRF_PWAT'] = ds['WRF_PWAT_05']**2
    ds = ds.drop_vars(['WRF_PWAT_05'])
    ds_year = ds.mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)
    
ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_diurnal = diurnal_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_HIST_PWAT_cycle.zarr'
ds_diurnal.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_HIST_PWAT_cycle.zarr


### CESM-LENS-SSP

In [17]:
varname_diurnal = ['TMQ',]

In [18]:
ds_collection = []
for year in range(2080, 2100):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_SSP_verif/CESM_GP_{year}.zarr'
    ds_year = xr.open_zarr(fn)[varname_diurnal].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)
    
ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_diurnal = diurnal_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_SSP_PWAT_cycle.zarr'
ds_diurnal.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_SSP_PWAT_cycle.zarr


### CESM-LENS-HIST

In [19]:
varname_diurnal = ['TMQ',]

In [20]:
ds_collection = []
for year in range(1980, 2010):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_HIST_verif/CESM_GP_{year}.zarr'
    ds_year = xr.open_zarr(fn)[varname_diurnal].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)

ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_diurnal = diurnal_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_HIST_PWAT_cycle.zarr'
ds_diurnal.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_HIST_PWAT_cycle.zarr


## Annual cycle

In [25]:
def annual_cycle_stats(
    rmse_ds, time_dim,
    tz=None, q=(0.10, 0.90),
):
    """
    Compute monthly (month-of-year) statistics of an hourly Dataset.

    Returns a Dataset with dims:
      - month: 1..12
      - stat: ["mean", "std", "p10", "p90"]  (or more if you pass more quantiles)
    Other, non-time dims are preserved (e.g., bottom_top, south_north, west_east).
    """
    # Keep only numeric vars that depend on time
    ds = rmse_ds[[v for v, da in rmse_ds.data_vars.items()
                  if np.issubdtype(da.dtype, np.number) and (time_dim in da.dims)]]

    if ds.sizes.get(time_dim, 0) == 0:
        raise ValueError(f"No data along time dim '{time_dim}' to compute monthly stats.")

    # Prepare time coordinate (optionally convert to local time)
    t = ds[time_dim]
    if tz is not None:
        # If naive, localize to UTC then convert; if tz-aware, just convert
        try:
            t_local = t.dt.tz_localize("UTC").dt.tz_convert(tz)
        except Exception:
            t_local = t.dt.tz_convert(tz)
    else:
        t_local = t

    # Group by month-of-year
    gb = ds.groupby(t_local.dt.month)

    # Mean & std over time
    mean_m = gb.mean(dim=time_dim, skipna=True)
    std_m  = gb.std(dim=time_dim, skipna=True)

    # Quantiles
    q_arr = np.atleast_1d(q).astype(float)
    if np.any((q_arr < 0) | (q_arr > 1)):
        raise ValueError("All quantiles in `q` must be within [0, 1].")

    try:
        q_m = gb.quantile(q=xr.DataArray(q_arr, dims="quantile"),
                          dim=time_dim, skipna=True, method="linear")
    except TypeError:  # xarray<2023.08 uses 'interpolation'
        q_m = gb.quantile(q=xr.DataArray(q_arr, dims="quantile"),
                          dim=time_dim, skipna=True, interpolation="linear")

    # Build readable labels, e.g. 0.10 -> "p10", 0.025 -> "p2.5"
    def _q_label(qq: float) -> str:
        val = qq * 100.0
        if np.isclose(val, round(val)):
            return f"p{int(round(val))}"
        s = f"{val:.2f}".rstrip("0").rstrip(".")
        return f"p{s}"

    q_labels = [_q_label(qq) for qq in q_arr]

    # Put quantiles under a 'stat' coordinate
    q_part = q_m.rename({"quantile": "stat"}).assign_coords(stat=q_labels)

    # Put mean/std under 'stat' too and concatenate
    mean_std = xr.concat(
        [mean_m, std_m],
        dim=xr.IndexVariable("stat", ["mean", "std"])
    )
    out = xr.concat([mean_std, q_part], dim="stat")

    # Ensure months 1..12 in order (missing months become NaN)
    out = out.reindex(month=np.arange(1, 13))

    # (Optional) make Zarrâ€‘friendly small chunks on the new dims
    out = out.chunk({"stat": -1, "month": -1})

    return out

### CESM-SSP

In [28]:
varname_annual = ['WRF_PWAT_05',]

In [29]:
ds_collection = []
for year in range(2070, 2075):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/full_output/opt_CESM_SSP_{year}_full.zarr'
    ds = xr.open_zarr(fn)[varname_annual]
    ds['WRF_PWAT'] = ds['WRF_PWAT_05']**2
    ds = ds.drop_vars(['WRF_PWAT_05'])
    ds_year = ds.mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)
    
ds_CLIM = xr.concat(ds_collection, dim='time')

ds_month = annual_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])
save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_SSP_PWAT_annual.zarr'
ds_month.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_SSP_PWAT_annual.zarr


### CESM-HIST

In [30]:
varname_annual = ['WRF_PWAT_05',]

In [32]:
ds_collection = []
for year in range(1980, 1985):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/opt_init_ERA5/full_output/opt_CESM_HIST_{year}_full.zarr'
    ds = xr.open_zarr(fn)[varname_annual]
    ds['WRF_PWAT'] = ds['WRF_PWAT_05']**2
    ds = ds.drop_vars(['WRF_PWAT_05'])
    ds_year = ds.mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)
    
ds_CLIM = xr.concat(ds_collection, dim='time')
ds_month = annual_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_HIST_PWAT_annual.zarr'
ds_month.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_HIST_PWAT_annual.zarr


### CESM-LENS SSP

In [36]:
varname_annual = ['TMQ',]

In [40]:
ds_collection = []
for year in range(2080, 2100):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_SSP_verif/CESM_GP_{year}.zarr'
    ds_year = xr.open_zarr(fn)[varname_annual].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)
    
ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_month = annual_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_SSP_PWAT_annual.zarr'
ds_month.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_SSP_PWAT_annual.zarr


### CESM-LENS HIST

In [41]:
varname_annual = ['TMQ',]

In [42]:
ds_collection = []
for year in range(1980, 2010):
    fn = f'/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_CESM_HIST_verif/CESM_GP_{year}.zarr'
    ds_year = xr.open_zarr(fn)[varname_annual].mean(('south_north', 'west_east'))
    ds_year = ds_year.chunk(dict(time=-1))
    ds_collection.append(ds_year)
    
ds_CLIM = xr.concat(ds_collection, dim='time').chunk(dict(time=-1))
ds_month = annual_cycle_stats(ds_CLIM, time_dim='time', q=[0.10, 0.90])

save_name = f'/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_HIST_PWAT_annual.zarr'
ds_month.to_zarr(save_name, mode='w', consolidated=True, compute=True)
print(save_name)

/glade/campaign/ral/hap/ksha/GWC_results/clim/CESM_LENS_HIST_PWAT_annual.zarr
