# Zonal energy spectrum

In [1]:
import os
import sys
import yaml
from glob import glob
from datetime import datetime, timedelta

import numpy as np
import xarray as xr

In [2]:
sys.path.insert(0, os.path.realpath('../libs/'))
import verif_utils as vu
import score_utils as su

In [3]:
config_name = os.path.realpath('verif_config.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [4]:
model_name = 'fuxi_dry'
lead_range = conf[model_name]['lead_range']
verif_lead_range = conf[model_name]['verif_lead_range']

leads_exist = list(np.arange(lead_range[0], lead_range[-1]+lead_range[0], lead_range[0]))
leads_verif = [24, 120, 240]
#list(np.arange(verif_lead_range[0], verif_lead_range[-1]+verif_lead_range[0], verif_lead_range[0]))
ind_lead = vu.lead_to_index(leads_exist, leads_verif)

print('Verifying lead times: {}'.format(leads_verif))
print('Verifying lead indices: {}'.format(ind_lead))

Verifying lead times: [24, 120, 240]
Verifying lead indices: [3, 19, 39]


In [5]:
verif_ind_start = 0; verif_ind_end = 3

path_verif = conf[model_name]['save_loc_verif']+'combined_zes_{}_{}_{}'.format(
    verif_ind_start, verif_ind_end, model_name)

## Gather forecasts

In [6]:
# ---------------------------------------------------------------------------------------- #
# forecast
filename_OURS = sorted(glob(conf[model_name]['save_loc_gather']+'*.nc'))

# pick years
year_range = conf[model_name]['year_range']
years_pick = np.arange(year_range[0], year_range[1]+1, 1).astype(str)
filename_OURS = [fn for fn in filename_OURS if any(year in fn for year in years_pick)]

L_max = len(filename_OURS)
assert verif_ind_end <= L_max, 'verified indices (days) exceeds the max index available'

filename_OURS = filename_OURS[verif_ind_start:verif_ind_end]

In [11]:
variables_levels = {
    'T': [500,],
    'U': [500],
    'V': [500]
}

levels = np.array([   1.,   50.,  150.,  200.,  250.,  
                    300.,  400.,  500.,  600.,  700., 
                     850.,  925., 1000.])

In [20]:
# CP_DRY = 1005 # J/kg K
# CP_VAPOR = 1846 # J/kg K

for i, ind_pick in enumerate(ind_lead):
    # allocate result for the current lead time
    verif_results = []
    
    for fn_ours in filename_OURS:
        ds_ours = xr.open_dataset(fn_ours)
        ds_ours['level'] = levels
        ds_ours = vu.ds_subset_everything(ds_ours, variables_levels)
        ds_ours = ds_ours.isel(time=ind_pick)
        ds_ours = ds_ours.compute()
        
        # -------------------------------------------------------------- #
        # potential temperature
        ds_ours['theta'] = ds_ours['T'] * (1000/500)**(287.0/1004)

        # -------------------------------------------------------------- #
        zes_temp = []
        for var in ['U', 'V', 'theta']:
            zes = su.zonal_energy_spectrum_sph(ds_ours.isel(latitude=slice(1, None)), var)
            zes_temp.append(zes)
            
        verif_results.append(xr.merge(zes_temp))
    
    ds_verif = xr.concat(verif_results, dim='time')
    save_name = path_verif+'_lead{}.nc'.format(leads_verif[i])
    #ds_verif.to_netcdf(save_name)

In [11]:
# # weatherbench2 alternative
# from weatherbench2.derived_variables import ZonalEnergySpectrum
# calc = ZonalEnergySpectrum('Z500')
# result_fft = calc.compute(ds_ours)
# result_fft = np.mean(np.array(result_fft), axis=0)
# result_fft = result_fft[:320]