# Adjust random forest predictions to resolve internal inconsistencies

In [None]:
# Libraries
import os, shutil
import xarray as xr

In [None]:
# Directories
dir01 = '../paper_deficit/output/01_prep/'
dir02 = '../paper_deficit/output/02_dbase/'
dir03 = '../paper_deficit/output/03_rf/'
dir03a = '../paper_deficit/output/03_rf/files_adjusted/'

---

In [None]:
def rf_adjust_step1(var_tar, scen):
    
    """ Adjust that (1) primary is larger than orig and 
    (2) secdondary larger than original where not primary """
    
    # Get Carbon type
    ctype = var_tar.split('_')[0]
    
    # Get dataset with original carbon values
    ds_orig = xr.open_zarr(os.path.join(dir02, f"ds_prep_{ctype}.zarr"))
    da_orig = ds_orig[var_tar].fillna(-32768).round(0).astype('int16')
    
    # Get predicted carbon values
    ds_rf = xr.open_zarr(
        os.path.join(dir03, f"files_predicted/ds_rfpred_{var_tar}_{scen}.zarr"))
    da_rf = ds_rf['rfr_mean']
    
    # Adjust that primary is larger than original
    if scen == 'prim':
        da_rfa = xr.where(da_rf < da_orig, da_orig, da_rf) \
            .rename('rfa_mean')
    # Adjust that secondary is larger than original where not primary
    if scen == 'secd':
        # Get primary land data
        ds_riggio_vlhi0 = xr.open_zarr(
            os.path.join(dir01, "ds_prep_riggio_vlhi0.zarr"))
        da_prim = ds_riggio_vlhi0.riggio_vlhi0
        # Adjust
        da_rfa = xr.where((da_rf < da_orig) & (da_prim != True), 
                          da_orig, da_rf) \
            .rename('rfa_mean')
    
    # Export as zarr file
    da_rfa.to_zarr(os.path.join(dir03a, f"ds_rfpred_{var_tar}_{scen}_rfa.zarr"),
                   mode='w')

In [None]:
def rf_adjust_step2(var_tar, scen):
    
    """ Adjust that secondary is always smaller than primary"""
    
    # Get data
    ds_rf_prim = xr.open_zarr(
        os.path.join(dir03a, f"ds_rfpred_{var_tar}_prim_rfa.zarr"))
    ds_rf_secd = xr.open_zarr(
        os.path.join(dir03a, f"ds_rfpred_{var_tar}_secd_rfa.zarr"))
    
    # If primary: No change
    if scen == 'prim':
        da_out = ds_rf_prim.rfa_mean
    
    # If secondary: Adjust values, 
    if scen == 'secd':
        # Adjust that secondary is always smaller than primary
        da_out = xr.where(ds_rf_secd.rfa_mean > ds_rf_prim.rfa_mean, 
                          ds_rf_prim.rfa_mean, ds_rf_secd.rfa_mean)

    # Rename and export as zarr file
    da_out \
        .rename('rfaa_mean') \
        .to_zarr(os.path.join(dir03a, f"ds_rfpred_{var_tar}_{scen}_rfaa.zarr"),
                 mode='w')

In [None]:
def rf_adjust_step3(var_tar, scen):
    
    """Adjust that min is lower or equal than mean and
    mean is lower or equal than max"""    

    # Get carbon type
    ctype = var_tar.split('_')[0]

    # Get data of min, mean and max data
    ds_rf_min = xr.open_zarr(
            os.path.join(dir03a, f"ds_rfpred_{ctype}_min_{scen}_rfaa.zarr"))
    ds_rf_mean = xr.open_zarr(
            os.path.join(dir03a, f"ds_rfpred_{ctype}_mean_{scen}_rfaa.zarr"))
    ds_rf_max = xr.open_zarr(
            os.path.join(dir03a, f"ds_rfpred_{ctype}_max_{scen}_rfaa.zarr"))

    # Adjust that min =< mean =< max
    if var_tar.split('_')[1] == 'min':
        da_out = ds_rf_min.rfaa_mean
    
    if var_tar.split('_')[1] == 'mean':
        da_out = xr.concat([ds_rf_min.rfaa_mean, ds_rf_mean.rfaa_mean], 
                           dim='a') \
            .max('a')
    
    if var_tar.split('_')[1] == 'max':
        da_out = xr.concat([ds_rf_min.rfaa_mean, ds_rf_mean.rfaa_mean, 
                            ds_rf_max.rfaa_mean], dim='a') \
            .max('a')

    # Rename and export as zarr
    da_out \
        .rename('rfaaa_mean') \
        .to_zarr(os.path.join(dir03a, f"ds_rfpred_{var_tar}_{scen}_rfaaa.zarr"),
                 mode='w')

In [None]:
%%time
# Adjust AGBC
# Step 1
for var_tar in ['agbc_min', 'agbc_mean', 'agbc_max']:
    for scen in ['prim', 'secd']:
        rf_adjust_step1(var_tar, scen)

# Step 2
for var_tar in ['agbc_min', 'agbc_mean', 'agbc_max']:
    for scen in ['prim', 'secd']:
        rf_adjust_step2(var_tar, scen)

# Step 3
for var_tar in ['agbc_min', 'agbc_mean', 'agbc_max']:
    for scen in ['prim', 'secd']:
        rf_adjust_step3(var_tar, scen)

In [None]:
%%time
# Adjust BGBC
# Step 1
for var_tar in ['bgbc_min', 'bgbc_mean', 'bgbc_max']:
    for scen in ['prim', 'secd']:
        rf_adjust_step1(var_tar, scen)

# Step 2
for var_tar in ['bgbc_min', 'bgbc_mean', 'bgbc_max']:
    for scen in ['prim', 'secd']:
        rf_adjust_step2(var_tar, scen)

# Step 3
for var_tar in ['bgbc_min', 'bgbc_mean', 'bgbc_max']:
    for scen in ['prim', 'secd']:
        rf_adjust_step3(var_tar, scen)

In [None]:
%%time
# Adjust SOC
# Step 1
for var_tar in ['soc_min', 'soc_mean', 'soc_max']:
    for scen in ['prim', 'secd']:
        rf_adjust_step1(var_tar, scen)

# Step 2
for var_tar in ['soc_min', 'soc_mean', 'soc_max']:
    for scen in ['prim', 'secd']:
        rf_adjust_step2(var_tar, scen)

# Step 3
for var_tar in ['soc_min', 'soc_mean', 'soc_max']:
    for scen in ['prim', 'secd']:
        rf_adjust_step3(var_tar, scen)

---

### Check

In [None]:
import pandas as pd

In [None]:
dir03p = '../paper_deficit/output/03_rf/files_predicted/'

In [None]:
ds_area = xr.open_zarr('../paper_deficit/output/01_prep/ds_prep_area_ha.zarr')
da_area = ds_area.area_ha

In [None]:
def get_carbon(var_tar, scen):

    ctype = var_tar.split('_')[0]
    ds_orig = xr.open_zarr(os.path.join(dir02, f"ds_prep_{ctype}.zarr"))
    da_orig = ds_orig[var_tar]
    
    ds_rf = xr.open_zarr(dir03p + f"ds_rfpred_{var_tar}_{scen}.zarr")
    da_rf = ds_rf.where(ds_rf != -32768).rfr_mean
    
    ds_rfa = xr.open_zarr(dir03a + f"ds_rfpred_{var_tar}_{scen}_rfa.zarr")
    da_rfa = ds_rfa.where(ds_rfa != -32768).rfa_mean
    
    ds_rfaa = xr.open_zarr(dir03a + f"ds_rfpred_{var_tar}_{scen}_rfaa.zarr")
    da_rfaa = ds_rfaa.where(ds_rfaa != -32768).rfaa_mean
    
    ds_rfaaa = xr.open_zarr(dir03a + f"ds_rfpred_{var_tar}_{scen}_rfaaa.zarr")
    da_rfaaa = ds_rfaaa.where(ds_rfaaa != -32768).rfaaa_mean

    
    v_orig = (da_orig * da_area).sum(['lat', 'lon'])
    v_rfr = (da_rf * da_area).sum(['lat', 'lon'])
    v_rfa = (da_rfa * da_area).sum(['lat', 'lon'])
    v_rfaa = (da_rfaa * da_area).sum(['lat', 'lon'])
    v_rfaaa = (da_rfaaa * da_area).sum(['lat', 'lon'])

    v_list = [int(round((i.compute() * 1E-09).item(),0)) for i in 
              [v_orig, v_rfr, v_rfa, v_rfaa, v_rfaaa]]
    return [var_tar, scen, *v_list]


def get_carbon_sum_df(ctype):
    columns=["var_tar", "scen", "v_orig", "v_rfr", "v_rfa", "v_rfaa", "v_rfaaa"]
    df_carbon_sum = pd.DataFrame(columns=columns)
    
    for var_tar in [ctype + i for i in ['_min', '_mean', '_max']]:
        for scen in ['prim', 'secd']:
            df_carbon_sum.loc[len(df_carbon_sum)] = get_carbon(var_tar, scen)

    df_carbon_sum = df_carbon_sum.assign(
        v_deficit_pgc = df_carbon_sum.v_rfaaa - df_carbon_sum.v_orig,
        v_deficit_percent = \
            round(((1 - df_carbon_sum.v_orig / df_carbon_sum.v_rfaaa) * 100), 1),
        v_imp_a_percent = \
            round((1 - (df_carbon_sum.v_rfr / df_carbon_sum.v_rfaaa)) * 100, 1)
    )
    
    return df_carbon_sum

In [None]:
df_agbc = get_carbon_sum_df('agbc')
df_agbc

In [None]:
df_bgbc = get_carbon_sum_df('bgbc')
df_bgbc

In [None]:
df_soc = get_carbon_sum_df('soc')
df_soc

In [None]:
def get_value(var_tar, scen, value):
    if 'agbc' in var_tar:
        df = df_agbc
    if 'bgbc' in var_tar:
        df = df_bgbc
    if 'soc' in var_tar:
        df = df_soc

    return df[(df.var_tar == var_tar) & (df.scen == scen)][value].item()

In [None]:
value = 'v_orig'
for scen in ['secd', 'prim']:
    print('Actual AGBC + BGBC + SOC')
    print(get_value('agbc_max', scen, value) + 
          get_value('bgbc_max', scen, value) + 
          get_value('soc_mean', scen, value))

In [None]:
value = 'v_rfaaa'
for scen in ['secd', 'prim']:
    print('Potential AGBC + BGBC + SOC')
    print(get_value('agbc_max', scen, value) + 
          get_value('bgbc_max', scen, value) + 
          get_value('soc_mean', scen, value))

In [None]:
value = 'v_deficit_pgc'
for scen in ['secd', 'prim']:
    print('Deficit AGBC + BGBC + SOC')
    print(get_value('agbc_max', scen, value) + 
          get_value('bgbc_max', scen, value) + 
          get_value('soc_mean', scen, value))

In [None]:
value = 'v_orig'
for scen in ['secd', 'prim']:
    print('Actual AGBC + BGBC')
    print(get_value('agbc_max', scen, value) + 
          get_value('bgbc_max', scen, value))

In [None]:
value = 'v_rfaaa'
for scen in ['secd', 'prim']:
    print('Potential AGBC + BGBC')
    print(get_value('agbc_max', scen, value) + 
          get_value('bgbc_max', scen, value))

In [None]:
value = 'v_deficit_pgc'
for scen in ['secd', 'prim']:
    print('Deficit AGBC + BGBC')
    print(get_value('agbc_max', scen, value) + 
          get_value('bgbc_max', scen, value))