# Prepare soilgrids data (Poggio, Hengl)

In [None]:
# Libraries
import os, time, shutil, rioxarray
import numpy as np
import xarray as xr
from scipy import ndimage
import matplotlib.pyplot as plt

In [None]:
# Directories
dir_data =  '../data/'
dir01 = '../paper_deficit/output/01_prep/'

---

### Pre-processing

In [None]:
# Libraries
from dask_jobqueue import SLURMCluster
from dask.distributed import Client
import dask

# Initialize dask
cluster = SLURMCluster(
    queue='compute',                      # SLURM queue to use
    cores=24,                             # Number of CPU cores per job
    memory='256 GB',                      # Memory per job
    account='bm0891',                     # Account allocation
    interface="ib0",                      # Network interface for communication
    walltime='01:00:00',                  # Maximum runtime per job
    local_directory='../dask/',           # Directory for local storage
    job_extra_directives=[                # Additional SLURM directives for logging
        '-o ../dask/LOG_worker_%j.o',     # Output log
        '-e ../dask/LOG_worker_%j.e'      # Error log
    ]
)

# Scale dask cluster
cluster.scale(jobs=10)

# Configurate dashboard url
dask.config.config.get('distributed').get('dashboard').update(
    {'link': '{JUPYTERHUB_SERVICE_PREFIX}/proxy/{port}/status'}
)

# Create client
client = Client(cluster)

client

In [None]:
# Function to trim memory of workers
def trim_memory() -> int:
    import ctypes
    libc = ctypes.CDLL("libc.so.6")
    return libc.malloc_trim(0)

In [None]:
def prep_soilgrids2020():
    """Load and prepare soilgrids 2020 data for regridding"""
    
    def _func_soilgrids2020_load(var, depth):
        """Load soilgrids data"""
        # File
        file_path = (dir_data + 'soilgrids/soilgrids2020/' + var + '/' + var + 
                    '_' + depth + '_mean.tif')
        # Load file
        da = rioxarray.open_rasterio(file_path, 
                                     chunks=dict(y=10000, x=10000))
        # Set non-land to nan and return
        return da.where(da != da.attrs.get('_FillValue'))
        
    # Prepare soilgrids 2020 data
    for var in ['bdod', 'cec', 'cfvo', 'clay', 'nitrogen', 'ocd', 'phh2o',
                'sand', 'silt', 'soc']:
        # Read data
        a = _func_soilgrids2020_load(var, '0-5cm')
        b = _func_soilgrids2020_load(var, '5-15cm')
        c = _func_soilgrids2020_load(var, '15-30cm')
        
        # Averarge over the layers and round
        da = ((a*5 + b*10 + c*15) / 30) \
                    .round(0) \
                    .astype('int16')
        # Set nan value to default value
        da_out = xr.where(a.isnull(), a.attrs.get('_FillValue'), da)
        # Add attributes
        da_out.spatial_ref.attrs = a.spatial_ref.attrs
        for attr in ['_FillValue', 'scale_factor', 'add_offset']:
            da_out.attrs[attr] = a.attrs.get(attr)
        # Export
        da_out.rio.to_raster(dir01 + 'soilgrids2020_' + var + '.tif')
        # Release memory
        client.run(trim_memory);


%time prep_soilgrids2020()

In [None]:
def prep_soilgrids2020_ocs():
    """Load and prepare soilgrids 2020 organic carbon stocks data for
    regridding"""

    # in soilgrids: 
    # ocs = organic carbon stocks (t/ha); 
    # soc = Soil organic carbon content in the fine earth fraction (dg/kg)
    dir_data_ocs = dir_data + 'soilgrids/soilgrids2020/ocs/'
    file_path_ocs = dir_data_ocs + 'ocs_0-30cm_mean.tif'
    file_path_ocsunc = dir_data_ocs + 'ocs_0-30cm_uncertainty.tif'
    
    # Prepare ocs data for regridding
    rioxarray.open_rasterio(file_path_ocs, chunks = dict(y=10000, x=10000)) \
        .rename('sgrids_ocs_2020_30cm') \
        .drop_vars('spatial_ref') \
        .rio.to_raster(dir01 + 'soilgrids2020_ocs_0-30cm.tif')

    # Prepare ocs uncertainty data for regridding
    rioxarray.open_rasterio(file_path_ocsunc, chunks = dict(y=10000, x=10000)) \
        .rename('sgrids_ocsunc_2020_30cm') \
        .drop_vars('spatial_ref') \
        .rio.to_raster(dir01 + 'soilgrids2020_ocsunc_0-30cm.tif')


%time prep_soilgrids2020_ocs()

In [None]:
def prep_soilgrids2017_ocs():
    """Load and prepare soilgrids 2017 organic carbon stocks data for
    regridding"""
    
    def _get_soilgrids2017_ocs_da(sdx):
        # Load data
        da = rioxarray.open_rasterio(
            dir_data + 'soilgrids/soilgrids2017/OCSTHA_M_' + sdx + '_250m_ll.tif', 
            chunks=dict(y=10000, x=10000))[0]
        # Change FillValue to nan
        da = da.where(da != da.attrs.get('_FillValue'))
        return da
    
    def _prep_soilgrids2017_ocs_da(da):
        # Replace nan values and change dtype to integer
        da = da.fillna(-32768).round(0).astype('int16')
        da.attrs = dict(_FillValue = -32768,
                    scale_factor = 1.0,
                    add_offset = 0.0)
        return da

    # Load data
    da_sd1 = _get_soilgrids2017_ocs_da('sd1')
    da_sd2 = _get_soilgrids2017_ocs_da('sd2')
    da_sd3 = _get_soilgrids2017_ocs_da('sd3')
    da_sd4 = _get_soilgrids2017_ocs_da('sd4')
    da_sd5 = _get_soilgrids2017_ocs_da('sd5')
    da_sd6 = _get_soilgrids2017_ocs_da('sd6')

    # Prepare and export
    _prep_soilgrids2017_ocs_da(da_sd1 + da_sd2 + da_sd3) \
        .rename('soilgrids2017_ocs_0-30cm') \
        .rio.to_raster(dir01 + 'soilgrids2017_ocs_0-30cm.tif')
    _prep_soilgrids2017_ocs_da(da_sd4 + da_sd5) \
        .rename('soilgrids2017_ocs_30-100cm') \
        .rio.to_raster(dir01 + 'soilgrids2017_ocs_30-100cm.tif')
    _prep_soilgrids2017_ocs_da(da_sd6) \
        .rename('soilgrids2017_ocs_100-200cm') \
        .rio.to_raster(dir01 + 'soilgrids2017_ocs_100-200cm.tif')


%time prep_soilgrids2017_ocs()

In [None]:
def prep_soilgrids2017_other():
    """Load and prepare soilgrids 2017 bdricm, bdrlog, bdticm data for
    regridding"""
    
    for i in ['bdricm', 'bdrlog', 'bdticm']:
        file_path = dir_data + 'soilgrids/soilgrids2017/' + i.upper() + \
            '_M_250m_ll.tif'
        
        rioxarray.open_rasterio(file_path, chunks = dict(y=10000, x=10000)) \
            .rename('sgrids_' + i) \
            .drop_vars('spatial_ref') \
            .rio.to_raster(dir01 + 'soilgrids2017_' + i + '.tif')


%time prep_soilgrids2017_other()

In [None]:
# Close dask cluster
cluster.close()

In [None]:
# Wait for 60s for dask client to completely disconnect
time.sleep(60)

---

In [None]:
# Import regridding function
from regrid_high_res_v1_01 import regrid_high_res, prep_tif

In [None]:
def regrid_da(f_source, dir_target, dir_source, dir_out, 
              size_tiles, fill_value=None, olap=1):  
    """Regrid large xarray dataarrays.

    Args:
        f_source (str): The filename (without extension) of the source .tif file to be regridded.
        dir_target (str): Directory containing target grid .tif file.
        dir_source (str): Directory containing the the source  .tif file.
        dir_out (str): Directory to store the output and intermediate files.
        size_tiles (int): Size of the regridding tiles in degrees.
        fill_value (float, optional): Fill value to use in the regridding process. Defaults to None.
        olap (int, optional): Overlap size in degrees for regridding tiles. Defaults to 1.
        
    Returns:
        xarray.Dataset: The combined dataset after regridding.
    """
    # Prepare the target and source data arrays from TIFF files
    da_target = prep_tif(dir_target + 'target_grid.tif', 'target_grid')
    da_source = prep_tif(dir_source + f_source + '.tif', f_source)
    # Regridd source array to target grid
    regrid_high_res(da_target, da_source, dir_out,
                    account='bm0891', partition='compute',
                    size_tiles=size_tiles, olap=olap, fill_value = fill_value,
                    type_export='zarr', del_interm=False)

In [None]:
# Regridding soilgrids 2020
for i in ['soilgrids2020_cec', 'soilgrids2020_phh2o', 'soilgrids2020_sand',
          'soilgrids2020_soc', 'soilgrids2020_clay', 'soilgrids2020_ocs_0-30cm',
          'soilgrids2020_nitrogen', 'soilgrids2020_cfvo', 'soilgrids2020_ocd',
          'soilgrids2020_bdod', 'soilgrids2020_silt']:
    %time regrid_da(i, dir01, dir01, dir01, 20, -32768, 0.1)

%time regrid_da('soilgrids2020_ocsunc_0-30cm', dir01, dir01, dir01, \
                20, 65535, 0.1)
# Takes about 9-13 min for one file

In [None]:
# Regridding soilgrids 2017
for i in ['soilgrids2017_ocs_0-30cm', 'soilgrids2017_ocs_30-100cm',
          'soilgrids2017_ocs_100-200cm', 'soilgrids2017_bdticm']:
    %time regrid_da(i, dir01, dir01, dir01, 20, -32768, 0.1)

for i in ['soilgrids2017_bdricm', 'soilgrids2017_bdrlog']:
    %time regrid_da(i, dir01, dir01, dir01, 20, 255, 0.1)
# Takes about 9-13 min for one file

---

### Fill nans

In [None]:
def fill_nans(var, dir_out):
    """
    Fills NaN values in the specified variable's dataset using the nearest valid 
    data, applies a land mask, and exports the result to a new Zarr dataset.

    Args:
        var (str): Name of the variable to process (e.g., 'temperature', 'precipitation').
        dir_out (str): Directory where prepared data is stored and the filled dataset will be exported.

    Returns:
        None
    """

    def fill_nans_array(data, invalid):
        """
        Replace invalid (NaN) data cells by the value of the nearest valid data 
        cell.
        """
        ind = ndimage.distance_transform_edt(invalid,
                                             return_distances=False,
                                             return_indices=True)
        return data[tuple(ind)]

    # Paths for input and output
    land_mask_path = os.path.join(dir_out, 'ds_prep_copernicus_land_mask.zarr')
    var_data_path = os.path.join(dir_out, f'ds_regridded_{var}.zarr')
    output_path = os.path.join(dir_out, f'ds_prep_{var}.zarr')

    # Read land mask data
    da_land = xr.open_zarr(land_mask_path) \
                .chunk(dict(lat=5000, lon=5000)) \
                .copernicus_land_mask \
                .compute()

    # Read variable data
    da_var = xr.open_zarr(var_data_path)['regridded_' + var]

    # Fill nan using function fill_nans_array
    # If there are no NaNs, skip filling process
    if not da_var.isnull().any():
        da_fill = da_var.values  # No filling required
    else:
        da_fill = fill_nans_array(da_var.values, da_var.isnull().values)

    # Create a new Dataset with filled data
    ds_filled = xr.Dataset(dict(lat = da_var.lat, lon=da_var.lon))
    ds_filled[var] = (('lat', 'lon'), da_fill)

    # Apply land mask to the filled data
    ds_filled = ds_filled.where(da_land)

    # Export the filled dataset to Zarr format
    ds_filled.chunk(dict(lat=5000, lon=5000)) \
             .to_zarr(output_path, mode='w')

In [None]:
%%time
for i in ['soilgrids2020_cec', 'soilgrids2020_phh2o', 'soilgrids2020_sand',
          'soilgrids2020_soc', 'soilgrids2020_clay', 'soilgrids2020_ocs_0-30cm',
          'soilgrids2020_nitrogen', 'soilgrids2020_cfvo', 'soilgrids2020_ocd',
          'soilgrids2020_bdod', 'soilgrids2020_silt', 
          'soilgrids2020_ocsunc_0-30cm']:
    %time fill_nans(i, dir01)

In [None]:
%%time
for i in ['soilgrids2017_ocs_0-30cm', 'soilgrids2017_ocs_30-100cm',
          'soilgrids2017_ocs_100-200cm', 'soilgrids2017_bdticm',
          'soilgrids2017_bdricm', 'soilgrids2017_bdrlog']:
    %time fill_nans(i, dir01)

---

### Check

In [None]:
# Plot to check
for i in ['soilgrids2020_cec', 'soilgrids2020_phh2o', 'soilgrids2020_sand',
          'soilgrids2020_soc', 'soilgrids2020_clay', 'soilgrids2020_ocs_0-30cm',
          'soilgrids2020_nitrogen', 'soilgrids2020_cfvo', 'soilgrids2020_ocd',
          'soilgrids2020_bdod', 'soilgrids2020_silt', 
          'soilgrids2020_ocsunc_0-30cm']:
    fig, ax = plt.subplots(figsize=(10, 5), ncols=1, nrows=1)
    xr.open_zarr(dir01 + 'ds_prep_' + i + '.zarr')[i] \
        .plot.imshow(ax=ax, robust=True)
    ax.set_title(i)

In [None]:
%%time
for i in ['soilgrids2017_ocs_0-30cm', 'soilgrids2017_ocs_30-100cm',
          'soilgrids2017_ocs_100-200cm', 'soilgrids2017_bdticm',
          'soilgrids2017_bdricm', 'soilgrids2017_bdrlog']:
    fig, ax = plt.subplots(figsize=(10, 5), ncols=1, nrows=1)
    xr.open_zarr(dir01 + 'ds_prep_' + i + '.zarr')[i] \
        .plot.imshow(ax=ax, robust=True)
    ax.set_title(i)