# Combining verification results on individual days to a single file

In [None]:
import os
import sys
import yaml
import argparse
from glob import glob
from datetime import datetime, timedelta

import numpy as np
import xarray as xr

In [None]:
sys.path.insert(0, os.path.realpath('../libs/'))
import verif_utils as vu

In [None]:
config_name = os.path.realpath('verif_config.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [None]:
model_names = ['wxformer', 'IFS']
VERIF = {} # the dict that collects all result for this notebook only
# will save as netCDF

# file name indices
IND_max = 2192 # the ind of the last day
INDs = np.arange(0, IND_max+40, 40) # qsub script creates files on every 40 days 
INDs[-1] = IND_max

In [None]:
for model_name in model_names:
    # file names to load
    verif_lead_range = conf[model_name]['verif_lead_range']
    path_ACC_verif = conf[model_name]['save_loc_verif']+'combined_acc_{:04d}_{:04d}_{:03d}h_{:03d}h_{}.nc'
    path_RMSE_verif = conf[model_name]['save_loc_verif']+'combined_rmse_{:04d}_{:04d}_{:03d}h_{:03d}h_{}.nc'

    # file names to save
    path_ACC_save = conf[model_name]['save_loc_verif']+'ACC_{:03d}h_{:03d}h_{}.nc'
    path_RMSE_save = conf[model_name]['save_loc_verif']+'RMSE_{:03d}h_{:03d}h_{}.nc'

    # load xarray.Dataset and merge all verified days
    ACC_verif = []
    RMSE_verif = []

    for i, ind_start in enumerate(INDs[:-1]):
        ind_end = INDs[i+1]
        filename = path_ACC_verif.format(ind_start, ind_end, verif_lead_range[0], verif_lead_range[-1], model_name)
        ds_verf_temp = xr.open_dataset(filename)
        ACC_verif.append(ds_verf_temp)
    
        filename = path_RMSE_verif.format(ind_start, ind_end, verif_lead_range[0], verif_lead_range[-1], model_name)
        ds_verf_temp = xr.open_dataset(filename)
        RMSE_verif.append(ds_verf_temp)

    # merge by concat
    ds_ACC_verif = xr.concat(ACC_verif, dim='days')
    ds_RMSE_verif = xr.concat(RMSE_verif, dim='days')

    # save to one dictionary for some checking
    VERIF['{}_ACC'.format(model_name)] = ds_ACC_verif
    VERIF['{}_RMSE'.format(model_name)] = ds_RMSE_verif

    # save to nc
    save_name_ACC = path_ACC_save.format(verif_lead_range[0], verif_lead_range[-1], model_name)
    ds_ACC_verif.to_netcdf(save_name_ACC)
    print('Save to {}'.format(save_name_ACC))
    
    save_name_RMSE = path_RMSE_save.format(verif_lead_range[0], verif_lead_range[-1], model_name)
    ds_ACC_verif.to_netcdf(save_name_RMSE)
    print('Save to {}'.format(save_name_RMSE))

## Check NaNs

In [None]:
# test on Q500 as an example

print(np.mean(np.array(VERIF['IFS_RMSE']['V500']), axis=0))
print(np.mean(np.array(VERIF['wxformer_RMSE']['V500']), axis=0))

print(np.mean(np.array(VERIF['IFS_ACC']['V500']), axis=0))
print(np.mean(np.array(VERIF['wxformer_ACC']['V500']), axis=0))

In [None]:
# # if see NaN, find its indices
for var_name in ['U500', 'V500', 'T500', 'Q500', 'Z500', 't2m', 'SP']:
    test = np.mean(np.array(VERIF['wxformer_RMSE']['V500']), axis=1)
    ind_found = np.argwhere(np.isnan(test))
    print(ind_found)

In [None]:
# # # if see abnormally large RMSE, find its indices, the file maybe corrupted
# for var_name in ['U500', 'V500', 'T500', 'Q500', 't2m', 'SP']:
#     test = np.mean(np.array(VERIF['wxformer_RMSE']['t2m']), axis=1)
#     ind_found = np.argwhere(test>50)
#     print(ind_found)

In [None]:
# forecast
ind_check = ind_found[0][0]
model_name = 'wxformer'
filename_OURS = sorted(glob(conf[model_name]['save_loc_gather']+'*.nc'))

# pick years
year_range = conf[model_name]['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_OURS = [fn for fn in filename_OURS if any(year in fn for year in years_pick)]
print('bad file: {}'.format(filename_OURS[412]))

## Get ready for data visualizaiton

In [None]:
model_names = ['wxformer', 'IFS']
varnames_plot = ['U500', 'V500', 'T500', 'Q500', 'Z500', 't2m', 'SP']

PLOT_data = {}

for var in varnames_plot:
    for model_name in model_names:
        np_RMSE = np.array(VERIF['{}_RMSE'.format(model_name)][var])
        np_ACC = np.array(VERIF['{}_ACC'.format(model_name)][var])

        # mean scores
        PLOT_data['RMSE_{}_{}_mean'.format(model_name, var)] = np.nanmean(np_RMSE, axis=0)
        PLOT_data['ACC_{}_{}_mean'.format(model_name, var)] = np.nanmean(np_ACC, axis=0)

        # 95th CIs
        PLOT_data['RMSE_{}_{}_95p'.format(model_name, var)] = np.nanquantile(np_RMSE, 0.95, axis=0)
        PLOT_data['ACC_{}_{}_95p'.format(model_name, var)] = np.nanquantile(np_ACC, 0.95, axis=0)

        PLOT_data['RMSE_{}_{}_05p'.format(model_name, var)] = np.nanquantile(np_RMSE, 0.05, axis=0)
        PLOT_data['ACC_{}_{}_05p'.format(model_name, var)] = np.nanquantile(np_ACC, 0.05, axis=0)

# Save
np.save('/glade/derecho/scratch/ksha/CREDIT/verif/PLOT_data/scores_CREDIT_arXiv_2024.npy', PLOT_data)