# Create final output datasets

In [None]:
# Libraries
import os, time, datetime
import xarray as xr

In [None]:
# Directories
dir01 = '../paper_deficit/output/01_prep/'
dir02 = '../paper_deficit/output/02_dbase/'
dir03 = '../paper_deficit/output/03_rf/'
dir04 = '../paper_deficit/output/04_out/'

---

In [None]:
# Libraries
from dask_jobqueue import SLURMCluster
from dask.distributed import Client
import dask

# Initialize dask
cluster = SLURMCluster(
    queue='compute',                      # SLURM queue to use
    cores=48,                             # Number of CPU cores per job
    memory='256 GB',                      # Memory per job
    account='bm0891',                     # Account allocation
    interface="ib0",                      # Network interface for communication
    walltime='02:00:00',                  # Maximum runtime per job
    local_directory='../dask/',           # Directory for local storage
    job_extra_directives=[                # Additional SLURM directives for logging
        '-o ../dask/LOG_worker_%j.o',     # Output log
        '-e ../dask/LOG_worker_%j.e'      # Error log
    ]
)

# Scale dask cluster
cluster.scale(jobs=5)

# Configurate dashboard url
dask.config.config.get('distributed').get('dashboard').update(
    {'link': '{JUPYTERHUB_SERVICE_PREFIX}/proxy/{port}/status'}
)

# Create client
client = Client(cluster)

client

In [None]:
def get_data_rf(var_tar, scen, rf):
    """
    Retrieve predicted carbon data from the rfaaa-file.

    Args:
        var_tar (str): The target variable (e.g., 'agbc_min').
        scen (str): The scenario ('prim' or 'secd').

    Returns:
        xr.DataArray: DataArray containing modeled carbon data.
    """
    if rf == 'rfaaa_mean':
        file_in = os.path.join(
            dir03 + f'/files_adjusted/ds_rfpred_{var_tar}_{scen}_rfaaa.zarr'
            )
    if rf in ['qrfr_005', 'qrfr_095']:
        file_in = os.path.join(
            dir03 + f'/files_predicted/ds_rfpred_{var_tar}_{scen}.zarr'
            )
    
    da = xr.open_zarr(file_in)[rf]
    da.attrs = dict(_FillValue=-32768)
    return da


def get_data_orig(var_tar, dir02):
    """
    Retrieve the actual carbon data.

    Args:
        var_tar (str): The target variable (e.g., 'agbc_mean').
        dir02 (str): Directory path where actual carbon data files are located.

    Returns:
        xr.DataArray: DataArray containing the original carbon data.
    """
    ctype = var_tar.split('_')[0]
    ds = xr.open_zarr(os.path.join(dir02, f'ds_prep_{ctype}.zarr'))
    da = ds[var_tar].fillna(-32768).round(0).astype('int16')
    da.attrs = dict(_FillValue=-32768)
    return da


def get_data_area(dir01):
    """
    Retrieve the area data (in hectares) for each grid cell.

    Args:
        dir01 (str): Directory path where area data files are located.

    Returns:
        xr.DataArray: DataArray containing the area data.
    """
    ds = xr.open_zarr(os.path.join(dir01, 'ds_prep_area_ha.zarr'))
    da = ds.area_ha
    da.attrs = dict(_FillValue='NaN')
    return da

In [None]:
def create_ds_ctype(ctype, complevel):
    """
    Create output dataset for specific carbon type ('agbc', 'bgbc', 'soc').

    Args:
        ctype (str): The carbon type ('agbc', 'bgbc', or 'soc').
        comlevel (int): Compression level (0-9): 1=best speed; 9=best compression
    """
    ds = xr.Dataset()  # Create an empty xarray Dataset

    # Add actual carbon data variables to the dataset
    for var_tar in [f'{ctype}_min', f'{ctype}_mean', f'{ctype}_max']:
        ds[f'{var_tar}_act'] = get_data_orig(var_tar, dir02)

    # Add modeled carbon data variables for each scenario to the dataset
    for scen in ['prim', 'secd']:
        for var_tar in [f'{ctype}_min', f'{ctype}_mean', f'{ctype}_max']:
            ds[f'{var_tar}_{scen}'] = get_data_rf(var_tar, scen, 'rfaaa_mean')
            ds[f'{var_tar}_{scen}_q005'] = get_data_rf(var_tar, scen, 'qrfr_005')
            ds[f'{var_tar}_{scen}_q095'] = get_data_rf(var_tar, scen, 'qrfr_095')

    # Add area data to the dataset
    ds['area_ha'] = get_data_area(dir01).astype('float32')

    # Update latitude attributes
    ds['lat'].attrs = dict(long_name="latitude",
                           units="degrees_north",
                           standard_name="latitude",
                           axis='X')

    # Update longitude attributes
    ds['lon'].attrs = dict(long_name="longitude",
                           units="degrees_east",
                           standard_name="longitude",
                           axis='Y')

    # Update attributes of carbon variables based on type, metric, and scenario
    for i in list(ds.data_vars):
        if i.startswith(ctype):
            if ctype == 'agbc':
                str_ln_ctype = 'Above ground biomass carbon'
            if ctype == 'bgbc':
                str_ln_ctype = 'Below ground biomass carbon'
            if ctype == 'soc':
                str_ln_ctype = 'Soil organic carbon 0-30 cm'
            if i.split('_')[1] == 'min':
                str_ln_mid = 'Minimum scenario'
            if i.split('_')[1] == 'mean':
                str_ln_mid = 'Mean scenario'
            if i.split('_')[1] == 'max':
                str_ln_mid = 'Maximum scenario'
            if i.split('_')[2] == 'act':
                str_ln_scen = 'Actual'
            if i.split('_')[2] == 'prim':
                str_ln_scen = 'Pristine land assumption'
            if i.split('_')[2] == 'secd':
                str_ln_scen = 'Low human influence assumption'

            # Combine components to create a descriptive long name
            str_ln = f'{str_ln_ctype}; {str_ln_mid}; {str_ln_scen}'
            if i.endswith('q005'):
                str_ln = f'{str_ln_ctype}; {str_ln_mid}; {str_ln_scen}; 0.05 quantile'
            if i.endswith('q095'):
                str_ln = f'{str_ln_ctype}; {str_ln_mid}; {str_ln_scen}; 0.95 quantile'

            # Update the variable's attributes
            ds[i].attrs.update(dict(standard_name=i,
                                    long_name=str_ln,
                                    unit='Ct ha-1'))

            # Update dataset title with the carbon type description
            ds.attrs.update(dict(title=str_ln_ctype))

    # Add attributes for the area variable
    ds['area_ha'].attrs.update(dict(standard_name='area_ha',
                                    long_name='Grid cell area',
                                    unit='ha'))

    # Add valid_min and valid_max attributes to each data variable
    for i in ds[[i for i in ds.data_vars]]:
        da_sel = ds[i].where(ds[i] != ds[i].attrs.get('_FillValue'))
        ds[i].attrs.update(dict(valid_min=da_sel.min().compute().item(),
                                valid_max=da_sel.max().compute().item()))

    # Add general dataset attributes
    ds.attrs.update(dict(
        institution=f'Department of Geography, Ludwig-Maximilians-Universität München, Munich, Germany',
        author='Raphael Ganzenmüller',
        version='1.0',
        date=f'{datetime.datetime.now():%Y-%m-%d %H:%M:%S%z}'  # Add current date and time
    ))

    # Load dataset to memory
    ds = ds.persist()
    
    # Export the dataset as main and uncertainty datasets   
    def export_data(ds_out, dir_out, file_out, complevel=4):
        # Get encoding
        encoding = {var: dict(zlib=True, complevel=complevel) for 
                    var in ds_out.data_vars}
        # Export
        ds_out.to_netcdf(os.path.join(dir_out, file_out),
                         engine='netcdf4', encoding=encoding, mode='w') 
    
    # Get quantile variables
    l_q0xx = [i for i in ds.data_vars if i.endswith(('q005', 'q095'))]
    
    # Export main file
    export_data(ds.drop_vars(l_q0xx), dir04, f'{ctype}.nc')
    
    # Export uncertainty file
    export_data(ds[l_q0xx], dir04, f'{ctype}_unc.nc')

In [None]:
%time create_ds_ctype('agbc', complevel=1)
%time create_ds_ctype('bgbc', complevel=1)
%time create_ds_ctype('soc', complevel=1)

---

### Check

In [None]:
import matplotlib.pyplot as plt

In [None]:
def get_ds_out(file_out):
    return xr.open_dataset(os.path.join(dir04, file_out), 
                          chunks='auto', decode_cf=False)

def get_carbon_sum(ds):
    return ((ds.where(ds != -32768) * ds.area_ha) \
                .sum(['lat', 'lon'])  * 1E-09).compute()

In [None]:
ds_agbc = get_ds_out('agbc.nc')
ds_agbc_unc = get_ds_out('agbc_unc.nc')

print(ds_agbc)
print(get_carbon_sum(ds_agbc))
print(ds_agbc_unc)

In [None]:
ds_bgbc = get_ds_out('bgbc.nc')
ds_bgbc_unc = get_ds_out('bgbc_unc.nc')

print(ds_bgbc)
print(get_carbon_sum(ds_bgbc))
print(ds_bgbc_unc)

In [None]:
ds_soc = get_ds_out('soc.nc')
ds_soc_unc = get_ds_out('soc_unc.nc')

print(ds_soc)
print(get_carbon_sum(ds_soc))
print(ds_soc_unc)

In [None]:
def plot_main(ds):
    ds = ds.where(ds != -32768).persist()
    
    fig, axs = plt.subplots(figsize=(20, 12), ncols=3, nrows=3)
    axs = axs.ravel()
    
    for i in range(0, 9):
        var_tar = [i for i in list(ds.data_vars) if i != 'area_ha'][i]
        ds[var_tar].plot.imshow(ax=axs[i], robust=True)
        axs[i].set_title(var_tar)

    plt.tight_layout()


def plot_unc(ds):
    ds = ds.where(ds != -32768).persist()
    
    fig, axs = plt.subplots(figsize=(20, 15), ncols=3, nrows=4)
    axs = axs.ravel()
    
    for i in range(0, 12):
        v = list(ds.data_vars)[i]
        ds[v].plot.imshow(ax=axs[i], robust=True)
        axs[i].set_title(v)

    plt.tight_layout()

In [None]:
plot_main(ds_agbc)

In [None]:
plot_unc(ds_agbc_unc)

In [None]:
plot_main(ds_bgbc)

In [None]:
plot_unc(ds_bgbc_unc)

In [None]:
plot_main(ds_soc)

In [None]:
plot_unc(ds_soc_unc)

In [None]:
cluster.close()