In [None]:
import os
import xarray as xr
import numpy as np
import pandas as pd
import rasterio as rio
from scipy.interpolate import griddata
import pyproj
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm

path_root = os.path.join(os.environ['E3SM_ROOT'], 'inputdata', 'lnd', 'clm2', 
                         'surfdata_map')
file_orig = os.path.join(path_root, 'surfdata_conus_erw_off_simyr1850_c211019.nc')
file_temp = os.path.join(path_root, 'surfdata_conus_erw.temp.nc')
file_dest = os.path.join(path_root, 'surfdata_conus_erw_on_simyr1850_c211019.nc')

# copy from the non-ERW file
os.system(f"cp {file_orig} {file_temp}")

In [None]:
# Read soil grids pH and CEC_TOT
path_soilgrids = os.path.join(os.environ['PROJDIR'], 'ERW_LDRD', 'data')

soilgrids_map = {}
for j, (col, newname) in enumerate(zip(
    ['cec_mean.tif', 'phh2o_mean.tif'], 
    ['CEC_TOT', 'SOIL_PH']
)):
    h = rio.open(os.path.join(path_soilgrids, col))

    # google earth engine exported file already in EPSG 4326
    left, bottom, right, top = h.bounds
    crs = h.crs
    rows, cols = np.indices(h.shape)
    src_x, src_y = rio.transform.xy(h.transform, rows, cols)
    lons = np.array(src_x)
    lats = np.array(src_y)

    # conduct reprojection
    lats_target = np.arange(23.25, 54.26, 0.5)
    lons_target = np.arange(234.75, 293.26, 0.5) - 360.
    lons_target, lats_target = np.meshgrid(lons_target, lats_target)
    src_coords = np.array([lons.flatten(), lats.flatten()]).T

    # loop through the layers
    soilgrids_map[newname] = []
    for i in range(1, h.count+1):
        # read data
        data = h.read(i, masked = True)

        # scale CEC from mmol/kg to mmol 100g-1 dry soil
        # scale pHx10 to pH
        data = data / 10.

        src_data_flat = data.flatten()
        data_reproj = griddata(src_coords, src_data_flat, (lons_target, lats_target),
                               method='linear')

        # save transformed data
        soilgrids_map[newname].append(data_reproj)

    h.close()

In [None]:
# Stack and convert to xarray format
soilgrids_top = np.array([0, 5, 15, 30, 60, 100]) # cm
soilgrids_bot = np.array([5, 15, 30, 60, 100, 200])
soilgrids_node = 0.5 * (soilgrids_top + soilgrids_bot) / 100
elm_nodes = np.array([0.0071, 0.0279, 0.0623, 0.1189, 0.2122, 0.3661, 0.6198, 1.0380, 1.7276, 2.8646])

soilgrids_data = {}
for col in soilgrids_map.keys():
    # Layer mapping: weighted average
    temp = [soilgrids_map[col][0]]
    for node in elm_nodes[1:-2]:
        prev = np.where(soilgrids_node < node)[0][-1]
        next = np.where(soilgrids_node > node)[0][0]
        next_weight = (node - soilgrids_node[prev])
        prev_weight = (soilgrids_node[next] - node)

        # print(node, prev, next, soilgrids_node[prev], soilgrids_node[next])

        temp.append( (soilgrids_map[col][prev]*prev_weight + soilgrids_map[col][next]*next_weight) / (prev_weight + next_weight) )
    temp.extend([soilgrids_map[col][-1], soilgrids_map[col][-1]])

    data = np.stack(temp, axis = 2)
    data = np.moveaxis(data, source=2, destination=0)
    soilgrids_data[col] = xr.DataArray(data, dims = ['nlevsoi','lsmlat','lsmlon']) # no coords in surfdata file
soilgrids_data = xr.Dataset(soilgrids_data)

In [None]:
# sanity check
fig, axes = plt.subplots(1, 2, figsize = (20, 6), subplot_kw = {'projection': ccrs.PlateCarree()})
fig.subplots_adjust(wspace = 0.01)
for i, var in enumerate(soilgrids_data.data_vars):
    ax = axes.flat[i]
    ax.coastlines()
    if i == 1:
        levels = np.linspace(0, 10, 21)
    else:
        levels = np.linspace(0, 80, 21)
    cf = ax.contourf(lons_target[0,:], lats_target[:,0], soilgrids_data[var][0, :, :], 
                     cmap = 'jet', levels = levels)
    plt.colorbar(cf, ax = ax, orientation = 'horizontal', shrink = 0.7)
    if i == 0:
        ax.set_title(var + ' (cmolc/kg)')
    else:
        ax.set_title(var)

In [None]:
# Read NCSS maps and interpoate to 0.5 degree maps
elm_bot = 100 * np.array([ 0.0906, 0.2891, 0.4929, 0.8289, 1.3828, 3.8019])
elm_top = np.insert(elm_bot[:-1], 0, 0)
path_ncss = os.path.join(os.environ['PROJDIR'], 'ERW_LDRD', 'results')

CEC_maps = {}
for j, (col, newname) in enumerate(zip(
    ['beta_Ca', 'beta_Mg', 'beta_Na', 'beta_K', 'beta_Al'], 
    ['CEC_EFF_1', 'CEC_EFF_2', 'CEC_EFF_3', 'CEC_EFF_4', 'CEC_EFF_5']
)):
    CEC_maps[newname] = []
    for i, (top, bottom) in enumerate(zip(elm_top, elm_bot)):

        h = rio.open(os.path.join(path_ncss, f'interp_NCSS_layer_{i}_{col}.tif'))

        # read data
        data = h.read(1, masked = True)

        # collect projection info
        left, bottom, right, top = h.bounds
        crs = h.crs
        rows, cols = np.indices(data.shape)
        src_x, src_y = rio.transform.xy(h.transform, rows, cols)
        src_x = np.array(src_x)
        src_y = np.array(src_y)

        # convert X-Y to lat-lon
        in_proj = pyproj.Proj(init="epsg:{}".format(crs.to_epsg()))
        out_proj = pyproj.Proj(init="epsg:4326")
        ## sanity check passed
        # lons, lats = pyproj.transform(in_proj, out_proj, [left, right, left, right], [top, top, bottom, bottom])
        # for lat, lon in zip(lats, lons):
        #    print(lat, lon)
        lons, lats = pyproj.transform(in_proj, out_proj, src_x, src_y)

        # conduct reprojection
        lats_target = np.arange(23.25, 54.26, 0.5)
        lons_target = np.arange(234.75, 293.26, 0.5) - 360.
        lons_target, lats_target = np.meshgrid(lons_target, lats_target)
        src_coords = np.array([lons.flatten(), lats.flatten()]).T
        src_data_flat = data.flatten()
        data_reproj = griddata(src_coords, src_data_flat, (lons_target, lats_target),
                               method='linear')

        h.close()

        # save transformed data
        CEC_maps[newname].append(data_reproj)
CEC_maps['CEC_ACID'] = []
for i, (top, bottom) in enumerate(zip(elm_top, elm_bot)):
    data_reproj = 1 - CEC_maps['CEC_EFF_1'][i] - CEC_maps['CEC_EFF_2'][i] - CEC_maps['CEC_EFF_3'][i] - CEC_maps['CEC_EFF_4'][i] - CEC_maps['CEC_EFF_5'][i]
    CEC_maps['CEC_ACID'].append(data_reproj)

In [None]:
# Debug check
row = 30
col = 6
for i in range(6):
    print(CEC_maps['CEC_ACID'][i][row,col], CEC_maps['CEC_EFF_1'][i][row,col],
          CEC_maps['CEC_EFF_2'][i][row,col], CEC_maps['CEC_EFF_3'][i][row,col],
          CEC_maps['CEC_EFF_4'][i][row,col], CEC_maps['CEC_EFF_5'][i][row,col])

In [None]:
# Some interpolated values are greater than 1 or smaller than 0
# When such things happen, zero out the negative values (0.05), 
# and scale down the positive values to add up to 0.95
hard_min = 0.05
for i in range(6):
    is_negative = np.zeros([6, 63, 118]).astype(bool)
    sum_negative = np.zeros([63, 118], dtype=float)
    for j, col in enumerate(['CEC_ACID', 'CEC_EFF_1', 'CEC_EFF_2', 'CEC_EFF_3', 
                             'CEC_EFF_4', 'CEC_EFF_5']):
        is_negative[j, :, :] = CEC_maps[col][i] < 0
        sum_negative = sum_negative + is_negative[j, :, :] * CEC_maps[col][i]
    scale_positive = (1 - hard_min) / (1 - sum_negative)
    for j, col in enumerate(['CEC_ACID', 'CEC_EFF_1', 'CEC_EFF_2', 'CEC_EFF_3', 
                             'CEC_EFF_4', 'CEC_EFF_5']):
        CEC_maps[col][i] = np.where( 
            is_negative[j, :, :], 
            hard_min, 
            CEC_maps[col][i] * scale_positive
        )

In [None]:
# Debug check
row = 30
col = 6
for i in range(6):
    print(CEC_maps['CEC_ACID'][i][row,col], CEC_maps['CEC_EFF_1'][i][row,col], 
          CEC_maps['CEC_EFF_2'][i][row,col], CEC_maps['CEC_EFF_3'][i][row,col],
          CEC_maps['CEC_EFF_4'][i][row,col], CEC_maps['CEC_EFF_5'][i][row,col])

In [None]:
# Stack and convert to xarray format
CEC_data = {}
for col in CEC_maps.keys():
    # Layer mapping: 0 => 0:3, 1 => 3:5, 2 => 5, 3 => 6, 4 => 7, 5 => 8:10
    data = np.stack(
        [CEC_maps[col][0], CEC_maps[col][0], CEC_maps[col][0], 
         CEC_maps[col][1], CEC_maps[col][1], CEC_maps[col][2], 
         CEC_maps[col][3], CEC_maps[col][4], 
         CEC_maps[col][5], CEC_maps[col][5]], 
        axis = 2)
    data = np.moveaxis(data, source=2, destination=0)

    # multiply beta with total CEC
    CEC_data[col] = xr.DataArray(data * soilgrids_data['CEC_TOT'].values,
                                 dims = ['nlevsoi','lsmlat','lsmlon']) # no coords in surfdata file
CEC_data = xr.Dataset(CEC_data)

In [None]:
# sanity check
for layer in range(6):
    fig, axes = plt.subplots(2, 3, figsize = (20, 6), subplot_kw = {'projection': ccrs.PlateCarree()})
    for i, var in enumerate(CEC_data.data_vars):
        ax = axes.flat[i]
        ax.coastlines()
        cf = ax.contourf(lons_target[layer,:], lats_target[:,layer], CEC_data[var][layer, :, :], 
                         cmap = 'RdYlBu')
        plt.colorbar(cf, ax = ax)
        ax.set_title(var)

In [None]:
# percentage kaolinite and calcite in soil
# The layers are: elm_bot = np.array([0.0906,  0.8289, 1.3828])
path_mineral = os.path.join(os.environ['PROJDIR'], 'ERW_LDRD', 'results')

mineral_map = {}
for j, (col,newname) in enumerate(zip(['Kaolinit', 'Calcite'],['KAOLINITE','CALCITE'])):
    mineral_map[f'PCT_{newname}'] = []
    for i in range(3):
        h = rio.open(os.path.join(path_mineral, f'interp_soilMineral_layer_{i}_{col}.tif'))

        # tif file is in EPSG:5070
        left, bottom, right, top = h.bounds
        crs = h.crs
        rows, cols = np.indices(h.shape)
        src_x, src_y = rio.transform.xy(h.transform, rows, cols)
        src_x = np.array(src_x)
        src_y = np.array(src_y)

        # convert X-Y to lat-lon
        in_proj = pyproj.Proj(init="epsg:{}".format(crs.to_epsg()))
        out_proj = pyproj.Proj(init="epsg:4326")
        ## sanity check passed
        # lons, lats = pyproj.transform(in_proj, out_proj, [left, right, left, right], [top, top, bottom, bottom])
        # for lat, lon in zip(lats, lons):
        #    print(lat, lon)
        lons, lats = pyproj.transform(in_proj, out_proj, src_x, src_y)

        # conduct reprojection
        lats_target = np.arange(23.25, 54.26, 0.5)
        lons_target = np.arange(234.75, 293.26, 0.5) - 360.
        lons_target, lats_target = np.meshgrid(lons_target, lats_target)
        src_coords = np.array([lons.flatten(), lats.flatten()]).T

        # read data
        data = h.read(1, masked = True)

        src_data_flat = data.flatten()
        data_reproj = griddata(src_coords, src_data_flat, (lons_target, lats_target),
                               method='linear')

        # save transformed data
        mineral_map[f'PCT_{newname}'].append(data_reproj)

        h.close()

In [None]:
# Stack and convert to xarray format
# The three layers in the interpolated data: [0.0906,  0.8289, 1.3828]
# Compared to ELM:
# Merge the top 10cm into one because too few observations
# 0.0175, 0.0451, 0.0906, 
# Merge the 10-80cm into one because too few observations
# 0.1655, 0.2891, 0.4929, 0.8289,
# Below 1.38m doesn't have much data, merge those layers too
# 1.3828, 2.2961, 3.8019
mineral_data = {}
for col in mineral_map.keys():
    data = np.stack(
        [mineral_map[col][0], mineral_map[col][0], mineral_map[col][0], 
         mineral_map[col][1], mineral_map[col][1], mineral_map[col][1], 
         mineral_map[col][1], mineral_map[col][2], mineral_map[col][2], 
         mineral_map[col][2]], 
        axis = 2)
    data = np.moveaxis(data, source=2, destination=0)

    # multiply beta with total CEC
    mineral_data[col] = xr.DataArray(data,
        dims = ['nlevsoi','lsmlat','lsmlon']) # no coords in surfdata file
mineral_data = xr.Dataset(mineral_data)

In [None]:
# sanity check
fig, axes = plt.subplots(3, 2, figsize = (16, 12), subplot_kw = {'projection': ccrs.PlateCarree()})
for i, var in enumerate(mineral_data.data_vars):
    if i == 0:
        levels = np.array([0, 0.6, 1.5, 2.5, 3.9, 6.8, 43.7])
    else:
        levels = np.array([0., 0.4, 1.7, 4.1, 5.6, 7.9, 12.2, 20.7, 69.8])
    for j, mapped_layer in enumerate([0, 3, 7]):
        ax = axes[j,i]
        ax.coastlines()
        cf = ax.contourf(lons_target[0,:], lats_target[:,0], mineral_data[var][mapped_layer, :, :],
                         cmap = 'RdYlBu', levels = levels, 
                         norm = BoundaryNorm(levels, 256, clip = False, extend = 'both'))
        plt.colorbar(cf, ax = ax)
        ax.set_title(var)

In [None]:
# put into the NetCDF
encoding = {}
for col in soilgrids_map.keys():
    encoding[col] = {'dtype': np.float32, '_FillValue': 1e20}
for col in CEC_maps.keys():
    encoding[col] = {'dtype': np.float32, '_FillValue': 1e20}
for col in mineral_map.keys():
    encoding[col] = {'dtype': np.float32, '_FillValue': 1e20}

soilgrids_data['SOIL_PH'].attrs = {'long_name': 'soil pH', 'units': ''}
soilgrids_data['CEC_TOT'].attrs = {'long_name': 'total cation exchange capacity', 
                                   'units': 'meq 100g-1 dry soil'}
CEC_data['CEC_ACID'].attrs = {'long_name': 'acid cation exchange capacity', 
                              'units': 'meq 100g-1 dry soil'}
for i in range(1,5):
    CEC_data[f'CEC_EFF_{i}'].attrs = {'long_name': 'individual cation exchange capacity', 
                                      'units': 'meq 100g-1 dry soil'}
mineral_data['PCT_KAOLINITE'].attrs = {'long_name': 'percentage naturally occuring kaolinite in soil mineral', 'units': 'g 100 g-1 soil'}
mineral_data['PCT_CALCITE'].attrs = {'long_name': 'percentage naturally occuring CaCO3 in soil mineral', 'units': 'g 100 g-1 soil'}

hr = xr.open_dataset(file_temp)
hr.update(CEC_data)
hr.update(soilgrids_data)
hr.update(mineral_data)
hr.to_netcdf(file_dest, encoding = encoding, format = 'NETCDF4_CLASSIC')
hr.close()

os.system(f'rm {file_temp}')

<span style="background-color: #fff7bc;font-size: 24px;color: black"><strong>Final quality check - grid cells that are within the domain mask, but have problematic data values</strong></span>

Set both mask = 0 and landfrac = 0 in nodata regions

In [None]:
hr_domain = xr.open_dataset(
    os.path.join(os.environ['E3SM_ROOT'], 'inputdata', 'share', 'domains', 'domain.clm',
                 'domain.lnd.conus_erw_jra.240712.nc')
)
# integer [0,1] indicates where the land is; we only need this one
mask = hr_domain['mask'].values > 0
#hr_domain.close()

hr = xr.open_dataset(file_dest)

In [None]:
# (1) edit the mask to exclud lat lon ranges actually outside CONUS
mask = mask & ((hr_domain['xc'] >= 233) & (hr_domain['xc'] <= 294.5) & \
    (hr_domain['yc'] >= 25) & (hr_domain['yc'] <= 50)).values

In [None]:
# (2) check if it is feasible to exclude the few grid cells with near-zero CEC_TOT
# Yes!
fig, axes = plt.subplots(4, 3, figsize = (20, 15))
ax = axes.flat[0]
mask_temp = mask[::-1, :]
ax.imshow(mask_temp)
for i in range(10):
    ax = axes.flat[i]
    mask_temp = mask_temp & (hr['CEC_TOT'][i,::-1,:].values > 5)
    ax.imshow(mask_temp)

In [None]:
# (3) check if after the exclusion based on CEC_TOT, other CEC variables behave okay
# Yes!
mask_cectot = mask & np.all(hr['CEC_TOT'].values > 5, axis = 0)

print(hr['CEC_ACID'].where(mask_cectot).min().values, hr['CEC_ACID'].where(mask_cectot).max().values)
print(hr['CEC_TOT'].where(mask_cectot).min().values, hr['CEC_TOT'].where(mask_cectot).max().values)
for cation in range(1,6):
    print(hr[f'CEC_EFF_{cation}'].where(mask_cectot).min().values, hr[f'CEC_EFF_{cation}'].where(mask_cectot).max().values)

In [None]:
# (4) check if after the exclusion based on CEC_TOT, soil pH behaves okay
# Yes!
fig, axes = plt.subplots(4, 3, figsize = (20, 15))
ax.imshow(mask_temp)
for i in range(10):
    ax = axes.flat[i]
    ax.imshow(mask_cectot, cmap = 'Reds')
    data = hr['SOIL_PH'][i,:,:].where(mask_cectot)
    data.plot(ax = ax, cmap = 'Blues')
    ax.set_title(
        f'{np.nanmin(data.values):.4f}-{np.nanmax(data.values):.4f}'
    )

In [None]:
# (5) check if after the exclusion based on CEC_TOT, PCT_CALCITE behaves okay
# No!
fig, axes = plt.subplots(4, 3, figsize = (20, 15))
ax.imshow(mask_temp)
for i in range(10):
    ax = axes.flat[i]
    ax.imshow(mask_cectot, cmap = 'Reds')
    data = hr['PCT_CALCITE'][i,:,:].where(mask_cectot)
    data.plot(ax = ax, cmap = 'Blues')
    ax.set_title(
        f'{np.nanmin(data.values):.4f}-{np.nanmax(data.values):.4f}'
    )

In [None]:
hr['TOPO']

In [None]:
# (5) check if after the exclusion based on CEC_TOT, PCT_KAOLINITE behaves okay
# No!
fig, axes = plt.subplots(4, 3, figsize = (20, 15))
ax.imshow(mask_temp)
for i in range(10):
    ax = axes.flat[i]
    ax.imshow(mask_cectot, cmap = 'Reds')
    data = hr['PCT_KAOLINITE'][i,:,:].where(mask_cectot)
    data.plot(ax = ax, cmap = 'Blues')
    ax.set_title(
        f'{np.nanmin(data.values):.4f}-{np.nanmax(data.values):.4f}'
    )