# Calculate RMSE

In [1]:
import os
import sys
import yaml
from glob import glob
from datetime import datetime, timedelta

import numpy as np
import xarray as xr

In [2]:
sys.path.insert(0, os.path.realpath('../libs/'))
import verif_utils as vu

In [3]:
config_name = os.path.realpath('verif_config_6h.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [4]:
model_name = 'fuxi'
lead_range = conf[model_name]['lead_range']
verif_lead_range = conf[model_name]['verif_lead_range']

leads_exist = list(np.arange(lead_range[0], lead_range[-1]+lead_range[0], lead_range[0]))
leads_verif = list(np.arange(verif_lead_range[0], verif_lead_range[-1]+verif_lead_range[0], verif_lead_range[0]))
ind_lead = vu.lead_to_index(leads_exist, leads_verif)

print('Verifying lead times: {}'.format(leads_verif))
print('Verifying lead indices: {}'.format(ind_lead))

Verifying lead times: [6, 12, 18, 24]
Verifying lead indices: [0, 1, 2, 3]


In [5]:
verif_ind_start = 0; verif_ind_end = 366*2
path_verif = conf[model_name]['save_loc_verif']+'combined_rmse_{:04d}_{:04d}_{:03d}h_{:03d}h_{}.nc'.format(
                                                                                            verif_ind_start, 
                                                                                            verif_ind_end,
                                                                                            verif_lead_range[0],
                                                                                            verif_lead_range[-1],
                                                                                            model_name)

## Verification setup

The following inputs are needed:

(1) range of indices

(2) required lead times to verify

(3) save location

In [6]:
# ---------------------------------------------------------------------------------------- #
# ERA5 verif target
filename_ERA5 = sorted(glob(conf['ERA5_ours']['save_loc']))

# pick years
year_range = conf['ERA5_ours']['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_ERA5 = [fn for fn in filename_ERA5 if any(year in fn for year in years_pick)]

# merge yearly ERA5 as one
ds_ERA5 = [vu.get_forward_data(fn) for fn in filename_ERA5]
ds_ERA5_merge = xr.concat(ds_ERA5, dim='time')
    
# Select the specified variables and their levels
variables_levels = conf['ERA5_ours']['verif_variables']

# subset merged ERA5 and unify coord names
ds_ERA5_merge = vu.ds_subset_everything(ds_ERA5_merge, variables_levels)
ds_ERA5_merge = ds_ERA5_merge.rename({'latitude':'lat','longitude':'lon'})

# ---------------------------------------------------------------------------------------- #
# forecast
filename_OURS = sorted(glob(conf[model_name]['save_loc_gather']+'*.nc'))

# pick years
year_range = conf[model_name]['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_OURS = [fn for fn in filename_OURS if any(year in fn for year in years_pick)]
# filename_OURS = [fn for fn in filename_OURS if '00Z' in fn]

L_max = len(filename_OURS)
assert verif_ind_end <= L_max, 'verified indices (days) exceeds the max index available'

filename_OURS = filename_OURS[verif_ind_start:verif_ind_end]

# latitude weighting
lat = xr.open_dataset(filename_OURS[0])["lat"]
w_lat = np.cos(np.deg2rad(lat))
w_lat = w_lat / w_lat.mean()

In [7]:
# some of the forecast files have lat/lon as masked arrays
# this may result-in a mismatch between the weatherbench clim (lat/lon arrays) and the fcst (masked arrays)
# it only happens to some CREDIT rollouts but will be applied to IFS as well
OURS_dataset = xr.open_dataset(conf['geo']['geo_file_nc'])
x_OURS = np.array(OURS_dataset['longitude'])
y_OURS = np.array(OURS_dataset['latitude'])

In [8]:
# ---------------------------------------------------------------------------------------- #
# RMSE compute
verif_results = []

for fn_ours in filename_OURS:
    ds_ours = xr.open_dataset(fn_ours)
    ds_ours = vu.ds_subset_everything(ds_ours, variables_levels)
    ds_ours = ds_ours.isel(time=ind_lead)

    # ============================== #
    # resolve the masked array issue
    ds_ours['lon'] = x_OURS
    ds_ours['lat'] = y_OURS
    # ============================== #
    ds_ours = ds_ours.compute()
    
    ds_target = ds_ERA5_merge.sel(time=ds_ours['time']).compute()

    # ds_ours = ds_ours.isel(time=slice(1, 4))
    # ds_target = ds_target.isel(time=slice(0, 3))
    # ds_target['time'] = ds_ours['time']
    
    # RMSE with latitude-based cosine weighting (check w_lat)
    RMSE = np.sqrt(
        (w_lat * (ds_ours - ds_target)**2).mean(['lat', 'lon'])
    )
    
    verif_results.append(RMSE.drop_vars('time'))

    #print('Completedd: {}'.format(fn_ours))
    
# Combine verif results
ds_verif = xr.concat(verif_results, dim='days')

# Save the combined dataset
print('Save to {}'.format(path_verif))
#ds_verif.to_netcdf(path_verif)

Save to /glade/derecho/scratch/ksha/CREDIT/VERIF/verif_6h/fuxi_6h/combined_rmse_0000_0732_006h_024h_fuxi.nc


## Old verification results (single-step, residual norm)

In [9]:
for varname in ['SP','t2m','V500','U500','T500','Z500','Q500']:
    rmse = np.array(ds_verif[varname]).mean(axis=0)
    print('{},{},{},{}'.format(rmse[0], rmse[1], rmse[2], rmse[3]))

44.89194647235368,124.12228642976429,167.88251456538043,267.6491895646606
0.6021144938259202,3.033289150523198,5.52782525655867,5.95657405501232
0.8679783348843955,1.386355206606685,1.6430430881838236,2.0499758110055555
0.8462169885676298,1.3822952134342896,1.5970537931084605,2.071430997386118
0.2617489141229444,0.45643758673410567,0.6895272985499132,0.8724526957762478
24.08903710016006,84.98020261516959,90.91937449171202,152.06612912237242
0.00014816306476049917,0.0002611510997947689,0.00034045483322229754,0.00044690745538375986


In [9]:
for varname in ['SP','t2m','V500','U500','T500','Z500','Q500']:
    rmse = np.array(ds_verif[varname]).mean(axis=0)
    print('{},{},{},{}'.format(rmse[0], rmse[1], rmse[2], rmse[3]))

46.33947626060378,127.01140203726739,170.1665673369648,273.10359272513153
0.6037021332255142,3.0321878637458446,5.52590204041381,5.943812610244141
0.871669062616841,1.3919048154836826,1.6507573208734623,2.0612509418034977
0.8502413675393385,1.3877147086279689,1.6082647749530479,2.086214078449288
0.26315508302165713,0.4587248051943241,0.6951465421046444,0.8810033104018331
24.892314914048402,87.31859187904014,93.21636212166541,156.0275666504848
0.0001484549276142416,0.00026144736374523976,0.00034103279487856505,0.0004475836960581638


## Old verification results (single-step, without residual norm)

In [10]:
for varname in ['SP','t2m','V500','U500','T500','Z500','Q500']:
    print(np.array(ds_verif[varname]).mean(axis=0))

[184.2498327  271.33226649 347.06797096 567.51448427]
[0.7371168  3.15104505 5.50106412 5.67369036]
[0.77032073 1.35613833 1.65020844 2.07473593]
[0.78886652 1.36594245 1.61001187 2.11477036]
[0.28820774 0.47826239 0.70598127 0.87292093]
[ 35.47029769 130.63376207 117.17512662 227.15196566]
[0.00013844 0.00026169 0.00034507 0.00045379]


array([ 36.00949756, 132.96644512, 120.6203009 , 229.15796213])

In [11]:
np.array(ds_verif['Q500']).mean(axis=0)

array([0.00013877, 0.00026235, 0.00034532, 0.00045397])

In [12]:
np.array(ds_verif['t2m']).mean(axis=0)

array([0.73620854, 3.13571614, 5.46001092, 5.61953639])