# Calculate ACC

In [1]:
import os
import sys
import yaml
from glob import glob
from datetime import datetime, timedelta

import numpy as np
import xarray as xr

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
sys.path.insert(0, os.path.realpath('../libs/'))
import verif_utils as vu

In [4]:
config_name = os.path.realpath('verif_config.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [5]:
model_name = 'fuxi_dry'
lead_range = conf[model_name]['lead_range']
verif_lead_range = conf[model_name]['verif_lead_range']

leads_exist = list(np.arange(lead_range[0], lead_range[-1]+lead_range[0], lead_range[0]))
leads_verif = list(np.arange(verif_lead_range[0], verif_lead_range[-1]+verif_lead_range[0], verif_lead_range[0]))
ind_lead = vu.lead_to_index(leads_exist, leads_verif)

print('Verifying lead times: {}'.format(leads_verif))
print('Verifying lead indices: {}'.format(ind_lead))

Verifying lead times: [6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114, 120, 126, 132, 138, 144, 150, 156, 162, 168, 174, 180, 186, 192, 198, 204, 210, 216, 222, 228, 234, 240, 246, 252, 258, 264, 270, 276, 282, 288, 294, 300, 306, 312, 318, 324, 330, 336, 342, 348, 354, 360]
Verifying lead indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]


In [6]:
verif_ind_start = 0; verif_ind_end = 3
path_verif = conf[model_name]['save_loc_verif']+'combined_acc_{}_{}_{}h_{}h_{}.nc'.format(verif_ind_start, 
                                                                                          verif_ind_end,
                                                                                          verif_lead_range[0],
                                                                                          verif_lead_range[-1],
                                                                                          model_name)

In [7]:
path_verif

'/glade/campaign/cisl/aiml/ksha/CREDIT_physics/VERIF/fuxi_plevel_dry/combined_acc_0_3_6h_360h_fuxi_dry.nc'

## Functions

In [27]:
def sp_avg(DS, wlat):
    return DS.weighted(wlat).mean(['latitude', 'longitude'], skipna=True)

In [10]:
# ERA5 climatology info
ERA5_path_string = conf['ERA5_weatherbench']['save_loc_clim'] + 'ERA5_clim_1990_2019_6h_1deg_interp.nc'
ds_ERA5_clim = xr.open_dataset(ERA5_path_string)

In [23]:
rename_IFS_to_ERA5 = {
    '10m_u_component_of_wind': 'VAR_10U',
    '10m_v_component_of_wind': 'VAR_10V',
    '2m_temperature': 'VAR_2T',
    'geopotential': 'Z',
    'mean_sea_level_pressure': 'MSL',
    'specific_humidity': 'Q',
    'surface_pressure': 'SP',
    'temperature': 'T',
    'u_component_of_wind': 'U',
    'v_component_of_wind': 'V'
}

varname_verif = ['MSL', 'Q', 'T', 'U', 'V', 'VAR_2T', 'Z']
level_pick = np.array([  50,  150,  200,  250,  300,  400,  500,  600,  700,  850, 925, 1000])

In [24]:
# ---------------------------------------------------------------------------------------- #
# ERA5 verif target
filename_ERA5 = sorted(glob(conf['ERA5_ours']['save_loc']))

# pick years
year_range = conf['ERA5_ours']['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_ERA5 = [fn for fn in filename_ERA5 if any(year in fn for year in years_pick)]

# merge yearly ERA5 as one
ds_ERA5 = [vu.get_forward_data(fn) for fn in filename_ERA5]
ds_ERA5_merge = xr.concat(ds_ERA5, dim='time')

In [25]:
# ---------------------------------------------------------------------------------------- #
# forecast
filename_OURS = sorted(glob(conf[model_name]['save_loc_gather']+'*.nc'))

# pick years
year_range = conf[model_name]['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_OURS = [fn for fn in filename_OURS if any(year in fn for year in years_pick)]

L_max = len(filename_OURS)
assert verif_ind_end <= L_max, 'verified indices (days) exceeds the max index available'

filename_OURS = filename_OURS[verif_ind_start:verif_ind_end]

# latitude weighting
lat = xr.open_dataset(filename_OURS[0])["latitude"]
w_lat = np.cos(np.deg2rad(lat))
w_lat = w_lat / w_lat.mean()

In [42]:
acc_results = []

for fn_ours in filename_OURS:
    ds_ours = xr.open_dataset(fn_ours)
    ds_ours = ds_ours.isel(time=ind_lead)
    dayofyear = ds_ours['time.dayofyear']
    hourofday = ds_ours['time'].dt.hour
    ds_ours = ds_ours.compute()
    
    # --------------------------------------------------------------- #
    # get ERA5 verification target
    ds_target = ds_ERA5_merge.sel(time=ds_ours['time']).compute()
    ds_target = ds_target[varname_verif]
    
    # update level coord vals based on target
    ds_ours['level'] = ds_target['level']
    
    # --------------------------------------------------------------- #
    # get ERA5 climatology
    ds_clim_target = ds_ERA5_clim.sel(dayofyear=dayofyear, hour=hourofday).compute()
    ds_clim_target = ds_clim_target.rename(rename_IFS_to_ERA5)
    ds_clim_target = ds_clim_target[varname_verif]

    ds_ours = ds_ours.sel(level=level_pick)
    ds_target = ds_target.sel(level=level_pick)
    ds_clim_target = ds_clim_target.sel(level=level_pick)
    
    # ========================================== #
    # ERA5 anomaly
    ds_anomaly_ERA5 = ds_target - ds_clim_target

    # fcst anomaly
    ds_anomaly_OURS = ds_ours - ds_clim_target
    
    # ========================================== #
    # anmalies --> ACC with latitude-based cosine weights (check sp_avg and w_lat)
    top = sp_avg(ds_anomaly_OURS*ds_anomaly_ERA5, w_lat)
    
    bottom = np.sqrt(
        sp_avg(ds_anomaly_OURS**2, w_lat) * sp_avg(ds_anomaly_ERA5**2, w_lat))
                
    acc_results.append((top/bottom).drop_vars('time'))
    
    print('ACC completed: {}'.format(fn_ours))
    
# Combine ACC results
ds_acc = xr.concat(acc_results, dim='days')

# Save
print('Save to {}'.format(path_verif))
#ds_acc.to_netcdf(path_verif)

ACC completed: /glade/campaign/cisl/aiml/ksha/CREDIT_physics/GATHER/fuxi_plevel_dry/2020-01-01T00Z.nc
ACC completed: /glade/campaign/cisl/aiml/ksha/CREDIT_physics/GATHER/fuxi_plevel_dry/2020-01-01T12Z.nc
ACC completed: /glade/campaign/cisl/aiml/ksha/CREDIT_physics/GATHER/fuxi_plevel_dry/2020-01-02T00Z.nc
Save to /glade/campaign/cisl/aiml/ksha/CREDIT_physics/VERIF/fuxi_plevel_dry/combined_acc_0_3_6h_360h_fuxi_dry.nc


In [47]:
# plt.plot(ds_acc['Z'].values[:, :, 6].mean(axis=0))