# Calculate RMSE

In [1]:
import os
import sys
import yaml
from glob import glob
from datetime import datetime, timedelta

import numpy as np
import xarray as xr

In [2]:
sys.path.insert(0, os.path.realpath('../libs/'))
import verif_utils as vu

In [3]:
config_name = os.path.realpath('verif_config_6h.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [4]:
model_name = 'wxformer'
lead_range = conf[model_name]['lead_range']
verif_lead_range = conf[model_name]['verif_lead_range']

leads_exist = list(np.arange(lead_range[0], lead_range[-1]+lead_range[0], lead_range[0]))
leads_verif = list(np.arange(verif_lead_range[0], verif_lead_range[-1]+verif_lead_range[0], verif_lead_range[0]))
ind_lead = vu.lead_to_index(leads_exist, leads_verif)

print('Verifying lead times: {}'.format(leads_verif))
print('Verifying lead indices: {}'.format(ind_lead))

Verifying lead times: [6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114, 120, 126, 132, 138, 144, 150, 156, 162, 168, 174, 180, 186, 192, 198, 204, 210, 216, 222, 228, 234, 240]
Verifying lead indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]


In [5]:
verif_ind_start = 360*2; verif_ind_end = 366*2
path_verif = conf[model_name]['save_loc_verif']+'combined_rmse_{:04d}_{:04d}_{:03d}h_{:03d}h_{}.nc'.format(
                                                                                            verif_ind_start, 
                                                                                            verif_ind_end,
                                                                                            verif_lead_range[0],
                                                                                            verif_lead_range[-1],
                                                                                            model_name)

## Verification setup

The following inputs are needed:

(1) range of indices

(2) required lead times to verify

(3) save location

In [6]:
# ---------------------------------------------------------------------------------------- #
# ERA5 verif target
filename_ERA5 = sorted(glob(conf['ERA5_ours']['save_loc']))

# pick years
year_range = conf['ERA5_ours']['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_ERA5 = [fn for fn in filename_ERA5 if any(year in fn for year in years_pick)]

# merge yearly ERA5 as one
ds_ERA5 = [vu.get_forward_data(fn) for fn in filename_ERA5]
ds_ERA5_merge = xr.concat(ds_ERA5, dim='time')
    
# Select the specified variables and their levels
variables_levels = conf['ERA5_ours']['verif_variables']

# subset merged ERA5 and unify coord names
ds_ERA5_merge = vu.ds_subset_everything(ds_ERA5_merge, variables_levels)
ds_ERA5_merge = ds_ERA5_merge.rename({'latitude':'lat','longitude':'lon'})

# ---------------------------------------------------------------------------------------- #
# forecast
filename_OURS = sorted(glob(conf[model_name]['save_loc_gather']+'*.nc'))

# pick years
year_range = conf[model_name]['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_OURS = [fn for fn in filename_OURS if any(year in fn for year in years_pick)]
# filename_OURS = [fn for fn in filename_OURS if '00Z' in fn]

L_max = len(filename_OURS)
assert verif_ind_end <= L_max, f'verified indices (days) exceeds the max index available, max is {L_max}'

filename_OURS = filename_OURS[verif_ind_start:verif_ind_end]

# latitude weighting
lat = xr.open_dataset(filename_OURS[0])["lat"]
w_lat = np.cos(np.deg2rad(lat))
w_lat = w_lat / w_lat.mean()

In [8]:
# some of the forecast files have lat/lon as masked arrays
# this may result-in a mismatch between the weatherbench clim (lat/lon arrays) and the fcst (masked arrays)
# it only happens to some CREDIT rollouts but will be applied to IFS as well
OURS_dataset = xr.open_dataset(conf['geo']['geo_file_nc'])
x_OURS = np.array(OURS_dataset['longitude'])
y_OURS = np.array(OURS_dataset['latitude'])

In [9]:
# ---------------------------------------------------------------------------------------- #
# RMSE compute
verif_results = []

for fn_ours in filename_OURS:
    ds_ours = xr.open_dataset(fn_ours)
    ds_ours = vu.ds_subset_everything(ds_ours, variables_levels)
    ds_ours = ds_ours.isel(time=ind_lead)

    # ============================== #
    # resolve the masked array issue
    ds_ours['lon'] = x_OURS
    ds_ours['lat'] = y_OURS
    # ============================== #
    ds_ours = ds_ours.compute()
    
    ds_target = ds_ERA5_merge.sel(time=ds_ours['time']).compute()

    # ds_ours = ds_ours.isel(time=slice(1, 7))
    # ds_target = ds_target.isel(time=slice(0, 6))
    # ds_target['time'] = ds_ours['time']
    
    # RMSE with latitude-based cosine weighting (check w_lat)
    RMSE = np.sqrt(
        (w_lat * (ds_ours - ds_target)**2).mean(['lat', 'lon'])
    )
    
    verif_results.append(RMSE.drop_vars('time'))

    #print('Completedd: {}'.format(fn_ours))
    
# Combine verif results
ds_verif = xr.concat(verif_results, dim='days')

# Save the combined dataset
print('Save to {}'.format(path_verif))
#ds_verif.to_netcdf(path_verif)

Save to /glade/derecho/scratch/ksha/CREDIT/VERIF/verif_6h/wxformer_6h/combined_rmse_0720_0732_006h_240h_wxformer.nc


### Print out some metrics

In [11]:
# for varname in ['SP','t2m','V500','U500','T500','Z500','Q500']:
#     rmse = np.array(ds_verif[varname]).mean(axis=0)
#     print(','.join(map(str, rmse)))