In [2]:
import os
import sys
import json
import warnings
import numpy as np
import xarray as xr
import proplot as pplt
warnings.filterwarnings('ignore')
pplt.rc.update({'reso':'xx-hi','figure.dpi':100})

In [3]:
with open('/global/cfs/cdirs/m4334/sferrett/monsoon-kernels/scripts/configs.json','r',encoding='utf-8') as f:
    CONFIGS = json.load(f)
SPLITSDIR  = CONFIGS['filepaths']['splits']    
MODELSDIR  = CONFIGS['filepaths']['models']
PREDSDIR   = CONFIGS['filepaths']['predictions']
MODELS     = CONFIGS['models']
SPLIT      = 'valid'   

In [None]:
import glob

with xr.open_dataset(os.path.join(SPLITSDIR,f'{SPLIT}.h5'),engine='h5netcdf') as ds:
    truepr = ds.pr.load()

results = {}
for model in MODELS:
    name,description = model['name'],model['description']
    # Look for all seed files matching this model
    pattern = f'{name}_seed*_{SPLIT}_predictions.nc'
    filepaths = glob.glob(os.path.join(PREDSDIR,pattern))
    
    if filepaths:
        # Load predictions for all seeds
        seed_results = []
        for filepath in filepaths:
            with xr.open_dataset(filepath,engine='h5netcdf') as ds:
                predpr = ds.pr.load()
            ytrue,ypred = xr.align(truepr,predpr,join='inner')
            seed_results.append(dict(ytrue=ytrue,ypred=ypred))
        results[name] = dict(description=description,seeds=seed_results)
print(f'Found {len(results)} completed models for `{SPLIT}`')

In [7]:
def get_r2(ytrue,ypred,dims=None):
    dims  = list(ytrue.dims) if dims is None else dims
    ssres = ((ytrue-ypred)**2).sum(dim=dims,skipna=True)
    sstot = ((ytrue-ytrue.mean(dim=dims,skipna=True))**2).sum(dim=dims,skipna=True)
    return 1-ssres/sstot

In [None]:
barsdata = []
for runname,result in results.items():
    description = result['description']
    seeds = result['seeds']
    
    # Compute R² for each seed
    r2_values = []
    for seed_result in seeds:
        ytrue,ypred = seed_result['ytrue'],seed_result['ypred']
        r2 = get_r2(ytrue,ypred,dims=None)
        r2_values.append(float(r2))
    
    # Compute mean and std across seeds
    r2_mean = np.mean(r2_values)
    r2_std = np.std(r2_values) if len(r2_values) > 1 else 0.0
    
    color = ('yellow3' if ('kernel' in runname and 'nonparametric' in runname)
             else 'red4' if ('kernel' in runname and 'nonparametric' not in runname)
             else 'blue4')
    barsdata.append((runname,description,r2_mean,r2_std,color))

barsdata = sorted(barsdata,key=lambda x:x[2])

labels = [item[1] for item in barsdata]
r2s_mean = [item[2] for item in barsdata]
r2s_std = [item[3] for item in barsdata]
colors = [item[4] for item in barsdata]

In [None]:
fig,ax = pplt.subplots(nrows=1,ncols=1,refwidth=8,refheight=2)
ax.format(xlabel='Model',ylabel='Validation Set R$^2$',yticks=0.1,yminorticks=0.05,grid=False)
bars = ax.bar(labels,r2s_mean,yerr=r2s_std,color=colors,capsize=3)
for bar,r2_m,r2_s in zip(bars,r2s_mean,r2s_std):
    label_text = f'{r2_m:.3f}' if r2_s == 0 else f'{r2_m:.3f}±{r2_s:.3f}'
    ax.text(bar.get_x()+bar.get_width()/2,bar.get_height()-0.02,label_text,ha='center',va='top',fontsize=8)
pplt.show()

fig.save('../figs/bars.jpeg',dpi=300)