# Convert predictions to dataarray

In [None]:
# Libraries
import os, shutil
import dask.dataframe as dd
import xarray as xr

In [None]:
# Directories
dir01 = '../paper_deficit/output/01_prep/'
dir03 = '../paper_deficit/output/03_rf/'
dir03p = os.path.join(dir03, 'files_predicted/')

---

In [None]:
# Dataset with orig lat and lon for reindexing
ds_worldclim_elev = xr.open_zarr(dir01 + 'ds_prep_worldclim_elev.zarr')

In [None]:
def df_rfpred2ds(var_tar, prim):

    """
    Calculate soc mean, min and max values of predictions from 10 best 
    performing models and export as zarr file
    """
        
    # Get rf data
    file_rfpred = os.path.join(dir03p, f"df_rfpred_{var_tar}_{scen}.parquet")
    df_rfpred  = dd.read_parquet(file_rfpred)
    # Get quantile rf data
    file_rfqpred = os.path.join(dir03p, f"df_rfqpred_{var_tar}_{scen}.parquet")
    df_rfqpred  = dd.read_parquet(file_rfqpred)
    
    # Calculate mean values for each grid cell from rf predicition
    rfr_cols = ['rfr_' + str(i) for i in range(1, 11)]
    df_rfpred = df_rfpred.assign(rfr_mean = df_rfpred[rfr_cols].mean(axis=1))
    
    # Transform to dataset and export as intermediate file
    def df2ds(df, i):
        """Transform each variable of dataframe to variable of dataset,
        ensure correct lat and lon (reindex_like) and export as intermediate file
        """
        
        # Name of output file
        file_out = f"ds_interm_{var_tar}_{scen}_{i}.zarr"
    
        # Tranfsform and export
        df[['lat', 'lon', i]] \
            .compute() \
            .set_index(['lat', 'lon']) \
            .to_xarray() \
            .chunk(dict(lat=5000, lon=5000)) \
            .sortby('lat', ascending=False) \
            .fillna(-32768) \
            .round(0) \
            .astype('int16') \
            .reindex_like(ds_worldclim_elev, method=None, fill_value=-32768) \
            .chunk(dict(lat=5000, lon=5000)) \
            .to_zarr(os.path.join(dir03p, file_out), mode='w')

    # Transform Random Forest predictions
    df2ds(df_rfpred, 'rfr_mean')
    # Transform Quantile Random Forest predictions
    for i in ['qrfr_005', 'qrfr_010', 'qrfr_090', 'qrfr_095']:
        df2ds(df_rfqpred, i)

    # Import datasets, merge and export as zarr
    xr.open_mfdataset(
        os.path.join(dir03p, f"ds_interm_{var_tar}_{scen}_*.zarr"),
        engine='zarr') \
        .to_zarr(
            os.path.join(dir03p, f"ds_rfpred_{var_tar}_{scen}.zarr"),
            mode='w')
    
    # Delete intermediate datasets
    for i in ['rfr_mean', 'qrfr_005', 'qrfr_010', 'qrfr_090', 'qrfr_095']:
        shutil.rmtree(
            os.path.join(dir03p, f"ds_interm_{var_tar}_{scen}_{i}.zarr"))

In [None]:
# Calculate agbc mean, min and max values of predictions from 10 best performing models
# Export as zarr file
for var_tar in ['agbc_min', 'agbc_mean', 'agbc_max']:
    for scen in ['prim', 'secd']:
        %time df_rfpred2ds(var_tar, scen)

In [None]:
# Calculate bgbc mean, min and max values of predictions from 10 best performing models
# Export as zarr file
for var_tar in ['bgbc_min', 'bgbc_mean', 'bgbc_max']:
    for scen in ['prim', 'secd']:
        %time df_rfpred2ds(var_tar, scen)

In [None]:
# Calculate soc mean, min and max values of predictions from 10 best performing models
# Export as zarr file
for var_tar in ['soc_min', 'soc_mean', 'soc_max']:
    for scen in ['prim', 'secd']:
        %time df_rfpred2ds(var_tar, scen)

---

### Check

In [None]:
def plot_check(var_tar, scen):
    ds_a = xr.open_zarr(os.path.join(dir03p, f"ds_rfpred_{var_tar}_{scen}.zarr"))
    ds_a.where(ds_a != -32768).rfr_mean.plot.imshow(robust=True)

In [None]:
plot_check('agbc_max', 'prim')

In [None]:
plot_check('agbc_max', 'secd')

In [None]:
plot_check('bgbc_max', 'prim')

In [None]:
plot_check('bgbc_max', 'secd')

In [None]:
plot_check('soc_mean', 'prim')

In [None]:
plot_check('soc_mean', 'secd')