# Compute residual norm coefficients

This notebook computes the residual norm coefficients as part of the variable weights.

In [1]:
import os
import yaml
import copy
import numpy as np
import xarray as xr

In [2]:
from scipy.stats import gmean

In [3]:
import matplotlib.pyplot as plt
%matplotlib inline

## ERA5 (x)

In [4]:
# # get variable information from data_preprocessing/config
# config_name = os.path.realpath('data_config_ERA5.yml')

# with open(config_name, 'r') as stream:
#     conf = yaml.safe_load(stream)

In [5]:
# N_levels = 11

# base_dir = '/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/dscale_1h/'
# ds_example = xr.open_zarr(base_dir+'ERA5_GP_1h_2021.zarr')
# level = np.array(ds_example['level'])

In [6]:
# # get variable names
# varnames = list(conf['residual'].keys())
# varnames = varnames[:-5] # remove save_loc and others

# varname_upper = ['U', 'V', 'T', 'Q']
# varname_surf = list(set(varnames) - set(varname_upper))

In [7]:
# # collect computed mean and variance values
# # See "qsub_STEP01_compute_mean_std.ipynb"
# MEAN_values = {}
# STD_values = {}

# for varname in varname_surf:
#     save_name = conf['residual']['save_loc'] + '{}_mean_std_{}.npy'.format(
#         conf['residual']['prefix'], varname)
#     mean_std = np.load(save_name)
#     MEAN_values[varname] = mean_std[0]
#     STD_values[varname] = mean_std[1]

# for varname in varname_upper:

#     # -------------------------------------------- #
#     # allocate all levels
#     mean_std_all_levels = np.empty((2, N_levels))
#     mean_std_all_levels[...] = np.nan
    
#     for i_level in range(N_levels):
#         save_name = conf['residual']['save_loc'] + '{}_level{}_mean_std_{}.npy'.format(
#             conf['residual']['prefix'], i_level, varname)
#         mean_std = np.load(save_name)
#         mean_std_all_levels[:, i_level] = mean_std

#     # -------------------------------------------- #
#     # save
#     MEAN_values[varname] = np.copy(mean_std_all_levels[0, :])
#     STD_values[varname] = np.copy(mean_std_all_levels[1, :])

# keys_to_drop = ['TCC', 'SKT', 'land_sea_CI_mask'] # <---------------- some variables are not used in the paper
# MEAN_values = {k: v for k, v in MEAN_values.items() if k not in keys_to_drop}
# STD_values = {k: v for k, v in STD_values.items() if k not in keys_to_drop}

In [8]:
# # separate upper air (list) and surf (float) std values
# std_val_all = list(STD_values.values())
# std_val_surf = np.array(std_val_all[:-4])
# std_val_upper = std_val_all[-4:]

# # combine
# std_concat = np.concatenate([std_val_surf]+ std_val_upper)

# # geometrical mean (not used)
# std_g = gmean(np.sqrt(std_concat))

### Save residual coef as a file

In [9]:
# # ------------------------------------------------------- #
# # create xr.DataArray for std
# ds_std_6h = xr.Dataset(coords={"level": level})

# for varname, data in STD_values.items():
#     data = np.sqrt(data) / std_g # <--- var to std and divided by std_g
#     if len(data.shape) == 1:
#         data_array = xr.DataArray(
#             data,
#             dims=["level",],
#             coords={"level": level},
#             name=varname,
#         )
#         ds_std_6h[varname] = data_array
#     else:
#         data_array = xr.DataArray(
#             data,
#             name=varname,
#         )
#         ds_std_6h[varname] = data_array

In [10]:
# ds_std_6h.to_netcdf('/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/ERA5_1h_residual_1980_2019.nc')

In [11]:
# ds_1h = xr.open_dataset('/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/ERA5_1h_residual_1980_2019.nc')
# ds_6h = xr.open_dataset('/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/ERA5_6h_residual_1980_2019.nc')

# for varname in ds_1h.keys():
#     print(f'=================== {varname} ===================')
#     print(ds_1h[varname].values)
#     print(ds_6h[varname].values)

## WRF

In [5]:
# get variable information from data_preprocessing/config
config_name = os.path.realpath('data_config_WRF.yml')

with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [6]:
N_levels = 12

base_dir = '/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/C404_land/'
ds_example = xr.open_zarr(base_dir+'C404_GP_1980.zarr')
level = np.array(ds_example['bottom_top'])

In [7]:
# # get variable names
# varnames = list(conf['residual'].keys())
# varnames = varnames[:-5] # remove save_loc and others

# varname_upper = ['WRF_P', 'WRF_U', 'WRF_V', 'WRF_T', 'WRF_Q_tot', 'WRF_Q_tot_05']
# varname_surf = list(set(varnames) - set(varname_upper))

In [8]:
varname_upper = [
    'WRF_P', 'WRF_U', 'WRF_V', 'WRF_T', 'WRF_Q_tot_05', 'WRF_W'
]

varname_surf = [
    'WRF_SP', 'WRF_MSLP', 'WRF_T2', 'WRF_TD2', 'WRF_U10', 
    'WRF_V10', 'WRF_PWAT_05', 'WRF_SMOIS', 'WRF_TSLB'
]#

# 'WRF_precip_025', 'WRF_radar_composite_025', 'WRF_OLR', 'WRF_TCC', 'WRF_GLW', 'WRF_SWDOWN'

In [9]:
MEAN_values = {}
STD_values = {}

for varname in varname_surf:
    save_name = conf['residual']['save_loc'] + '{}_mean_std_{}.npy'.format(
        conf['residual']['prefix'], varname)
    mean_std = np.load(save_name)
    MEAN_values[varname] = mean_std[0]
    STD_values[varname] = mean_std[1]

for varname in varname_upper:

    # -------------------------------------------- #
    # allocate all levels
    mean_std_all_levels = np.empty((2, N_levels))
    mean_std_all_levels[...] = np.nan
    
    for i_level in range(N_levels):
        save_name = conf['residual']['save_loc'] + '{}_level{}_mean_std_{}.npy'.format(
            conf['residual']['prefix'], i_level, varname)
        mean_std = np.load(save_name)
        mean_std_all_levels[:, i_level] = mean_std

    # -------------------------------------------- #
    # save
    MEAN_values[varname] = np.copy(mean_std_all_levels[0, :])
    STD_values[varname] = np.copy(mean_std_all_levels[1, :])

In [10]:
# separate upper air (list) and surf (float) std values
N_upper = len(varname_upper)
std_val_all = list(STD_values.values())
std_val_surf = np.array(std_val_all[:-N_upper])
std_val_upper = std_val_all[-N_upper:]

# combine
std_concat = np.concatenate([std_val_surf]+ std_val_upper)

# geometrical mean (not used)
std_g = gmean(np.sqrt(std_concat))

In [11]:
# ------------------------------------------------------- #
# create xr.DataArray for std

# use the same level coord as mean
ds_std_6h = xr.Dataset(coords={'bottom_top': level})

for varname, data in STD_values.items():
    data = np.sqrt(data) / std_g
    if len(data.shape) == 1:
        data_array = xr.DataArray(
            data,
            dims=["bottom_top",],
            coords={"bottom_top": level},
            name=varname,
        )
        ds_std_6h[varname] = data_array
    else:
        data_array = xr.DataArray(
            data,
            name=varname,
        )
        ds_std_6h[varname] = data_array

In [12]:
# ds_std_6h.to_netcdf(
#     '/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/C404_residual_1980_2019_12lev.nc'
# )

In [15]:
ds_new = xr.open_dataset('/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/C404_residual_1980_2019_12lev.nc')
ds_W = xr.open_dataset('/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/C404_residual_1980_2019_12lev_clean.nc')
ds_old = xr.open_dataset('/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/C404_residual_1980_2019_15lev_20250629.nc')

for varname in ds_new.keys():
    print(f'=================== {varname} ===================')
    try:
        print(ds_new[varname].values)
        print(ds_W[varname].values)
        print(ds_old[varname].values)
    except:
        pass

0.09639225375117795
0.09474170269892271
0.11926121541140546
0.10737514260281934
0.10553652852635395
0.7739778027954418
0.7722965511266701
0.7590723050331437
0.955523102335158
0.6472721789529152
0.6361887595442969
0.8008368282702392
2.627299790180062
2.5823118138177117
3.250623924987081
2.223484072487494
2.1854107436389905
2.7510033494731885
0.6945971632891261
0.6827033851674961
0.8593896157777442
0.3556348084151752
0.7184900598262294
[0.09628631 0.09636575 0.09661374 0.09715853 0.09817187 0.09939734
 0.10030108 0.10070841 0.10078935 0.10127266 0.10228783 0.1072464 ]
[0.09463757 0.09471565 0.09495939 0.09549486 0.09649085 0.09769533
 0.0985836  0.09898395 0.09906351 0.09953854 0.10053633 0.10540999]
[0.11913013 0.11917258 0.11930666 0.11953525 0.11992822 0.12055753
 0.12146304 0.12249296 0.12341966 0.12409741 0.12460137 0.12470152
 0.12529949 0.12655551 0.13269049]
[2.65847761 2.2142442  1.94368222 1.80943606 1.7533852  1.53844313
 1.2755708  1.09455102 0.97479134 0.77802218 0.58267895 

In [20]:
ds_new = xr.open_dataset('/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/C404_residual_1980_2019_12lev.nc')
ds_W = xr.open_dataset('/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/C404_residual_1980_2019_12lev_clean.nc')
ds_old = xr.open_dataset('/glade/derecho/scratch/ksha/DWC_data/CONUS_domain_GP/mean_std/C404_residual_1980_2019_15lev_20250629.nc')

for varname in ds_W.keys():
    print(f'=================== {varname} ===================')
    try:
        print(ds_W[varname].values)
        print(ds_new[varname].values)
        print(ds_old[varname].values)
    except:
        pass

0.09474170269892271
0.14131897851483102
0.11926121541140546
0.10553652852635395
0.15742079762637887
0.7739778027954418
0.7590723050331437
1.132250315440044
0.955523102335158
0.6361887595442969
0.9489542944686785
0.8008368282702392
2.5823118138177117
3.8518377582383168
3.250623924987081
2.1854107436389905
3.2598106760636796
2.7510033494731885
0.6827033851674961
1.018336617055383
0.8593896157777442
[0.09463757 0.09471565 0.09495939 0.09549486 0.09649085 0.09769533
 0.0985836  0.09898395 0.09906351 0.09953854 0.10053633 0.10540999]
[0.14116365 0.14128012 0.14164369 0.14244241 0.14392804 0.14572469
 0.14704965 0.14764682 0.14776549 0.14847405 0.14996238 0.15723204]
[0.11913013 0.11917258 0.11930666 0.11953525 0.11992822 0.12055753
 0.12146304 0.12249296 0.12341966 0.12409741 0.12460137 0.12470152
 0.12529949 0.12655551 0.13269049]
[2.61295577 2.17632909 1.91040001 1.77845259 1.72336151 1.51209995
 1.25372885 1.07580872 0.95809972 0.76469989 0.57270158 0.697794  ]
[3.89754702 3.24626427 2.8