# Global-scale atmospheric mass and energy conservations on ERA5 pressure level data

In [1]:
import torch
import numpy as np
import xarray as xr

In [2]:
from typing import Dict, Any, Optional

### Load data

ERA5 pressure level data:

* `/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/upper_air/*.zarr`
* `/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/surf/*.zarr`
* `/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/accum/*.zarr`

In [3]:
base_dir = '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/'
ds_surf = xr.open_zarr(base_dir + 'surf/ERA5_plevel_6h_surf_1979.zarr')
ds_accum = xr.open_zarr(base_dir + 'accum/ERA5_plevel_6h_accum_1979.zarr')
ds_upper = xr.open_zarr(base_dir + 'upper_air/ERA5_plevel_6h_upper_air_1979.zarr')
ds_static = xr.open_zarr(base_dir + 'static/ERA5_plevel_6h_static.zarr')

In [4]:
x = ds_surf['longitude']
y = ds_surf['latitude']
lon, lat = np.meshgrid(x, y)
level_p = 100*np.array(ds_upper['level'])
tensor_shape = (len(level_p),) + lon.shape

### Convert data to `torch.Tensor`

In [5]:
batch_size = 32
target_shape_4D = (batch_size, 2,)+tensor_shape
target_shape_3D = (batch_size, 2,)+tensor_shape[1:]

t_slice = np.arange(batch_size+1)

In [6]:
def time_series_to_batch(q, target_shape):
    q_batch = torch.as_strided(
        q, size=target_shape, 
        stride=(q.stride(0), q.stride(0), *q.stride()[1:]))
    return q_batch

In [7]:
q = torch.from_numpy(np.array(ds_upper['Q'].isel(time=t_slice))) # kg/kg
T = torch.from_numpy(np.array(ds_upper['T'].isel(time=t_slice))) # kg/kg
u = torch.from_numpy(np.array(ds_upper['U'].isel(time=t_slice))) # m/s
v = torch.from_numpy(np.array(ds_upper['V'].isel(time=t_slice)))
precip = torch.from_numpy(np.array(ds_accum['total_precipitation'].isel(time=t_slice)))
evapor = torch.from_numpy(np.array(ds_accum['evaporation'].isel(time=t_slice)))

GPH_surf = torch.from_numpy(np.array(ds_static['geopotential_at_surface'])) # J/m2
TOA_net = torch.from_numpy(np.array(ds_accum['top_net_solar_radiation'].isel(time=t_slice))) # J/m2
OLR = torch.from_numpy(np.array(ds_accum['top_net_thermal_radiation'].isel(time=t_slice))) # J/m2
R_short = torch.from_numpy(np.array(ds_accum['surface_net_solar_radiation'].isel(time=t_slice))) # J/m2
R_long = torch.from_numpy(np.array(ds_accum['surface_net_thermal_radiation'].isel(time=t_slice))) # J/m2
LH = torch.from_numpy(np.array(ds_accum['surface_latent_heat_flux'].isel(time=t_slice))) # J/m2
SH = torch.from_numpy(np.array(ds_accum['surface_sensible_heat_flux'].isel(time=t_slice))) # J/m2

In [8]:
q_batch = time_series_to_batch(q, target_shape_4D)
T_batch = time_series_to_batch(T, target_shape_4D)
u_batch = time_series_to_batch(u, target_shape_4D)
v_batch = time_series_to_batch(v, target_shape_4D)
precip_batch = time_series_to_batch(precip, target_shape_3D)
evapor_batch = time_series_to_batch(evapor, target_shape_3D)

GPH_surf_batch = GPH_surf.unsqueeze(0).unsqueeze(0).unsqueeze(0)
TOA_net_batch = time_series_to_batch(TOA_net, target_shape_3D)
OLR_batch = time_series_to_batch(OLR, target_shape_3D)
R_short_batch = time_series_to_batch(R_short, target_shape_3D)
R_long_batch = time_series_to_batch(R_long, target_shape_3D)
LH_batch = time_series_to_batch(LH, target_shape_3D)
SH_batch = time_series_to_batch(SH, target_shape_3D)

In [9]:
longitude = torch.from_numpy(lon)
latitude = torch.from_numpy(lat)
upper_air_pressure = torch.from_numpy(level_p)

### `credit.physics_core` pressure level class

In [10]:
GRAVITY = 9.80665
R = 6371000  # m
RHO_WATER = 1000.0 # kg/m^3
RAD_EARTH = 6371000 # m
LH_WATER = 2.26e6  # J/kg
CP_DRY = 1005 # J/kg K
CP_VAPOR = 1846 # J/kg K


class physics_pressure_level:
    '''
    Pressure level physics

    Attributes:
        upper_air_pressure (torch.Tensor): pressure levels in Pa.
        lon (torch.Tensor): longitude in degrees.
        lat (torch.Tensor): latitude in degrees.
        pressure_thickness (torch.Tensor): pressure thickness between levels.
        dx, dy (torch.Tensor): grid spacings in longitude and latitude.
        area (torch.Tensor): area of grid cells.
        integral (function): vertical integration method (midpoint or trapezoidal).
    '''
    
    def __init__(self,
                 lon: torch.Tensor,
                 lat: torch.Tensor,
                 upper_air_pressure: torch.Tensor,
                 midpoint: bool = False):
        '''
        Initialize the class with longitude, latitude, and pressure levels.

        All inputs must be in the same torch device.

        Full order of dimensions:  (batch, time, level, latitude, longitude)
        
        Args:
            lon (torch.Tensor): Longitude in degrees.
            lat (torch.Tensor): Latitude in degrees.
            upper_air_pressure (torch.Tensor): Pressure levels in Pa.
            midpoint (bool): True if vertical level quantities are midpoint values
                      otherwise False
            
        '''
        self.lon = lon
        self.lat = lat
        self.upper_air_pressure = upper_air_pressure
        
        # ========================================================================= #
        # compute pressure level thickness
        self.pressure_thickness = self.upper_air_pressure.diff(dim=-1)
        
        # ========================================================================= #
        # compute grid spacings
        lat_rad = torch.deg2rad(self.lat)
        lon_rad = torch.deg2rad(self.lon)
        self.dy = torch.gradient(lat_rad * RAD_EARTH, dim=0)[0]
        self.dx = torch.gradient(lon_rad * RAD_EARTH, dim=1)[0] * torch.cos(lat_rad)
        self.area = (self.dx * self.dy).abs()

        # ========================================================================== #
        # vertical integration method
        if midpoint:
            self.integral = self.pressure_integral_midpoint
        else:
            self.integral = self.pressure_integral_trapz
            
    def pressure_integral_midpoint(self, q_mid: torch.Tensor) -> torch.Tensor:
        '''
        Compute the pressure level integral of a given quantity; assuming its mid point
        values are pre-computed
        
        Args:
            q_mid: the quantity with dims of (batch_size, time, level-1, latitude, longitude)
    
        Returns:
            Pressure level integrals of q
        '''
        num_dims = len(q_mid.shape)
        
        if num_dims == 5:  # (batch_size, time, level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=2)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_area = q_mid * delta_p  # Trapezoidal rule
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q.shape}")
        
        return q_trapz
        
    def pressure_integral_trapz(self, q: torch.Tensor) -> torch.Tensor:
        '''
        Compute the pressure level integral of a given quantity using the trapezoidal rule.
        
        Args:
            q: the quantity with dims of (batch_size, time, level, latitude, longitude)
    
        Returns:
            Pressure level integrals of q
        '''
        num_dims = len(q.shape)
        
        if num_dims == 5:  # (batch_size, time, level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_area = 0.5 * (q[:, :, :-1, :, :] + q[:, :, 1:, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=2)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_area = 0.5 * (q[:, :-1, :, :] + q[:, 1:, :, :]) * delta_p  # Trapezoidal rule
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_area = 0.5 * (q[:-1, :, :] + q[1:, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q.shape}")
        
        return q_trapz

    def weighted_sum(self,
                     q: torch.Tensor, 
                     axis: Dict[tuple, None] = None, 
                     keepdims: bool = False) -> torch.Tensor:
        '''
        Compute the weighted sum of a given quantity for PyTorch tensors.
        
        Args:
            data: the quantity to be summed (PyTorch tensor)
            axis: dims to compute the sum (can be int or tuple of ints)
            keepdims: whether to keep the reduced dimensions or not
    
        Returns:
            Weighted sum (PyTorch tensor)
        '''
        q_w = q * self.area
        q_sum = torch.sum(q_w, dim=axis, keepdim=keepdims)
        return q_sum

    def total_dry_air_mass(self, q: torch.Tensor) -> torch.Tensor:
        '''
        Compute the total mass of dry air over the entire globe [kg]
        '''
        mass_dry_per_area = self.integral(1-q) / GRAVITY # kg/m^2
        # weighted sum on latitude and longitude dimensions
        mass_dry_sum = self.weighted_sum(mass_dry_per_area, axis=(-2, -1)) # kg
        
        return mass_dry_sum

    def total_column_water(self, q: torch.Tensor) -> torch.Tensor:
        '''
        Compute total column water (TCW) per air column [kg/m2]
        '''
        TWC = self.integral(q) / GRAVITY # kg/m^2
        
        return TWC

In [11]:
physics_core = physics_pressure_level(longitude, latitude, upper_air_pressure)

## Conservation of total dry air mass

In [12]:
q_batch_correct = q_batch.clone()

correction_cycle_num = 1 # iterative to handle numrical precision

for i in range(correction_cycle_num):
    mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct)

    # ------------------------------------------------------------------------------ #
    # check residual term
    mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
    print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
    # ------------------------------------------------------------------------------ #
    
    q_correct_ratio = mass_dry_sum[:, 0] / mass_dry_sum[:, 1]
    q_correct_ratio = q_correct_ratio.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    q_batch_correct[:, 1, ...] = 1 - (1 - q_batch_correct[:, 1, ...]) * q_correct_ratio

Residual to conserve the dry air mass [kg]: tensor([ 4.8998e+12,  2.3332e+13,  1.4195e+13,  2.4289e+13,  5.1000e+12,
         1.4327e+13, -6.2937e+12,  5.2303e+12, -1.4075e+13, -2.0736e+13,
        -3.4998e+12, -9.2598e+12,  3.3425e+11,  7.4894e+12, -6.0463e+12,
        -2.4065e+13, -1.9103e+13, -4.1000e+13, -2.3711e+13, -1.7346e+13,
        -2.3995e+13, -3.3204e+13, -7.7813e+12,  6.0171e+12,  1.7287e+12,
        -1.6342e+13,  1.4258e+13,  9.6297e+12,  1.6811e+13, -4.1555e+12,
         9.1050e+12,  2.2959e+13], dtype=torch.float64)


In [13]:
# ------------------------------------------------------------------------------ #
mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct)
mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
# ------------------------------------------------------------------------------ #

Residual to conserve the dry air mass [kg]: tensor([-5.5535e+10,  2.3331e+10, -2.7973e+10, -9.3268e+10,  6.8353e+09,
         8.1128e+10, -9.9454e+10, -1.8999e+10,  6.6406e+10,  3.0085e+10,
        -1.0031e+11,  3.0784e+10, -2.1513e+10,  5.6420e+10,  7.4013e+10,
         2.0000e+10,  2.4711e+10,  4.1241e+10,  5.6080e+10, -4.7367e+09,
         5.2450e+10, -7.3481e+10, -1.5247e+10,  1.6357e+11, -1.2895e+11,
         3.8222e+10,  1.7374e+10,  2.8830e+10,  8.9226e+10, -1.6810e+11,
         1.5477e+11,  5.9000e+10], dtype=torch.float64)


## Conservation of moisture

In [14]:
N_seconds = 3600 * 6 # 6 hourly data

precip_batch_flux = precip_batch[:, 1, ...] * RHO_WATER / N_seconds # m/hour --> kg/m^2/s, positive
evapor_batch_flux = evapor_batch[:, 1, ...] * RHO_WATER / N_seconds # kg/m^2/s, negative

precip_batch_correct = precip_batch_flux.clone()

# pre-compute TWC
TWC = physics_core.total_column_water(q_batch_correct)
dTWC_dt = (TWC[:, 1, ...] - TWC[:, 0, ...]) / N_seconds # kg/m^2/s
TWC_sum = physics_core.weighted_sum(dTWC_dt, axis=(-2, -1)) # kg/s

# pre-compute evaporation
E_sum = physics_core.weighted_sum(evapor_batch_flux, axis=(-2, -1)) # kg/s

correction_cycle_num = 1

for i in range(correction_cycle_num):
    P_sum = physics_core.weighted_sum(precip_batch_correct, axis=(-2, -1)) # kg/s
    residual = -TWC_sum - E_sum - P_sum # kg/s

    # ------------------------------------------------------------------------------ #
    print('Residual to conserve moisture budge [kg/s]: {}'.format(residual))
    # ------------------------------------------------------------------------------ #
    
    P_correct = P_sum + residual # kg/s
    P_correct_ratio = (P_sum + residual) / P_sum
    P_correct_ratio = P_correct_ratio.unsqueeze(-1).unsqueeze(-1)
    precip_batch_correct = precip_batch_correct * P_correct_ratio

Residual to conserve moisture budge [kg/s]: tensor([-8.4651e+08, -3.4187e+08, -1.8498e+09, -1.6545e+09, -7.3230e+08,
         9.6724e+07, -8.0649e+08, -9.7128e+08,  9.2233e+07,  4.4473e+08,
        -6.3185e+08, -8.3198e+08, -1.7069e+08,  7.5317e+07, -4.8904e+08,
        -1.0326e+08,  6.6847e+08,  9.9884e+08,  3.3981e+08, -1.5201e+08,
         9.2791e+08,  1.0165e+09, -4.5664e+08, -1.1849e+09, -6.0900e+08,
        -1.8667e+08, -1.5532e+09, -1.6841e+09, -1.1457e+09, -7.0225e+08,
        -1.4512e+09, -1.7518e+09], dtype=torch.float64)


In [15]:
# ------------------------------------------------------------------------------ #
P_sum = physics_core.weighted_sum(precip_batch_correct, axis=(-2, -1)) # kg/s
residual = -TWC_sum - E_sum - P_sum # kg/s
print('Residual to conserve moisture budge [kg/s]: {}'.format(residual))
# ------------------------------------------------------------------------------ #

Residual to conserve moisture budge [kg/s]: tensor([ 1.9073e-06, -1.9073e-06,  3.8147e-06, -1.9073e-06,  1.9073e-06,
         1.9073e-06,  3.8147e-06,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         3.8147e-06, -1.9073e-06,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  1.9073e-06,  3.8147e-06,  0.0000e+00,  0.0000e+00,
         0.0000e+00, -3.8147e-06,  0.0000e+00, -1.9073e-06,  0.0000e+00,
         1.9073e-06, -1.9073e-06, -1.9073e-06,  0.0000e+00, -1.9073e-06,
         1.9073e-06,  0.0000e+00], dtype=torch.float64)


### Conservation of energy

In [16]:
# C_p (batch, time, level, lat, lon)
C_p = (1 - q_batch_correct) * CP_DRY + q_batch_correct * CP_VAPOR
# kinetic energy (batch, time, level, lat, lon)
ken = 0.5 * (u_batch ** 2 + v_batch ** 2)

# initialize T_correct
T_batch_correct = T_batch.clone()

# layer-wise atmospheric energy, but without thermal energy 
# (batch, time, level, lat, lon)
E_qgk = LH_WATER * q_batch_correct + GPH_surf_batch + ken

# TOA net energy flux (batch, time, lat, lon)
R_T = (TOA_net_batch + OLR_batch) / N_seconds
R_T = R_T[:, 1, :, :]
# R_T global sum
R_T_sum = physics_core.weighted_sum(R_T, axis=(-2, -1))

# surface net energy flux (batch, time, lat, lon)
F_S = (R_short_batch + R_long_batch + LH_batch + SH_batch) / N_seconds
F_S = F_S[:, 1, :, :]  # Extract time index 1
# F_S global sum
F_S_sum = physics_core.weighted_sum(F_S, axis=(-2, -1))

correction_cycle_num = 1

for i in range(correction_cycle_num):

    # layer-wise atmospheric energy (sensible heat + others)
    #  (batch, time, level, lat, lon)
    E_level = C_p * T_batch_correct + E_qgk

    # total atmospheric energy (TE) of an air column
    # (batch, time, lat, lon)
    TE = physics_core.integral(E_level) / GRAVITY

    # ---------------------------------------------------------------------------- #
    # tendency of TE (batch, lat, lon)
    dTE_dt = (TE[:, 1, :, :] - TE[:, 0, :, :]) / N_seconds
    # global sum of TE tendency (batch,)
    dTE_sum = physics_core.weighted_sum(dTE_dt, axis=(1, 2), keepdims=False)
    # compute the residual (batch,)
    delta_dTE_sum = (R_T_sum - F_S_sum) - dTE_sum
    print('Residual to conserve energy budget [Watts]: {}'.format(delta_dTE_sum))
    # ---------------------------------------------------------------------------- #

    # TE at t0 and t1 (batch,)
    total_weighted_TE_t0 = physics_core.weighted_sum(TE[:, 0, :, :], axis=(-2, -1)) 
    total_weighted_TE_t1 = physics_core.weighted_sum(TE[:, 1, :, :], axis=(-2, -1))

    # calculate the correction ratio for E_t1 (batch,) --> (batch, 1, 1, 1)
    E_correct_ratio = (N_seconds * (R_T_sum - F_S_sum) + total_weighted_TE_t0) / total_weighted_TE_t1
    E_correct_ratio = E_correct_ratio.view(-1, 1, 1, 1)

    # Apply the correction to layer-wise atmospheric energy at t1
    # (batch, level, lat, lon)
    E_t1_correct = E_level[:, 1, :, :, :] * E_correct_ratio

    # barotropic correction of T at t1
    T_batch_correct[:, 1, :, :, :] = (E_t1_correct - E_qgk[:, 1, :, :, :]) / C_p[:, 1, :, :, :]

Residual to conserve energy budget [Watts]: tensor([-2.1879e+15, -1.1417e+16, -1.1753e+15, -2.4352e+15, -3.2384e+15,
        -9.8366e+15,  9.0865e+14, -5.5544e+15,  5.2073e+14, -8.9519e+15,
         2.0783e+15, -5.0938e+15, -1.5662e+14, -5.8486e+15,  2.4310e+15,
        -3.4802e+15,  8.4460e+14, -1.1415e+16,  4.7944e+15,  1.1075e+15,
         1.1121e+14, -1.2772e+16,  1.1447e+15, -5.1273e+15, -2.2700e+15,
        -1.6989e+16,  1.7992e+14, -5.3378e+15, -3.7386e+15, -1.5961e+16,
        -2.1321e+14, -4.7991e+15], dtype=torch.float64)


In [17]:
# ---------------------------------------------------------------------------- #
E_level = C_p * T_batch_correct + E_qgk
TE = physics_core.integral(E_level) / GRAVITY
dTE_dt = (TE[:, 1, :, :] - TE[:, 0, :, :]) / N_seconds
dTE_sum = physics_core.weighted_sum(dTE_dt, axis=(-2, -1), keepdims=False)
energy_residual = dTE_sum - (R_T_sum - F_S_sum)
print('Residual to conserve energy budget [Watts]: {}'.format(energy_residual))
# ---------------------------------------------------------------------------- #

Residual to conserve energy budget [Watts]: tensor([-2.1642e+09, -9.9808e+08, -4.5335e+09,  1.7122e+09,  2.0903e+09,
         2.6986e+08, -8.8201e+09, -1.1782e+08, -1.2469e+10, -1.8253e+09,
         8.5489e+08, -1.4252e+08,  1.7130e+10, -4.8771e+07, -1.9931e+09,
         7.0583e+08, -1.3315e+09, -1.7214e+08, -4.7963e+08,  3.3855e+09,
        -1.3779e+10, -1.9998e+07,  5.8746e+09,  4.1342e+08, -2.5294e+09,
         9.9548e+08, -1.4509e+10, -2.1733e+09, -2.3611e+08,  6.2275e+08,
        -2.3722e+10,  1.4820e+09], dtype=torch.float64)
