In [None]:
from sklearn.metrics import make_scorer
from statsmodels.stats.weightstats import DescrStatsW
from sklearn.metrics import r2_score, mean_squared_error, d2_tweedie_score
from sklearn.metrics import mean_absolute_error
from sklearn.ensemble import RandomForestRegressor
from sklearn.utils.validation import check_is_fitted
from joblib import Parallel, delayed
import threading

import joblib
import numpy as np
import pandas as pd
import seaborn as sns

import bottleneck as bn

def NormRootMeanSqrtErr(y_true, y_pred, type = "minmax"):
    """
    Normalized Root Mean Square Error.
    
    interpretation: smaller is better.

    Args:
        y_true ([np.array]): test samples
        y_pred ([np.array]): predicted samples
        type (str): type of normalization. default is "minmax"
        - sd (standard deviation) : it divides the value by the standard deviation of the data.
        - mean (mean) : it divides the value by the mean of the data.
        - minmax (min-max) : it divides the value by the range of the data.
        - iqr (interquartile range) : it divides the value by the interquartile range of the data.

    Returns:
        [float]: normalized root mean square error
    """
    if type=="sd":
        return np.sqrt(mean_squared_error(y_true, y_pred))/np.std(y_true)
    elif type=="mean":
        return np.sqrt(mean_squared_error(y_true, y_pred))/np.mean(y_true)
    elif type=="minmax":
        return np.sqrt(mean_squared_error(y_true, y_pred))/(np.max(y_true) - np.min(y_true))
    elif type=="iqr":
        return np.sqrt(mean_squared_error(y_true, y_pred))/(np.quantile(y_true, 0.75) - np.quantile(y_true, 0.25))
    elif type!="":
        raise ValueError("type must be either 'sd', 'mean', 'minmax', or 'iqr'")

def concordance_correlation_coefficient(y_true, y_pred, wei):
  """Concordance correlation coefficient."""
  # Raw data
  dct = {
      'y_true': y_true,
      'y_pred': y_pred
  }
  df = pd.DataFrame(dct)
  # Remove NaNs
  df = df.dropna()
  # Pearson product-moment correlation coefficients
  y_true = df['y_true']
  y_pred = df['y_pred']
  #cor = np.corrcoef(y_true, y_pred)[0][1]
  cor = DescrStatsW(df.to_numpy(), weights=wei).corrcoef[0][1]
  # Means
  mean_true = np.average(y_true, weights=wei)
  mean_pred = np.average(y_pred, weights=wei)
  # Population variances
  #var_true = np.var(y_true)
  var_true = DescrStatsW(y_true, weights=wei).var
  #print(var_true, var_true1)
  #var_pred = np.var(y_pred)
  var_pred = DescrStatsW(y_pred, weights=wei).var
  # Population standard deviations
  sd_true = DescrStatsW(y_true, weights=wei).std
  sd_pred = DescrStatsW(y_pred, weights=wei).std
  # Calculate CCC
  numerator = 2 * cor * sd_true * sd_pred
  denominator = var_true + var_pred + (mean_true - mean_pred)**2

  return numerator / denominator

def _single_prediction(predict, X, out, i, lock):
    prediction = predict(X, check_input=False)
    with lock:
        out[i, :] = prediction

def cast_tree_rf(model):
    model.__class__ = TreesRandomForestRegressor
    return model

class TreesRandomForestRegressor(RandomForestRegressor):
    def predict(self, X):
        """
        Predict regression target for X.

        The predicted regression target of an input sample is computed according
        to a list of functions that receives the predicted regression targets of each 
        single tree in the forest.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            The input samples. Internally, its dtype will be converted to
            ``dtype=np.float32``. If a sparse matrix is provided, it will be
            converted into a sparse ``csr_matrix``.

        Returns
        -------
        s : an ndarray of shape (n_estimators, n_samples)
            The predicted values for each single tree.
        """
        check_is_fitted(self)
        # Check data
        X = self._validate_X_predict(X)

        # store the output of every estimator
        assert(self.n_outputs_ == 1)
        pred_t = np.empty((len(self.estimators_), X.shape[0]), dtype=np.float32)
        # Assign chunk of trees to jobs
        n_jobs = min(self.n_estimators, self.n_jobs)
        # Parallel loop prediction
        lock = threading.Lock()
        Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")(
            delayed(_single_prediction)(self.estimators_[i].predict, X, pred_t, i, lock)
            for i in range(len(self.estimators_))
        )
        return pred_t

ccc_scorer = make_scorer(concordance_correlation_coefficient, greater_is_better=True)

sns.set_theme(style="whitegrid", palette="pastel")
sns.set_context("notebook")

wd = '/mnt/tupi/WRI/livestock_global_modeling/livestock_census_ard'

#sample_fn = f'{wd}/gpw_livestock.animals_gpw.fao.malek.2024_zonal.samples_20000101_20231231_go_epsg.4326_v1.1.pq'
sample_fn = f'{wd}/gpw_livestock.animals_gpw.fao.faostat.malek.2024_zonal.samples_20000101_20231231_go_epsg.4326_v1.pq'
##sample_fn = f'{wd}/gpw_livestock.animals_faostat_zonal.samples_20000101_20211231_go_epsg.4326_v1.pq'
samples = pd.read_parquet(sample_fn)

animals = ['cattle', 'sheep', 'goat', 'horse', 'buffalo']
model_names = ['lgb', 'rf']
max_density = {
    'cattle': 1511,
    'sheep': 713,
    'goat': 832,
    'horse': 83,
    'buffalo': 490
}

## Number of samples

In [None]:
animals = ['cattle', 'horse', 'goat', 'sheep', 'buffalo']
stats = []
for a in animals:
    mask = np.logical_and(samples[f'ind_{a}'] == 1, samples[f'{a}_density'] <= max_density[a])
    omask = np.logical_and(samples[f'ind_{a}'] == 1, samples[f'{a}_density'] > max_density[a])
    zmask = np.logical_and(samples[f'ind_{a}'] == 1, samples[f'{a}_density'] == 0.001)
    samples[mask]['source'].unique()
    stats.append({
        'animal': a,
        'source': ', '.join(list(samples[mask]['source'].unique())),
        'q98': max_density[a],
        'n_outlier': samples[omask].shape[0],
        'n_zeros': samples[zmask].shape[0],
        'n_samples': samples[mask].shape[0]
    })

pd.DataFrame(stats)#.to_csv('samples_stats.csv')

## Reading models

In [None]:
models = {}

model_names = ['lgb', 'rf']

wei = 'nowei'
model_fn = f'zonal_models_zeros_{wei}_prod_v20250924'

suffs = ['', '_boxcox']

for animal in animals:
  for mod in model_names:
    for s in suffs:

        fn_model = f'{animal}.{animal}_density{s}'

        wd = '/mnt/tupi/WRI/livestock_global_modeling/livestock_census_ard'

        locals().update(**joblib.load(f'{wd}/{model_fn}/{fn_model}.{mod}_prod.lz4'))
        locals().update(**joblib.load(f'{wd}/{model_fn}/{fn_model}_rfecv.lz4'))

        prod_mod = eval(f"prod_{mod}")
        if 'rf' in mod:
            prod_mod.__class__ = TreesRandomForestRegressor

        models[f'{animal}_{mod}{s}'] = {
            'rfe': covs_rfe,
            'prod': prod_mod,
            'pt': target_pt
        }

## Predict test set

In [None]:
import geopandas as gpd

In [None]:
import math

livestock_test = []

for animal in animals:
  for mod in model_names: #(model_names + ['eml']):
    for s in suffs:
    
        print(f"Runing model {animal}_{mod}{s}")

        samples_test = samples[np.logical_and.reduce([
            samples[f'ind_{animal}'] == 1 
            ,samples[f'{animal}_ml_type'] == 'testing'
        ])]
        
        
        ## Adding area
        #samples_test = samples_test.merge(polygon_samples[['gazID','geometry']], on='gazID')
        mask = (samples_test[f'{animal}_density'] <= max_density[animal])
        #mask = (samples_test[f'{animal}_density'] > 0)
        #samples_test['area_km2'] = 
        #print(samples_test.shape)
            
        covs_rfe = models[f'{animal}_{mod}{s}']['rfe']
        prod_mod = models[f'{animal}_{mod}{s}']['prod']
        pt = models[f'{animal}_{mod}{s}']['pt']

        pred = prod_mod.predict(samples_test[mask][covs_rfe])
        
        #print(pred.shape, pred[0:10])
        if 'boxcox' in s:
            pred = pt.inverse_transform(pred.reshape(-1,1)).reshape(pred.shape)
            #print(pred.shape, pred[0:10])
        
        if 'rf' in mod:
            perc = np.nanpercentile(pred, [2.5,97.5], axis=0)
            std = (perc[1,:] - perc[0,:]) / 4
            pred = bn.nanmean(pred, axis=0)
        
        weight = 1
        if wei == 'wei':
            weight = samples_test[mask]['weight']
        
        #print(pred.shape)
        
        mod_row = pd.DataFrame({
          'livestock_area_km': samples_test[mask][f'mask_km2'],
          'country': samples_test[mask]['country'],
          'source': samples_test[mask]['source'],
          'predicted': pred,
          'expected': samples_test[mask][f'{animal}_density'],
          'weight': weight,
          'source': samples_test[mask]['source'],
          'gazName': samples_test[mask]['gazName'],
          'year': samples_test[mask]['year'],
          'model': f'{mod}{s}',
          'animal': animal,
          'strategy': model_fn
          #,'area_km2': samples_test[mask]['area_km2'] #gpd.GeoSeries(samples_test[mask]['geometry']).to_crs('+proj=igh +lon_0=0 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs +type=crs').area / 1e6
        })
        
        if 'rf' in mod:
            mod_row['predicted_p025'] = perc[0,:]
            mod_row['predicted_p975'] = perc[1,:]
            mod_row['predicted_lower'] = pred - std
            mod_row['predicted_upper'] = pred + std

        livestock_test.append(mod_row)
    
livestock_test = pd.concat(livestock_test)

In [None]:
livestock_test

In [None]:
livestock_test['rf_inside_pr'] = np.logical_and.reduce([
    livestock_test['predicted_p025'] < livestock_test['expected'],
    livestock_test['predicted_p975'] > livestock_test['expected'],
])

In [None]:
livestock_test['rf_inside_sd'] = np.logical_and.reduce([
    livestock_test['predicted_lower'] < livestock_test['expected'],
    livestock_test['predicted_upper'] > livestock_test['expected'],
])

In [None]:
livestock_test[['rf_inside_sd','animal']].rename(columns={
    'animal': 'Animals',
    'rf_inside_sd': 'RF'
}).groupby('Animals').mean().sort_values('RF')
#.plot(kind='barh', xlabel="Prediction Interval Coverage Probability (PICP)", xlim=(0, 1)).legend(loc='best', bbox_to_anchor=(1, 1))

In [None]:
livestock_test[['rf_inside_pr','animal']].rename(columns={
    'animal': 'Animals',
    'rf_inside_pr': 'RF'
}).groupby('Animals').mean().sort_values('RF')
#.plot(kind='barh', xlabel="Prediction Interval Coverage Probability (PICP)", xlim=(0, 1)).legend(loc='best', bbox_to_anchor=(1, 1))

### Temporal validation

In [None]:
for (year,animal), rows in livestock_test.groupby(['year', 'animal']):
    pred_label, expe_label = 'predicted', 'expected'
    rmse = mean_squared_error(rows[expe_label], rows[pred_label], squared=False)
    #print(animal, year, rmse)

In [None]:
livestock_test.to_parquet(f'{wd}/test_acc_{model_fn}.pq')

## Accuracy plots

In [None]:
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.metrics import mean_absolute_error
from matplotlib import colors
import math

import math
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.metrics import mean_absolute_error
from matplotlib.ticker import LogFormatterExponent
from matplotlib import colors
import math

def plot_hexbin(livestock_rows_df, models, animals, pos, n_col_rows, figsize=(10,8)):

    fontsize = 14

    plt.rc('font', size=fontsize)          # controls default text sizes
    plt.rc('axes', titlesize=fontsize)     # fontsize of the axes title
    plt.rc('axes', labelsize=fontsize)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=fontsize)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=fontsize)    # fontsize of the tick labels
    plt.rc('legend', fontsize=fontsize)    # legend fontsize
    plt.rc('figure', titlesize=fontsize)  # fontsize of the figure title
    
    fig, axs = plt.subplots(ncols=n_col_rows[0], nrows=n_col_rows[1], figsize=figsize)
    #fig.suptitle(title)
    plt.tight_layout(pad=3, h_pad=4)

    for (animal,p,(model, title)) in zip(animals, pos, models):
        
        pred_label = f'Predicted density (heads per km2)'
        expe_label = f'Observed density (heads per km2)'
        
        col_remap = {}
        col_remap['predicted'] = pred_label
        col_remap['expected'] = expe_label
        
        livestock_rows_hexbin = livestock_rows_df[
            np.logical_and.reduce([
                livestock_rows_df['model'] == model,
                livestock_rows_df['animal'] == animal,
                livestock_rows_df['livestock_area_km'] > 0,
                livestock_rows_df['expected'] > 0,
                livestock_rows_df['predicted'] > 0,
                #livestock_rows_df['expected_density'] > 1
                #livestock_rows_df['source'].isin(['GPW']),
                #livestock_rows_df['source'].isin(['Malek et al., 2024']),
            ])
        ].rename(columns=col_remap)[[expe_label,pred_label]]

        max_log = math.ceil(np.max(np.log10(livestock_rows_hexbin.to_numpy())))
        print(max_log)

        d2 = d2_tweedie_score(livestock_rows_hexbin[expe_label], livestock_rows_hexbin[pred_label], power=1)
        nrmse = NormRootMeanSqrtErr(livestock_rows_hexbin[expe_label], livestock_rows_hexbin[pred_label], 'sd')
        rmse = mean_squared_error(livestock_rows_hexbin[expe_label], livestock_rows_hexbin[pred_label], squared=False)
        ccc = concordance_correlation_coefficient(livestock_rows_hexbin[expe_label], livestock_rows_hexbin[pred_label], 
                                            np.ones(livestock_rows_hexbin.shape[0]))
        r2 = r2_score(livestock_rows_hexbin[expe_label], livestock_rows_hexbin[pred_label])
        stats=f'CCC={ccc:.3f}\nD2={d2:.3f}\nRMSE={rmse:.2f}'

        livestock_rows_hexbin.plot.hexbin(x=pred_label, y=expe_label, mincnt=1, gridsize=32,  cmap="copper_r", 
                                               yscale='log', xscale='log', bins='log', extent=[-0.2,max_log,-0.2,max_log], ax=axs[p[0],p[1]])
        axs[p[0],p[1]].set_title(f'{animal.capitalize()} model ({title})', fontweight='bold')
        axs[p[0],p[1]].axline([0, 0], [1, 1], color='silver')
        axs[p[0],p[1]].text(0.05, 0.95, stats, transform=axs[p[0],p[1]].transAxes, fontsize=11,
            verticalalignment='top', bbox=dict(boxstyle='square,pad=.6',
            facecolor=colors.to_rgba('white', alpha=0.8)))

model_label = 'Random Forest BX'
model_types = [['rf', 'RF'],['rf', 'RF'],
          ['rf', 'RF'], ['rf', 'RF'], ['rf', 'RF']]
title = f"{model_label}"
animals = ['cattle', 'horse', 'goat', 'sheep', 'buffalo']
pos = [[0,0],[0,1],[1,0],[1,1],[2,0]]
#animals = ['cattle', 'sheep']
#pos = [[0,0],[0,1]]
        
plot_hexbin(livestock_test, model_types, animals, pos, [2,3], figsize=(12,15))

In [None]:
top10 = list(livestock_test.groupby('gazName').aggregate({'livestock_area_km': 'first'}).sort_values('livestock_area_km', ascending=False).head(10).reset_index()['gazName'])
top10

## Variable importance

In [None]:
import matplotlib.pyplot as plt 
import seaborn as sns
sns.set_theme(style="whitegrid", palette="deep")
sns.set_context("notebook")

In [None]:
#set(sum([ list(models[f'{a}_rf']['rfe']) for a in animals ],[]))
feat_labels = {
    "filtered.dtm_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20240528": "Static elevation (GEDTM)",
    "dfme_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "DFME (GEDTM)",
    "geomorphon_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Geomorphon (GEDTM)",
    "hillshade_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Hillshade (GEDTM)",
    "ls.factor_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "LS Factor (GEDTM)",
    "maxic_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Maxic (GEDTM)",
    "minic_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Minic (GEDTM)",
    "neg.openness_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Neg. Openess (GEDTM)",
    "nodepress.dtm_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Nodepress (GEDTM)",
    "pos.openness_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Pos. Openess (GEDTM)",
    "pro.curv_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Prov. curv. (GEDTM)",
    "ring.curv_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Ring curv. (GEDTM)",
    "shpindx_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Shpindx (GEDTM)",
    "slope.in.degree_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Slope (GEDTM)",
    "spec.catch_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Spec. catch (GEDTM)",
    "ssdon_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "SSDON (GEDTM)",
    "tan.curv_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Tan. Cuvr. (GEDTM)",
    "twi_edtm_m_960m_s_20000101_20221231_go_epsg.4326_v20241230": "Topo. Weteness index (GEDTM)",
    "lcv_accessibility.to.cities_map.ox.var10_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (20k to 110 mi. pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var11_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (50k to 50 mi. pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var12_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (5k to 110 mi. pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var1_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (5—50 mi. pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var2_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (1—5 mi. pop - Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var3_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (500—1000k pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var4_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (200—500k pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var5_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (100—200k pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var6_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (50—100k pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var7_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (20—50k pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var8_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (10—20k pop., Nelson et al., 2019)",
    "lcv_accessibility.to.cities_map.ox.var9_m_1km_s0..0cm_2015_v14052019": "Travel time to near. city (5—10k pop., Nelson et al., 2019)",
    "lcv_accessibility.to.ports_map.ox.var1_m_1km_s0..0cm_2015_v14052019": "Travel time to near. large port (Nelson et al., 2019)",
    "lcv_accessibility.to.ports_map.ox.var2_m_1km_s0..0cm_2015_v14052019": "Travel time to near. medium port (Nelson et al. 2019)",
    "lcv_accessibility.to.ports_map.ox.var3_m_1km_s0..0cm_2015_v14052019": "Travel time to near. small port (Nelson et al. 2019)",
    "lcv_accessibility.to.ports_map.ox.var4_m_1km_s0..0cm_2015_v14052019": "Travel time to near. very small port (Nelson et al. 2019)",
    "lcv_accessibility.to.ports_map.ox.var5_m_1km_s0..0cm_2015_v14052019": "Travel time to near. any port (Nelson et al. 2019)",
    'lcv_water.occurance_jrc.surfacewater_p_250m_b0..200cm_1984..2018_v1.1': 'Long-term water occur. (JRC Global Surface Water)',
    "peatland.extent_wri.gfw.peatgrids_p_1km_s_2000_2020_go_epsg4326_v20241017": "Long-term peatland extent (WRI/GFW)",
    "buddhism.pct_world.religion_m_1km_s_20100101_20101231_go_epsg4326_v20241107": 'Buddhism population pct. (World Religion Data)',
    'christianity.pct_world.religion_m_1km_s_20100101_20101231_go_epsg4326_v20241107': 'Christian popuplation. pct. (World Religion Data)',
    'islam.pct_world.religion_m_1km_s_20100101_20101231_go_epsg4326_v20241107': 'Muslim population pct. (World Religion Data)',
    'judaism.pct_world.religion_m_1km_s_20100101_20101231_go_epsg4326_v20241107': 'Jewish population pct. (World Religion Data)',
    'judaism.islam.pct_world.religion_m_1km_s_20100101_20101231_go_epsg4326_v20241107': 'Judaism & Islamic population pct. (World Religion Data)',
    'other.religions.pct_world.religion_m_1km_s_20100101_20101231_go_epsg4326_v20241107': 'Other religions population pct. (World Religion Data)',
    "clm_lst_mod11a2.daytime.m02_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Feb. (MOD11A2)",
    "clm_lst_mod11a2.daytime.m03_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Mar. (MOD11A2)",
    "clm_lst_mod11a2.daytime.m04_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Apr. (MOD11A2)",
    "clm_lst_mod11a2.daytime.m05_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - May (MOD11A2)",
    "clm_lst_mod11a2.daytime.m06_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Jun. (MOD11A2)",
    "clm_lst_mod11a2.daytime.m07_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Jul. (MOD11A2)",
    "clm_lst_mod11a2.daytime.m08_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Aug. (MOD11A2)",
    "clm_lst_mod11a2.daytime.m09_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Sep. (MOD11A2)",
    "clm_lst_mod11a2.daytime.m10_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Oct. (MOD11A2)",
    "clm_lst_mod11a2.daytime.m11_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Nov. (MOD11A2)",
    "clm_lst_mod11a2.daytime.m12_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term daytime land surf. temp. - Dec. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m02_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Feb. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m03_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Mar. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m04_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Apr. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m05_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - May (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m06_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Jun. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m07_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Jul. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m08_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Aug. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m09_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Sep. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m10_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Oct. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m11_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Nov. (MOD11A2)",
    "clm_lst_mod11a2.nighttime.m12_p50_1km_s0..0cm_2000..2021_v1.2": "Long-term nightime land surf. temp. - Dec. (MOD11A2)",
    "wv_mcd19a2v061.seasconv.m.m01_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Jan. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m02_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Feb. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m03_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Mar. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m04_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Apr. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m05_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - May (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m06_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Jun. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m07_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Jul. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m08_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Aug. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m09_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Sep. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m10_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Oct. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m11_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Nov. (MCD19A2)",
    "wv_mcd19a2v061.seasconv.m.m12_p50_1km_s_20000101_20221231_go_epsg.4326_v20230619": "Long-term water vapour - Dec. (MCD19A2)",
    "aridity.index_gai.pet.m01_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Jan. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m02_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Feb. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m03_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Mar. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m04_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Apr. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m05_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index May  (Zomer et al., 2022)",
    "aridity.index_gai.pet.m06_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Jun. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m07_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Jul. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m08_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Aug. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m09_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Sep. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m10_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Oct. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m11_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Nov. (Zomer et al., 2022)",
    "aridity.index_gai.pet.m12_m_1km_s_19700101_20001231_go_epsg4326_v3": "Long-term aridity Index Dec. (Zomer et al., 2022)",
    "hdi_sdata.kummu.2018_m_10km_s_year0101_year1231_go_epsg4326_v20241107": "Annual Human Develop. Index (Kummu et al., 2018)",
    "wilderness_li2022.human.footprint_p_1km_s_year0101_year1231_go_epsg.4326_v16022022": 'Annual wilderness, Human Footprint (Mu et al., 2022)',
    "real.gdp_sdata.chen.2022_m_1km_s_year0101_year1231_go_epsg4326_v1": "Annual sub-national real GDP (Chen et al., 2022)",
    "rural.pop.dist_worldpop_m_1km_year0101_year1231_go_epsg4326_v20241129": 'Annual rural population distance (WorldPop/GHS-SMOD)',
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.01.01..year.01.31_v1.2": "Monthly nighttime land surf. temp. - Jan. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.02.01..year.02.28_v1.2": "Monthly nighttime land surf. temp. - Feb. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.03.01..year.03.31_v1.2": "Monthly nighttime land surf. temp. - Mar. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.04.01..year.04.30_v1.2": "Monthly nighttime land surf. temp. - Apr. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.05.01..year.05.31_v1.2": "Monthly nighttime land surf. temp. - May  (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.06.01..year.06.30_v1.2": "Monthly nighttime land surf. temp. - Jun. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.07.01..year.07.31_v1.2": "Monthly nighttime land surf. temp. - Jul. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.08.01..year.08.31_v1.2": "Monthly nighttime land surf. temp. - Aug. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.09.01..year.09.30_v1.2": "Monthly nighttime land surf. temp. - Sep. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.10.01..year.10.31_v1.2": "Monthly nighttime land surf. temp. - Oct. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.11.01..year.11.30_v1.2": "Monthly nighttime land surf. temp. - Nov. (MOD11A2)",
    "clm_lst_mod11a2.nighttime_p50_1km_s0..0cm_year.12.01..year.12.31_v1.2": "Monthly nighttime land surf. temp. - Dec. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.01.01..year.01.31_v1.2": "Monthly daytime land surf. temp. - Jan. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.02.01..year.02.28_v1.2": "Monthly daytime land surf. temp. - Feb. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.03.01..year.03.31_v1.2": "Monthly daytime land surf. temp. - Mar. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.04.01..year.04.30_v1.2": "Monthly daytime land surf. temp. - Apr. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.05.01..year.05.31_v1.2": "Monthly daytime land surf. temp. - May  (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.06.01..year.06.30_v1.2": "Monthly daytime land surf. temp. - Jun. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.07.01..year.07.31_v1.2": "Monthly daytime land surf. temp. - Jul. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.08.01..year.08.31_v1.2": "Monthly daytime land surf. temp. - Aug. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.09.01..year.09.30_v1.2": "Monthly daytime land surf. temp. - Sep. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.10.01..year.10.31_v1.2": "Monthly daytime land surf. temp. - Oct. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.11.01..year.11.30_v1.2": "Monthly daytime land surf. temp. - Nov. (MOD11A2)",
    "clm_lst_mod11a2.daytime_p50_1km_s0..0cm_year.12.01..year.12.31_v1.2": "Monthly daytime land surf. temp. - Dec. (MOD11A2)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.01.01..year.01.31_v2": "Monthly veget. index - Jan. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.02.01..year.02.28_v2": "Monthly veget. index - Feb. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.03.01..year.03.31_v2": "Monthly veget. index - Mar. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.04.01..year.04.30_v2": "Monthly veget. index - Apr. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.05.01..year.05.31_v2": "Monthly veget. index - May  (MOD13Q1/NDVI",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.06.01..year.06.30_v2": "Monthly veget. index - Jun. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.07.01..year.07.31_v2": "Monthly veget. index - Jul. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.08.01..year.08.31_v2": "Monthly veget. index - Aug. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.09.01..year.09.30_v2": "Monthly veget. index - Sep. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.10.01..year.10.31_v2": "Monthly veget. index - Oct. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.11.01..year.11.30_v2": "Monthly veget. index - Nov. (MOD13Q1/NDVI)",
    "veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_year.12.01..year.12.31_v2": "Monthly veget. index - Dec. (MOD13Q1/NDVI)",
    "night.lights_dmsp.viirs_m_1km_s_year0101_year1231_go_epsg4326_v1": "Annual harm. night lights (DMSP/VIIRS)",
    'clm_lst_max.geom.temp_m_30m_s_m1': 'Long-term geometric max. temperature - Jan. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m10': 'Long-term geometric max. temperature - Oct. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m11': 'Long-term geometric max. temperature - Nov. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m12': 'Long-term geometric max. temperature - Dec. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m2': 'Long-term geometric max. temperature - Feb. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m3': 'Long-term geometric max. temperature - Mar. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m4': 'Long-term geometric max. temperature - Apr. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m5': 'Long-term geometric max. temperature - May. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m6': 'Long-term geometric max. temperature - Jun. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m7': 'Long-term geometric max. temperature - Jul. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m8': 'Long-term geometric max. temperature - Aug. (Kilibarda et al., 2014)',
    'clm_lst_max.geom.temp_m_30m_s_m9': 'Long-term geometric max. temperature - Sep. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m1': 'Long-term geometric min. temperature - Jan. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m10': 'Long-term geometric min. temperature - Oct. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m11': 'Long-term geometric min. temperature - Nov. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m12': 'Long-term geometric min. temperature - Dec. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m2': 'Long-term geometric min. temperature - Feb. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m3': 'Long-term geometric min. temperature - Mar. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m4': 'Long-term geometric min. temperature - Apr. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m5': 'Long-term geometric min. temperature - May. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m6': 'Long-term geometric min. temperature - Jun. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m7': 'Long-term geometric min. temperature - Jul. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m8': 'Long-term geometric min. temperature - Aug. (Kilibarda et al., 2014)',
    'clm_lst_min.geom.temp_m_30m_s_m9': 'Long-term geometric min. temperature - Sep. (Kilibarda et al., 2014)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year0101_year0131_go_epsg.4326_v20250820': 'Monthly precipitation Jan. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year0201_year0228_go_epsg.4326_v20250820': 'Monthly precipitation Feb. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year0301_year0331_go_epsg.4326_v20250820': 'Monthly precipitation Mar. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year0401_year0430_go_epsg.4326_v20250820': 'Monthly precipitation Apr. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year0501_year0531_go_epsg.4326_v20250820': 'Monthly precipitation May (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year0601_year0630_go_epsg.4326_v20250820': 'Monthly precipitation Jun. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year0701_year0731_go_epsg.4326_v20250820': 'Monthly precipitation Jul. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year0801_year0831_go_epsg.4326_v20250820': 'Monthly precipitation Aug. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year0901_year0930_go_epsg.4326_v20250820': 'Monthly precipitation Sep. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year1001_year1031_go_epsg.4326_v20250820': 'Monthly precipitation Oct. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year1101_year1130_go_epsg.4326_v20250820': 'Monthly precipitation Nov. (IMERG)',
    'precipitation.liquid.rate_imerg.3b_m_10km_s_year1201_year1231_go_epsg.4326_v20250820': 'Monthly precipitation Dec. (IMERG)',
    'gpw_short.veg.height_egbt_m_30m_s_year0101_year1231_go_epsg.4326_v1': "Short vegetation height (GPW)"
}

In [None]:
def plot_var_imp(feature_cols, estimator, title, top_n = 20, figsize=(4,7), color = 'blue', model='rf'):
    
  if model == 'rf':
      var_imp = pd.DataFrame({'name':feature_cols, 'importance': estimator.feature_importances_})
  else:
      var_imp = pd.DataFrame({'name':feature_cols, 'importance': estimator.booster_.feature_importance(importance_type='gain')})
  
  var_imp['importance'] = var_imp['importance'] / var_imp['importance'].sum()
  var_imp.index = var_imp['name']
  result = var_imp.sort_values('importance', ascending=False)[0:top_n].sort_values('importance').plot(kind = 'barh', ylabel="", xlabel="Normalized importance", figsize=figsize, title = title, color = color)
  plt.title(title, fontweight='bold')
  return result

### Feature importance

In [None]:
animal = 'cattle'
s = ''

covs_rfe_rf = models[f'{animal}_rf{s}']['rfe']
samples_test = samples[np.logical_and.reduce([
    samples[f'ind_{animal}'] == 1, 
    samples[f'{animal}_ml_type'] == 'testing'
])]
mask = (samples_test[f'{animal}_density'] <= max_density[animal])
X = samples_test[mask][covs_rfe_rf]

feature_cols = [ feat_labels[f].replace(' (',' \n(') for f in models[f'{animal}_rf{s}']['rfe'] ]

feature_cols = [ f for f in models[f'{animal}_rf{s}']['rfe'] ]
estimator = models[f'{animal}_rf{s}']['prod']
var_imp = pd.DataFrame({'name':feature_cols, 'importance': estimator.feature_importances_})

top_feat = list(var_imp.sort_values('importance', ascending=False).head(4)['name'])

In [None]:
for p, a in zip(pos,animals):
    feature_cols = [ feat_labels[f] for f in models[f'{a}_rf']['rfe'] ]

In [None]:
fontsize = 15

plt.rc('font', size=fontsize)          # controls default text sizes
plt.rc('axes', titlesize=fontsize)     # fontsize of the axes title
plt.rc('axes', labelsize=fontsize)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('ytick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('legend', fontsize=fontsize)    # legend fontsize
plt.rc('figure', titlesize=fontsize)  # fontsize of the figure title

fig, axs = plt.subplots(ncols=2, nrows=3, figsize=(20,22))
#fig.suptitle(title)

pos = [[0,0],[0,1],[1,0],[1,1],[2,0]]
animals = ['cattle', 'horse', 'goat', 'sheep', 'buffalo']

for p, a in zip(pos,animals):
    feature_cols = [ feat_labels[f] for f in models[f'{a}_rf']['rfe'] ]
    
    estimator = models[f'{a}_rf']['prod']
    title = f'{a.capitalize()} model (RF)'
    color = '#bf794d'
    top_n = 25
    
    xlabel = "Normalized importance"
    
    var_imp = pd.DataFrame({'name':feature_cols, 'importance': estimator.feature_importances_})
    var_imp['importance'] = var_imp['importance'] / var_imp['importance'].sum()
    var_imp.index = var_imp['name']
    
    result = var_imp.sort_values('importance', ascending=False)[0:top_n].sort_values('importance').plot(kind = 'barh', ylabel="", 
                xlabel=xlabel, color = color, legend=False,
                ax=axs[p[0],p[1]])
    axs[p[0],p[1]].set_title(title, fontweight='bold')

plt.tight_layout(pad=2, h_pad=1)

### Partial Dep. plots

In [None]:
animal = 'cattle'

covs_rfe_rf = models[f'{animal}_rf']['rfe']
samples_test = samples[np.logical_and.reduce([
    samples[f'ind_{animal}'] == 1, 
    samples[f'{animal}_ml_type'] == 'testing'
])]
mask = (samples_test[f'{animal}_density'] <= max_density[animal])
X = samples_test[mask][covs_rfe_rf]

feature_cols = [ feat_labels[f].replace(' (',' \n(') for f in models[f'{animal}_rf']['rfe'] ]
#feature_cols = [ f for f in models[f'{animal}_rf']['rfe'] ]
estimator = models[f'{animal}_rf']['prod']
var_imp = pd.DataFrame({'name':feature_cols, 'importance': estimator.feature_importances_})

#top_feat = list(var_imp.sort_values('importance', ascending=False).head(4)['name'])

In [None]:
list(var_imp['name'])

In [None]:
top_feat = {
    'cattle': [
        'Long-term water vapour - Oct. \n(MCD19A2)', 
        'Long-term aridity Index Aug. \n(Zomer et al., 2022)',
        "Christian popuplation. pct. \n(World Religion Data)",
        "Long-term nightime land surf. temp. - Mar. \n(MOD11A2)"
    ],
    'horse': [
        'Long-term daytime land surf. temp. - Nov. \n(MOD11A2)',
        'Long-term aridity Index Sep. \n(Zomer et al., 2022)',
        'Monthly veget. index - Oct. \n(MOD13Q1/NDVI)',
        'Long-term water vapour - Oct. \n(MCD19A2)'
    ],
    'goat': [
        'Muslim population pct. \n(World Religion Data)',
        'Long-term aridity Index Sep. \n(Zomer et al., 2022)',
        "Long-term nightime land surf. temp. - Jun. \n(MOD11A2)",
        "Long-term water vapour - Feb. \n(MCD19A2)"
    ],
    'sheep': [
        'Jewish population pct. \n(World Religion Data)',
        'Long-term aridity Index Sep. \n(Zomer et al., 2022)',
        'Long-term nightime land surf. temp. - Oct. \n(MOD11A2)',
        'Long-term water vapour - Jan. \n(MCD19A2)'
    ],
    'buffalo': [
        'Muslim population pct. \n(World Religion Data)',
        'Long-term daytime land surf. temp. - Oct. \n(MOD11A2)',
        'Long-term peatland extent \n(WRI/GFW)',
        'Long-term water vapour - Aug. \n(MCD19A2)'
    ]
}

In [None]:
from sklearn.inspection import PartialDependenceDisplay

fontsize = 12

plt.rc('font', size=fontsize)          # controls default text sizes
plt.rc('axes', titlesize=fontsize)     # fontsize of the axes title
plt.rc('axes', labelsize=fontsize)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('ytick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('legend', fontsize=fontsize)    # legend fontsize
plt.rc('figure', titlesize=fontsize)  # fontsize of the figure title

title = f'{animal.capitalize()} model (RF)'

fig, ax = plt.subplots(figsize=(18, 4))
plt.tight_layout()
ax.set_title(title, fontweight='bold')

pdd = PartialDependenceDisplay.from_estimator(models[f'{animal}_rf']['prod'], 
                                              X, 
                                              top_feat[animal], 
                                              n_cols=5,
                                              subsample=None, feature_names=feature_cols, ax=ax)
fig.savefig(f"{animal}_rf_pda_top4.pdf", bbox_inches='tight')

## Model benchmark

In [None]:
from skmap.misc import find_files

In [None]:
from skmap.misc import find_files

v = 'v20250924'

df_bm = pd.concat([ pd.read_parquet(f) for f in find_files(wd, f'test*{v}.pq') ])
df_bm = df_bm[df_bm['model'].isin(['lgb','lgb_boxcox','rf','rf_boxcox'])]
df_bm['model'] = df_bm['model'] + df_bm['strategy'].map({
    f'zonal_models_zeros_nowei_prod_{v}': '',
    f'zonal_models_zeros_wei_prod_{v}': '_weight'
})
df_bm['label'] = df_bm['model'].map({
    'lgb': 'GBT',
    'lgb_boxcox': 'GBT / BoxCox trans.',
    'rf': 'RF',
    'rf_boxcox': 'RF / BoxCox Trans.',
    'lgb_weight': 'GBT / Sample weights',
    'lgb_boxcox_weight': 'GBT / BoxCox trans. / Sample weights',
    'rf_weight': 'RF / Sample weights',
    'rf_boxcox_weight': 'RF / BoxCox trans. / Sample weights'
})
df_bm

In [None]:
mask = np.logical_and.reduce([
    df_bm['livestock_area_km'] > 0,
    df_bm['expected'] > 0,
    df_bm['predicted'] > 0
])

In [None]:
results = []
for (label, animal), frows in df_bm[mask].groupby(['label','animal']):
    expe_label, pred_label = 'expected', 'predicted'
    results.append({
        'label': label,
        'animal': animal,
        'd2': d2_tweedie_score(frows[expe_label], frows[pred_label], power=1),
        'rmse': mean_squared_error(frows[expe_label], frows[pred_label], squared=False),
        'ccc': concordance_correlation_coefficient(frows[expe_label], frows[pred_label], np.ones(frows.shape[0])),
        'r2': r2_score(frows[expe_label], frows[pred_label])
    })

results = pd.DataFrame(results)

In [None]:
results['label'] = results['label'].replace({
    'GBT / BoxCox trans. / Sample weights': 'GBT / BoxCox trans. / weighed samples',
    'GBT / Sample weights': 'GBT / weighed samples',
    'RF / BoxCox trans. / Sample weights': 'RF / BoxCox trans. / weighed samples',
    'RF / Sample weights': 'RF / weighed samples',
    'RF': 'RF',
})

In [None]:
results[['animal','label','d2','r2','ccc','rmse']].round(3).sort_values(['animal','label'], ascending=True)#.to_csv('benchmark_results.csv')

In [None]:
livestock_test['model'].value_counts()

In [None]:
animal = 'cattle'
livestock_test[np.logical_and.reduce([
        livestock_test['animal'] == animal,
        livestock_test['model'] == 'rf',
        livestock_test['expected'] > 1,
        livestock_test['expected'] <= max_density[animal]
])][['expected','predicted']]\
.plot.hist(bins=128,  histtype='step', linewidth=1.5)

In [None]:
animal = 'cattle'
livestock_test[np.logical_and.reduce([
        livestock_test['animal'] == animal,
        livestock_test['model'] == 'rf_boxcox',
        livestock_test['expected'] > 1,
        livestock_test['expected'] <= max_density[animal]
])][['expected','predicted']]\
.plot.hist(bins=128,  histtype='step', linewidth=1.5)

In [None]:
livestock_test[np.logical_and.reduce([
        livestock_test['animal'] == 'cattle',
        livestock_test['source'] == 'GPW',
        livestock_test['model'] == 'rf',
    ])]['strategy'].value_counts()

In [None]:
for source, livestock_rows_hexbin in livestock_test[np.logical_and.reduce([
        livestock_test['animal'] == 'cattle',
        livestock_test['model'] == 'rf',
    ])].groupby('country'):
    print(source, livestock_rows_hexbin.shape)
    
    pred_label = 'predicted'
    expe_label = 'expected' 
    
    livestock_rows_hexbin_log = pd.DataFrame(livestock_rows_hexbin)
    livestock_rows_hexbin_log[expe_label] = np.log1p(livestock_rows_hexbin[expe_label])
    livestock_rows_hexbin_log[pred_label] = np.log1p(livestock_rows_hexbin[pred_label])
    
    d2 = d2_tweedie_score(livestock_rows_hexbin[expe_label], livestock_rows_hexbin[pred_label], power=1)
    nrmse = NormRootMeanSqrtErr(livestock_rows_hexbin[expe_label], livestock_rows_hexbin[pred_label], 'sd')
    ccc = concordance_correlation_coefficient(livestock_rows_hexbin_log[expe_label], livestock_rows_hexbin_log[pred_label], 
                                        np.ones(livestock_rows_hexbin.shape[0]))
    #r2 = r2_score(livestock_rows_hexbin_log[expe_label], livestock_rows_hexbin_log[pred_label])
    stats=f'D2={d2:.3f} NRMSE={nrmse:.3f} CCC(log1p)={ccc:.3f}'
    print(stats)

In [None]:
import math
for animal in animals:
  for model in models:

    fn_model = f'{animal}.{animal}_density'

    wd = '/mnt/tupi/WRI/livestock_global_modeling/livestock_census_ard'

    #locals().update(**joblib.load(f'{wd}/zonal_models_prod_v20250203/{fn_model}.{model}_cv.lz4'))
    locals().update(**joblib.load(f'{wd}/zonal_models_prod_v20250203/{fn_model}.{model}_prod.lz4'))
    locals().update(**joblib.load(f'{wd}/zonal_models_prod_v20250203/{fn_model}_rfecv.lz4'))

    prod_mod = eval(f"prod_{model}")

    samples_test = samples[np.logical_and(samples[f'ind_{animal}'] == 1, samples[f'{animal}_ml_type'] == 'testing')]
    
    if 'boxcox' in fn_model:
      df_cv['predicted'] = target_pt.inverse_transform(df_cv['predicted'].to_numpy().reshape(-1,1))
      df_cv['expected'] = target_pt.inverse_transform(df_cv['expected'].to_numpy().reshape(-1,1))

    mask = (samples_test[f'{animal}_density'] <= max_density[animal])
    df_test = pd.DataFrame({
      'predicted': prod_mod.predict(samples_test[mask][covs_rfe])  * samples_test[mask][f'livestock_area_km'],
      'expected':  samples_test[mask][f'{animal}_density'] * samples_test[mask][f'livestock_area_km'],
      'weight': samples_test[mask]['weight']
    })

    if 'boxcox' in fn_model:
      min_transformed_val = -2.65610344
      df_test.loc[df_test['predicted'] < min_transformed_val, 'predicted'] = min_transformed_val
      df_test['predicted'] = target_pt.inverse_transform(df_test['predicted'].to_numpy().reshape(-1,1))
    
    max_hist = np.percentile(df_test[['expected','predicted']].to_numpy(),[99])[0]
    mask_hist = np.logical_and(df_test['expected'] <= max_hist, df_test['predicted'] <= max_hist)
    
    plot = df_test[mask_hist][['predicted','expected']].plot.hist(bins=128, title=f"{fn_model}.{model}", histtype='step', linewidth=1.5, log=True)
    fig = plot.get_figure()
    
    df_test['weight'] = 1
    d2 = d2_tweedie_score(df_test['expected'], df_test['predicted'], power=1)
    rmse = mean_squared_error(df_test['expected'], df_test['predicted'], sample_weight=df_test['weight'], squared=False)
    
    max_log = math.ceil(np.log10(np.max(df_test[['expected','predicted']].to_numpy())))
    
    plot = plot_var_imp(covs_rfe, prod_mod, f'{fn_model}.{model}', 40, color='#6fa8dd', model=model)
    fig = plot.get_figure()

## Point time-series

In [None]:
y,x = -8.96941598, -64.01426078 #-9.04264281, -63.45872163
y1, y2 = 2000, 2022

In [None]:
df = pd.DataFrame({
    'x': x,
    'y': y,
    'year': range(y1, y2+1) 
})
gdf = gpd.GeoDataFrame(
    df, geometry=gpd.points_from_xy(df.x, df.y), crs="EPSG:4326"
).to_crs('+proj=igh +lon_0=0 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs +type=crs')
#gdf['dt'] = pd.to_datetime(df['year'],format='%Y')

In [None]:
import rasterio
def sample(row):
    y = row['year']
    
    ds = rasterio.open(f'https://s3.opengeohub.org/gpw/livestock_v2/gpw_cattle.density_rf_m_1km_s_{y}0101_{y}1231_go_epsg.4326_v1.tif')
    row['density'] = next(ds.sample([(row["geometry"].x,row["geometry"].y)]))[0] * 0.1
    
    ds1 = rasterio.open(f'https://s3.opengeohub.org/gpw/livestock_v2/gpw_cattle.density_rf_p.025_1km_s_{y}0101_{y}1231_go_epsg.4326_v1.tif')
    row['density_p025'] = next(ds1.sample([(row["geometry"].x,row["geometry"].y)]))[0] * 0.1
    
    ds3 = rasterio.open(f'https://s3.opengeohub.org/gpw/livestock_v2/gpw_cattle.density_rf_p.975_1km_s_{y}0101_{y}1231_go_epsg.4326_v1.tif')
    row['density_p975'] = next(ds3.sample([(row["geometry"].x,row["geometry"].y)]))[0] * 0.1
    
    std = (row['density_p975'] - row['density_p025']) / 4
    row['lower'] = row['density'] - std
    row['upper'] = row['density'] + std
    
    return row

result = gdf.apply(sample, axis=1)

In [None]:
((result['density_p975'] - result['density_p025']) / 4).mean()

In [None]:
import seaborn 
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns

sns.set_style("ticks")
FONT_SIZE = 25
plt.rc('font', size=FONT_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_SIZE) 

myfig, myax = plt.subplots(figsize=(24, 6))

# Plot temperature
myax.plot(result['year'], result['density'], color='tab:red', linestyle='-', label='Cattle density')
myax.fill_between(result['year'], result['lower'], result['upper'], color='tab:red', alpha=0.2, label='Prediction error (1 SD)')
myax.fill_between(result['year'], result['density_p025'], result['density_p975'], color='tab:blue', alpha=0.2, label='Prediction interval (95% prob.)')
#myax.plot(pd.to_datetime(gpf['start_date'], format='%Y-%m-%d'), gpf['value']/10000, color='tab:blue', linestyle='-', label='Gapfilled')
#myax.plot(pd.to_datetime(stl['start_date'], format='%Y-%m-%d'), stl['value']/10000, color='tab:green', linestyle='-', label='Trend')

myax.set_ylim([0,900])
myax.set_ylabel('Heads per km-square')
myax.set_title(f'Coordinate: {x:.4f}, {y:.4f}; average cattle density: {result["density"].mean():.1f}')
myax.grid(True)

# format x axis labels
#myax.xaxis.set_major_locator(DayLocator())
#myax.xaxis.set_major_formatter(mdates.DateFormatter('%y%m%d'))
#fmt_half_year = mdates.MonthLocator(interval=29)

myax.legend(loc='upper right');