In [70]:
import os
import json
import warnings
import numpy as np
import xarray as xr
import proplot as pplt
warnings.filterwarnings('ignore')
pplt.rc.update({'reso':'xx-hi','figure.dpi':100})

import matplotlib.ticker as mticker

In [64]:
with open('/global/cfs/cdirs/m4334/sferrett/monsoon-kernels/scripts/configs.json','r',encoding='utf-8') as f:
    CONFIGS = json.load(f)
SPLITSDIR  = CONFIGS['filepaths']['splits']
WEIGHTSDIR = CONFIGS['filepaths']['weights']
MODELS     = CONFIGS['models']
SPLIT      = 'valid'

In [None]:
with xr.open_dataset(f'{SPLITSDIR}/{SPLIT}.h5',engine='h5netcdf') as ds:
    dlev = ds['dlev'].load()
    lev  = ds['lev'].load()

results = {}
for model in MODELS:
    name = model['name']
    if 'kernel' not in name:
        continue
    description = model['description']
    filepath = os.path.join(WEIGHTSDIR,f'{name}_{SPLIT}_weights.nc')
    if not os.path.exists(filepath):
        continue
    with xr.open_dataset(filepath,engine='h5netcdf') as ds:
        # Check if this has component weights (k1, k2) or single kernel (k)
        if 'k1' in ds and 'k2' in ds:
            # Mixture kernel with components
            weights_c1 = ds['k1'].load()
            weights_c2 = ds['k2'].load()
            mean_c1 = weights_c1.mean(dim='member') if 'member' in weights_c1.dims else weights_c1
            mean_c2 = weights_c2.mean(dim='member') if 'member' in weights_c2.dims else weights_c2
            std_c1 = weights_c1.std(dim='member') if 'member' in weights_c1.dims else None
            std_c2 = weights_c2.std(dim='member') if 'member' in weights_c2.dims else None
            component_means = (mean_c1, mean_c2)
            component_stds = (std_c1, std_c2)
            # Use combined weights for the main mean/std
            mean = mean_c1 + mean_c2  # Approximate combined
            std = None
        elif 'k' in ds:
            # Single kernel (non-mixture)
            weights = ds['k'].load()
            mean = weights.mean(dim='member') if 'member' in weights.dims else weights
            std = weights.std(dim='member') if 'member' in weights.dims else None
            component_means = None
            component_stds = None
        else:
            # Legacy format with 'weights' variable
            weights = ds['weights'].load()
            mean = weights.mean(dim='member') if 'member' in weights.dims else weights
            std = weights.std(dim='member') if 'member' in weights.dims else None
            component_means = None
            component_stds = None
    
    results[name] = dict(description=description,mean=mean,std=std,
                        component_means=component_means,component_stds=component_stds)

print(f'Found {len(results)} kernel models with saved weights')

In [None]:
kernelfamilies = [
    ('nonparametric','Nonparametric'),
    ('gaussian','Gaussian'),
    ('mixture','Mixed Gaussian'),
    ('exponential','Exponential'),
    ('bidirectional','Bidirectional Exponential'),
    ('cosine','Cosine'),
]

family_to_name = {
    family: next((name for name in results if family in name), None)
    for family,_ in kernelfamilies}

fig,axs = pplt.subplots(nrows=3,ncols=len(kernelfamilies),refwidth=2,refheight=2.5,sharex=False,sharey=True)
axs.format(suptitle='Raw Kernel Weights',
           xlabel='',xticks=0.01,xminorticks='none',
           ylabel='Pressure (hPa)',yminorticks='none',yreverse=True,
           rowlabels=['RH',r'$\theta_e$',r'$\theta_e^*$'])

for j,(family,title) in enumerate(kernelfamilies):
    axs[0,j].format(title=title)
    name = family_to_name[family]
    if name is None:
        continue
    mean = results[name]['mean']
    std  = results[name]['std']
    component_means = results[name]['component_means']
    component_stds = results[name]['component_stds']
    
    for i in range(3):
        ax = axs[i,j]
        
        # For mixture kernels, plot components separately
        if component_means is not None:
            mean_c1, mean_c2 = component_means
            m_c1 = mean_c1.isel(field=i)
            m_c2 = mean_c2.isel(field=i)
            
            # Plot component 1 as solid line
            ax.plot(m_c1,lev,linewidth=1.5,color='C0',label='Component 1')
            # Plot component 2 as dashed line
            ax.plot(m_c2,lev,linewidth=1.5,color='C1',linestyle='--',label='Component 2')
            
            xmax = max(float(np.nanmax(np.abs(m_c1.values))), float(np.nanmax(np.abs(m_c2.values))))
            
            # Add uncertainty bands if available
            if component_stds is not None and component_stds[0] is not None:
                std_c1, std_c2 = component_stds
                s_c1 = std_c1.isel(field=i)
                s_c2 = std_c2.isel(field=i)
                ax.fill_betweenx(lev,m_c1-s_c1,m_c1+s_c1,alpha=0.2,color='C0')
                ax.fill_betweenx(lev,m_c2-s_c2,m_c2+s_c2,alpha=0.2,color='C1')
                xmax = max(xmax,
                          float(np.nanmax(np.abs((m_c1-s_c1).values))),
                          float(np.nanmax(np.abs((m_c1+s_c1).values))),
                          float(np.nanmax(np.abs((m_c2-s_c2).values))),
                          float(np.nanmax(np.abs((m_c2+s_c2).values))))
            
            # Add legend for mixture plots (only on top row)
            if i == 0:
                ax.legend(loc='upper right',ncols=1,fontsize=7)
        else:
            # For non-mixture kernels, plot single line as before
            m = mean.isel(field=i)
            ax.plot(m,lev,linewidth=1.5)
            xmax = float(np.nanmax(np.abs(m.values)))
            if std is not None:
                s = std.isel(field=i)
                ax.fill_betweenx(lev,m-s,m+s,alpha=0.3)
                xmax = max(xmax,float(np.nanmax(np.abs((m-s).values))),float(np.nanmax(np.abs((m+s).values))))
        
        ax.format(xlim=(-1.05*xmax,1.05*xmax))
        ax.xaxis.set_major_locator(mticker.MaxNLocator(nbins=5,symmetric=True))
        ticks = ax.xaxis.get_major_locator().tick_values(-1.05*xmax,1.05*xmax)
        ax.format(xticks=ticks)

pplt.show()

fig.save('../figs/weights.jpeg',dpi=300)