# Global-scale atmospheric moisture and mass budgets on ERA5 pressure level data

In [1]:
import numpy as np
import xarray as xr

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [17]:
dt_str = '2020-01-01'

base_dir = '/glade/derecho/scratch/ksha/CREDIT/GATHER/fuxi_dry_1deg_raw/'
filename = base_dir + f'{dt_str}T00Z.nc'
ds_rollout = xr.open_dataset(filename)

base_dir = '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/'
filename = base_dir + 'static/ERA5_plevel_1deg_6h_conserve_static.zarr'
ds_static = xr.open_zarr(filename)

In [18]:
base_dir = '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/'
filename = base_dir + 'all_in_one/ERA5_plevel_1deg_6h_2020_conserve.zarr'
ds_ERA5 = xr.open_zarr(filename)

In [19]:
R = 6371000  # m
GRAVITY = 9.80665
RHO_WATER = 1000.0 # kg/m^3
RAD_EARTH = 6371000 # m
LH_WATER = 2.26e6  # J/kg
CP_DRY = 1005 # J/kg K
CP_VAPOR = 1846 # J/kg K

In [20]:
x = ds_rollout['longitude']
y = ds_rollout['latitude']
lon, lat = np.meshgrid(x, y)
level_p = 100*np.array(ds_static['level'])

In [21]:
level_p # Pa or kg/m/s2

array([   100.,    200.,    300.,    500.,    700.,   1000.,   2000.,
         3000.,   5000.,   7000.,  10000.,  12500.,  15000.,  17500.,
        20000.,  22500.,  25000.,  30000.,  35000.,  40000.,  45000.,
        50000.,  55000.,  60000.,  65000.,  70000.,  75000.,  77500.,
        80000.,  82500.,  85000.,  87500.,  90000.,  92500.,  95000.,
        97500., 100000.], dtype=float32)

In [22]:
# level_diff = np.diff(level_p)
# level_diff_cumsum = np.concatenate(([0], np.cumsum(level_diff)))

In [23]:
q = ds_rollout['specific_total_water'].values
T = ds_rollout['T'].values
u = ds_rollout['U'].values
v = ds_rollout['V'].values
precip = ds_rollout['total_precipitation'].values
evapor = ds_rollout['evaporation'].values
GPH_surf = ds_static['geopotential_at_surface'].values
TOA_net = ds_rollout['top_net_solar_radiation'].values
OLR = ds_rollout['top_net_thermal_radiation'].values
R_short = ds_rollout['surface_net_solar_radiation'].values
R_long = ds_rollout['surface_net_thermal_radiation'].values
LH = ds_rollout['surface_latent_heat_flux'].values
SH = ds_rollout['surface_sensible_heat_flux'].values

# replace to ERA5
q_ERA5 = ds_ERA5['specific_total_water'].sel(time=dt_str).values
T_ERA5 = ds_ERA5['T'].sel(time=dt_str).values
u_ERA5 = ds_ERA5['U'].sel(time=dt_str).values
v_ERA5 = ds_ERA5['V'].sel(time=dt_str).values
precip_ERA5 = ds_ERA5['total_precipitation'].sel(time=dt_str).values
evapor_ERA5 = ds_ERA5['evaporation'].sel(time=dt_str).values
TOA_net_ERA5 = ds_ERA5['top_net_solar_radiation'].sel(time=dt_str).values
OLR_ERA5 = ds_ERA5['top_net_thermal_radiation'].sel(time=dt_str).values
R_short_ERA5 = ds_ERA5['surface_net_solar_radiation'].sel(time=dt_str).values
R_long_ERA5 = ds_ERA5['surface_net_thermal_radiation'].sel(time=dt_str).values
LH_ERA5 = ds_ERA5['surface_latent_heat_flux'].sel(time=dt_str).values
SH_ERA5 = ds_ERA5['surface_sensible_heat_flux'].sel(time=dt_str).values

q = np.concatenate((q_ERA5[0, ...][None, ...], q), axis=0)
precip = np.concatenate((precip_ERA5[0, ...][None, ...], precip), axis=0)
evapor = np.concatenate((evapor_ERA5[0, ...][None, ...], evapor), axis=0)

In [24]:
N_seconds = 3600 * 6 # 6 hourly data

In [25]:
def weighted_sum(data, weights, axis, keepdims=False):
    '''
    Compute the weighted sum of a given quantity

    Args:
        data: the quantity to be sum-ed
        weights: weights that can be broadcasted to the shape of data
        axis: dims to compute the sum
        keepdims: keepdims

    Returns:
        weighted sum
    '''
    expanded_weights = np.broadcast_to(weights, data.shape)
    return np.sum(data * expanded_weights, axis=axis, keepdims=keepdims)

def pressure_integral(q, level_p, output_shape):
    '''
    Compute the pressure level integral of a given quantity using np.trapz

    Args:
        q: the quantity with dims of (level, lat, lon) or (time, level, lat, lon)
        level_p: the pressure level of q as [Pa] and with dims of (level,)
        output_shape: either (lat, lon) or (time, lat, lon)

    Returns:
        Pressure level integrals of q
    '''
    # Ensure level_p is a NumPy array
    level_p = np.asarray(level_p)

    # (level, lat, lon) --> (lat, lon)
    if len(output_shape) == 2 and q.ndim == 3:
        Q = np.trapz(q, x=level_p, axis=0)

    # (time, level, lat, lon) --> (time, lat, lon)
    elif len(output_shape) == 3 and q.ndim == 4:
        Q = np.trapz(q, x=level_p, axis=1)

    else:
        raise ValueError('Invalid output_shape or dimensions of q.')

    return Q

In [26]:
def grid_area(lat, lon):
    '''
    Compute grid cell areas using the exact formula for spherical quadrilaterals.

    Args:
        lat, lon: 2D arrays of latitude and longitude in degrees.

    Return:
        area: 2D array of grid cell areas in square meters.
    '''
    # Convert latitude and longitude to radians
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    
    # Compute sine of latitude
    sin_lat_rad = np.sin(lat_rad)
    
    # Compute gradient of sine of latitude (d_phi)
    d_phi = np.gradient(sin_lat_rad, axis=0, edge_order=2)
    
    # Compute gradient of longitude (d_lambda)
    d_lambda = np.gradient(lon_rad, axis=1, edge_order=2)
    
    # Adjust d_lambda to be within -π and π
    d_lambda = (d_lambda + np.pi) % (2 * np.pi) - np.pi
    
    # Compute grid cell area
    area = np.abs(RAD_EARTH**2 * d_phi * d_lambda)
    
    return area

area = grid_area(lat, lon)
w_lat = area #/ np.sum(area)

In [27]:
output_shape = (q.shape[0],)+lon.shape

In [28]:
def mass_residual_compute(q, level_p, output_shape, w_lat):
    mass_dry_per_area = pressure_integral(1-q, level_p, output_shape) / GRAVITY
    mass_dry_sum = weighted_sum(mass_dry_per_area, w_lat, axis=(1, 2), keepdims=False)
    # ----------------------------------------------------------------------- #
    # check residual term
    mass_dry_res = np.diff(mass_dry_sum)
    return mass_dry_res


In [29]:
mass_residual = mass_residual_compute(q, level_p, output_shape, w_lat)

In [30]:
mass_residual

array([ 3.55132168e+13,  3.43767759e+13, -1.69713975e+13,  2.67831969e+13,
       -1.04526350e+13,  4.09086716e+13,  8.31145651e+12,  4.30334176e+13,
        4.64466790e+12,  5.19408096e+13,  1.91620339e+13,  4.29138605e+13,
        1.16382680e+13,  5.26801867e+13,  8.00381868e+12,  3.14225162e+13,
       -5.75310545e+12,  3.37995621e+13, -1.55694893e+13,  1.60719024e+12,
       -2.32985689e+13,  1.31689192e+13, -2.42015852e+13, -1.12547935e+13,
       -3.44378873e+13,  1.22492105e+13, -2.99203591e+13,  2.49098977e+11,
       -2.31524477e+13,  2.50031652e+13, -2.86271609e+13, -1.10697024e+12,
       -1.90030900e+13,  3.71240275e+13, -1.13206348e+13,  1.94843150e+13,
       -8.47395844e+11,  4.08886114e+13, -1.70374542e+12,  2.77828278e+13])

In [31]:
#plt.plot(mass_residual)

In [32]:
#def water_budget_compute(q, precip, evapor, N_seconds, output_shape, w_lat):

precip_flux = precip[1:, ...] * RHO_WATER / N_seconds # m/hour --> kg/m^2/s, positive
evapor_flux = evapor[1:, ...] * RHO_WATER / N_seconds # kg/m^2/s, negative


# pre-compute TWC
TWC = pressure_integral(q, level_p, output_shape) / GRAVITY # kg/m^2
dTWC_dt = np.diff(TWC, axis=0) / N_seconds # kg/m^2/s
TWC_sum = weighted_sum(dTWC_dt, w_lat, axis=(1, 2), keepdims=False) # kg/s

# pre-compute evaporation
E_sum = weighted_sum(evapor_flux, w_lat, axis=(1, 2), keepdims=False) # kg/s


P_sum = weighted_sum(precip_flux, w_lat, axis=(1, 2), keepdims=False) # kg/s

residual = -TWC_sum - E_sum - P_sum
print('Residual to conserve moisture budge [kg/s]: {}'.format(residual))


Residual to conserve moisture budge [kg/s]: [ 6.89155080e+08  1.18932877e+09 -2.32658457e+09 -4.55295532e+08
 -1.32646831e+09  1.48316487e+09 -1.51274105e+09  1.75815756e+08
 -9.83049848e+08  1.39501288e+09 -1.57119160e+09 -3.62471038e+08
 -1.02451936e+09  1.29536047e+09 -1.88611670e+09 -1.06368156e+09
 -1.64573645e+09  1.02135545e+09 -2.26559515e+09 -1.36102377e+09
 -1.46075137e+09  6.64544635e+08 -1.93632121e+09 -1.46154493e+09
 -1.67079974e+09  5.57177329e+08 -2.04946776e+09 -1.16379617e+09
 -1.62896008e+09  7.86659179e+08 -2.45432269e+09 -1.31692339e+09
 -1.60106097e+09  4.16465611e+08 -2.51916669e+09 -1.06430417e+09
 -1.34180995e+09  7.49861273e+08 -2.10526055e+09 -6.46093099e+08]
