# Combining verification results on individual days to a single file

In [None]:
import os
import sys
import yaml
import argparse
from glob import glob
from datetime import datetime, timedelta

import numpy as np
import xarray as xr

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
sys.path.insert(0, os.path.realpath('../libs/'))
import verif_utils as vu

In [None]:
config_name = os.path.realpath('verif_config.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

## Get metrics

In [None]:
model_names = ['wxformer', 'IFS']
VERIF = {} # the dict that collects all result for this notebook

In [None]:
for model_name in model_names:
    # file names to load
    verif_lead_range = conf[model_name]['verif_lead_range']
    path_ACC_save = conf[model_name]['save_loc_verif']+'ACC_{:03d}h_{:03d}h_{}.nc'
    path_RMSE_save = conf[model_name]['save_loc_verif']+'RMSE_{:03d}h_{:03d}h_{}.nc'
    
    # save to one dictionary for some checking
    VERIF['{}_ACC'.format(model_name)] = xr.open_dataset(
        path_ACC_save.format(verif_lead_range[0], verif_lead_range[-1], model_name))
    VERIF['{}_RMSE'.format(model_name)] = xr.open_dataset(
        path_RMSE_save.format(verif_lead_range[0], verif_lead_range[-1], model_name))

### ERA5 target

In [None]:
# ---------------------------------------------------------------------------------------- #
# ERA5 verif target
filename_ERA5 = sorted(glob(conf['ERA5_ours']['save_loc']))

# pick years
year_range = conf['ERA5_ours']['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_ERA5 = [fn for fn in filename_ERA5 if any(year in fn for year in years_pick)]

# merge yearly ERA5 as one
ds_ERA5 = [vu.get_forward_data(fn) for fn in filename_ERA5]
ds_ERA5_merge = xr.concat(ds_ERA5, dim='time')
    
# Select the specified variables and their levels
variables_levels = conf['ERA5_ours']['verif_variables']

# subset merged ERA5 and unify coord names
ds_ERA5_merge = vu.ds_subset_everything(ds_ERA5_merge, variables_levels)
ds_ERA5_merge = ds_ERA5_merge.rename({'latitude':'lat','longitude':'lon'})

OURS_dataset = xr.open_dataset(conf['geo']['geo_file_nc'])
x_OURS = np.array(OURS_dataset['longitude'])
y_OURS = np.array(OURS_dataset['latitude'])

lon_OURS, lat_OURS = np.meshgrid(x_OURS, y_OURS)

## Check abnormal values

In [None]:
# if see abnormally large RMSE, find its indices, the file maybe corrupted
test = np.max(np.array(VERIF['wxformer_RMSE']['t2m']), axis=1)
ind_found = np.argwhere(test>10)
print(ind_found)

In [None]:
# forecast
ind_check = ind_found[0][0]
model_name = 'wxformer'
filename_OURS = sorted(glob(conf[model_name]['save_loc_gather']+'*.nc'))

# pick years
year_range = conf[model_name]['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_OURS = [fn for fn in filename_OURS if any(year in fn for year in years_pick)]
print('bad file: {}'.format(filename_OURS[ind_check]))

In [None]:
ds_ours = xr.open_dataset(filename_OURS[ind_check])
t2m_test = ds_ours['t2m'].isel(time=239)

ds_target = ds_ERA5_merge.sel(time=ds_ours['time']).compute()
t2m_ref = ds_target['t2m'].isel(time=239)

In [None]:
t2m_diff = np.array(t2m_test-t2m_ref)

In [None]:
plt.pcolormesh(lon_OURS, lat_OURS, t2m_diff, cmap=plt.cm.nipy_spectral)
plt.colorbar()
plt.title('Ours minus ERA5\n{}, day-10'.format(
    os.path.basename(filename_OURS[ind_check])), fontsize=14)

### RMSE histograms 

In [None]:
rmse_t2m = np.array(VERIF['wxformer_RMSE']['t2m'])

In [None]:
fig, AX = plt.subplots(2, 2)

AX[0][0].hist(rmse_t2m[:, 0:8].ravel(), bins=20, density=True);
AX[0][0].set_title('t2m, day 0-2')

AX[0][1].hist(rmse_t2m[:, 8:16].ravel(), bins=20, density=True);
AX[0][1].set_title('t2m, day 2-4')

AX[1][0].hist(rmse_t2m[:, 16:24].ravel(), bins=20, density=True);
AX[1][0].set_title('t2m, day 4-6')

AX[1][1].hist(rmse_t2m[:, 24:32].ravel(), bins=20, density=True);
AX[1][1].set_title('t2m, day 6-8')

plt.tight_layout()

## Forecast climatology on different lead times 

In [None]:
ds_OURS_list = []
for i_fn, fn in enumerate(filename_OURS[::2]):
    ds_t2m = vu.get_forward_data_netCDF4(filename_OURS[0])['t2m'].isel(time=np.arange(6, 240+6, 6)-1)
    ds_t2m = ds_t2m.assign_coords({'ini_time': i_fn})
    ds_OURS_list.append(ds_t2m)

In [None]:
ds_OURS_merge = xr.concat(ds_OURS_list, dim='ini_time')

In [None]:
ds_OURS_mean = ds_OURS_merge.mean(['ini_time'])
#ds_OURS_mean.to_netcdf('/glade/derecho/scratch/ksha/CREDIT/verif/wxformer_clim_lead.nc')

In [None]:
# # ERA5 climatology info
# ERA5_path_string = conf['ERA5_weatherbench']['save_loc_clim'] + 'ERA5_clim_1990_2019_6h_interp.nc'
# ds_ERA5_clim = xr.open_dataset(ERA5_path_string)
# ds_ERA5_t2m_clim = ds_ERA5_clim['t2m'].isel(hour=0).mean(['dayofyear'])

In [None]:
fig, AX = plt.subplots(2, 2, figsize=(11, 9))

AX[0][0].pcolormesh(lon_OURS, lat_OURS, ds_OURS_mean.isel(time=0), 
                    cmap=plt.cm.nipy_spectral, vmin=220, vmax=320)
AX[0][0].set_title('t2m, day 0')


AX[0][1].pcolormesh(lon_OURS, lat_OURS, ds_OURS_mean.isel(time=12), 
                    cmap=plt.cm.nipy_spectral, vmin=220, vmax=320)
AX[0][1].set_title('t2m, day 3')


AX[1][0].pcolormesh(lon_OURS, lat_OURS, ds_OURS_mean.isel(time=24), 
                    cmap=plt.cm.nipy_spectral, vmin=220, vmax=320)
AX[1][0].set_title('t2m, day 6')


AX[1][1].pcolormesh(lon_OURS, lat_OURS, ds_OURS_mean.isel(time=36), 
                    cmap=plt.cm.nipy_spectral, vmin=220, vmax=320)
AX[1][1].set_title('t2m, day 9')

plt.tight_layout()