In [1]:
import xgboost
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV, KFold

In [2]:
df = pd.read_csv("../data/beijingb_scaled.csv", index_col = 0)

In [3]:
df.head(10)

Unnamed: 0,latitude,longitude,ts,temperature,pressure,humidity,WSX,WSY,weather,PM2.5,station_id
0,0.58292,0.17557,0.0,0.718851,0.653601,0.521719,0.587483,0.417198,0.0,156.25,1001.0
0,0.58292,0.17557,0.002747,0.622896,0.738081,0.231674,0.34857,0.672356,0.0,24.833333,1001.0
0,0.58292,0.17557,0.005495,0.628654,0.706512,0.271946,0.477721,0.503845,0.0,72.583333,1001.0
0,0.58292,0.17557,0.008242,0.556687,0.758978,0.122624,0.335436,0.687828,0.0,30.043478,1001.0
0,0.58292,0.17557,0.010989,0.573,0.7572,0.169683,0.624696,0.445893,0.0,23.083333,1001.0
0,0.58292,0.17557,0.013736,0.535577,0.698064,0.370588,0.703812,0.44955,0.0,67.75,1001.0
0,0.58292,0.17557,0.019231,0.730558,0.706957,0.297285,0.761101,0.1985,0.0625,67.909091,1001.0
0,0.58292,0.17557,0.021978,0.652642,0.723853,0.415385,0.825457,0.281039,0.125,79.541667,1001.0
0,0.58292,0.17557,0.024725,0.638249,0.737636,0.426244,0.802629,0.396662,0.125,61.875,1001.0
0,0.58292,0.17557,0.027473,0.523896,0.644786,0.824926,0.641541,0.531182,0.3125,28.666667,1001.0


In [4]:
df = df.rename(columns={'ts': 'Time', 'station_id': 'Station'})
stations = df['Station'].unique()

In [11]:
kf = KFold(n_splits=6, random_state=0, shuffle=True)
kf1 = KFold(n_splits=5, random_state= 0, shuffle=True)

In [13]:
train_df['Station'].unique()

array([1002., 1006., 1003., 1004., 1010., 1011., 1016., 1036., 1007.,
       1012., 1014., 1022., 1009., 1015., 1013., 1020., 1017., 1035.,
       1019., 1034., 1023., 1024., 1030., 1025., 1026., 1033., 1027.,
       1028., 1032., 1031.])

In [27]:
depths = [1, 10, 50, 100, 300]
lrs = [ 0.01, 0.1, 1]
estimators = [10, 20, 80, 160]

In [28]:
from sklearn.metrics import mean_squared_error

In [29]:
import operator


In [30]:
for train, test in kf.split(stations):
    
    train_df = pd.concat([df.groupby('Station').get_group(stn) for stn in stations[train]])
    test_df = pd.concat([df.groupby('Station').get_group(stn) for stn in stations[test]])
    result = {}

    for dp in depths:
        for lr in lrs:
            for est in estimators:
    
                for train_1, valid in kf1.split(train_df['Station'].unique()):

                    train_1_df = pd.concat([df.groupby('Station').get_group(stn) for stn in stations[train_1]])
                    valid_1_df = pd.concat([df.groupby('Station').get_group(stn) for stn in stations[valid]])
                    
                    xgb = xgboost.XGBRegressor(max_depth=dp, learning_rate=lr, n_estimators=est)
                    xgb.fit(train_1_df.drop(columns = ['Station', 'PM2.5']), train_1_df[['PM2.5']])
                    pred = xgb.predict(valid_1_df.drop(columns = ['Station', 'PM2.5']))
                    
                    gt = valid_1_df[['PM2.5']]
                    
                    if (dp, lr, est) not in result:
                        result[(dp, lr, est)] = np.sqrt(mean_squared_error(gt, pred))
                    else:
                        result[(dp, lr, est)]+= np.sqrt(mean_squared_error(gt, pred))
                        
    best_hyper_parm = min(result.items(), key=operator.itemgetter(1))[0]
    final = xgboost.XGBRegressor(max_depth= best_hyper_parm[0], learning_rate=best_hyper_parm[1], n_estimators=best_hyper_parm[2])
    
    final.fit(train_df.drop(columns = ['Station', 'PM2.5']), train_df[['PM2.5']])
    pred = final.predict(test_df.drop(columns = ['Station', 'PM2.5']))
    gt = test_df[['PM2.5']]
    print(result)
    print(best_hyper_parm)
    print("*"*50)
    print("Error on test")
    print(np.sqrt(mean_squared_error(gt, pred)))

                        
        









{(1, 0.01, 10): 534.1418084837003, (1, 0.01, 20): 499.32591782765746, (1, 0.01, 80): 368.7650305085574, (1, 0.01, 160): 305.40183581765757, (1, 0.1, 10): 340.75941337653876, (1, 0.1, 20): 290.20985571139374, (1, 0.1, 80): 252.79606541319873, (1, 0.1, 160): 241.6760344357054, (1, 1, 10): 251.5332180236784, (1, 1, 20): 239.574269012148, (1, 1, 80): 217.12737101992775, (1, 1, 160): 206.23542125025352, (10, 0.01, 10): 524.75230626377, (10, 0.01, 20): 480.5775261701589, (10, 0.01, 80): 293.9083107979988, (10, 0.01, 160): 177.92442690443676, (10, 0.1, 10): 246.21555182316544, (10, 0.1, 20): 146.90800950662032, (10, 0.1, 80): 116.83611575674327, (10, 0.1, 160): 115.9523586969531, (10, 1, 10): 148.06301994281947, (10, 1, 20): 148.45760333718675, (10, 1, 80): 148.8155910187528, (10, 1, 160): 148.8099294643226, (50, 0.01, 10): 524.2378412945005, (50, 0.01, 20): 479.4286062332617, (50, 0.01, 80): 290.3755285639777, (50, 0.01, 160): 174.01828724761418, (50, 0.1, 10): 242.05522961320523, (50, 0.1, 









{(1, 0.01, 10): 534.1418084837003, (1, 0.01, 20): 499.32591782765746, (1, 0.01, 80): 368.7650305085574, (1, 0.01, 160): 305.40183581765757, (1, 0.1, 10): 340.75941337653876, (1, 0.1, 20): 290.20985571139374, (1, 0.1, 80): 252.79606541319873, (1, 0.1, 160): 241.6760344357054, (1, 1, 10): 251.5332180236784, (1, 1, 20): 239.574269012148, (1, 1, 80): 217.12737101992775, (1, 1, 160): 206.23542125025352, (10, 0.01, 10): 524.75230626377, (10, 0.01, 20): 480.5775261701589, (10, 0.01, 80): 293.9083107979988, (10, 0.01, 160): 177.92442690443676, (10, 0.1, 10): 246.21555182316544, (10, 0.1, 20): 146.90800950662032, (10, 0.1, 80): 116.83611575674327, (10, 0.1, 160): 115.9523586969531, (10, 1, 10): 148.06301994281947, (10, 1, 20): 148.45760333718675, (10, 1, 80): 148.8155910187528, (10, 1, 160): 148.8099294643226, (50, 0.01, 10): 524.2378412945005, (50, 0.01, 20): 479.4286062332617, (50, 0.01, 80): 290.3755285639777, (50, 0.01, 160): 174.01828724761418, (50, 0.1, 10): 242.05522961320523, (50, 0.1, 









{(1, 0.01, 10): 534.1418084837003, (1, 0.01, 20): 499.32591782765746, (1, 0.01, 80): 368.7650305085574, (1, 0.01, 160): 305.40183581765757, (1, 0.1, 10): 340.75941337653876, (1, 0.1, 20): 290.20985571139374, (1, 0.1, 80): 252.79606541319873, (1, 0.1, 160): 241.6760344357054, (1, 1, 10): 251.5332180236784, (1, 1, 20): 239.574269012148, (1, 1, 80): 217.12737101992775, (1, 1, 160): 206.23542125025352, (10, 0.01, 10): 524.75230626377, (10, 0.01, 20): 480.5775261701589, (10, 0.01, 80): 293.9083107979988, (10, 0.01, 160): 177.92442690443676, (10, 0.1, 10): 246.21555182316544, (10, 0.1, 20): 146.90800950662032, (10, 0.1, 80): 116.83611575674327, (10, 0.1, 160): 115.9523586969531, (10, 1, 10): 148.06301994281947, (10, 1, 20): 148.45760333718675, (10, 1, 80): 148.8155910187528, (10, 1, 160): 148.8099294643226, (50, 0.01, 10): 524.2378412945005, (50, 0.01, 20): 479.4286062332617, (50, 0.01, 80): 290.3755285639777, (50, 0.01, 160): 174.01828724761418, (50, 0.1, 10): 242.05522961320523, (50, 0.1, 









{(1, 0.01, 10): 534.1418084837003, (1, 0.01, 20): 499.32591782765746, (1, 0.01, 80): 368.7650305085574, (1, 0.01, 160): 305.40183581765757, (1, 0.1, 10): 340.75941337653876, (1, 0.1, 20): 290.20985571139374, (1, 0.1, 80): 252.79606541319873, (1, 0.1, 160): 241.6760344357054, (1, 1, 10): 251.5332180236784, (1, 1, 20): 239.574269012148, (1, 1, 80): 217.12737101992775, (1, 1, 160): 206.23542125025352, (10, 0.01, 10): 524.75230626377, (10, 0.01, 20): 480.5775261701589, (10, 0.01, 80): 293.9083107979988, (10, 0.01, 160): 177.92442690443676, (10, 0.1, 10): 246.21555182316544, (10, 0.1, 20): 146.90800950662032, (10, 0.1, 80): 116.83611575674327, (10, 0.1, 160): 115.9523586969531, (10, 1, 10): 148.06301994281947, (10, 1, 20): 148.45760333718675, (10, 1, 80): 148.8155910187528, (10, 1, 160): 148.8099294643226, (50, 0.01, 10): 524.2378412945005, (50, 0.01, 20): 479.4286062332617, (50, 0.01, 80): 290.3755285639777, (50, 0.01, 160): 174.01828724761418, (50, 0.1, 10): 242.05522961320523, (50, 0.1, 









{(1, 0.01, 10): 534.1418084837003, (1, 0.01, 20): 499.32591782765746, (1, 0.01, 80): 368.7650305085574, (1, 0.01, 160): 305.40183581765757, (1, 0.1, 10): 340.75941337653876, (1, 0.1, 20): 290.20985571139374, (1, 0.1, 80): 252.79606541319873, (1, 0.1, 160): 241.6760344357054, (1, 1, 10): 251.5332180236784, (1, 1, 20): 239.574269012148, (1, 1, 80): 217.12737101992775, (1, 1, 160): 206.23542125025352, (10, 0.01, 10): 524.75230626377, (10, 0.01, 20): 480.5775261701589, (10, 0.01, 80): 293.9083107979988, (10, 0.01, 160): 177.92442690443676, (10, 0.1, 10): 246.21555182316544, (10, 0.1, 20): 146.90800950662032, (10, 0.1, 80): 116.83611575674327, (10, 0.1, 160): 115.9523586969531, (10, 1, 10): 148.06301994281947, (10, 1, 20): 148.45760333718675, (10, 1, 80): 148.8155910187528, (10, 1, 160): 148.8099294643226, (50, 0.01, 10): 524.2378412945005, (50, 0.01, 20): 479.4286062332617, (50, 0.01, 80): 290.3755285639777, (50, 0.01, 160): 174.01828724761418, (50, 0.1, 10): 242.05522961320523, (50, 0.1, 









{(1, 0.01, 10): 534.1418084837003, (1, 0.01, 20): 499.32591782765746, (1, 0.01, 80): 368.7650305085574, (1, 0.01, 160): 305.40183581765757, (1, 0.1, 10): 340.75941337653876, (1, 0.1, 20): 290.20985571139374, (1, 0.1, 80): 252.79606541319873, (1, 0.1, 160): 241.6760344357054, (1, 1, 10): 251.5332180236784, (1, 1, 20): 239.574269012148, (1, 1, 80): 217.12737101992775, (1, 1, 160): 206.23542125025352, (10, 0.01, 10): 524.75230626377, (10, 0.01, 20): 480.5775261701589, (10, 0.01, 80): 293.9083107979988, (10, 0.01, 160): 177.92442690443676, (10, 0.1, 10): 246.21555182316544, (10, 0.1, 20): 146.90800950662032, (10, 0.1, 80): 116.83611575674327, (10, 0.1, 160): 115.9523586969531, (10, 1, 10): 148.06301994281947, (10, 1, 20): 148.45760333718675, (10, 1, 80): 148.8155910187528, (10, 1, 160): 148.8099294643226, (50, 0.01, 10): 524.2378412945005, (50, 0.01, 20): 479.4286062332617, (50, 0.01, 80): 290.3755285639777, (50, 0.01, 160): 174.01828724761418, (50, 0.1, 10): 242.05522961320523, (50, 0.1, 

In [32]:
np.mean([25.834337552066895,
 22.172174357503465,
 14.436417376275056,
 26.57864143604879,
 23.56558398371505])
 
 

22.51743094112185