### Imports

In [11]:
import numpy as np
import pandas as pd
from numpy import array
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error

In [12]:
population_df = pd.read_csv('world_population.csv', index_col='Country Code')
meta_df = pd.read_csv('metadata.csv', index_col='Country Code')

In [13]:
population_df.head()

Unnamed: 0_level_0,1960,1961,1962,1963,1964,1965,1966,1967,1968,1969,...,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017
Country Code,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ABW,54211.0,55438.0,56225.0,56695.0,57032.0,57360.0,57715.0,58055.0,58386.0,58726.0,...,101353.0,101453.0,101669.0,102053.0,102577.0,103187.0,103795.0,104341.0,104822.0,105264.0
AFG,8996351.0,9166764.0,9345868.0,9533954.0,9731361.0,9938414.0,10152331.0,10372630.0,10604346.0,10854428.0,...,27294031.0,28004331.0,28803167.0,29708599.0,30696958.0,31731688.0,32758020.0,33736494.0,34656032.0,35530081.0
AGO,5643182.0,5753024.0,5866061.0,5980417.0,6093321.0,6203299.0,6309770.0,6414995.0,6523791.0,6642632.0,...,21759420.0,22549547.0,23369131.0,24218565.0,25096150.0,25998340.0,26920466.0,27859305.0,28813463.0,29784193.0
ALB,1608800.0,1659800.0,1711319.0,1762621.0,1814135.0,1864791.0,1914573.0,1965598.0,2022272.0,2081695.0,...,2947314.0,2927519.0,2913021.0,2905195.0,2900401.0,2895092.0,2889104.0,2880703.0,2876101.0,2873457.0
AND,13411.0,14375.0,15370.0,16412.0,17469.0,18549.0,19647.0,20758.0,21890.0,23058.0,...,83861.0,84462.0,84449.0,83751.0,82431.0,80788.0,79223.0,78014.0,77281.0,76965.0


In [14]:
meta_df.head()

Unnamed: 0_level_0,Region,Income Group,Special Notes
Country Code,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ABW,Latin America & Caribbean,High income,Mining is included in agriculture\r\r\r\nElect...
AFG,South Asia,Low income,Fiscal year end: March 20; reporting period fo...
AGO,Sub-Saharan Africa,Lower middle income,
ALB,Europe & Central Asia,Upper middle income,
AND,Europe & Central Asia,High income,WB-3 code changed from ADO to AND to align wit...


### Function for countries by income group

In [15]:
def get_total_pop_by_income(income_group_name = 'Low income'):
    if income_group_name not in meta_df.values:
        raise ValueError
        
    df = pd.merge(population_df, meta_df, on = 'Country Code', how = 'inner')
    df = df.loc[df['Income Group'] == income_group_name].drop(['Region', 'Special Notes', 'Income Group'], axis = 1).sum(axis = 0).astype(np.int64)
    dates = population_df.columns.values.astype(np.int64)
    pop = df.to_numpy()
    output = np.vstack((dates, pop)).T
    return output

In [16]:
data = get_total_pop_by_income('Low income')
print(get_total_pop_by_income('High income'))

[[      1960  769889923]
 [      1961  781225329]
 [      1962  791207437]
 [      1963  801108277]
 [      1964  810900987]
 [      1965  820309686]
 [      1966  829088382]
 [      1967  837479954]
 [      1968  844905494]
 [      1969  854059674]
 [      1970  862276721]
 [      1971  871169187]
 [      1972  880246152]
 [      1973  888486025]
 [      1974  897803169]
 [      1975  906573084]
 [      1976  913843314]
 [      1977  921330504]
 [      1978  928906293]
 [      1979  936836246]
 [      1980  944587066]
 [      1981  952368316]
 [      1982  959759971]
 [      1983  966754949]
 [      1984  973423742]
 [      1985  980143630]
 [      1986  987194728]
 [      1987  994242786]
 [      1988 1001421456]
 [      1989 1009036892]
 [      1990 1017092667]
 [      1991 1025345408]
 [      1992 1031949811]
 [      1993 1040349480]
 [      1994 1048121445]
 [      1995 1057290586]
 [      1996 1064630661]
 [      1997 1071969568]
 [      1998 1078927765]
 [      1999 1085992668]


### Function for k-fold cross validation

In [17]:
def sklearn_kfold_split(data,K):
    X, y = data.T
    result = []
    kfolds = KFold(n_splits = K, shuffle = False)
    for train_index, test_index in kfolds.split(X):
        result.append((train_index, test_index))
    return result

In [18]:
data = get_total_pop_by_income('High income');
sklearn_kfold_split(data,4)

[(array([15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
         32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
         49, 50, 51, 52, 53, 54, 55, 56, 57]),
  array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])),
 (array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 30, 31,
         32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
         49, 50, 51, 52, 53, 54, 55, 56, 57]),
  array([15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])),
 (array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
         17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 44, 45, 46, 47,
         48, 49, 50, 51, 52, 53, 54, 55, 56, 57]),
  array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43])),
 (array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
         17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
         34, 35, 36, 37, 38

### Function to find best testing set performance

In [23]:
def best_k_model(data,data_indices):
    mse_list = []
    for train_indices, test_indices in data_indices:
        X_train, y_train = data[train_indices,0],data[train_indices,1]
        X_test, y_test = data[test_indices,0],data[test_indices,1]
        r_forest = RandomForestRegressor(n_estimators = 100, random_state = 42)
        r_forest.fit(X_train.reshape(-1, 1), y_train)
    return r_forest

In [24]:
data = get_total_pop_by_income('High income')
data_indices = sklearn_kfold_split(data,5)

best_model = best_k_model(data,data_indices)
best_model.predict([[1960]])

array([7.76089629e+08])