This notebook generates the train/val/test splits for the DHS survey-based datasets included in SustainBench.

In [None]:
import copy
import os

import numpy as np
import pandas as pd

In [None]:
df = pd.read_csv('output_labels/dhs_final_labels.csv')
display(df.head())

# get DHS clusters CSV from Yeh et al. (2020, Nature Communications)
df2020 = pd.read_csv('https://github.com/chrisyeh96/africa_poverty_clean/raw/main/data/dhs_clusters.csv')
display(df2020.head())

dhs_countries_api = "http://api.dhsprogram.com/rest/dhs/countries?returnFields=CountryName,DHS_CountryCode,ISO2_countryCode,ISO3_countryCode&f=csv"
dhs_countries_crosswalk = pd.read_csv(dhs_countries_api)
display(dhs_countries_crosswalk.head())

In [None]:
for x in [df, df2020]:
    for col in ['lat', 'lon']:
        x[col + '32'] = x[col].astype(np.float32)

In [None]:
df2020['iso3'] = df2020['GID_1'].str[:3]
df2020 = df2020.merge(
    dhs_countries_crosswalk[['DHS_CountryCode', 'ISO3_CountryCode']],
    left_on='iso3', right_on='ISO3_CountryCode')
df2020.rename(columns={'DHS_CountryCode': 'dhs_cc'}, inplace=True)

In [None]:
df_subset = df.loc[
    df['cname'].isin(df2020['dhs_cc'].unique()) &
    (df['year'] >= 2009) &
    (df['year'] <= 2017)
]

In [None]:
merged = df_subset.merge(df2020, on=['lat32', 'lon32'], how='inner', validate='1:1')
merged

In [None]:
import matplotlib.pyplot as plt

In [None]:
merged[merged['asset_index'].isna()]

In [None]:
merged[merged['households'] < 5]

In [None]:
merged_notna = merged[merged['asset_index'].notna()]

In [None]:
import scipy.stats
r = scipy.stats.pearsonr(merged_notna['asset_index'], merged_notna['wealthpooled'])[0]
print('r:', r, 'r^2:', r**2)

In [None]:
fig, ax = plt.subplots(1, 1)
ax.scatter(merged_notna['asset_index'], merged_notna['wealthpooled'], s=1)
ax.set_xlabel('SustainBench asset index')
ax.set_ylabel('NatComms asset index')

In [None]:
FOLDS_2020 = {
    'A': ['angola', 'cote_d_ivoire', 'ethiopia', 'mali', 'rwanda'],
    'B': ['benin', 'burkina_faso', 'guinea', 'sierra_leone', 'tanzania'],
    'C': ['cameroon', 'ghana', 'malawi', 'zimbabwe'],
    'D': ['democratic_republic_of_congo', 'mozambique', 'nigeria', 'togo', 'uganda'],
    'E': ['kenya', 'lesotho', 'senegal', 'zambia'],
}

In [None]:
cname2020_to_dhscc = dict(df2020.groupby(['country', 'dhs_cc']).groups.keys())
display(cname2020_to_dhscc)

In [None]:
FOLDS_2020_dhscc = {
    f: [cname2020_to_dhscc[c] for c in countries]
    for f, countries in FOLDS_2020.items()
}
display(FOLDS_2020_dhscc)

In [None]:
# get a sense of existing fold sizes
FOLDS_2020_sizes = {
    f: df['cname'].isin(dhscc_codes).sum()
    for f, dhscc_codes in FOLDS_2020_dhscc.items()
}
print(FOLDS_2020_sizes)

In [None]:
FOLDS = copy.deepcopy(FOLDS_2020_dhscc)
FOLDS_sizes = copy.deepcopy(FOLDS_2020_sizes)

In [None]:
remaining_dhscc = set(df['cname'].unique())
for cc_list in FOLDS_2020_dhscc.values():
    remaining_dhscc -= set(cc_list)
print(remaining_dhscc)

In [None]:
sizes = df.groupby('cname').size()
sizes.loc[sorted(remaining_dhscc)].head()

In [None]:
while len(remaining_dhscc) > 0:
    top_cc = sizes.loc[list(remaining_dhscc)].idxmax()
    top_size = sizes.loc[top_cc]
    smallest_fold = min(FOLDS_sizes, key=FOLDS_sizes.get)
    FOLDS[smallest_fold].append(top_cc)
    FOLDS_sizes[smallest_fold] += top_size
    remaining_dhscc.remove(top_cc)

In [None]:
print(FOLDS)

In [None]:
print(FOLDS_sizes)

In [None]:
SPLITS = {
    'train': sorted(FOLDS['C'] + FOLDS['D'] + FOLDS['E']),
    'val':   sorted(FOLDS['B']),
    'test':  sorted(FOLDS['A'])
}
print(SPLITS)

In [None]:
for label in ['asset_index', 'under5_mort', 'women_bmi', 'women_edu', 'water_index', 'sanitation_index']:
    print(f'{label:17s}', {
        split: len(df[df[label].notna() & df['cname'].isin(SPLITS[split])])
        for split in ['train', 'val', 'test']
    })

In [None]:
# calculate size of each split in %
for label in ['asset_index', 'under5_mort', 'women_bmi', 'women_edu', 'water_index', 'sanitation_index']:
    print(f'{label:17s}', {
        split: round(len(df[df[label].notna() & df['cname'].isin(SPLITS[split])]) / len(df[df[label].notna()]), 2)
        for split in ['train', 'val', 'test']
    })