# Calculate RMSE

In [1]:
import os
import sys
import yaml
from glob import glob
from datetime import datetime, timedelta

import numpy as np
import xarray as xr

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
sys.path.insert(0, os.path.realpath('../libs/'))
import verif_utils as vu

In [6]:
config_name = os.path.realpath('verif_config.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [7]:
model_name = 'fuxi_physics'
lead_range = conf[model_name]['lead_range']
verif_lead_range = conf[model_name]['verif_lead_range']

leads_exist = list(np.arange(lead_range[0], lead_range[-1]+lead_range[0], lead_range[0]))
leads_verif = list(np.arange(verif_lead_range[0], verif_lead_range[-1]+verif_lead_range[0], verif_lead_range[0]))
ind_lead = vu.lead_to_index(leads_exist, leads_verif)

print('Verifying lead times: {}'.format(leads_verif))
print('Verifying lead indices: {}'.format(ind_lead))

Verifying lead times: [6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114, 120, 126, 132, 138, 144, 150, 156, 162, 168, 174, 180, 186, 192, 198, 204, 210, 216, 222, 228, 234, 240]
Verifying lead indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]


In [8]:
verif_ind_start = 0; verif_ind_end = 4
path_verif = conf[model_name]['save_loc_verif']+'combined_rmse_{:04d}_{:04d}_{:03d}h_{:03d}h_{}.nc'.format(
                                                                                            verif_ind_start, 
                                                                                            verif_ind_end,
                                                                                            verif_lead_range[0],
                                                                                            verif_lead_range[-1],
                                                                                            model_name)

## Verification setup

In [9]:
# ---------------------------------------------------------------------------------------- #
# ERA5 verif target
filename_ERA5 = sorted(glob(conf['ERA5_ours']['save_loc']))

# pick years
year_range = conf['ERA5_ours']['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_ERA5 = [fn for fn in filename_ERA5 if any(year in fn for year in years_pick)]

# merge yearly ERA5 as one
ds_ERA5 = [vu.get_forward_data(fn) for fn in filename_ERA5]
ds_ERA5_merge = xr.concat(ds_ERA5, dim='time')
    
# Select the specified variables and their levels
variables_levels = conf['ERA5_ours']['verif_variables']

# subset merged ERA5 and unify coord names
# ds_ERA5_merge = vu.ds_subset_everything(ds_ERA5_merge, variables_levels)

# ---------------------------------------------------------------------------------------- #
# forecast
filename_OURS = sorted(glob(conf[model_name]['save_loc_gather']+'*.nc'))

# pick years
year_range = conf[model_name]['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_OURS = [fn for fn in filename_OURS if any(year in fn for year in years_pick)]
# filename_OURS = [fn for fn in filename_OURS if '00Z' in fn]

L_max = len(filename_OURS)
assert verif_ind_end <= L_max, 'verified indices (days) exceeds the max index available'

filename_OURS = filename_OURS[verif_ind_start:verif_ind_end]

# latitude weighting
lat = xr.open_dataset(filename_OURS[0])["latitude"]
w_lat = np.cos(np.deg2rad(lat))
w_lat = w_lat / w_lat.mean()

In [8]:
# ---------------------------------------------------------------------------------------- #
# RMSE compute
verif_results = []

for fn_ours in filename_OURS:
    ds_ours = xr.open_dataset(fn_ours)
    #ds_ours = vu.ds_subset_everything(ds_ours, variables_levels)
    ds_ours = ds_ours.isel(time=ind_lead)
    ds_ours = ds_ours.compute()
    
    ds_target = ds_ERA5_merge.sel(time=ds_ours['time']).compute()
    ds_ours['level'] = ds_target['level']
    
    # RMSE with latitude-based cosine weighting (check w_lat)
    RMSE = np.sqrt(
        (w_lat * (ds_ours - ds_target)**2).mean(['latitude', 'longitude'])
    )
    
    verif_results.append(RMSE.drop_vars('time'))

    #print('Completedd: {}'.format(fn_ours))
    
# Combine verif results
ds_verif = xr.concat(verif_results, dim='days')

# # Save the combined dataset
# print('Save to {}'.format(path_verif))
# #ds_verif.to_netcdf(path_verif)

In [9]:
# tp = ds_ours['total_precipitation'].values
# tp_target = ds_target['total_precipitation'].values

In [10]:
RMSE_tp = ds_verif['total_precipitation'].values
RMSE_tp.mean(axis=0)

array([0.00108566, 0.00117439, 0.00124403, 0.00147704, 0.00149676,
       0.00169668, 0.00161795, 0.00181048, 0.00173478, 0.00183029,
       0.00179609, 0.00185533, 0.00178906, 0.00183535, 0.00183919,
       0.00184217, 0.00182637, 0.0019142 , 0.00185248, 0.00200094,
       0.00188207, 0.00204301, 0.00190229, 0.0020946 , 0.00193883,
       0.00211515, 0.0020048 , 0.00216498, 0.00205543, 0.00221879,
       0.00208395, 0.0022174 , 0.0020997 , 0.00229208, 0.00212421,
       0.00226463, 0.00215945, 0.0023003 , 0.00215761, 0.00234431,
       0.00226485, 0.00236864, 0.00230108, 0.002387  , 0.00231282,
       0.00247382, 0.00240186, 0.00249212, 0.002368  , 0.00248036,
       0.0023573 , 0.00251683, 0.00236785, 0.00244664, 0.00236935,
       0.00248309, 0.0023997 , 0.00253602, 0.00249012, 0.00257913,
       0.00253316, 0.00257705, 0.00254187, 0.00259573, 0.00262199,
       0.00265515, 0.00263887, 0.00265584, 0.00266816, 0.00269479,
       0.00266057, 0.00269963, 0.00273437, 0.00276045, 0.00275

In [11]:
RMSE_Z500 = ds_verif['Z'].isel(level=21).values
RMSE_Z500.mean(axis=0)

array([  46.755135,   66.68367 ,   81.49125 ,  101.2041  ,  113.86123 ,
        132.82985 ,  146.0913  ,  164.80855 ,  179.7341  ,  200.20848 ,
        214.91391 ,  236.8743  ,  255.96068 ,  282.3226  ,  305.16238 ,
        334.54355 ,  359.03735 ,  386.3771  ,  408.479   ,  432.4906  ,
        451.5495  ,  470.99594 ,  487.3045  ,  506.58298 ,  522.7719  ,
        540.4766  ,  554.9831  ,  568.5557  ,  578.5769  ,  589.27576 ,
        599.3127  ,  613.79016 ,  627.5905  ,  644.2897  ,  661.7292  ,
        682.3252  ,  704.3015  ,  723.5621  ,  741.38416 ,  755.0564  ,
        764.9993  ,  770.1166  ,  774.11566 ,  778.45557 ,  780.6873  ,
        784.2377  ,  786.92896 ,  794.9525  ,  803.8307  ,  817.00476 ,
        829.4291  ,  845.7548  ,  859.5417  ,  878.0012  ,  894.0211  ,
        915.6959  ,  932.53955 ,  953.66736 ,  972.3029  ,  994.43823 ,
       1013.51935 , 1036.9778  , 1059.0139  , 1084.0133  , 1101.3414  ,
       1118.8447  , 1132.0436  , 1146.1436  , 1155.6957  , 1166.

In [12]:
RMSE_T500 = ds_verif['T'].isel(level=21).values
RMSE_T500.mean(axis=0)

array([0.37150496, 0.46372545, 0.5121753 , 0.58744705, 0.6403605 ,
       0.7233365 , 0.7928848 , 0.87714374, 0.9537309 , 1.058413  ,
       1.1533767 , 1.2746701 , 1.3706552 , 1.4897425 , 1.5767485 ,
       1.6824512 , 1.7721217 , 1.8679146 , 1.9413195 , 2.0092852 ,
       2.0705462 , 2.1117847 , 2.1675844 , 2.2443829 , 2.332168  ,
       2.432778  , 2.5427766 , 2.6469738 , 2.7453911 , 2.8208742 ,
       2.8928537 , 2.9529312 , 3.0141037 , 3.0746179 , 3.1402302 ,
       3.206091  , 3.282396  , 3.3333354 , 3.3755972 , 3.391325  ,
       3.4072065 , 3.393458  , 3.3960528 , 3.3854766 , 3.4002023 ,
       3.42171   , 3.4540842 , 3.4890072 , 3.5292947 , 3.5827444 ,
       3.630941  , 3.6731644 , 3.711898  , 3.767492  , 3.8324466 ,
       3.8956347 , 3.9736447 , 4.0525246 , 4.1177077 , 4.1585307 ,
       4.2033453 , 4.265438  , 4.3597355 , 4.4572945 , 4.559969  ,
       4.625142  , 4.6771317 , 4.6967993 , 4.704012  , 4.6897907 ,
       4.66107   , 4.6375513 , 4.627992  , 4.618984  , 4.60809

In [13]:
RMSE_Q800 = ds_verif['specific_total_water'].isel(level=36).values
RMSE_Q800.mean(axis=0)

array([0.00028761, 0.00035484, 0.0003927 , 0.00042877, 0.00045652,
       0.00049614, 0.0005257 , 0.00055728, 0.00058514, 0.0006201 ,
       0.00065256, 0.00067799, 0.00070347, 0.00072426, 0.00074979,
       0.00076781, 0.00080479, 0.00082675, 0.0008618 , 0.0008798 ,
       0.000907  , 0.00092002, 0.00093861, 0.00096171, 0.00098981,
       0.00101572, 0.00103988, 0.0010651 , 0.00108993, 0.0011178 ,
       0.00115328, 0.00117673, 0.00120026, 0.00121786, 0.00125116,
       0.00126977, 0.00130146, 0.00131847, 0.00133979, 0.00135691,
       0.00139177, 0.00141544, 0.00145336, 0.00148518, 0.00153151,
       0.00155485, 0.0015856 , 0.00160872, 0.00163659, 0.00166033,
       0.0017112 , 0.00174258, 0.0017801 , 0.00180401, 0.00185714,
       0.00186309, 0.0018871 , 0.00188131, 0.00189975, 0.00186359,
       0.00186357, 0.00182899, 0.00183896, 0.00182073, 0.00184024,
       0.00184122, 0.00187849, 0.00189816, 0.0019318 , 0.0019536 ,
       0.00198697, 0.00200721, 0.00202996, 0.00204714, 0.00206