# Implement conservation schemes in Pytorch

This notebook moves numpy-based conservation schemes (in other notebooks) to Pytorch.

In [1]:
import torch
import numpy as np
import xarray as xr

In [2]:
from typing import Dict, Any, Optional

### Load data

ERA5 pressure level data:

* `/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/all_in_one/ERA5_plevel_1deg_6h_1993_bilinear.zarr`
* `/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/static/ERA5_plevel_1deg_6h_static.zarr`

In [6]:
base_dir = '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/'
filename = base_dir + 'all_in_one/ERA5_plevel_1deg_6h_1993_conserve.zarr'

ds_surf = xr.open_zarr(filename)
ds_accum = xr.open_zarr(filename)
ds_upper = xr.open_zarr(filename)
ds_static = xr.open_zarr(base_dir + 'static/ERA5_plevel_1deg_6h_conserve_static.zarr')


# base_dir = '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/'
# ds_surf = xr.open_zarr(base_dir + 'surf/ERA5_plevel_6h_surf_1979.zarr')
# ds_accum = xr.open_zarr(base_dir + 'accum/ERA5_plevel_6h_accum_1979.zarr')
# # ds_upper = xr.open_zarr(base_dir + 'upper_air/ERA5_plevel_6h_upper_air_1979.zarr')
# ds_upper = xr.open_zarr(base_dir + 'upper_subset/ERA5_subset_6h_upper_air_1979.zarr')
# ds_static = xr.open_zarr(base_dir + 'static/ERA5_plevel_6h_static.zarr')

In [7]:
x = ds_surf['longitude']
y = ds_surf['latitude']

lon, lat = np.meshgrid(x, y)
level_p = 100*np.array(ds_upper['level'])
tensor_shape = (len(level_p),) + lon.shape
# level_p = 100*np.array([1, 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000])

### Convert data to `torch.Tensor`

In [8]:
batch_size = 64
target_shape_4D = (batch_size, 2)+tensor_shape
target_shape_3D = (batch_size, 2,)+tensor_shape[1:]

t_slice = np.arange(batch_size+1)

In [9]:
# (batch, time, level, lat, lon) version
def time_series_to_batch(q, target_shape):
    q_batch = torch.as_strided(
        q, size=target_shape, 
        stride=(q.stride(0), q.stride(0), *q.stride()[1:]))
    return q_batch

In [10]:
q = torch.from_numpy(np.array(ds_upper['Q'].isel(time=t_slice))) # kg/kg
T = torch.from_numpy(np.array(ds_upper['T'].isel(time=t_slice))) # kg/kg
u = torch.from_numpy(np.array(ds_upper['U'].isel(time=t_slice))) # m/s
v = torch.from_numpy(np.array(ds_upper['V'].isel(time=t_slice)))
precip = torch.from_numpy(np.array(ds_accum['total_precipitation'].isel(time=t_slice)))
evapor = torch.from_numpy(np.array(ds_accum['evaporation'].isel(time=t_slice)))

GPH_surf = torch.from_numpy(np.array(ds_static['geopotential_at_surface'])) # J/m2
TOA_net = torch.from_numpy(np.array(ds_accum['top_net_solar_radiation'].isel(time=t_slice))) # J/m2
OLR = torch.from_numpy(np.array(ds_accum['top_net_thermal_radiation'].isel(time=t_slice))) # J/m2
R_short = torch.from_numpy(np.array(ds_accum['surface_net_solar_radiation'].isel(time=t_slice))) # J/m2
R_long = torch.from_numpy(np.array(ds_accum['surface_net_thermal_radiation'].isel(time=t_slice))) # J/m2
LH = torch.from_numpy(np.array(ds_accum['surface_latent_heat_flux'].isel(time=t_slice))) # J/m2
SH = torch.from_numpy(np.array(ds_accum['surface_sensible_heat_flux'].isel(time=t_slice))) # J/m2

In [11]:
q_batch = time_series_to_batch(q, target_shape_4D).permute(0, 2, 1, 3, 4)
T_batch = time_series_to_batch(T, target_shape_4D).permute(0, 2, 1, 3, 4)
u_batch = time_series_to_batch(u, target_shape_4D).permute(0, 2, 1, 3, 4)
v_batch = time_series_to_batch(v, target_shape_4D).permute(0, 2, 1, 3, 4)
precip_batch = time_series_to_batch(precip, target_shape_3D)
evapor_batch = time_series_to_batch(evapor, target_shape_3D)

GPH_surf_batch = GPH_surf.unsqueeze(0).unsqueeze(0).unsqueeze(0)
TOA_net_batch = time_series_to_batch(TOA_net, target_shape_3D)
OLR_batch = time_series_to_batch(OLR, target_shape_3D)
R_short_batch = time_series_to_batch(R_short, target_shape_3D)
R_long_batch = time_series_to_batch(R_long, target_shape_3D)
LH_batch = time_series_to_batch(LH, target_shape_3D)
SH_batch = time_series_to_batch(SH, target_shape_3D)

In [12]:
longitude = torch.from_numpy(lon)
latitude = torch.from_numpy(lat)
upper_air_pressure = torch.from_numpy(level_p)

### `credit.physics_core` pressure level class

In [13]:
# Earth's radius
RAD_EARTH = 6371000 # m
# ideal gas constant of water vapor
RVGAS = 461.5 # J/kg/K
# ideal gas constant of dry air
RDGAS = 287.05 # J/kg/K
# gravity
GRAVITY = 9.80665 # m/s^2
# density of water
RHO_WATER = 1000.0 # kg/m^3
# latent heat caused by the phase change of water
LH_WATER = 2.26e6  # J/kg
# heat capacity on constant pressure for dry air
CP_DRY = 1005 # J/kg K
# heat capacity on constant pressure for water vapor
CP_VAPOR = 1846 # J/kg K

In [14]:
class physics_pressure_level:
    '''
    Pressure level physics

    Attributes:
        upper_air_pressure (torch.Tensor): pressure levels in Pa.
        lon (torch.Tensor): longitude in degrees.
        lat (torch.Tensor): latitude in degrees.
        pressure_thickness (torch.Tensor): pressure thickness between levels.
        dx, dy (torch.Tensor): grid spacings in longitude and latitude.
        area (torch.Tensor): area of grid cells.
        integral (function): vertical integration method (midpoint or trapezoidal).
    '''
    
    def __init__(self,
                 lon: torch.Tensor,
                 lat: torch.Tensor,
                 upper_air_pressure: torch.Tensor,
                 midpoint: bool = False):
        '''
        Initialize the class with longitude, latitude, and pressure levels.

        All inputs must be in the same torch device.

        Full order of dimensions:  (batch, time, level, latitude, longitude)
        
        Args:
            lon (torch.Tensor): Longitude in degrees.
            lat (torch.Tensor): Latitude in degrees.
            upper_air_pressure (torch.Tensor): Pressure levels in Pa.
            midpoint (bool): True if vertical level quantities are midpoint values
                      otherwise False
            
        '''
        self.lon = lon
        self.lat = lat
        self.upper_air_pressure = upper_air_pressure
        
        # ========================================================================= #
        # compute pressure level thickness
        self.pressure_thickness = self.upper_air_pressure.diff(dim=-1)
        
        # # ========================================================================= #
        # # compute grid spacings
        # lat_rad = torch.deg2rad(self.lat)
        # lon_rad = torch.deg2rad(self.lon)
        # self.dy = torch.gradient(lat_rad * RAD_EARTH, dim=0)[0]
        # self.dx = torch.gradient(lon_rad * RAD_EARTH, dim=1)[0] * torch.cos(lat_rad)

        # ========================================================================= #
        # compute gtid area
        # area = R^2 * d_sin(lat) * d_lon
        lat_rad = torch.deg2rad(self.lat)
        lon_rad = torch.deg2rad(self.lon)
        sin_lat_rad = torch.sin(lat_rad)
        d_phi = torch.gradient(sin_lat_rad, dim=0, edge_order=2)[0]
        d_lambda = torch.gradient(lon_rad, dim=1, edge_order=2)[0]
        d_lambda = (d_lambda + torch.pi) % (2 * torch.pi) - torch.pi
        self.area = torch.abs(RAD_EARTH**2 * d_phi * d_lambda)
        
        # ========================================================================== #
        # vertical integration method
        if midpoint:
            self.integral = self.pressure_integral_midpoint
            self.integral_sliced = self.pressure_integral_midpoint_sliced
        else:
            self.integral = self.pressure_integral_trapz
            self.integral_sliced = self.pressure_integral_trapz_sliced
            
    def pressure_integral_midpoint(self, q_mid: torch.Tensor) -> torch.Tensor:
        '''
        Compute the pressure level integral of a given quantity; assuming its mid point
        values are pre-computed
        
        Args:
            q_mid: the quantity with dims of (batch_size, time, level-1, latitude, longitude)
    
        Returns:
            Pressure level integrals of q
        '''
        num_dims = len(q_mid.shape)
        delta_p = self.pressure_thickness.to(q_mid.device)
        
        if num_dims == 5:  # (batch_size, level, time, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = delta_p.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q_mid.shape}")
        
        return q_trapz

    def pressure_integral_midpoint_sliced(self, 
                                          q_mid: torch.Tensor,                                        
                                          ind_start: int,
                                          ind_end: int) -> torch.Tensor:
        '''
        As in `pressure_integral_midpoint`, but supports pressure level indexing,
        so it can calculate integrals of a subset of levels
        '''
        num_dims = len(q_mid.shape)
        delta_p = self.pressure_thickness[ind_start:ind_end].to(q_mid.device)
        
        if num_dims == 5:  # (batch_size, time, level, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            q_mid = q_mid[:, ind_start:ind_end, ...]
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_mid = q_mid[:, ind_start:ind_end, ...]
            q_area = q_mid * delta_p  # Trapezoidal rule
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = delta_p.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_mid = q_mid[ind_start:ind_end, ...]
            q_area = q_mid * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q_mid.shape}")
        
        return q_trapz
        
    def pressure_integral_trapz(self, q: torch.Tensor) -> torch.Tensor:
        '''
        Compute the pressure level integral of a given quantity using the trapezoidal rule.
        
        Args:
            q: the quantity with dims of (batch_size, time, level, latitude, longitude)
    
        Returns:
            Pressure level integrals of q
        '''
        num_dims = len(q.shape)
        delta_p = self.pressure_thickness.to(q.device)
        
        if num_dims == 5:  # (batch_size, level, time, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            q_area = 0.5 * (q[:, :-1, :, :, :] + q[:, 1:, :, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_area = 0.5 * (q[:, :-1, :, :] + q[:, 1:, :, :]) * delta_p  # Trapezoidal rule
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = delta_p.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_area = 0.5 * (q[:-1, :, :] + q[1:, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q.shape}")
        
        return q_trapz

    def pressure_integral_trapz_sliced(self, 
                                       q: torch.Tensor,
                                       ind_start: int,
                                       ind_end: int) -> torch.Tensor:
        '''
        As in `pressure_integral_trapz`, but supports pressure level indexing,
        so it can calculate integrals of a subset of levels
        '''
        num_dims = len(q.shape)
        delta_p = self.upper_air_pressure[ind_start:ind_end].diff(dim=-1).to(q.device)
        
        if num_dims == 5:  # (batch_size, level, time, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            q_slice = q[:, ind_start:ind_end, ...]
            q_area = 0.5 * (q_slice[:, :-1, :, :, :] + q_slice[:, 1:, :, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 4:  # (batch_size, level, latitude, longitude) or (time, level, latitude, longitude)
            delta_p = delta_p.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            q_slice = q[:, ind_start:ind_end, ...]
            q_area = 0.5 * (q_slice[:, :-1, :, :] + q_slice[:, 1:, :, :]) * delta_p  # Trapezoidal rule
            q_trapz = torch.sum(q_area, dim=1)
        
        elif num_dims == 3:  # (level, latitude, longitude)
            delta_p = delta_p.unsqueeze(-1).unsqueeze(-1)  # Expand for broadcasting
            q_slice = q[ind_start:ind_end, ...]
            q_area = 0.5 * (q_slice[:-1, :, :] + q_slice[1:, :, :]) * delta_p
            q_trapz = torch.sum(q_area, dim=0)
        
        else:
            raise ValueError(f"Unsupported tensor dimensions: {q.shape}")
        
        return q_trapz
    

    def weighted_sum(self,
                     q: torch.Tensor, 
                     axis: Dict[tuple, None] = None, 
                     keepdims: bool = False) -> torch.Tensor:
        '''
        Compute the weighted sum of a given quantity for PyTorch tensors.
        
        Args:
            data: the quantity to be summed (PyTorch tensor)
            axis: dims to compute the sum (can be int or tuple of ints)
            keepdims: whether to keep the reduced dimensions or not
    
        Returns:
            Weighted sum (PyTorch tensor)
        '''
        q_w = q * self.area.to(q.device)
        q_sum = torch.sum(q_w, dim=axis, keepdim=keepdims)
        return q_sum

    def total_dry_air_mass(self, 
                           q: torch.Tensor) -> torch.Tensor:
        '''
        Compute the total mass of dry air over the entire globe [kg]
        '''
        mass_dry_per_area = self.integral(1-q) / GRAVITY # kg/m^2
        # weighted sum on latitude and longitude dimensions
        mass_dry_sum = self.weighted_sum(mass_dry_per_area, axis=(-2, -1)) # kg
        
        return mass_dry_sum

    def total_column_water(self, 
                           q: torch.Tensor) -> torch.Tensor:
        '''
        Compute total column water (TCW) per air column [kg/m2]
        '''
        TWC = self.integral(q) / GRAVITY # kg/m^2
        
        return TWC

In [15]:
physics_core = physics_pressure_level(longitude, latitude, upper_air_pressure, midpoint=False)

## Conservation of total dry air mass

In [16]:
ind_fix = 24
N_levels = len(upper_air_pressure)

In [17]:
q_batch_correct = q_batch.clone()

mass_dry_per_area_hold = physics_core.integral_sliced(1-q_batch_correct, 0, ind_fix) / GRAVITY
mass_dry_sum_hold = physics_core.weighted_sum(mass_dry_per_area_hold, axis=(-2, -1))

mass_dry_per_area_fix = physics_core.integral_sliced(1-q_batch_correct, ind_fix-1, N_levels) / GRAVITY
mass_dry_sum_fix = physics_core.weighted_sum(mass_dry_per_area_fix, axis=(-2, -1))

mass_dry_sum = mass_dry_sum_hold + mass_dry_sum_fix
# ------------------------------------------------------------------------------ #
# check residual term
mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
# ------------------------------------------------------------------------------ #

#mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct)

mass_residual_on_fix = mass_dry_sum[:, 0] - mass_dry_sum_hold[:, 1]

# Compute the ratio
q_correct_ratio = mass_residual_on_fix / mass_dry_sum_fix[:, 1]
q_correct_ratio = q_correct_ratio.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

q_batch_correct[:, ind_fix-1:, 1, ...] = 1 - (1 - q_batch_correct[:, ind_fix-1:, 1, ...]) * q_correct_ratio

Residual to conserve the dry air mass [kg]: tensor([-1.2095e+13, -4.9478e+12, -2.1990e+13, -1.0995e+13, -1.6493e+13,
         1.0995e+12, -1.8692e+13, -3.2985e+12, -5.4976e+12,  8.2463e+12,
        -1.2644e+13, -3.5734e+13, -1.2644e+13, -1.4843e+13, -9.3458e+12,
         1.8692e+13, -2.6938e+13,  1.2095e+13, -2.4739e+13,  5.4976e+12,
        -1.0445e+13,  1.6493e+13,  3.8483e+12,  7.6966e+12,  1.6493e+13,
         4.2331e+13,  1.4843e+13, -4.3980e+12,  3.0237e+13,  3.4635e+13,
         1.0995e+12, -9.3458e+12,  1.7042e+13,  2.3639e+13, -3.2985e+12,
         1.3744e+13,  3.6284e+13,  4.6179e+13,  5.4976e+12, -8.7961e+12,
         2.1990e+12,  2.6938e+13, -1.0995e+12,  1.1545e+13,  1.5393e+13,
         3.7383e+13, -1.6493e+12, -1.4843e+13,  1.3194e+13,  4.7279e+13,
         1.2644e+13,  6.0473e+12,  6.5971e+12,  2.5289e+13, -8.7961e+12,
         1.3744e+13, -5.4976e+11,  1.3194e+13, -1.6493e+13,  2.7488e+12,
        -1.6493e+13,  1.0995e+12, -1.0445e+13,  1.8692e+13])


In [18]:
# ------------------------------------------------------------------------------ #
mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct)
mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
# ------------------------------------------------------------------------------ #

Residual to conserve the dry air mass [kg]: tensor([ 5.4976e+11,  0.0000e+00,  1.6493e+12,  1.6493e+12,  5.4976e+11,
         0.0000e+00,  5.4976e+11,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         5.4976e+11,  2.1990e+12,  5.4976e+11,  1.0995e+12,  5.4976e+11,
        -1.0995e+12,  1.6493e+12, -1.0995e+12,  1.6493e+12,  0.0000e+00,
         0.0000e+00, -5.4976e+11,  0.0000e+00, -5.4976e+11, -1.0995e+12,
        -1.6493e+12, -5.4976e+11,  0.0000e+00, -1.6493e+12, -2.1990e+12,
        -5.4976e+11,  0.0000e+00, -5.4976e+11, -1.0995e+12,  0.0000e+00,
        -1.0995e+12, -1.6493e+12, -2.7488e+12, -5.4976e+11,  1.6493e+12,
        -5.4976e+11, -1.6493e+12, -5.4976e+11, -1.0995e+12, -1.0995e+12,
        -2.1990e+12,  5.4976e+11,  5.4976e+11, -5.4976e+11, -3.2985e+12,
        -5.4976e+11, -5.4976e+11, -1.0995e+12, -1.0995e+12,  0.0000e+00,
        -1.6493e+12,  1.6493e+12, -1.0995e+12,  1.0995e+12, -5.4976e+11,
         5.4976e+11, -5.4976e+11,  1.0995e+12, -1.6493e+12])


In [19]:
(q_batch_correct - q_batch).max()

tensor(2.2739e-05)

**Old**

In [20]:
# q_batch_correct = q_batch.clone()

# correction_cycle_num = 1 # iterative to handle numrical precision

# for i in range(correction_cycle_num):
#     mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct)
    
#     # ------------------------------------------------------------------------------ #
#     # check residual term
#     mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
#     print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
#     # ------------------------------------------------------------------------------ #
    
#     q_correct_ratio = mass_dry_sum[:, 0] / mass_dry_sum[:, 1]
#     q_correct_ratio = q_correct_ratio.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
#     q_batch_correct[:, :, 1, ...] = 1 - (1 - q_batch_correct[:, :, 1, ...]) * q_correct_ratio

## Conservation of moisture

In [21]:
N_seconds = 3600 * 6 # 6 hourly data

precip_batch_flux = precip_batch[:, 1, ...] * RHO_WATER / N_seconds # m/hour --> kg/m^2/s, positive
evapor_batch_flux = evapor_batch[:, 1, ...] * RHO_WATER / N_seconds # kg/m^2/s, negative

precip_batch_correct = precip_batch_flux.clone()

# pre-compute TWC
TWC = physics_core.total_column_water(q_batch_correct)
dTWC_dt = (TWC[:, 1, ...] - TWC[:, 0, ...]) / N_seconds # kg/m^2/s
TWC_sum = physics_core.weighted_sum(dTWC_dt, axis=(-2, -1)) # kg/s

# pre-compute evaporation
E_sum = physics_core.weighted_sum(evapor_batch_flux, axis=(-2, -1)) # kg/s

correction_cycle_num = 1

for i in range(correction_cycle_num):
    P_sum = physics_core.weighted_sum(precip_batch_correct, axis=(-2, -1)) # kg/s
    residual = -TWC_sum - E_sum - P_sum # kg/s

    # ------------------------------------------------------------------------------ #
    print('Residual to conserve moisture budge [kg/s]: {}'.format(residual))
    # ------------------------------------------------------------------------------ #
    
    # P_correct = P_sum + residual # kg/s
    P_correct_ratio = (P_sum + residual) / P_sum
    P_correct_ratio = P_correct_ratio.unsqueeze(-1).unsqueeze(-1)
    precip_batch_correct = precip_batch_correct * P_correct_ratio

Residual to conserve moisture budge [kg/s]: tensor([-1.4254e+08,  7.5257e+08,  9.9556e+07, -4.9864e+08,  2.7928e+07,
         3.9238e+08, -9.3704e+07, -3.2787e+08,  4.7290e+07,  5.2874e+08,
        -1.8125e+08, -4.7772e+08,  5.6629e+08,  5.7078e+08, -5.8838e+08,
        -5.1810e+08,  5.4958e+08,  1.1177e+09,  4.6432e+08, -3.7339e+08,
         4.5476e+08,  8.7583e+06, -6.9696e+08, -1.1930e+09, -4.8718e+08,
        -9.3419e+08, -1.5579e+09, -1.8793e+09, -1.1421e+09, -6.3806e+08,
        -7.0396e+08, -1.0798e+09, -7.7447e+08, -4.5485e+08, -6.7336e+08,
        -1.3868e+09, -1.6368e+09, -8.9995e+08, -9.4251e+08, -1.0389e+09,
        -2.4913e+08, -3.7275e+08, -6.7839e+08, -1.1786e+09, -9.3631e+08,
        -9.6865e+08, -9.7149e+08, -9.4713e+08, -8.3010e+08, -1.2971e+09,
        -1.2306e+09, -1.3645e+09, -6.6267e+08, -4.5411e+08, -3.7219e+08,
        -7.8534e+08, -6.8249e+07,  2.4143e+08, -6.6547e+07, -6.1234e+08,
        -1.6346e+07,  4.8045e+08, -3.4416e+08, -1.3659e+09])


In [22]:
# ------------------------------------------------------------------------------ #
P_sum = physics_core.weighted_sum(precip_batch_correct, axis=(-2, -1)) # kg/s
residual = -TWC_sum - E_sum - P_sum # kg/s
print('Residual to conserve moisture budge [kg/s]: {}'.format(residual))
# ------------------------------------------------------------------------------ #

Residual to conserve moisture budge [kg/s]: tensor([ 3072.,  1024.,     0., -1024.,     0., -1024.,     0.,  2048.,     0.,
        -1024.,     0.,  1024.,  1024.,     0., -2048.,  1024.,  1024.,  2048.,
         2048.,     0.,     0., -2048.,     0., -2048., -1024.,  1024., -1024.,
         2048.,     0.,     0.,  1024.,     0., -2048., -1024.,     0., -2048.,
         1024.,     0.,     0.,  1024., -1024., -1024.,     0., -1024., -1024.,
            0.,  1024.,  2048.,  1024., -1024.,     0., -1024.,  2048.,     0.,
            0., -2048., -1024.,     0.,     0.,  1024.,     0.,  1024.,     0.,
        -1024.])


In [23]:
(precip_batch_correct - precip_batch_flux).mean()

tensor(-7.6981e-07)

In [24]:
(precip_batch_correct - precip_batch_flux).max()

tensor(0.0002)

### Conservation of energy

In [46]:
TOA_net_batch.shape

torch.Size([64, 2, 181, 360])

In [45]:
N_seconds = 3600 * 6 # 6 hourly data

# C_p (batch, time, level, lat, lon)
C_p = (1 - q_batch_correct) * CP_DRY + q_batch_correct * CP_VAPOR
# kinetic energy (batch, time, level, lat, lon)
ken = 0.5 * (u_batch ** 2 + v_batch ** 2)

# initialize T_correct
T_batch_correct = T_batch.clone()

# layer-wise atmospheric energy, but without thermal energy 
# (batch, time, level, lat, lon)
E_qgk = LH_WATER * q_batch_correct + GPH_surf_batch + ken

# TOA net energy flux (batch, time, lat, lon)
R_T = (TOA_net_batch + OLR_batch) / N_seconds
R_T = R_T[:, 1, :, :]
# R_T global sum
R_T_sum = physics_core.weighted_sum(R_T, axis=(-2, -1))

# surface net energy flux (batch, time, lat, lon)
F_S = (R_short_batch + R_long_batch + LH_batch + SH_batch) / N_seconds
F_S = F_S[:, 1, :, :]  # Extract time index 1
# F_S global sum
F_S_sum = physics_core.weighted_sum(F_S, axis=(-2, -1))

correction_cycle_num = 1

for i in range(correction_cycle_num):

    # layer-wise atmospheric energy (sensible heat + others)
    #  (batch, time, level, lat, lon)
    E_level = C_p * T_batch_correct + E_qgk

    # total atmospheric energy (TE) of an air column
    # (batch, time, lat, lon)
    TE = physics_core.integral(E_level) / GRAVITY

    # ---------------------------------------------------------------------------- #
    # tendency of TE (batch, lat, lon)
    dTE_dt = (TE[:, 1, :, :] - TE[:, 0, :, :]) / N_seconds
    # global sum of TE tendency (batch,)
    dTE_sum = physics_core.weighted_sum(dTE_dt, axis=(1, 2), keepdims=False)
    # compute the residual (batch,)
    delta_dTE_sum = R_T_sum - F_S_sum - dTE_sum
    print('Residual to conserve energy budget [Watts]: {}'.format(delta_dTE_sum))
    print('Sources & sinks [Watts]: {}'.format(R_T_sum - F_S_sum))
    print('Tendency [Watts]: {}'.format(dTE_sum))
    # ---------------------------------------------------------------------------- #

    # TE at t0 and t1 (batch,)
    total_weighted_TE_t0 = physics_core.weighted_sum(TE[:, 0, :, :], axis=(-2, -1)) 
    total_weighted_TE_t1 = physics_core.weighted_sum(TE[:, 1, :, :], axis=(-2, -1))

    # calculate the correction ratio for E_t1 (batch,) --> (batch, 1, 1, 1)
    E_correct_ratio = (N_seconds * (R_T_sum - F_S_sum) + total_weighted_TE_t0) / total_weighted_TE_t1
    E_correct_ratio = E_correct_ratio.view(-1, 1, 1, 1)

    # Apply the correction to layer-wise atmospheric energy at t1
    # (batch, level, lat, lon)
    E_t1_correct = E_level[:, :, 1, :, :] * E_correct_ratio

    # barotropic correction of T at t1
    T_batch_correct[:, :, 1, :, :] = (E_t1_correct - E_qgk[:, :, 1, :, :]) / C_p[:, :, 1, :, :]

Residual to conserve energy budget [Watts]: tensor([-5.2883e+15, -1.7804e+16, -2.5578e+15, -8.3312e+15, -5.1158e+15,
        -2.0988e+16, -1.2584e+15, -5.8478e+15, -7.5238e+15, -1.7156e+16,
        -2.1534e+15, -1.1279e+16, -8.0018e+15, -2.4406e+16, -4.5671e+15,
        -1.3407e+16, -5.8175e+15, -1.8521e+16, -3.8354e+15, -1.1517e+16,
        -6.4357e+15, -1.6664e+16, -8.6003e+15, -1.3433e+16, -1.1580e+16,
        -2.5848e+16, -1.0536e+16, -1.5758e+16, -1.3896e+16, -2.2500e+16,
        -4.8670e+15, -1.4867e+16, -1.3130e+16, -1.8924e+16, -3.7133e+15,
        -1.6368e+16, -1.6980e+16, -2.1180e+16, -7.5134e+15, -2.2942e+16,
        -1.5146e+16, -1.8013e+16, -7.9557e+15, -2.3404e+16, -1.6152e+16,
        -2.0337e+16, -9.2454e+15, -2.2301e+16, -1.4037e+16, -2.2594e+16,
        -1.3392e+16, -2.7264e+16, -1.2570e+16, -2.4419e+16, -1.2373e+16,
        -2.1940e+16, -1.1714e+16, -1.9049e+16, -9.4465e+15, -2.0854e+16,
        -8.1213e+15, -2.2210e+16, -1.0753e+16, -1.9913e+16])
Sources & sinks [Wa

In [42]:
# ---------------------------------------------------------------------------- #
E_level = C_p * T_batch_correct + E_qgk
TE = physics_core.integral(E_level) / GRAVITY
dTE_dt = (TE[:, 1, :, :] - TE[:, 0, :, :]) / N_seconds
dTE_sum = physics_core.weighted_sum(dTE_dt, axis=(-2, -1), keepdims=False)
energy_residual = R_T_sum - F_S_sum - dTE_sum
print('Residual to conserve energy budget [Watts]: {}'.format(energy_residual))
# ---------------------------------------------------------------------------- #

Residual to conserve energy budget [Watts]: tensor([ 4.0856e+12, -5.4585e+12, -8.2251e+12,  2.8562e+12,  3.4532e+12,
         2.3757e+10,  1.7193e+11, -3.4736e+12, -3.2078e+12, -1.9209e+12,
        -4.9642e+12, -4.2896e+11,  2.1684e+12, -2.3946e+12,  1.4678e+12,
        -8.4106e+12,  4.7229e+12, -5.0702e+12,  3.2921e+12,  2.9013e+12,
         1.5698e+12,  7.0156e+12, -6.0409e+12,  6.2127e+12,  6.7109e+11,
        -8.7837e+12,  4.5607e+11,  7.9457e+11, -1.0329e+12, -4.4775e+11,
         8.4568e+12, -3.3136e+12,  3.7581e+11, -4.0332e+12, -3.8523e+12,
         6.9149e+11, -4.4147e+12, -5.9420e+12, -4.9922e+12, -4.6278e+11,
         3.7919e+12,  3.6335e+12, -8.9861e+12,  5.6264e+12,  1.2133e+12,
         5.9120e+12, -2.6554e+12,  2.5995e+12, -2.2602e+12, -6.3472e+12,
         4.7953e+12,  1.0226e+13, -2.0766e+12,  3.5557e+12,  2.4669e+12,
         3.9728e+10, -3.0859e+12, -7.3312e+12, -9.1268e+09, -7.6075e+12,
        -1.0609e+12, -6.3340e+12,  9.5815e+12, -9.6626e+12])


In [43]:
(T_batch_correct - T_batch).mean()

tensor(-0.0100)

In [44]:
(T_batch_correct - T_batch).abs().max()

tensor(0.1024)