# Global-scale atmospheric mass and energy conservations on ERA5 pressure level data

In [1]:
import torch
import numpy as np
import xarray as xr

In [2]:
from typing import Dict, Any, Optional

### Load data

ERA5 pressure level data:

* `/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/upper_air/*.zarr`
* `/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/surf/*.zarr`
* `/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/accum/*.zarr`

In [3]:
base_dir = '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/'
filename = base_dir + 'all_in_one/ERA5_plevel_1deg_6h_1993_bilinear.zarr'

ds_surf = xr.open_zarr(filename)
ds_accum = xr.open_zarr(filename)
ds_upper = xr.open_zarr(filename)
ds_static = xr.open_zarr(base_dir + 'static/ERA5_plevel_1deg_6h_static.zarr')


# base_dir = '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/'
# ds_surf = xr.open_zarr(base_dir + 'surf/ERA5_plevel_6h_surf_1979.zarr')
# ds_accum = xr.open_zarr(base_dir + 'accum/ERA5_plevel_6h_accum_1979.zarr')
# # ds_upper = xr.open_zarr(base_dir + 'upper_air/ERA5_plevel_6h_upper_air_1979.zarr')
# ds_upper = xr.open_zarr(base_dir + 'upper_subset/ERA5_subset_6h_upper_air_1979.zarr')
# ds_static = xr.open_zarr(base_dir + 'static/ERA5_plevel_6h_static.zarr')

In [4]:
x = ds_surf['longitude']
y = ds_surf['latitude']

lon, lat = np.meshgrid(x, y)
level_p = 100*np.array(ds_upper['level'])
tensor_shape = (len(level_p),) + lon.shape
# level_p = 100*np.array([1, 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000])

### Convert data to `torch.Tensor`

In [5]:
batch_size = 64
target_shape_4D = (batch_size, 2)+tensor_shape
target_shape_3D = (batch_size, 2,)+tensor_shape[1:]

t_slice = np.arange(batch_size+1)

In [6]:
# (batch, time, level, lat, lon) version
def time_series_to_batch(q, target_shape):
    q_batch = torch.as_strided(
        q, size=target_shape, 
        stride=(q.stride(0), q.stride(0), *q.stride()[1:]))
    return q_batch

In [7]:
q = torch.from_numpy(np.array(ds_upper['Q'].isel(time=t_slice))) # kg/kg
T = torch.from_numpy(np.array(ds_upper['T'].isel(time=t_slice))) # kg/kg
u = torch.from_numpy(np.array(ds_upper['U'].isel(time=t_slice))) # m/s
v = torch.from_numpy(np.array(ds_upper['V'].isel(time=t_slice)))
precip = torch.from_numpy(np.array(ds_accum['total_precipitation'].isel(time=t_slice)))
evapor = torch.from_numpy(np.array(ds_accum['evaporation'].isel(time=t_slice)))

GPH_surf = torch.from_numpy(np.array(ds_static['geopotential_at_surface'])) # J/m2
TOA_net = torch.from_numpy(np.array(ds_accum['top_net_solar_radiation'].isel(time=t_slice))) # J/m2
OLR = torch.from_numpy(np.array(ds_accum['top_net_thermal_radiation'].isel(time=t_slice))) # J/m2
R_short = torch.from_numpy(np.array(ds_accum['surface_net_solar_radiation'].isel(time=t_slice))) # J/m2
R_long = torch.from_numpy(np.array(ds_accum['surface_net_thermal_radiation'].isel(time=t_slice))) # J/m2
LH = torch.from_numpy(np.array(ds_accum['surface_latent_heat_flux'].isel(time=t_slice))) # J/m2
SH = torch.from_numpy(np.array(ds_accum['surface_sensible_heat_flux'].isel(time=t_slice))) # J/m2

In [8]:
q_batch = time_series_to_batch(q, target_shape_4D).permute(0, 2, 1, 3, 4)
T_batch = time_series_to_batch(T, target_shape_4D).permute(0, 2, 1, 3, 4)
u_batch = time_series_to_batch(u, target_shape_4D).permute(0, 2, 1, 3, 4)
v_batch = time_series_to_batch(v, target_shape_4D).permute(0, 2, 1, 3, 4)
precip_batch = time_series_to_batch(precip, target_shape_3D)
evapor_batch = time_series_to_batch(evapor, target_shape_3D)

GPH_surf_batch = GPH_surf.unsqueeze(0).unsqueeze(0).unsqueeze(0)
TOA_net_batch = time_series_to_batch(TOA_net, target_shape_3D)
OLR_batch = time_series_to_batch(OLR, target_shape_3D)
R_short_batch = time_series_to_batch(R_short, target_shape_3D)
R_long_batch = time_series_to_batch(R_long, target_shape_3D)
LH_batch = time_series_to_batch(LH, target_shape_3D)
SH_batch = time_series_to_batch(SH, target_shape_3D)

In [9]:
longitude = torch.from_numpy(lon)
latitude = torch.from_numpy(lat)
upper_air_pressure = torch.from_numpy(level_p)

### `credit.physics_core` pressure level class

In [10]:
GRAVITY = 9.80665
RHO_WATER = 1000.0 # kg/m^3
RAD_EARTH = 6371000 # m
LH_WATER = 2.26e6  # J/kg
CP_DRY = 1005 # J/kg K
CP_VAPOR = 1846 # J/kg K

class physics_pressure_level:
    '''
    Pressure level physics

    Attributes:
        upper_air_pressure (torch.Tensor): pressure levels in Pa.
        lon (torch.Tensor): longitude in degrees.
        lat (torch.Tensor): latitude in degrees.
        pressure_thickness (torch.Tensor): pressure thickness between levels.
        dx, dy (torch.Tensor): grid spacings in longitude and latitude.
        area (torch.Tensor): area of grid cells.
        integral (function): vertical integration method (midpoint or trapezoidal).
    '''
    
    def __init__(self,
                 lon: torch.Tensor,
                 lat: torch.Tensor,
                 upper_air_pressure: torch.Tensor,
                 midpoint: bool = False):
        '''
        Initialize the class with longitude, latitude, and pressure levels.

        All inputs must be in the same torch device.

        Full order of dimensions:  (batch, time, level, latitude, longitude)
        
        Args:
            lon (torch.Tensor): Longitude in degrees.
            lat (torch.Tensor): Latitude in degrees.
            upper_air_pressure (torch.Tensor): Pressure levels in Pa.
            midpoint (bool): True if vertical level quantities are midpoint values
                      otherwise False
            
        '''
        self.lon = lon
        self.lat = lat
        self.upper_air_pressure = upper_air_pressure
        
        # ========================================================================= #
        # compute pressure level thickness
        self.pressure_thickness = self.upper_air_pressure.diff(dim=-1)
        
        # # ========================================================================= #
        # # compute grid spacings
        # lat_rad = torch.deg2rad(self.lat)
        # lon_rad = torch.deg2rad(self.lon)
        # self.dy = torch.gradient(lat_rad * RAD_EARTH, dim=0)[0]
        # self.dx = torch.gradient(lon_rad * RAD_EARTH, dim=1)[0] * torch.cos(lat_rad)

        # ========================================================================= #
        # compute gtid area
        # area = R^2 * d_sin(lat) * d_lon
        lat_rad = torch.deg2rad(self.lat)
        lon_rad = torch.deg2rad(self.lon)
        sin_lat_rad = torch.sin(lat_rad)
        d_phi = torch.gradient(sin_lat_rad, dim=0, edge_order=2)[0]
        d_lambda = torch.gradient(lon_rad, dim=1, edge_order=2)[0]
        d_lambda = (d_lambda + torch.pi) % (2 * torch.pi) - torch.pi
        self.area = torch.abs(RAD_EARTH**2 * d_phi * d_lambda)
        
        # ========================================================================== #
        # vertical integration method
        if midpoint:
            self.integral = self.pressure_integral_midpoint
            self.integral_sliced = self.pressure_integral_midpoint_sliced
        else:
            self.integral = self.pressure_integral_trapz
            self.integral_sliced = self.pressure_integral_trapz_sliced
            
    def pressure_integral_midpoint(self, q_mid: torch.Tensor) -> torch.Tensor:
        '''
        Compute the pressure level integral of a given quantity; assuming its mid point
        values are pre-computed
        
        Args:
            q_mid: the quantity with dims of (batch_size, time, level-1, latitude, longitude)
    
        Returns:
            Pressure level integrals of q
        '''
        num_dims = len(q_mid.shape)
        
        if num_dims == 5:  # (batch_size, level, time, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q.shape}")
        
        return q_trapz

    def pressure_integral_midpoint_sliced(self, 
                                          q_mid: torch.Tensor,                                        
                                          ind_start: int,
                                          ind_end: int) -> torch.Tensor:
        '''
        As in `pressure_integral_midpoint`, but supports pressure level indexing,
        so it can calculate integrals of a subset of levels
        '''
        num_dims = len(q_mid.shape)

        delta_p = self.pressure_thickness[ind_start:ind_end]
        
        if num_dims == 5:  # (batch_size, time, level, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            q_mid = q_mid[:, ind_start:ind_end, ...]
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_mid = q_mid[:, ind_start:ind_end, ...]
            q_area = q_mid * delta_p  # Trapezoidal rule
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = delta_p.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_mid = q_mid[ind_start:ind_end, ...]
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q.shape}")
        
        return q_trapz
        
    def pressure_integral_trapz(self, q: torch.Tensor) -> torch.Tensor:
        '''
        Compute the pressure level integral of a given quantity using the trapezoidal rule.
        
        Args:
            q: the quantity with dims of (batch_size, time, level, latitude, longitude)
    
        Returns:
            Pressure level integrals of q
        '''
        num_dims = len(q.shape)
        
        if num_dims == 5:  # (batch_size, level, time, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            q_area = 0.5 * (q[:, :-1, :, :, :] + q[:, 1:, :, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_area = 0.5 * (q[:, :-1, :, :] + q[:, 1:, :, :]) * delta_p  # Trapezoidal rule
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = self.pressure_thickness.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_area = 0.5 * (q[:-1, :, :] + q[1:, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q.shape}")
        
        return q_trapz

    def pressure_integral_trapz_sliced(self, 
                                       q: torch.Tensor,
                                       ind_start: int,
                                       ind_end: int) -> torch.Tensor:
        '''
        As in `pressure_integral_trapz`, but supports pressure level indexing,
        so it can calculate integrals of a subset of levels
        '''
        num_dims = len(q.shape)

        delta_p = self.upper_air_pressure[ind_start:ind_end].diff(dim=-1)
        
        if num_dims == 5:  # (batch_size, level, time, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            q_slice = q[:, ind_start:ind_end, ...]
            q_area = 0.5 * (q_slice[:, :-1, :, :, :] + q_slice[:, 1:, :, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_slice = q[:, ind_start:ind_end, ...]
            q_area = 0.5 * (q_slice[:, :-1, :, :] + q_slice[:, 1:, :, :]) * delta_p  # Trapezoidal rule
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = delta_p.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_slice = q[ind_start:ind_end, ...]
            q_area = 0.5 * (q_slice[:-1, :, :] + q_slice[1:, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q.shape}")
        
        return q_trapz
    

    def weighted_sum(self,
                     q: torch.Tensor, 
                     axis: Dict[tuple, None] = None, 
                     keepdims: bool = False) -> torch.Tensor:
        '''
        Compute the weighted sum of a given quantity for PyTorch tensors.
        
        Args:
            data: the quantity to be summed (PyTorch tensor)
            axis: dims to compute the sum (can be int or tuple of ints)
            keepdims: whether to keep the reduced dimensions or not
    
        Returns:
            Weighted sum (PyTorch tensor)
        '''
        q_w = q * self.area
        q_sum = torch.sum(q_w, dim=axis, keepdim=keepdims)
        return q_sum

    def total_dry_air_mass(self, 
                           q: torch.Tensor) -> torch.Tensor:
        '''
        Compute the total mass of dry air over the entire globe [kg]
        '''
        mass_dry_per_area = self.integral(1-q) / GRAVITY # kg/m^2
        # weighted sum on latitude and longitude dimensions
        mass_dry_sum = self.weighted_sum(mass_dry_per_area, axis=(-2, -1)) # kg
        
        return mass_dry_sum

    def total_column_water(self, 
                           q: torch.Tensor) -> torch.Tensor:
        '''
        Compute total column water (TCW) per air column [kg/m2]
        '''
        TWC = self.integral(q) / GRAVITY # kg/m^2
        
        return TWC

In [11]:
physics_core = physics_pressure_level(longitude, latitude, upper_air_pressure, midpoint=False)

## Conservation of total dry air mass

In [12]:
ind_fix = 24
N_levels = len(upper_air_pressure)

In [13]:
q_batch_correct = q_batch.clone()

mass_dry_per_area_hold = physics_core.integral_sliced(1-q_batch_correct, 0, ind_fix) / GRAVITY
mass_dry_sum_hold = physics_core.weighted_sum(mass_dry_per_area_hold, axis=(-2, -1))

mass_dry_per_area_fix = physics_core.integral_sliced(1-q_batch_correct, ind_fix-1, N_levels) / GRAVITY
mass_dry_sum_fix = physics_core.weighted_sum(mass_dry_per_area_fix, axis=(-2, -1))

mass_dry_sum = mass_dry_sum_hold + mass_dry_sum_fix
# ------------------------------------------------------------------------------ #
# check residual term
mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
# ------------------------------------------------------------------------------ #

#mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct)

mass_residual_on_fix = mass_dry_sum[:, 0] - mass_dry_sum_hold[:, 1]

# Compute the ratio
q_correct_ratio = mass_residual_on_fix / mass_dry_sum_fix[:, 1]
q_correct_ratio = q_correct_ratio.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

q_batch_correct[:, ind_fix-1:, 1, ...] = 1 - (1 - q_batch_correct[:, ind_fix-1:, 1, ...]) * q_correct_ratio

Residual to conserve the dry air mass [kg]: tensor([-1.1399e+13, -6.1822e+12, -2.2353e+13, -1.0393e+13, -1.6828e+13,
         5.1657e+11, -1.7665e+13, -3.6353e+12, -6.0755e+12,  9.3962e+12,
        -1.2338e+13, -3.5996e+13, -1.3081e+13, -1.5854e+13, -9.0417e+12,
         1.9265e+13, -2.7896e+13,  1.3063e+13, -2.4751e+13,  5.3644e+12,
        -9.9773e+12,  1.6798e+13,  2.8719e+12,  7.7928e+12,  1.5677e+13,
         4.3084e+13,  1.6181e+13, -5.7406e+12,  2.9898e+13,  3.4587e+13,
         6.4398e+11, -8.1429e+12,  1.7020e+13,  2.4333e+13, -5.4689e+12,
         1.3990e+13,  3.6388e+13,  4.6260e+13,  5.1963e+12, -7.7809e+12,
         1.7242e+12,  2.8496e+13, -1.3586e+12,  1.0690e+13,  1.5660e+13,
         3.6193e+13,  1.0461e+11, -1.6270e+13,  1.4462e+13,  4.7283e+13,
         1.2494e+13,  6.5931e+12,  5.9314e+12,  2.6126e+13, -9.1448e+12,
         1.3154e+13, -9.1608e+11,  1.3272e+13, -1.7012e+13,  2.2345e+12,
        -1.6404e+13,  1.6785e+12, -8.8135e+12,  1.8329e+13],
       dtype=torch.

In [14]:
# ------------------------------------------------------------------------------ #
mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct)
mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
# ------------------------------------------------------------------------------ #

Residual to conserve the dry air mass [kg]: tensor([ 6.9509e+11,  4.0885e+11,  1.4078e+12,  6.4591e+11,  1.0496e+12,
        -1.0420e+10,  1.0934e+12,  2.1636e+11,  3.5707e+11, -5.8398e+11,
         7.8352e+11,  2.2627e+12,  8.0991e+11,  9.6994e+11,  5.5208e+11,
        -1.2107e+12,  1.7561e+12, -8.3651e+11,  1.5445e+12, -2.8516e+11,
         6.2872e+11, -1.0571e+12, -1.7373e+11, -4.3924e+11, -1.0149e+12,
        -2.7005e+12, -1.0249e+12,  3.4124e+11, -1.8657e+12, -2.1568e+12,
         8.4724e+09,  5.2184e+11, -1.0760e+12, -1.5072e+12,  3.2792e+11,
        -8.7736e+11, -2.2855e+12, -2.9114e+12, -3.3312e+11,  4.6299e+11,
        -1.2021e+11, -1.7943e+12,  9.0686e+10, -6.7237e+11, -9.9338e+11,
        -2.2831e+12, -1.9362e+10,  1.0396e+12, -8.9415e+11, -2.9445e+12,
        -7.7169e+11, -4.0491e+11, -3.8082e+11, -1.6561e+12,  5.7175e+11,
        -8.1994e+11, -1.7098e+10, -8.4203e+11,  1.0766e+12, -1.3691e+11,
         1.0470e+12, -1.2581e+11,  5.5494e+11, -1.1431e+12],
       dtype=torch.

**Old**

In [15]:
# q_batch_correct = q_batch.clone()

# correction_cycle_num = 1 # iterative to handle numrical precision

# for i in range(correction_cycle_num):
#     mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct)
    
#     # ------------------------------------------------------------------------------ #
#     # check residual term
#     mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
#     print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
#     # ------------------------------------------------------------------------------ #
    
#     q_correct_ratio = mass_dry_sum[:, 0] / mass_dry_sum[:, 1]
#     q_correct_ratio = q_correct_ratio.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
#     q_batch_correct[:, :, 1, ...] = 1 - (1 - q_batch_correct[:, :, 1, ...]) * q_correct_ratio

In [16]:
# ------------------------------------------------------------------------------ #
mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct)
# mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
# ------------------------------------------------------------------------------ #

Residual to conserve the dry air mass [kg]: tensor([ 6.9509e+11,  4.0885e+11,  1.4078e+12,  6.4591e+11,  1.0496e+12,
        -1.0420e+10,  1.0934e+12,  2.1636e+11,  3.5707e+11, -5.8398e+11,
         7.8352e+11,  2.2627e+12,  8.0991e+11,  9.6994e+11,  5.5208e+11,
        -1.2107e+12,  1.7561e+12, -8.3651e+11,  1.5445e+12, -2.8516e+11,
         6.2872e+11, -1.0571e+12, -1.7373e+11, -4.3924e+11, -1.0149e+12,
        -2.7005e+12, -1.0249e+12,  3.4124e+11, -1.8657e+12, -2.1568e+12,
         8.4724e+09,  5.2184e+11, -1.0760e+12, -1.5072e+12,  3.2792e+11,
        -8.7736e+11, -2.2855e+12, -2.9114e+12, -3.3312e+11,  4.6299e+11,
        -1.2021e+11, -1.7943e+12,  9.0686e+10, -6.7237e+11, -9.9338e+11,
        -2.2831e+12, -1.9362e+10,  1.0396e+12, -8.9415e+11, -2.9445e+12,
        -7.7169e+11, -4.0491e+11, -3.8082e+11, -1.6561e+12,  5.7175e+11,
        -8.1994e+11, -1.7098e+10, -8.4203e+11,  1.0766e+12, -1.3691e+11,
         1.0470e+12, -1.2581e+11,  5.5494e+11, -1.1431e+12],
       dtype=torch.

## Conservation of moisture

In [17]:
N_seconds = 3600 * 6 # 6 hourly data

precip_batch_flux = precip_batch[:, 1, ...] * RHO_WATER / N_seconds # m/hour --> kg/m^2/s, positive
evapor_batch_flux = evapor_batch[:, 1, ...] * RHO_WATER / N_seconds # kg/m^2/s, negative

precip_batch_correct = precip_batch_flux.clone()

# pre-compute TWC
TWC = physics_core.total_column_water(q_batch_correct)
dTWC_dt = (TWC[:, 1, ...] - TWC[:, 0, ...]) / N_seconds # kg/m^2/s
TWC_sum = physics_core.weighted_sum(dTWC_dt, axis=(-2, -1)) # kg/s

# pre-compute evaporation
E_sum = physics_core.weighted_sum(evapor_batch_flux, axis=(-2, -1)) # kg/s

correction_cycle_num = 1

for i in range(correction_cycle_num):
    P_sum = physics_core.weighted_sum(precip_batch_correct, axis=(-2, -1)) # kg/s
    residual = -TWC_sum - E_sum - P_sum # kg/s

    # ------------------------------------------------------------------------------ #
    print('Residual to conserve moisture budge [kg/s]: {}'.format(residual))
    # ------------------------------------------------------------------------------ #
    
    # P_correct = P_sum + residual # kg/s
    P_correct_ratio = (P_sum + residual) / P_sum
    P_correct_ratio = P_correct_ratio.unsqueeze(-1).unsqueeze(-1)
    precip_batch_correct = precip_batch_correct * P_correct_ratio

Residual to conserve moisture budge [kg/s]: tensor([-1.5168e+08,  8.1725e+08,  7.8938e+07, -4.6144e+08,  1.6145e+08,
         4.0507e+08,  2.4236e+07, -2.6004e+08,  6.1994e+07,  6.2030e+08,
        -1.0645e+08, -4.9454e+08,  6.1004e+08,  5.5018e+08, -6.7023e+08,
        -5.4988e+08,  5.8153e+08,  1.1151e+09,  4.1976e+08, -3.8566e+08,
         5.0965e+08,  4.2488e+07, -6.1936e+08, -1.1281e+09, -5.8318e+08,
        -9.2500e+08, -1.4729e+09, -1.7925e+09, -1.1098e+09, -7.3493e+08,
        -7.4473e+08, -1.0190e+09, -6.5689e+08, -4.1555e+08, -5.8804e+08,
        -1.2679e+09, -1.5663e+09, -9.3553e+08, -9.8787e+08, -1.0895e+09,
        -3.1924e+08, -3.3563e+08, -7.7815e+08, -1.1565e+09, -9.5750e+08,
        -9.4542e+08, -9.0858e+08, -7.8964e+08, -8.9020e+08, -1.2442e+09,
        -1.2491e+09, -1.2804e+09, -7.2070e+08, -4.3521e+08, -4.2332e+08,
        -6.4981e+08, -8.3582e+07,  3.3719e+08, -7.9500e+07, -5.8825e+08,
         2.7331e+07,  5.3509e+08, -3.8185e+08, -1.5456e+09],
       dtype=torch.

In [18]:
# ------------------------------------------------------------------------------ #
P_sum = physics_core.weighted_sum(precip_batch_correct, axis=(-2, -1)) # kg/s
residual = -TWC_sum - E_sum - P_sum # kg/s
print('Residual to conserve moisture budge [kg/s]: {}'.format(residual))
# ------------------------------------------------------------------------------ #

Residual to conserve moisture budge [kg/s]: tensor([ -874.7156,   -90.9392,  -501.8408,   435.4669,   336.4176,   636.4086,
         -732.4830,   575.5671,   450.8901, -1179.6511,   502.2926,  -927.9430,
          646.7278,  -342.5582,   820.5086,    88.2951,  -206.5913, -1568.8612,
          -95.0119,   627.5631,   152.7741,  -497.4650,   122.2431,   514.7252,
          334.7354, -1137.4798,  1284.5080,   123.6681,   414.2711,   398.7937,
          476.8464,  -622.6649,   332.2394, -1204.9807,   -74.0208,   407.3359,
         -982.1229,  -354.3493, -1055.6186,  -769.4539,  -465.1213,  -606.3862,
         -498.8248,   403.5292,   838.3971,   228.1954,  -279.3270,  -198.4896,
          423.5581,  -195.8684,  -393.9427,  -230.5711,  -753.5994,   310.2205,
         -509.0625,   368.2702,  -429.9662,   321.7530,   825.5231,   -89.6560,
          129.3203,   634.6681,   337.3184,  -707.9703], dtype=torch.float64)


In [19]:
(precip_batch_correct - precip_batch_flux).mean()

tensor(-7.4016e-07, dtype=torch.float64)

In [20]:
(precip_batch_correct - precip_batch_flux).max()

tensor(0.0003, dtype=torch.float64)

### Conservation of energy

In [21]:
N_seconds = 3600 * 6 # 6 hourly data

# C_p (batch, time, level, lat, lon)
C_p = (1 - q_batch_correct) * CP_DRY + q_batch_correct * CP_VAPOR
# kinetic energy (batch, time, level, lat, lon)
ken = 0.5 * (u_batch ** 2 + v_batch ** 2)

# initialize T_correct
T_batch_correct = T_batch.clone()

# layer-wise atmospheric energy, but without thermal energy 
# (batch, time, level, lat, lon)
E_qgk = LH_WATER * q_batch_correct + GPH_surf_batch + ken

# TOA net energy flux (batch, time, lat, lon)
R_T = (TOA_net_batch + OLR_batch) / N_seconds
R_T = R_T[:, 1, :, :]
# R_T global sum
R_T_sum = physics_core.weighted_sum(R_T, axis=(-2, -1))

# surface net energy flux (batch, time, lat, lon)
F_S = (R_short_batch + R_long_batch + LH_batch + SH_batch) / N_seconds
F_S = F_S[:, 1, :, :]  # Extract time index 1
# F_S global sum
F_S_sum = physics_core.weighted_sum(F_S, axis=(-2, -1))

correction_cycle_num = 1

for i in range(correction_cycle_num):

    # layer-wise atmospheric energy (sensible heat + others)
    #  (batch, time, level, lat, lon)
    E_level = C_p * T_batch_correct + E_qgk

    # total atmospheric energy (TE) of an air column
    # (batch, time, lat, lon)
    TE = physics_core.integral(E_level) / GRAVITY

    # ---------------------------------------------------------------------------- #
    # tendency of TE (batch, lat, lon)
    dTE_dt = (TE[:, 1, :, :] - TE[:, 0, :, :]) / N_seconds
    # global sum of TE tendency (batch,)
    dTE_sum = physics_core.weighted_sum(dTE_dt, axis=(1, 2), keepdims=False)
    # compute the residual (batch,)
    delta_dTE_sum = (R_T_sum - F_S_sum) - dTE_sum
    print('Residual to conserve energy budget [Watts]: {}'.format(delta_dTE_sum))
    print('Sources & sinks [Watts]: {}'.format(R_T_sum - F_S_sum))
    print('Tendency [Watts]: {}'.format(dTE_sum))
    # ---------------------------------------------------------------------------- #

    # TE at t0 and t1 (batch,)
    total_weighted_TE_t0 = physics_core.weighted_sum(TE[:, 0, :, :], axis=(-2, -1)) 
    total_weighted_TE_t1 = physics_core.weighted_sum(TE[:, 1, :, :], axis=(-2, -1))

    # calculate the correction ratio for E_t1 (batch,) --> (batch, 1, 1, 1)
    E_correct_ratio = (N_seconds * (R_T_sum - F_S_sum) + total_weighted_TE_t0) / total_weighted_TE_t1
    E_correct_ratio = E_correct_ratio.view(-1, 1, 1, 1)

    # Apply the correction to layer-wise atmospheric energy at t1
    # (batch, level, lat, lon)
    E_t1_correct = E_level[:, :, 1, :, :] * E_correct_ratio

    # barotropic correction of T at t1
    T_batch_correct[:, :, 1, :, :] = (E_t1_correct - E_qgk[:, :, 1, :, :]) / C_p[:, :, 1, :, :]

Residual to conserve energy budget [Watts]: tensor([-7.5009e+14, -1.1777e+16,  2.6745e+15, -4.6578e+15,  1.0571e+15,
        -1.4666e+16,  3.1927e+15, -1.6726e+15, -1.2179e+14, -1.0903e+16,
         2.6925e+15, -4.3855e+15,  1.5747e+15, -1.6774e+16,  2.2277e+15,
        -5.8720e+15,  2.8603e+15, -9.0812e+15,  5.8271e+15, -3.2539e+15,
         1.9158e+15, -9.1749e+15,  2.0047e+15, -4.2040e+15, -1.0004e+15,
        -1.7374e+16, -1.2836e+15, -7.0447e+15, -2.9478e+15, -1.4620e+16,
         1.7221e+15, -5.4929e+15, -2.7841e+15, -1.2605e+16,  2.0757e+15,
        -3.8119e+15, -6.0479e+15, -1.5652e+16,  2.3484e+14, -7.9958e+15,
        -1.9099e+15, -1.2560e+16,  1.6726e+15, -7.9714e+15, -3.2428e+15,
        -1.5958e+16,  1.2185e+15, -7.2020e+15, -2.1522e+15, -1.6731e+16,
        -1.2706e+15, -1.1454e+16, -2.0984e+15, -1.8182e+16,  1.0513e+15,
        -6.9333e+15, -1.3373e+15, -1.3688e+16,  1.7749e+15, -8.2970e+15,
        -2.1178e+13, -1.5914e+16,  6.6188e+14, -8.0890e+15],
       dtype=torch.

In [22]:
# ---------------------------------------------------------------------------- #
E_level = C_p * T_batch_correct + E_qgk
TE = physics_core.integral(E_level) / GRAVITY
dTE_dt = (TE[:, 1, :, :] - TE[:, 0, :, :]) / N_seconds
dTE_sum = physics_core.weighted_sum(dTE_dt, axis=(-2, -1), keepdims=False)
energy_residual = dTE_sum - (R_T_sum - F_S_sum)
print('Residual to conserve energy budget [Watts]: {}'.format(energy_residual))
# ---------------------------------------------------------------------------- #

Residual to conserve energy budget [Watts]: tensor([-1.0645e+10,  1.4958e+09,  7.9249e+08, -4.9176e+09,  5.3336e+09,
        -2.3437e+09,  1.3843e+09,  4.0844e+09,  1.3450e+10, -8.2858e+08,
        -1.4026e+09,  3.4404e+09, -2.1101e+09,  1.6056e+09, -4.6264e+09,
        -3.9229e+09, -1.1526e+09, -1.9281e+08,  4.8714e+09,  1.9771e+09,
        -2.9832e+09,  2.2839e+09, -4.6357e+08,  8.2582e+08, -2.8786e+09,
         6.9375e+09,  5.0283e+09,  9.5705e+08,  9.0029e+07, -5.1233e+08,
        -1.3577e+09, -1.3430e+09, -1.1327e+09,  3.7176e+09,  1.5401e+09,
         9.1192e+09,  2.9654e+09, -1.7639e+08,  1.6451e+10,  1.0168e+09,
         4.7398e+09,  2.2298e+09, -1.3036e+09,  3.5192e+09,  1.0512e+09,
        -6.6120e+08,  1.7799e+09, -4.2629e+09, -1.2831e+09,  4.7185e+09,
         2.6793e+09, -4.3820e+09,  2.9694e+09, -2.1660e+09,  4.1275e+09,
        -2.8971e+09,  1.1264e+09,  7.6727e+09,  7.7292e+08, -1.1620e+09,
        -1.6591e+10, -1.5810e+08,  1.1713e+10, -4.8974e+09],
       dtype=torch.

In [28]:
(T_batch_correct - T_batch).mean()

tensor(-0.0076)

In [29]:
(T_batch_correct - T_batch).abs().max()

tensor(0.0970)

In [30]:
R_T_sum.dtype

torch.float32