# Global mass and energy conservations in CREDIT

In [1]:
# system
import os
import sys
import logging
from glob import glob
from typing import Dict

# others
import yaml
import numpy as np

# torch
import torch
from torch import nn
from torchvision import transforms as tforms

# credit
from credit.data import (
    Sample,
    concat_and_reshape,
    reshape_only,
    ERA5_and_Forcing_Dataset,
    get_forward_data
)

from credit.transforms import (
    Normalize_ERA5_and_Forcing,
    ToTensor_ERA5_and_Forcing,
    load_transforms
)

from credit.parser import (
    CREDIT_main_parser,
    training_data_check
)

from credit.physics_core import physics_pressure_level

from credit.physics_constants import (RAD_EARTH, GRAVITY, 
                                      RHO_WATER, LH_WATER, 
                                      RVGAS, RDGAS, CP_DRY, CP_VAPOR)

from credit.postblock import (
    PostBlock,
    SKEBS,
    tracer_fixer,
    global_mass_fixer,
    global_energy_fixer
)

In [2]:
# plot
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
# Logging setup
logger = logging.getLogger(__name__)

# single node steup
rank = 0
world_size = 1

## Load yaml

In [4]:
# old rollout config
# config_name = '/glade/work/ksha/CREDIT_runs/wxformer_6h/model_single.yml'
config_name = '/glade/u/home/ksha/miles-physics/config/model_single.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [5]:
conf = CREDIT_main_parser(conf, parse_training=True, parse_predict=False, print_summary=False)

## Data workflow

### Gather data information

In [6]:
# pick a year
train_years_range = [2018, 2020]
valid_years_range = [2018, 2020]

In [7]:
all_ERA_files = sorted(glob(conf["data"]["save_loc"]))

# check and glob surface files
if ('surface_variables' in conf['data']) and (len(conf['data']['surface_variables']) > 0):
    surface_files = sorted(glob(conf["data"]["save_loc_surface"]))

else:
    surface_files = None

# check and glob dyn forcing files
if ('dynamic_forcing_variables' in conf['data']) and (len(conf['data']['dynamic_forcing_variables']) > 0):
    dyn_forcing_files = sorted(glob(conf["data"]["save_loc_dynamic_forcing"]))

else:
    dyn_forcing_files = None

# check and glob diagnostic files
if ('diagnostic_variables' in conf['data']) and (len(conf['data']['diagnostic_variables']) > 0):
    diagnostic_files = sorted(glob(conf["data"]["save_loc_diagnostic"]))

else:
    diagnostic_files = None

# convert year info to str for file name search
train_years = [str(year) for year in range(train_years_range[0], train_years_range[1])]
valid_years = [str(year) for year in range(valid_years_range[0], valid_years_range[1])]

# Filter the files for training / validation
train_files = [file for file in all_ERA_files if any(year in file for year in train_years)]
valid_files = [file for file in all_ERA_files if any(year in file for year in valid_years)]

if surface_files is not None:

    train_surface_files = [file for file in surface_files if any(year in file for year in train_years)]
    valid_surface_files = [file for file in surface_files if any(year in file for year in valid_years)]
    
else:
    train_surface_files = None
    valid_surface_files = None

if dyn_forcing_files is not None:

    train_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in train_years)]
    valid_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in valid_years)]

else:
    train_dyn_forcing_files = None
    valid_dyn_forcing_files = None

if diagnostic_files is not None:

    train_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in train_years)]
    valid_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in valid_years)]

else:
    train_diagnostic_files = None
    valid_diagnostic_files = None

# --------------------------------------------------- #
is_train = False
# separate training set and validation set cases
if is_train:
    history_len = conf["data"]["history_len"]
    forecast_len = conf["data"]["forecast_len"]
    name = "training"
else:
    history_len = conf["data"]["valid_history_len"]
    forecast_len = conf["data"]["valid_forecast_len"]
    name = 'validation'

### Dataset

In [8]:
# # transforms
transforms = load_transforms(conf)

if conf['data']['sst_forcing']['activate']:
    sst_forcing = {'varname_skt': conf['data']['sst_forcing']['varname_skt'], 
                   'varname_ocean_mask': conf['data']['sst_forcing']['varname_ocean_mask']}
else:
    sst_forcing = None

# Z-score
dataset = ERA5_and_Forcing_Dataset(
    varname_upper_air=conf['data']['variables'],
    varname_surface=conf['data']['surface_variables'],
    varname_dyn_forcing=conf['data']['dynamic_forcing_variables'],
    varname_forcing=conf['data']['forcing_variables'],
    varname_static=conf['data']['static_variables'],
    varname_diagnostic=conf['data']['diagnostic_variables'],
    filenames=train_files,
    filename_surface=train_surface_files,
    filename_dyn_forcing=train_dyn_forcing_files,
    filename_forcing=conf['data']['save_loc_forcing'],
    filename_static=conf['data']['save_loc_static'],
    filename_diagnostic=train_diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=conf["data"]["skip_periods"],
    one_shot=conf['data']['one_shot'],
    max_forecast_len=conf["data"]["max_forecast_len"],
    transform=transforms,
    sst_forcing=sst_forcing
)

# # sampler
# sampler = DistributedSampler(
#     dataset,
#     num_replicas=world_size,
#     rank=rank,
#     seed=seed,
#     shuffle=is_train,
#     drop_last=True
# )

### An example training batch

In [9]:
batch_single = dataset.__getitem__(999)

In [10]:
batch = {}
keys = list(batch_single.keys())
keys = keys[:-1]
for var in keys:
    batch[var] = batch_single[var].unsqueeze(0) # give a single sample batch dimension

# ------------------------- #
# base trainer workflow

if "x_surf" in batch:
    # combine x and x_surf
    # input: (batch_num, time, var, level, lat, lon), (batch_num, time, var, lat, lon)
    # output: (batch_num, var, time, lat, lon), 'x' first and then 'x_surf'
    x = concat_and_reshape(batch["x"], batch["x_surf"])
else:
    # no x_surf
    x = reshape_only(batch["x"]).to(self.device).float()

# --------------------------------------------------------------------------------- #
# add forcing and static variables
if 'x_forcing_static' in batch:

    # (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon)
    x_forcing_batch = batch['x_forcing_static'].permute(0, 2, 1, 3, 4)

    # concat on var dimension
    x = torch.cat((x, x_forcing_batch), dim=1)

# --------------------------------------------------------------------------------- #
# combine y and y_surf
if "y_surf" in batch:
    y = concat_and_reshape(batch["y"], batch["y_surf"])
else:
    y = reshape_only(batch["y"])

if 'y_diag' in batch:

    # (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon)
    y_diag_batch = batch['y_diag'].permute(0, 2, 1, 3, 4).float()

    # concat on var dimension
    y = torch.cat((y, y_diag_batch), dim=1)

In [11]:
y_original = y.clone()
x_original = x.clone()

## postblock tests

### global energy fixer

In [12]:
input_dict = {"y_pred": y, "x": x,}

post_conf = conf['model']['post_conf']
opt = global_energy_fixer(post_conf)

In [13]:
output_dict = opt(input_dict)

In [14]:
y_pred = output_dict['y_pred']

In [15]:
y_pred.dtype

torch.float32

In [16]:
output_dict['x'].dtype

torch.float32

In [17]:
def test_global_energy_fixer_rand():
    
    # turn-off other blocks
    conf = {'post_conf': {'skebs': {'activate': False}}}
    conf['post_conf']['tracer_fixer'] = {'activate': False}
    conf['post_conf']['global_mass_fixer'] = {'activate': False}
    
    # energy fixer specs
    conf['post_conf']['global_energy_fixer'] = {
        'activate': True,
        'simple_demo': True,
        'denorm': False,
        'midpoint': False,
        'T_inds': [0, 1, 2, 3, 4, 5, 6],
        'q_inds': [0, 1, 2, 3, 4, 5, 6],
        'U_inds': [0, 1, 2, 3, 4, 5, 6],
        'V_inds': [0, 1, 2, 3, 4, 5, 6],
        'TOA_rad_inds': [7, 8],
        'surf_rad_inds': [7, 8],
        'surf_flux_inds': [7, 8]}
    
    conf['post_conf']['data'] = {'lead_time_periods': 6}
    
    # initialize postblock
    postblock = PostBlock(**conf)

    # verify that global_max_fixer is registered in the postblock
    assert any([isinstance(module, global_energy_fixer) for module in postblock.modules()])
    
    # input tensor
    x = torch.randn((1, 7, 2, 10, 18))
    # output tensor
    y_pred = torch.randn((1, 9, 1, 10, 18))
    
    input_dict = {"y_pred": y_pred, "x": x}
    # corrected output
    y_pred_fix = postblock(input_dict)
    
    assert y_pred_fix.shape == y_pred.shape

In [18]:
test_global_energy_fixer_rand()

**Check before & after**

In [19]:
def energy_residual_verif(x, y_pred):

    state_trans = load_transforms(post_conf, scaler_only=True)
    
    x = state_trans.inverse_transform_input(x)
    y_pred = state_trans.inverse_transform(y_pred)
    
    N_seconds = 3600 * 6
    
    T_ind_start = opt.T_ind_start
    T_ind_end = opt.T_ind_end
    
    q_ind_start = opt.q_ind_start
    q_ind_end = opt.q_ind_end
    
    U_ind_start = opt.U_ind_start
    U_ind_end = opt.U_ind_end
    
    V_ind_start = opt.V_ind_start
    V_ind_end = opt.V_ind_end
    
    TOA_solar_ind = opt.TOA_solar_ind
    TOA_OLR_ind = opt.TOA_OLR_ind
    
    surf_solar_ind = opt.surf_solar_ind
    surf_LR_ind = opt.surf_LR_ind
    
    surf_SH_ind = opt.surf_SH_ind
    surf_LH_ind = opt.surf_LH_ind

    ds_physics = get_forward_data(post_conf['data']['save_loc_physics'])        
    lon2d = torch.from_numpy(ds_physics['lon2d'].values)
    lat2d = torch.from_numpy(ds_physics['lat2d'].values)
    p_level = torch.from_numpy(ds_physics['p_level'].values)
    GPH_surf = torch.from_numpy(ds_physics['geopotential_at_surface'].values)
    
    core_compute = physics_pressure_level(lon2d, lat2d, p_level, midpoint=False)
    
    T_input = x[:, T_ind_start:T_ind_end, -1, ...]
    q_input = x[:, q_ind_start:q_ind_end, -1, ...]
    U_input = x[:, U_ind_start:U_ind_end, -1, ...]
    V_input = x[:, V_ind_start:V_ind_end, -1, ...]
    
    T_pred = y_pred[:, T_ind_start:T_ind_end, 0, ...]
    q_pred = y_pred[:, q_ind_start:q_ind_end, 0, ...]
    U_pred = y_pred[:, U_ind_start:U_ind_end, 0, ...]
    V_pred = y_pred[:, V_ind_start:V_ind_end, 0, ...]
            
    TOA_solar_pred = y_pred[:, TOA_solar_ind, 0, ...]
    TOA_OLR_pred = y_pred[:, TOA_OLR_ind, 0, ...]
            
    surf_solar_pred = y_pred[:, surf_solar_ind, 0, ...]
    surf_LR_pred = y_pred[:, surf_LR_ind, 0, ...]
    surf_SH_pred = y_pred[:, surf_SH_ind, 0, ...]
    surf_LH_pred = y_pred[:, surf_LH_ind, 0, ...]
    
    CP_t0 = (1 - q_input) * CP_DRY + q_input * CP_VAPOR
    CP_t1 = (1 - q_pred) * CP_DRY + q_pred * CP_VAPOR
    
    # kinetic energy
    ken_t0 = 0.5 * (U_input ** 2 + V_input ** 2)
    ken_t1 = 0.5 * (U_pred ** 2 + V_pred ** 2)
    
    # packing latent heat + potential energy + kinetic energy
    E_qgk_t0 = LH_WATER * q_input + GPH_surf + ken_t0
    E_qgk_t1 = LH_WATER * q_input + GPH_surf + ken_t1
    
    # TOA energy flux
    R_T = (TOA_solar_pred + TOA_OLR_pred) / N_seconds
    R_T_sum = core_compute.weighted_sum(R_T, axis=(-2, -1))
    
    # surface net energy flux
    F_S = (surf_solar_pred + surf_LR_pred + surf_SH_pred + surf_LH_pred) / N_seconds
    F_S_sum = core_compute.weighted_sum(F_S, axis=(-2, -1))

    E_level_t0 = CP_t0 * T_input + E_qgk_t0
    E_level_t1 = CP_t1 * T_pred + E_qgk_t1

    # column integrated total energy
    TE_t0 = core_compute.integral(E_level_t0) / GRAVITY
    TE_t1 = core_compute.integral(E_level_t1) / GRAVITY
    
    dTE_dt = (TE_t1 - TE_t0) / N_seconds
    
    dTE_sum = core_compute.weighted_sum(dTE_dt, axis=(1, 2), keepdims=False)
    
    delta_dTE_sum = (R_T_sum - F_S_sum) - dTE_sum
    
    print('Residual to conserve energy budget [Watts]: {}'.format(delta_dTE_sum))
    return delta_dTE_sum, dTE_sum, (R_T_sum - F_S_sum)

In [20]:
residual_, tendency_, source_sinks_ = energy_residual_verif(x, y_original)
print(f'Tendency of atmos total energy [Watts]: {tendency_}')
print(f'Sources and sinks [Watts]: {source_sinks_}')

Residual to conserve energy budget [Watts]: tensor([-8.7528e+14], dtype=torch.float64)
Tendency of atmos total energy [Watts]: tensor([-5.5654e+15], dtype=torch.float64)
Sources and sinks [Watts]: tensor([-6.4407e+15])


In [21]:
residual_, tendency_, source_sinks_ = energy_residual_verif(x, y_pred)
print(f'Tendency of atmos total energy [Watts]: {tendency_}')
print(f'Sources and sinks [Watts]: {source_sinks_}')

Residual to conserve energy budget [Watts]: tensor([-6.8107e+11], dtype=torch.float64)
Tendency of atmos total energy [Watts]: tensor([-6.4400e+15], dtype=torch.float64)
Sources and sinks [Watts]: tensor([-6.4407e+15])


In [22]:
y_pred_np = np.array(y_pred)
y_original_np = np.array(y_original)

In [23]:
T_ind_start = opt.T_ind_start
T_ind_end = opt.T_ind_end

ds_physics = get_forward_data(post_conf['data']['save_loc_physics'])        
lon2d = torch.from_numpy(ds_physics['lon2d'].values)
lat2d = torch.from_numpy(ds_physics['lat2d'].values)
p_level = torch.from_numpy(ds_physics['p_level'].values)

for i in range(37):
    print(f'{p_level[i]/100} hPa largest modified amount: {np.abs(y_pred_np[0, T_ind_start+i, ...] - y_original_np[0, T_ind_start+i, ...]).max()}')

1.0 hPa largest modified amount: 0.00025260448455810547
2.0 hPa largest modified amount: 0.00023890286684036255
3.0 hPa largest modified amount: 0.0002582967281341553
5.0 hPa largest modified amount: 0.0003235340118408203
7.0 hPa largest modified amount: 0.00036485493183135986
10.0 hPa largest modified amount: 0.0004243701696395874
20.0 hPa largest modified amount: 0.0005273222923278809
30.0 hPa largest modified amount: 0.0005829334259033203
50.0 hPa largest modified amount: 0.0006197523325681686
70.0 hPa largest modified amount: 0.0005974769592285156
100.0 hPa largest modified amount: 0.0006909370422363281
125.0 hPa largest modified amount: 0.0007288455963134766
150.0 hPa largest modified amount: 0.0006852149963378906
175.0 hPa largest modified amount: 0.0005910396575927734
200.0 hPa largest modified amount: 0.0005325078964233398
225.0 hPa largest modified amount: 0.0005257129669189453
250.0 hPa largest modified amount: 0.0005657672882080078
300.0 hPa largest modified amount: 0.000717

In [24]:
# for i in range(37):
#     plt.figure()
#     plt.pcolormesh(y_pred_np[0, T_ind_start+i, 0, ...], cmap=plt.cm.nipy_spectral_r)
#     plt.title('level {}'.format(i))
#     plt.colorbar()

### global mass fixer

In [25]:
input_dict = {"y_pred": y, "x": x,}
post_conf = conf['model']['post_conf']
opt = global_mass_fixer(post_conf)

In [26]:
output_dict = opt(input_dict)

In [27]:
y_pred = output_dict['y_pred']

In [28]:
y_pred.dtype

torch.float32

In [29]:
output_dict['x'].dtype

torch.float32

**test module develop**

In [30]:
def test_global_mass_fixer_rand():
    '''
    This function provides a I/O size test on 
    global_mass_fixer at credit.postblock
    '''
    # initialize post_conf, turn-off other blocks
    conf = {'post_conf': {'skebs': {'activate': False}}}
    conf['post_conf']['tracer_fixer'] = {'activate': False}
    conf['post_conf']['global_energy_fixer'] = {'activate': False}
    
    # global mass fixer specs
    conf['post_conf']['global_mass_fixer'] = {
        'activate': True, 
        'denorm': False, 
        'midpoint': False,
        'simple_demo': True, 
        'fix_level_num': 3,
        'q_inds': [0, 1, 2, 3, 4, 5, 6],
        'precip_ind': 7,
        'evapor_ind': 8
    }
    
    # data specs
    conf['post_conf']['data'] = {'lead_time_periods': 6}
    
    # initialize postblock
    postblock = PostBlock(**conf)

    # verify that global_mass_fixer is registered in the postblock
    assert any([isinstance(module, global_mass_fixer) for module in postblock.modules()])
    
    # input tensor
    x = torch.randn((1, 7, 2, 10, 18))
    # output tensor
    y_pred = torch.randn((1, 9, 1, 10, 18))
    
    input_dict = {"y_pred": y_pred, "x": x}
    
    # corrected output
    y_pred_fix = postblock(input_dict)

    # verify `y_pred_fix` and `y_pred` has the same size
    assert y_pred_fix.shape == y_pred.shape

In [31]:
test_global_mass_fixer_rand()

**Check before & after**

In [32]:
def mass_residual_verif(x, y_pred):

    state_trans = load_transforms(post_conf, scaler_only=True)
    
    x = state_trans.inverse_transform_input(x)
    y_pred = state_trans.inverse_transform(y_pred)
    
    precip_ind  = opt.precip_ind
    q_ind_start = opt.q_ind_start
    q_ind_end = opt.q_ind_end
    
    ds_physics = get_forward_data(post_conf['data']['save_loc_physics'])        
    lon2d = torch.from_numpy(ds_physics['lon2d'].values)
    lat2d = torch.from_numpy(ds_physics['lat2d'].values)
    p_level = torch.from_numpy(ds_physics['p_level'].values)
    core_compute = physics_pressure_level(lon2d, lat2d, p_level, midpoint=False)
    
    mass_dry_sum_t0 = core_compute.total_dry_air_mass(x[:, q_ind_start:q_ind_end, -1, ...].unsqueeze(2))
    mass_dry_sum_t1 = core_compute.total_dry_air_mass(y_pred[:, q_ind_start:q_ind_end, ...])
    mass_residual = mass_dry_sum_t1 - mass_dry_sum_t0
    print(f'Residual to conserve energy budget [kg]: {mass_residual}')
    return mass_residual, mass_dry_sum_t1, mass_dry_sum_t0

In [33]:
residual_, M_t1, M_t0 = mass_residual_verif(x, y_original)
print(f'Input state total air mass [kg]: {M_t0}')
print(f'Output state total air mass [kg]: {M_t1}')

Residual to conserve energy budget [kg]: tensor([[-1.9034e+13]], dtype=torch.float64)
Input state total air mass [kg]: tensor([[5.1824e+18]], dtype=torch.float64)
Output state total air mass [kg]: tensor([[5.1824e+18]], dtype=torch.float64)


In [34]:
residual_, M_t1, M_t0 = mass_residual_verif(x, y_pred)
print(f'Input state total air mass [kg]: {M_t0}')
print(f'Output state total air mass [kg]: {M_t1}')

Residual to conserve energy budget [kg]: tensor([[1.1610e+12]], dtype=torch.float64)
Input state total air mass [kg]: tensor([[5.1824e+18]], dtype=torch.float64)
Output state total air mass [kg]: tensor([[5.1824e+18]], dtype=torch.float64)


**Check modified amount after normalization**

In [35]:
y_pred_np = np.array(y_pred)
y_original_np = np.array(y_original)

In [36]:
q_ind_start = opt.q_ind_start
q_ind_end = opt.q_ind_end

ds_physics = get_forward_data(post_conf['data']['save_loc_physics'])        
lon2d = torch.from_numpy(ds_physics['lon2d'].values)
lat2d = torch.from_numpy(ds_physics['lat2d'].values)
p_level = torch.from_numpy(ds_physics['p_level'].values)

for i in range(37):
    print(f'{p_level[i]/100} hPa largest modified amount: {np.abs(y_pred_np[0, q_ind_start+i, ...] - y_original_np[0, q_ind_start+i, ...]).max()}')

1.0 hPa largest modified amount: 5.364418029785156e-07
2.0 hPa largest modified amount: 4.76837158203125e-07
3.0 hPa largest modified amount: 4.76837158203125e-07
5.0 hPa largest modified amount: 4.76837158203125e-07
7.0 hPa largest modified amount: 7.152557373046875e-07
10.0 hPa largest modified amount: 4.76837158203125e-07
20.0 hPa largest modified amount: 4.76837158203125e-07
30.0 hPa largest modified amount: 5.960464477539062e-07
50.0 hPa largest modified amount: 7.152557373046875e-07
70.0 hPa largest modified amount: 4.76837158203125e-07
100.0 hPa largest modified amount: 9.5367431640625e-07
125.0 hPa largest modified amount: 4.76837158203125e-07
150.0 hPa largest modified amount: 4.76837158203125e-07
175.0 hPa largest modified amount: 2.384185791015625e-07
200.0 hPa largest modified amount: 4.76837158203125e-07
225.0 hPa largest modified amount: 2.384185791015625e-07
250.0 hPa largest modified amount: 2.384185791015625e-07
300.0 hPa largest modified amount: 2.384185791015625e-07


**Check modified q**

In [37]:
# for i in range(37):
#     plt.figure()
#     plt.pcolormesh(y_pred_np[0, q_ind_start+i, 0, ...], cmap=plt.cm.nipy_spectral_r)
#     plt.title('level {}'.format(i))
#     plt.colorbar()

### tracer fixer

In [None]:
## see test_postblock.py on how it work