In [None]:
import os
import json
import warnings
import numpy as np
import xarray as xr
import proplot as pplt
warnings.filterwarnings('ignore')
pplt.rc.update({'reso':'xx-hi','figure.dpi':100})

In [None]:
with open('/global/cfs/cdirs/m4334/sferrett/monsoon-kernels/scripts/configs.json','r',encoding='utf-8') as f:
    CONFIGS = json.load(f)
SPLITSDIR  = CONFIGS['filepaths']['splits']
WEIGHTSDIR = CONFIGS['filepaths']['weights']
MODELS     = CONFIGS['models']
SPLIT      = 'valid'

In [None]:
# Load pressure grid and level thickness weights
with xr.open_dataset(f'{SPLITSDIR}/{SPLIT}.h5',engine='h5netcdf') as ds:
    dlev = ds['dlev'].load()
    lev  = ds['lev'].load()

# Load all kernel weight files
results = {}
for model in MODELS:
    name = model['name']
    if 'kernel' not in name:
        continue
    description = model['description']
    filepath = os.path.join(WEIGHTSDIR,f'{name}_{SPLIT}_weights.nc')
    if not os.path.exists(filepath):
        continue
    with xr.open_dataset(filepath,engine='h5netcdf') as ds:
        weights = ds['weights'].load()
    
    # Compute Δp-weighted kernel (shows contribution in pressure space)
    weighted = weights * dlev
    mean = weighted.mean(dim='member') if 'member' in weighted.dims else weighted
    std  = weighted.std(dim='member') if 'member' in weighted.dims else None
    results[name] = {'description': description, 'mean': mean, 'std': std}

print(f'Found {len(results)} kernel models with saved weights')

In [None]:
# Plot: 4x3 grid showing all models
field_labels = ['RH', r'$\theta_e$', r'$\theta_e^*$']
nrows, ncols = 4, 3

fig, axs = pplt.subplots(nrows=nrows, ncols=ncols, refwidth=2, refheight=2, share=True)
axs.format(
    xlabel=r'$\Delta p$-weighted kernel',
    xlim=(-2, 2), xticks=1, xminorticks='none',
    ylabel='Pressure (hPa)', yminorticks='none', yreverse=True
)

names = list(results.keys())
for k, name in enumerate(names[:nrows * ncols]):
    ax = axs[k]
    mean = results[name]['mean']
    std  = results[name]['std']
    
    # Determine field dimension
    field_dim = 'field' if 'field' in mean.dims else next(d for d in mean.dims if d != 'lev')
    nfields = mean.sizes[field_dim]
    
    # Plot each field
    for i in range(nfields):
        field_mean = mean.isel({field_dim: i})
        ax.plot(field_mean, lev, linewidth=1)
        
        if std is not None:
            field_std = std.isel({field_dim: i})
            ax.fill_betweenx(lev, field_mean - field_std, field_mean + field_std, alpha=0.2)
    
    ax.format(title=f'{name}\n{results[name]["description"]}')

pplt.show()

In [None]:
# Plot: Organized 3x4 grid (fields × kernel types)
kernel_families = [
    ('nonparametric', 'Nonparametric'),
    ('gaussian', 'Gaussian'),
    ('exponential', 'Exponential'),
    ('tophat', 'Top-Hat')
]

# Find one model per kernel family
picked = {}
for family, title in kernel_families:
    for name in results:
        if family in name and 'vertical' in name:  # Prioritize vertical models
            picked[family] = name
            break

print(f'Selected models: {picked}')

# Infer field count from first model
any_name = next(iter(results))
mean0 = results[any_name]['mean']
field_dim = 'field' if 'field' in mean0.dims else next(d for d in mean0.dims if d != 'lev')
nfields = mean0.sizes[field_dim]
nkernels = len(kernel_families)

field_labels = ['RH', r'$\theta_e$', r'$\theta_e^*$'] if nfields == 3 else [f'Field {i}' for i in range(nfields)]

# Create grid: fields (rows) × kernel types (cols)
fig, axs = pplt.subplots(nrows=nfields, ncols=nkernels, refwidth=2, refheight=2.5, share=True)
axs.format(
    xlabel=r'$\Delta p$-weighted kernel',
    xlim=(-2, 2), xticks=1, xminorticks='none',
    ylabel='Pressure (hPa)', yminorticks='none', yreverse=True
)

for c, (family, fam_title) in enumerate(kernel_families):
    name = picked.get(family)
    if name is None:
        for r in range(nfields):
            axs[r, c].format(title=fam_title if r == 0 else None)
        continue
    
    mean = results[name]['mean']
    std  = results[name]['std']
    
    for r in range(nfields):
        ax = axs[r, c]
        field_mean = mean.isel({field_dim: r})
        ax.plot(field_mean, lev, linewidth=1.5, color='black')
        
        if std is not None:
            field_std = std.isel({field_dim: r})
            ax.fill_betweenx(lev, field_mean - field_std, field_mean + field_std, 
                            alpha=0.3, color='gray6')
        
        # Column titles only on top row
        if r == 0:
            ax.format(title=fam_title)
        
        # Row labels on left column
        if c == 0:
            ax.format(ylabel=f'{field_labels[r]}\nPressure (hPa)')

pplt.show()