# Create resaidual norm files for CREDIT

In [None]:
import os
import yaml
import copy
import numpy as np
import xarray as xr

In [None]:
from scipy.stats import gmean

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

## File creation

### 6 hourly mean std files

In [None]:
# get variable information from data_preprocessing/config
config_name = os.path.realpath('data_config_6h.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [None]:
N_levels = 15

In [None]:
# get variable names
varnames = list(conf['residual'].keys())
varnames = varnames[:-5] # remove save_loc and others

varname_surf = list(set(varnames) - set(['U', 'V', 'T', 'Q']))
varname_upper = ['U', 'V', 'T', 'Q']

# collect computed mean and variance values
# See "qsub_STEP01_compute_mean_std.ipynb"
MEAN_values = {}
STD_values = {}

for varname in varname_surf:
    save_name = conf['residual']['save_loc'] + '{}_mean_std_{}.npy'.format(conf['residual']['prefix'], varname)
    mean_std = np.load(save_name)
    MEAN_values[varname] = mean_std[0]
    STD_values[varname] = mean_std[1]

for varname in varname_upper:

    # -------------------------------------------- #
    # allocate all levels
    mean_std_all_levels = np.empty((2, N_levels))
    mean_std_all_levels[...] = np.nan
    
    for i_level in range(15):
        save_name = conf['residual']['save_loc'] + '{}_level{}_mean_std_{}.npy'.format(conf['residual']['prefix'], i_level+1, varname)
        mean_std = np.load(save_name)
        mean_std_all_levels[:, i_level] = mean_std

    # -------------------------------------------- #
    # save
    MEAN_values[varname] = np.copy(mean_std_all_levels[0, :])
    STD_values[varname] = np.copy(mean_std_all_levels[1, :])

In [None]:
std_val_all = list(STD_values.values())
std_val_surf = np.array(std_val_all[:len(varname_surf)])
std_val_upper = std_val_all[len(varname_surf):]

In [None]:
std_concat = np.concatenate([std_val_surf]+ std_val_upper)
std_g = gmean(np.sqrt(std_concat))

In [None]:
# ------------------------------------------------------- #
# create xr.DataArray for std
# Initialize level coord
level = np.array([ 10,  30,  40,  50,  60,  70,  80,  90,  95, 100, 105, 110, 120, 130, 136])

ds_std_6h = xr.Dataset(coords={"level": level})

for varname, data in STD_values.items():
    data = np.sqrt(data) / std_g # <--- var to std and divided by std_g
    if len(data.shape) == 1:
        data_array = xr.DataArray(
            data,
            dims=["level",],
            coords={"level": level},
            name=varname,
        )
        ds_std_6h[varname] = data_array
    else:
        data_array = xr.DataArray(
            data,
            name=varname,
        )
        ds_std_6h[varname] = data_array

In [None]:
# ds_std_6h.to_netcdf('/glade/campaign/cisl/aiml/ksha/CREDIT/residual_6h_1979_2018_0.25deg.nc')

In [None]:
# ------------------------------------------------------- #
# Compare with my old ones
my_std = xr.open_dataset('/glade/campaign/cisl/aiml/ksha/CREDIT/residual_6h_1979_2018_0.25deg.nc')
#DJ_std = xr.open_dataset('xxxxx')

for varname in varnames:
    print('=============== {} ================='.format(varname))
    print(np.array(my_std[varname]))
    #print(np.array(DJ_std[varname]))

## Plot

In [None]:
varname_plot = copy.copy(varname_surf)

for varname in varname_upper:
    for i_level in range(15):
        varname_plot.append('{}_lev{}'.format(varname, i_level))

residual_vals = std_concat / std_g

In [None]:
fig = plt.figure(figsize=(8, 15))
ax = fig.gca()
ax.grid(linestyle=':')
ax.barh(varname_plot[::-1], residual_vals[::-1], color='skyblue', edgecolor='k')
plt.autoscale(enable=True, axis='y', tight=True)
ax.set_title('Residual norm constants (larger means higher penalty)', fontsize=14)