In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

import numpy as np
import pandas as pd
import joblib

import os
from tqdm import tqdm
from glob import glob
import matplotlib.pyplot as plt


import pathlib
DATA_DIR = pathlib.Path.cwd()/'data/input'
OUT_DIR = pathlib.Path.cwd()/'data/output'

import sys 
sys.path.append(str(pathlib.Path.cwd()/'utils'))
from utils.misc_utils import fullrange, realized_volatility, log_return, rmspe, get_stock_path, load_parquet_file, load_parquet_files, load_train_test
from utils.feature_engineering_utils import full_feature_engineering, groupby_and_aggregate, generate_features_book_data, generate_features_trade_data, full_feature_engineering_by_cutoff

book_aggregation = {
    'wap1': [np.mean, np.std, fullrange], 
    'wap2': [np.mean, np.std, fullrange], 
    'log_return_1': [fullrange, np.sum, np.mean, realized_volatility], 
    'log_return_2': [fullrange, np.sum, np.mean, realized_volatility], 
    'bid_ask_price_spread_1': [np.mean, np.std, fullrange],
    'bid_ask_price_spread_2': [np.mean, np.std, fullrange],
    'bid_ask_size_spread_1': [np.mean, np.std, fullrange],
    'bid_ask_size_spread_2': [np.mean, np.std, fullrange]
    }

trade_aggregation = {
    'volume': [np.mean, np.sum, np.std], 
    'price': [np.mean, np.std], 
    'order_count': [np.mean, np.sum, np.std]
    }

time_agg_trade = {
    'volume_mean': [np.mean, np.sum, np.std], 
    'price_mean': [np.mean, np.std], 
    'order_count_mean': [np.mean, np.sum, np.std]
    }

time_agg_book = {
    'wap1_std': [np.mean], 
    'wap2_std': [np.mean], 
    'log_return_1_realized_volatility': [np.mean], 
    'log_return_2_realized_volatility': [np.mean], 
    'log_return_1_sum': [np.mean], 
    'log_return_2_sum': [np.mean] 
    }

stock_agg_trade = {
    'volume_mean': [np.mean,np.std], 
    'price_mean': [np.mean, np.std], 
    'order_count_mean': [np.mean, np.std]
    }

stock_agg_book = {
    'wap1_std': [np.std], 
    'wap2_std': [np.std], 
    'log_return_1_realized_volatility': [np.std], 
    'log_return_2_realized_volatility': [np.std], 
    'log_return_1_sum': [np.std], 
    'log_return_2_sum': [np.std] 
    }

pd.set_option('display.max_columns', None)

In [3]:
training_target = pd.read_csv(DATA_DIR/'train.csv')
book = load_parquet_files(stock_ids=[0, 126], file_type='book')
trade = load_parquet_files(stock_ids=[0, 126], file_type='trade')

100%|██████████| 2/2 [00:02<00:00,  1.18s/it]
100%|██████████| 2/2 [00:00<00:00,  5.93it/s]


In [4]:
book = generate_features_book_data(book)
book.head()

Unnamed: 0,time_id,seconds_in_bucket,bid_price1,ask_price1,bid_price2,ask_price2,bid_size1,ask_size1,bid_size2,ask_size2,stock_id,id,wap1,wap2,log_return_1,log_return_2,bid_ask_price_spread_1,bid_ask_price_spread_2,bid_ask_size_spread_1,bid_ask_size_spread_2
0,5,0,1.001422,1.002301,1.00137,1.002353,3,226,2,100,0,0-5,1.001434,1.00139,0.0,0.0,0.000879,0.000983,223,98
1,5,1,1.001422,1.002301,1.00137,1.002353,3,100,2,100,0,0-5,1.001448,1.00139,1.4e-05,0.0,0.000879,0.000983,97,98
2,5,5,1.001422,1.002301,1.00137,1.002405,3,100,2,100,0,0-5,1.001448,1.001391,0.0,1e-06,0.000879,0.001034,97,98
3,5,6,1.001422,1.002301,1.00137,1.002405,3,126,2,100,0,0-5,1.001443,1.001391,-5e-06,0.0,0.000879,0.001034,123,98
4,5,7,1.001422,1.002301,1.00137,1.002405,3,126,2,100,0,0-5,1.001443,1.001391,0.0,0.0,0.000879,0.001034,123,98


In [5]:
agg_book_data = groupby_and_aggregate(book, agg_col = 'id', agg_dict=book_aggregation)
agg_book_data.head()

Unnamed: 0_level_0,wap1_mean,wap1_std,wap1_fullrange,wap2_mean,wap2_std,wap2_fullrange,log_return_1_fullrange,log_return_1_sum,log_return_1_mean,log_return_1_realized_volatility,log_return_2_fullrange,log_return_2_sum,log_return_2_mean,log_return_2_realized_volatility,bid_ask_price_spread_1_mean,bid_ask_price_spread_1_std,bid_ask_price_spread_1_fullrange,bid_ask_price_spread_2_mean,bid_ask_price_spread_2_std,bid_ask_price_spread_2_fullrange,bid_ask_size_spread_1_mean,bid_ask_size_spread_1_std,bid_ask_size_spread_1_fullrange,bid_ask_size_spread_2_mean,bid_ask_size_spread_2_std,bid_ask_size_spread_2_fullrange
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1
0-1000,0.9988,0.000531,0.001829,0.998738,0.000538,0.002868,0.001286,-0.000944,-6e-06,0.001781,0.002234,-0.001198,-7e-06,0.004389,0.000443,0.000173,0.000878,0.000765,0.000249,0.001478,11.695122,171.482507,1116,28.609756,83.79463,599
0-10000,0.999884,0.000395,0.001409,0.999912,0.000488,0.00208,0.001733,0.000546,2e-06,0.00289,0.001641,0.000686,2e-06,0.004111,0.00046,0.000232,0.001022,0.000739,0.000239,0.001314,10.143357,152.601616,1045,-17.912587,130.656572,747
0-10005,1.001301,0.001809,0.00713,1.00116,0.001676,0.007296,0.004685,0.002516,1.4e-05,0.008674,0.006666,0.000928,5e-06,0.013725,0.002107,0.000602,0.003449,0.002743,0.000602,0.003025,-11.918478,103.390706,532,28.766304,107.369872,694
0-10017,0.996141,0.004763,0.0173,0.99603,0.004678,0.01804,0.016086,-0.002971,-1.3e-05,0.017629,0.01851,-0.003512,-1.5e-05,0.021224,0.003538,0.001323,0.007071,0.004959,0.002062,0.010697,60.934211,158.634259,860,-0.675439,151.850436,700
0-10030,0.999464,0.000433,0.001729,0.999477,0.000498,0.002206,0.001503,0.00058,3e-06,0.002551,0.004305,0.003158,1.6e-05,0.005463,0.000623,0.000246,0.001085,0.00104,0.000305,0.001463,33.57732,83.867072,364,6.876289,85.745528,416


In [6]:
trade = generate_features_trade_data(trade)
trade.head()

Unnamed: 0,time_id,seconds_in_bucket,price,size,order_count,stock_id,id,volume
0,5,21,1.002301,326,12,0,0-5,326.750244
1,5,46,1.002778,128,4,0,0-5,128.355591
2,5,50,1.002818,55,1,0,0-5,55.155014
3,5,57,1.003155,121,5,0,0-5,121.381798
4,5,68,1.003646,4,1,0,0-5,4.014584


In [7]:
agg_trade_data =  groupby_and_aggregate(trade, agg_col = 'id', agg_dict=trade_aggregation)
agg_trade_data.head()

Unnamed: 0_level_0,volume_mean,volume_sum,volume_std,price_mean,price_std,order_count_mean,order_count_sum,order_count_std
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
0-1000,29.961796,898.853882,45.738499,0.998908,0.000554,2.133333,64,2.029665
0-10000,80.825371,1939.80896,132.549789,0.999931,0.000373,2.916667,70,2.500725
0-10005,81.59082,2366.133789,119.623756,1.001232,0.001756,3.137931,91,3.502286
0-10017,89.54705,3760.976074,133.678314,0.996374,0.004358,3.166667,133,3.875606
0-10030,98.03035,2940.910645,103.179214,0.999314,0.000493,2.333333,70,1.667816


In [8]:
output_df = agg_book_data.merge(agg_trade_data, left_index=True, right_index=True).reset_index()
output_df['stock_id'] = output_df['id'].apply(lambda x: int(x.split('-')[0]))
output_df['time_id'] = output_df['id'].apply(lambda x: int(x.split('-')[1]))
output_df.head()

Unnamed: 0,id,wap1_mean,wap1_std,wap1_fullrange,wap2_mean,wap2_std,wap2_fullrange,log_return_1_fullrange,log_return_1_sum,log_return_1_mean,log_return_1_realized_volatility,log_return_2_fullrange,log_return_2_sum,log_return_2_mean,log_return_2_realized_volatility,bid_ask_price_spread_1_mean,bid_ask_price_spread_1_std,bid_ask_price_spread_1_fullrange,bid_ask_price_spread_2_mean,bid_ask_price_spread_2_std,bid_ask_price_spread_2_fullrange,bid_ask_size_spread_1_mean,bid_ask_size_spread_1_std,bid_ask_size_spread_1_fullrange,bid_ask_size_spread_2_mean,bid_ask_size_spread_2_std,bid_ask_size_spread_2_fullrange,volume_mean,volume_sum,volume_std,price_mean,price_std,order_count_mean,order_count_sum,order_count_std,stock_id,time_id
0,0-1000,0.9988,0.000531,0.001829,0.998738,0.000538,0.002868,0.001286,-0.000944,-6e-06,0.001781,0.002234,-0.001198,-7e-06,0.004389,0.000443,0.000173,0.000878,0.000765,0.000249,0.001478,11.695122,171.482507,1116,28.609756,83.79463,599,29.961796,898.853882,45.738499,0.998908,0.000554,2.133333,64,2.029665,0,1000
1,0-10000,0.999884,0.000395,0.001409,0.999912,0.000488,0.00208,0.001733,0.000546,2e-06,0.00289,0.001641,0.000686,2e-06,0.004111,0.00046,0.000232,0.001022,0.000739,0.000239,0.001314,10.143357,152.601616,1045,-17.912587,130.656572,747,80.825371,1939.80896,132.549789,0.999931,0.000373,2.916667,70,2.500725,0,10000
2,0-10005,1.001301,0.001809,0.00713,1.00116,0.001676,0.007296,0.004685,0.002516,1.4e-05,0.008674,0.006666,0.000928,5e-06,0.013725,0.002107,0.000602,0.003449,0.002743,0.000602,0.003025,-11.918478,103.390706,532,28.766304,107.369872,694,81.59082,2366.133789,119.623756,1.001232,0.001756,3.137931,91,3.502286,0,10005
3,0-10017,0.996141,0.004763,0.0173,0.99603,0.004678,0.01804,0.016086,-0.002971,-1.3e-05,0.017629,0.01851,-0.003512,-1.5e-05,0.021224,0.003538,0.001323,0.007071,0.004959,0.002062,0.010697,60.934211,158.634259,860,-0.675439,151.850436,700,89.54705,3760.976074,133.678314,0.996374,0.004358,3.166667,133,3.875606,0,10017
4,0-10030,0.999464,0.000433,0.001729,0.999477,0.000498,0.002206,0.001503,0.00058,3e-06,0.002551,0.004305,0.003158,1.6e-05,0.005463,0.000623,0.000246,0.001085,0.00104,0.000305,0.001463,33.57732,83.867072,364,6.876289,85.745528,416,98.03035,2940.910645,103.179214,0.999314,0.000493,2.333333,70,1.667816,0,10030


## Time Period Aggregation

In [9]:
time_agg_trade_data = groupby_and_aggregate(output_df, agg_col = 'time_id', agg_dict=time_agg_trade, suffix='_period')
time_agg_trade_data.head()

Unnamed: 0_level_0,volume_mean_mean_period,volume_mean_sum_period,volume_mean_std_period,price_mean_mean_period,price_mean_std_period,order_count_mean_mean_period,order_count_mean_sum_period,order_count_mean_std_period
time_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
5,63.489216,126.978432,23.001141,1.002913,0.001144,2.669118,5.338235,0.114385
11,75.123787,150.247574,45.460377,1.000473,0.000377,2.709259,5.418519,1.144465
16,109.012894,218.025787,32.058147,0.999548,0.000486,2.744615,5.489231,0.034811
31,81.574921,163.149841,69.389015,0.998861,0.000225,3.05,6.1,1.249222
62,67.259079,134.518158,19.963213,0.999277,0.000483,3.320025,6.640049,1.025913


In [10]:
time_agg_book_data = groupby_and_aggregate(output_df, agg_col = 'time_id', agg_dict=time_agg_book, suffix='_period')
time_agg_book_data.head()

Unnamed: 0_level_0,wap1_std_mean_period,wap2_std_mean_period,log_return_1_realized_volatility_mean_period,log_return_2_realized_volatility_mean_period,log_return_1_sum_mean_period,log_return_2_sum_mean_period
time_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
5,0.000876,0.00099,0.006064,0.009029,0.002922,0.00275
11,0.000503,0.000533,0.003823,0.00454,-0.00206,-0.001627
16,0.000802,0.00084,0.002699,0.00501,-0.002588,-0.002912
31,0.000809,0.000708,0.004144,0.004617,-0.000918,-0.000466
62,0.000413,0.000478,0.003319,0.004168,0.001852,0.001475


In [11]:
time_agg = time_agg_book_data.merge(time_agg_trade_data, left_index=True, right_index=True)
output_df_with_time = output_df.merge(time_agg, on='time_id')
output_df_with_time = output_df_with_time.merge(training_target, on=['stock_id', 'time_id']).head()

In [12]:
train, test = load_train_test()
train_id = train.stock_id.unique()
test_id = test.stock_id.unique()

In [13]:
#final_training_data = full_feature_engineering(stock_ids = [], train=True, training_target= training_target)
#final_test_data = full_feature_engineering(stock_ids = [], train=False)

cutoffs =  [(0,150), (150, 300), (300, 450), (450, 600)]
final_training_data, datalist_train = full_feature_engineering_by_cutoff(
    cutoffs = cutoffs,
    stock_ids = train_id,
    train=True)

final_test_data, datalist_test = full_feature_engineering_by_cutoff(
    cutoffs = cutoffs,
    stock_ids = test_id,
    train=False)

#final_training_data = train.merge(final_training_data, on = ['id', 'stock_id', 'time_id'], how='left')
#final_test_data = test.merge(final_test_data, on = ['id', 'stock_id', 'time_id'], how='left')

final_training_data.to_pickle(OUT_DIR/'final_training_data.pkl')
final_test_data.to_pickle(OUT_DIR/'final_test_data.pkl')


100%|██████████| 2/2 [00:09<00:00,  4.53s/it]
100%|██████████| 2/2 [00:09<00:00,  4.51s/it]
100%|██████████| 2/2 [00:08<00:00,  4.49s/it]
100%|██████████| 2/2 [00:09<00:00,  4.53s/it]
100%|██████████| 1/1 [00:00<00:00, 20.41it/s]
100%|██████████| 1/1 [00:00<00:00, 22.22it/s]
100%|██████████| 1/1 [00:00<00:00, 22.22it/s]
100%|██████████| 1/1 [00:00<00:00, 22.07it/s]


In [None]:
[len(dataset) for dataset in datalist_train]

In [None]:
final_training_data.head()

Unnamed: 0,id,time_id,stock_id,wap1_mean_0_300,wap1_std_0_300,wap1_fullrange_0_300,wap2_mean_0_300,wap2_std_0_300,wap2_fullrange_0_300,log_return_1_fullrange_0_300,log_return_1_sum_0_300,log_return_1_mean_0_300,log_return_1_realized_volatility_0_300,log_return_2_fullrange_0_300,log_return_2_sum_0_300,log_return_2_mean_0_300,log_return_2_realized_volatility_0_300,bid_ask_price_spread_1_mean,bid_ask_price_spread_1_std,bid_ask_price_spread_1_fullrange,bid_ask_price_spread_2_mean,bid_ask_price_spread_2_std,bid_ask_price_spread_2_fullrange,bid_ask_size_spread_1_mean,bid_ask_size_spread_1_std,bid_ask_size_spread_1_fullrange,bid_ask_size_spread_2_mean,bid_ask_size_spread_2_std,bid_ask_size_spread_2_fullrange,volume_mean_0_300,volume_sum_0_300,volume_std_0_300,price_mean_0_300,price_std_0_300,order_count_mean_0_300,order_count_sum_0_300,order_count_std_0_300,wap1_std_mean_period_0_300,wap2_std_mean_period_0_300,log_return_1_realized_volatility_mean_period_0_300,log_return_2_realized_volatility_mean_period_0_300,log_return_1_sum_mean_period_0_300,log_return_2_sum_mean_period_0_300,volume_mean_mean_period_0_300,volume_mean_sum_period_0_300,volume_mean_std_period_0_300,price_mean_mean_period_0_300,price_mean_std_period_0_300,order_count_mean_mean_period_0_300,order_count_mean_sum_period_0_300,order_count_mean_std_period_0_300,wap1_std_std_stock_0_300,wap2_std_std_stock_0_300,log_return_1_realized_volatility_std_stock_0_300,log_return_2_realized_volatility_std_stock_0_300,log_return_1_sum_std_stock_0_300,log_return_2_sum_std_stock_0_300,volume_mean_mean_stock_0_300,volume_mean_std_stock_0_300,price_mean_mean_stock_0_300,price_mean_std_stock_0_300,order_count_mean_mean_stock_0_300,order_count_mean_std_stock_0_300,wap1_mean_300_600,wap1_std_300_600,wap1_fullrange_300_600,wap2_mean_300_600,wap2_std_300_600,wap2_fullrange_300_600,log_return_1_fullrange_300_600,log_return_1_sum_300_600,log_return_1_mean_300_600,log_return_1_realized_volatility_300_600,log_return_2_fullrange_300_600,log_return_2_sum_300_600,log_return_2_mean_300_600,log_return_2_realized_volatility_300_600,bid_ask_price_spread_1_mean.1,bid_ask_price_spread_1_std.1,bid_ask_price_spread_1_fullrange.1,bid_ask_price_spread_2_mean.1,bid_ask_price_spread_2_std.1,bid_ask_price_spread_2_fullrange.1,bid_ask_size_spread_1_mean.1,bid_ask_size_spread_1_std.1,bid_ask_size_spread_1_fullrange.1,bid_ask_size_spread_2_mean.1,bid_ask_size_spread_2_std.1,bid_ask_size_spread_2_fullrange.1,volume_mean_300_600,volume_sum_300_600,volume_std_300_600,price_mean_300_600,price_std_300_600,order_count_mean_300_600,order_count_sum_300_600,order_count_std_300_600,wap1_std_mean_period_300_600,wap2_std_mean_period_300_600,log_return_1_realized_volatility_mean_period_300_600,log_return_2_realized_volatility_mean_period_300_600,log_return_1_sum_mean_period_300_600,log_return_2_sum_mean_period_300_600,volume_mean_mean_period_300_600,volume_mean_sum_period_300_600,volume_mean_std_period_300_600,price_mean_mean_period_300_600,price_mean_std_period_300_600,order_count_mean_mean_period_300_600,order_count_mean_sum_period_300_600,order_count_mean_std_period_300_600,wap1_std_std_stock_300_600,wap2_std_std_stock_300_600,log_return_1_realized_volatility_std_stock_300_600,log_return_2_realized_volatility_std_stock_300_600,log_return_1_sum_std_stock_300_600,log_return_2_sum_std_stock_300_600,volume_mean_mean_stock_300_600,volume_mean_std_stock_300_600,price_mean_mean_stock_300_600,price_mean_std_stock_300_600,order_count_mean_mean_stock_300_600,order_count_mean_std_stock_300_600,target
0,0-1000,1000,0,0.999208,0.000427,0.001128,0.999107,0.000398,0.001455,0.000991,-0.001171,-1.5e-05,0.001184,0.001793,-0.001089,-1.4e-05,0.002684,0.000422,0.000136,0.000601,0.000744,0.000202,0.000785,58.772152,129.514572,482,42.506329,91.762496,599,26.356178,421.698853,37.846413,0.999351,0.000284,2.125,34.0,1.408309,0.000367,0.000402,0.001448,0.002109,-0.000198,-0.000207,334.647034,37480.46875,1207.417847,1.00015,0.000896,4.118892,461.315938,2.6526,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,0.998406,0.000235,0.000974,0.99838,0.000386,0.001953,0.001077,-0.000194,-2e-06,0.001263,0.002234,-0.000519,-6e-06,0.003449,0.000465,0.000201,0.000832,0.000785,0.000288,0.001478,-30.261905,194.006031,1013,16.119048,74.202492,230,34.0825,477.155029,54.587475,0.998402,0.000272,2.142857,30.0,2.626994,0.000398,0.000436,0.00151,0.002207,-0.000249,-0.000236,246.684113,27381.9375,374.158386,0.999757,0.001217,4.079591,452.83458,2.69336,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.00384,3.243385,1.198809,0.001348
1,0-10000,10000,0,0.999706,0.00031,0.001236,0.999732,0.000453,0.00208,0.001075,0.000374,2e-06,0.001921,0.001641,0.000894,5e-06,0.003385,0.000404,0.000213,0.000924,0.000731,0.000252,0.001314,48.048913,87.931747,597,2.88587,114.083868,572,78.393791,1332.694458,148.739655,0.999786,0.000297,2.705882,46.0,2.823223,0.000512,0.000542,0.001978,0.002608,-0.000285,-0.000286,409.834778,45901.496094,1277.457642,0.999976,0.001214,3.842042,430.308752,1.719662,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,1.000209,0.000323,0.001329,1.00024,0.000366,0.001359,0.001733,-0.000221,-2e-06,0.002124,0.001253,-0.000695,-7e-06,0.002281,0.000565,0.000231,0.001022,0.000754,0.000215,0.00107,-58.267327,211.674509,944,-55.980198,150.240739,644,86.730652,607.114563,91.083176,1.000283,0.000305,3.428571,24.0,1.511858,0.000525,0.000561,0.001795,0.002445,3e-05,-4e-06,306.332672,34309.257812,528.679932,1.000028,0.001713,3.587736,401.826426,1.137513,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.00384,3.243385,1.198809,0.001805
2,0-10005,10005,0,1.000671,0.002085,0.00593,1.000502,0.001673,0.006349,0.004657,0.001638,1.7e-05,0.006434,0.005195,0.0026,2.7e-05,0.009354,0.002086,0.000579,0.002866,0.002667,0.000702,0.003025,-26.105263,94.669354,440,26.957895,105.811245,648,82.635696,1404.806885,118.957565,1.000655,0.001913,2.941176,50.0,2.703484,0.001173,0.001215,0.004297,0.006057,0.00187,0.001768,282.409454,31629.859375,379.877075,1.001123,0.002449,3.6901,413.291207,1.693999,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,1.001995,0.001117,0.004183,1.001894,0.001349,0.004966,0.004055,0.000734,8e-06,0.005815,0.006499,-0.000731,-8e-06,0.009999,0.002141,0.000622,0.003237,0.002833,0.000455,0.0026,3.261364,111.118089,524,30.511364,110.190013,497,80.11058,961.326965,125.848511,1.00205,0.001133,3.416667,41.0,4.521833,0.00103,0.001076,0.004125,0.005606,0.000172,0.000239,294.831665,33021.148438,603.385681,1.001939,0.003773,3.525943,394.905618,1.401347,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.00384,3.243385,1.198809,0.007544
3,0-10017,10017,0,0.992769,0.00263,0.012138,0.99292,0.002265,0.01245,0.007942,0.010507,7.9e-05,0.008735,0.011743,0.004862,3.7e-05,0.012851,0.003482,0.001295,0.00695,0.004393,0.002236,0.010697,83.639098,140.409592,784,3.654135,118.318717,699,72.568245,1596.501465,144.283875,0.992861,0.002041,2.454545,54.0,3.158168,0.003173,0.003296,0.010398,0.013775,0.004919,0.00475,348.265839,38657.507812,273.327026,0.997095,0.006774,4.649639,516.109974,2.981141,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,1.000952,0.002397,0.012924,1.000474,0.003488,0.013315,0.010333,-0.003168,-3.4e-05,0.011322,0.014066,0.000668,7e-06,0.014268,0.003624,0.001371,0.00689,0.005771,0.001465,0.007434,28.680851,177.907806,860,-7.861702,190.333584,647,108.223732,2164.474609,121.86541,1.00024,0.002523,3.95,79.0,4.489461,0.003545,0.003684,0.011788,0.015887,-0.003682,-0.003748,394.402893,43778.71875,387.723358,0.998773,0.00889,5.00763,555.846893,3.872129,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.00384,3.243385,1.198809,0.011218
4,0-10030,10030,0,0.999657,0.000297,0.001556,0.999645,0.000328,0.001726,0.001021,-0.001273,-1.1e-05,0.001535,0.002165,-0.001844,-1.6e-05,0.003117,0.000597,0.000227,0.000849,0.001013,0.000327,0.001368,33.034188,87.391404,364,25.316239,88.005158,392,105.400513,1581.00769,107.743195,0.999531,0.000467,2.666667,40.0,2.058663,0.000569,0.000616,0.002002,0.002809,-0.000814,-0.00083,256.73111,28753.882812,350.382935,0.999946,0.00139,3.22984,361.742054,1.232023,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,0.999159,0.000435,0.001729,0.999203,0.00058,0.002206,0.001503,0.001363,1.8e-05,0.001977,0.001942,0.001874,2.5e-05,0.003216,0.000664,0.00027,0.000991,0.00108,0.000268,0.00118,34.855263,79.186986,327,-20.421053,74.592272,342,89.993896,1259.914551,105.425652,0.999041,0.000386,2.0,28.0,1.176697,0.000514,0.000558,0.001781,0.002661,0.000324,0.000375,262.555603,29406.226562,437.999481,0.999672,0.001687,3.182764,356.469517,1.436575,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.00384,3.243385,1.198809,0.002854


In [None]:
final_training_data.isna().sum()

id                                     0
time_id                                0
stock_id                               0
wap1_mean_0_300                        0
wap1_std_0_300                         0
                                      ..
price_mean_mean_stock_300_600          0
price_mean_std_stock_300_600           0
order_count_mean_mean_stock_300_600    0
order_count_mean_std_stock_300_600     0
target                                 0
Length: 124, dtype: int64

## Model Training

In [None]:
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import make_scorer, r2_score, mean_absolute_error
from sklearn.model_selection import train_test_split, GridSearchCV, RandomizedSearchCV, StratifiedKFold
from sklearn.neural_network import MLPRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor

from xgboost import XGBRegressor
from skopt import BayesSearchCV
from skopt.space import Real, Categorical, Integer
from skopt.plots import plot_objective, plot_histogram, plot_convergence

rmspe_scorer = make_scorer(rmspe, greater_is_better=False)

In [None]:
final_training_data = pd.read_pickle(OUT_DIR/'final_training_data.pkl')
final_test_data = pd.read_pickle(OUT_DIR/'final_test_data.pkl')
#final_training_data = final_training_data.dropna(axis=1)

In [None]:
from sklearn.utils import resample

model_col = [col for col in final_training_data.columns if ('id' not in col) & ('target' not in col)]

X_train, X_test, y_train, y_test = train_test_split(
                                        final_training_data.drop('target', axis=1)[model_col],
                                        final_training_data['target'],
                                        test_size=0.1
                                        )

X_train_sample = resample(X_train, replace=False, n_samples=5000)
y_train_sample = y_train.loc[X_train_sample.index]

In [None]:
final_training_data

Unnamed: 0,id,time_id,stock_id,wap1_mean_0_300,wap1_std_0_300,wap1_fullrange_0_300,wap2_mean_0_300,wap2_std_0_300,wap2_fullrange_0_300,log_return_1_fullrange_0_300,log_return_1_sum_0_300,log_return_1_mean_0_300,log_return_1_realized_volatility_0_300,log_return_2_fullrange_0_300,log_return_2_sum_0_300,log_return_2_mean_0_300,log_return_2_realized_volatility_0_300,bid_ask_price_spread_1_mean,bid_ask_price_spread_1_std,bid_ask_price_spread_1_fullrange,bid_ask_price_spread_2_mean,bid_ask_price_spread_2_std,bid_ask_price_spread_2_fullrange,bid_ask_size_spread_1_mean,bid_ask_size_spread_1_std,bid_ask_size_spread_1_fullrange,bid_ask_size_spread_2_mean,bid_ask_size_spread_2_std,bid_ask_size_spread_2_fullrange,volume_mean_0_300,volume_sum_0_300,volume_std_0_300,price_mean_0_300,price_std_0_300,order_count_mean_0_300,order_count_sum_0_300,order_count_std_0_300,wap1_std_mean_period_0_300,wap2_std_mean_period_0_300,log_return_1_realized_volatility_mean_period_0_300,log_return_2_realized_volatility_mean_period_0_300,log_return_1_sum_mean_period_0_300,log_return_2_sum_mean_period_0_300,volume_mean_mean_period_0_300,volume_mean_sum_period_0_300,volume_mean_std_period_0_300,price_mean_mean_period_0_300,price_mean_std_period_0_300,order_count_mean_mean_period_0_300,order_count_mean_sum_period_0_300,order_count_mean_std_period_0_300,wap1_std_std_stock_0_300,wap2_std_std_stock_0_300,log_return_1_realized_volatility_std_stock_0_300,log_return_2_realized_volatility_std_stock_0_300,log_return_1_sum_std_stock_0_300,log_return_2_sum_std_stock_0_300,volume_mean_mean_stock_0_300,volume_mean_std_stock_0_300,price_mean_mean_stock_0_300,price_mean_std_stock_0_300,order_count_mean_mean_stock_0_300,order_count_mean_std_stock_0_300,wap1_mean_300_600,wap1_std_300_600,wap1_fullrange_300_600,wap2_mean_300_600,wap2_std_300_600,wap2_fullrange_300_600,log_return_1_fullrange_300_600,log_return_1_sum_300_600,log_return_1_mean_300_600,log_return_1_realized_volatility_300_600,log_return_2_fullrange_300_600,log_return_2_sum_300_600,log_return_2_mean_300_600,log_return_2_realized_volatility_300_600,bid_ask_price_spread_1_mean.1,bid_ask_price_spread_1_std.1,bid_ask_price_spread_1_fullrange.1,bid_ask_price_spread_2_mean.1,bid_ask_price_spread_2_std.1,bid_ask_price_spread_2_fullrange.1,bid_ask_size_spread_1_mean.1,bid_ask_size_spread_1_std.1,bid_ask_size_spread_1_fullrange.1,bid_ask_size_spread_2_mean.1,bid_ask_size_spread_2_std.1,bid_ask_size_spread_2_fullrange.1,volume_mean_300_600,volume_sum_300_600,volume_std_300_600,price_mean_300_600,price_std_300_600,order_count_mean_300_600,order_count_sum_300_600,order_count_std_300_600,wap1_std_mean_period_300_600,wap2_std_mean_period_300_600,log_return_1_realized_volatility_mean_period_300_600,log_return_2_realized_volatility_mean_period_300_600,log_return_1_sum_mean_period_300_600,log_return_2_sum_mean_period_300_600,volume_mean_mean_period_300_600,volume_mean_sum_period_300_600,volume_mean_std_period_300_600,price_mean_mean_period_300_600,price_mean_std_period_300_600,order_count_mean_mean_period_300_600,order_count_mean_sum_period_300_600,order_count_mean_std_period_300_600,wap1_std_std_stock_300_600,wap2_std_std_stock_300_600,log_return_1_realized_volatility_std_stock_300_600,log_return_2_realized_volatility_std_stock_300_600,log_return_1_sum_std_stock_300_600,log_return_2_sum_std_stock_300_600,volume_mean_mean_stock_300_600,volume_mean_std_stock_300_600,price_mean_mean_stock_300_600,price_mean_std_stock_300_600,order_count_mean_mean_stock_300_600,order_count_mean_std_stock_300_600,target
0,0-1000,1000,0,0.999208,0.000427,0.001128,0.999107,0.000398,0.001455,0.000991,-0.001171,-0.000015,0.001184,0.001793,-0.001089,-0.000014,0.002684,0.000422,0.000136,0.000601,0.000744,0.000202,0.000785,58.772152,129.514572,482,42.506329,91.762496,599,26.356178,421.698853,37.846413,0.999351,0.000284,2.125000,34.0,1.408309,0.000367,0.000402,0.001448,0.002109,-0.000198,-0.000207,334.647034,37480.468750,1207.417847,1.000150,0.000896,4.118892,461.315938,2.652600,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,0.998406,0.000235,0.000974,0.998380,0.000386,0.001953,0.001077,-0.000194,-0.000002,0.001263,0.002234,-0.000519,-6.182011e-06,0.003449,0.000465,0.000201,0.000832,0.000785,0.000288,0.001478,-30.261905,194.006031,1013,16.119048,74.202492,230,34.082500,477.155029,54.587475,0.998402,0.000272,2.142857,30.0,2.626994,0.000398,0.000436,0.001510,0.002207,-0.000249,-0.000236,246.684113,27381.937500,374.158386,0.999757,0.001217,4.079591,452.834580,2.693360,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.003840,3.243385,1.198809,0.001348
1,0-10000,10000,0,0.999706,0.000310,0.001236,0.999732,0.000453,0.002080,0.001075,0.000374,0.000002,0.001921,0.001641,0.000894,0.000005,0.003385,0.000404,0.000213,0.000924,0.000731,0.000252,0.001314,48.048913,87.931747,597,2.885870,114.083868,572,78.393791,1332.694458,148.739655,0.999786,0.000297,2.705882,46.0,2.823223,0.000512,0.000542,0.001978,0.002608,-0.000285,-0.000286,409.834778,45901.496094,1277.457642,0.999976,0.001214,3.842042,430.308752,1.719662,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,1.000209,0.000323,0.001329,1.000240,0.000366,0.001359,0.001733,-0.000221,-0.000002,0.002124,0.001253,-0.000695,-6.885418e-06,0.002281,0.000565,0.000231,0.001022,0.000754,0.000215,0.001070,-58.267327,211.674509,944,-55.980198,150.240739,644,86.730652,607.114563,91.083176,1.000283,0.000305,3.428571,24.0,1.511858,0.000525,0.000561,0.001795,0.002445,0.000030,-0.000004,306.332672,34309.257812,528.679932,1.000028,0.001713,3.587736,401.826426,1.137513,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.003840,3.243385,1.198809,0.001805
2,0-10005,10005,0,1.000671,0.002085,0.005930,1.000502,0.001673,0.006349,0.004657,0.001638,0.000017,0.006434,0.005195,0.002600,0.000027,0.009354,0.002086,0.000579,0.002866,0.002667,0.000702,0.003025,-26.105263,94.669354,440,26.957895,105.811245,648,82.635696,1404.806885,118.957565,1.000655,0.001913,2.941176,50.0,2.703484,0.001173,0.001215,0.004297,0.006057,0.001870,0.001768,282.409454,31629.859375,379.877075,1.001123,0.002449,3.690100,413.291207,1.693999,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,1.001995,0.001117,0.004183,1.001894,0.001349,0.004966,0.004055,0.000734,0.000008,0.005815,0.006499,-0.000731,-8.306375e-06,0.009999,0.002141,0.000622,0.003237,0.002833,0.000455,0.002600,3.261364,111.118089,524,30.511364,110.190013,497,80.110580,961.326965,125.848511,1.002050,0.001133,3.416667,41.0,4.521833,0.001030,0.001076,0.004125,0.005606,0.000172,0.000239,294.831665,33021.148438,603.385681,1.001939,0.003773,3.525943,394.905618,1.401347,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.003840,3.243385,1.198809,0.007544
3,0-10017,10017,0,0.992769,0.002630,0.012138,0.992920,0.002265,0.012450,0.007942,0.010507,0.000079,0.008735,0.011743,0.004862,0.000037,0.012851,0.003482,0.001295,0.006950,0.004393,0.002236,0.010697,83.639098,140.409592,784,3.654135,118.318717,699,72.568245,1596.501465,144.283875,0.992861,0.002041,2.454545,54.0,3.158168,0.003173,0.003296,0.010398,0.013775,0.004919,0.004750,348.265839,38657.507812,273.327026,0.997095,0.006774,4.649639,516.109974,2.981141,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,1.000952,0.002397,0.012924,1.000474,0.003488,0.013315,0.010333,-0.003168,-0.000034,0.011322,0.014066,0.000668,7.101216e-06,0.014268,0.003624,0.001371,0.006890,0.005771,0.001465,0.007434,28.680851,177.907806,860,-7.861702,190.333584,647,108.223732,2164.474609,121.865410,1.000240,0.002523,3.950000,79.0,4.489461,0.003545,0.003684,0.011788,0.015887,-0.003682,-0.003748,394.402893,43778.718750,387.723358,0.998773,0.008890,5.007630,555.846893,3.872129,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.003840,3.243385,1.198809,0.011218
4,0-10030,10030,0,0.999657,0.000297,0.001556,0.999645,0.000328,0.001726,0.001021,-0.001273,-0.000011,0.001535,0.002165,-0.001844,-0.000016,0.003117,0.000597,0.000227,0.000849,0.001013,0.000327,0.001368,33.034188,87.391404,364,25.316239,88.005158,392,105.400513,1581.007690,107.743195,0.999531,0.000467,2.666667,40.0,2.058663,0.000569,0.000616,0.002002,0.002809,-0.000814,-0.000830,256.731110,28753.882812,350.382935,0.999946,0.001390,3.229840,361.742054,1.232023,0.000716,0.000787,0.002897,0.004439,0.002512,0.002677,106.565811,63.514313,1.000043,0.003232,3.328446,1.267156,0.999159,0.000435,0.001729,0.999203,0.000580,0.002206,0.001503,0.001363,0.000018,0.001977,0.001942,0.001874,2.465259e-05,0.003216,0.000664,0.000270,0.000991,0.001080,0.000268,0.001180,34.855263,79.186986,327,-20.421053,74.592272,342,89.993896,1259.914551,105.425652,0.999041,0.000386,2.000000,28.0,1.176697,0.000514,0.000558,0.001781,0.002661,0.000324,0.000375,262.555603,29406.226562,437.999481,0.999672,0.001687,3.182764,356.469517,1.436575,0.000599,0.000642,0.002406,0.003411,0.002105,0.002143,101.976982,58.440357,1.000028,0.003840,3.243385,1.198809,0.002854
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
428927,126-9972,9972,126,1.000025,0.000656,0.002407,1.000198,0.000643,0.002575,0.002501,-0.000912,-0.000008,0.003510,0.003024,-0.000795,-0.000007,0.004716,0.000961,0.000302,0.001463,0.001270,0.000299,0.001312,-19.562500,137.126415,821,-50.589286,105.874527,496,74.560196,1267.523315,96.122559,0.999743,0.000585,2.705882,46.0,2.172691,0.000597,0.000621,0.002288,0.003022,-0.000287,-0.000367,256.068604,28679.683594,447.448425,0.999685,0.001138,3.651820,409.003883,1.810729,0.000793,0.000821,0.003229,0.004550,0.002868,0.002951,124.057465,68.361176,1.000002,0.003331,3.408063,1.172630,1.001136,0.000735,0.002995,1.001074,0.000710,0.003128,0.002737,0.000545,0.000006,0.003494,0.003292,0.000642,6.614236e-06,0.005733,0.000941,0.000281,0.000958,0.001400,0.000334,0.001564,-112.969072,192.365508,823,-28.907216,109.716688,561,126.885010,2030.160156,174.859329,1.001057,0.000587,3.375000,54.0,2.941088,0.000606,0.000655,0.002299,0.003110,0.000692,0.000784,388.994354,43567.367188,1754.175049,1.000015,0.001608,3.731581,417.937060,2.977252,0.000691,0.000713,0.002588,0.003539,0.002729,0.002849,120.838669,66.311111,0.999995,0.004342,3.355169,1.187193,0.004435
428928,126-9973,9973,126,1.003724,0.002118,0.008004,1.003531,0.002121,0.007998,0.005079,0.000891,0.000007,0.008098,0.005308,0.000265,0.000002,0.010492,0.002460,0.000881,0.003742,0.003097,0.000945,0.003742,-40.565574,122.747674,681,-26.942623,115.981496,600,104.858994,2726.333984,130.432297,1.003577,0.001598,2.538462,66.0,2.082897,0.002115,0.002269,0.007771,0.010909,-0.002787,-0.002751,312.977570,35053.488281,325.233734,0.999877,0.004160,4.337428,485.791956,1.828659,0.000793,0.000821,0.003229,0.004550,0.002868,0.002951,124.057465,68.361176,1.000002,0.003331,3.408063,1.172630,0.999369,0.002965,0.011018,0.998545,0.003273,0.011724,0.009330,-0.007783,-0.000051,0.012827,0.012139,-0.008934,-5.801068e-05,0.017765,0.002887,0.001264,0.005225,0.003689,0.001456,0.006001,-52.000000,114.905474,675,20.785714,99.577471,600,256.054199,4608.975586,292.445374,0.999839,0.003171,5.000000,90.0,4.934870,0.002180,0.002312,0.007133,0.010104,-0.004159,-0.004398,296.141144,33167.808594,336.651794,0.995722,0.006563,4.246012,475.553306,1.901205,0.000691,0.000713,0.002588,0.003539,0.002729,0.002849,120.838669,66.311111,0.999995,0.004342,3.355169,1.187193,0.010697
428929,126-9976,9976,126,1.006518,0.001554,0.006587,1.006516,0.001551,0.007799,0.006725,-0.002616,-0.000018,0.010820,0.008707,-0.001864,-0.000013,0.015023,0.003288,0.000654,0.003227,0.003816,0.000675,0.003824,-31.204225,108.329190,798,-6.577465,92.584125,616,242.553970,2182.985840,213.524887,1.006290,0.001299,5.111111,46.0,4.136558,0.002427,0.002520,0.008846,0.011893,-0.001005,-0.000625,452.170166,50643.058594,613.799377,0.998689,0.010483,5.314826,595.260539,1.819623,0.000793,0.000821,0.003229,0.004550,0.002868,0.002951,124.057465,68.361176,1.000002,0.003331,3.408063,1.172630,1.006438,0.001213,0.005695,1.006081,0.001582,0.006279,0.005747,0.004258,0.000032,0.009904,0.006442,0.004476,3.315644e-05,0.012886,0.002812,0.000829,0.003585,0.003488,0.000855,0.003944,-41.133333,162.859329,1082,-73.711111,332.768058,1211,59.821888,598.218872,80.918709,1.006508,0.000884,1.400000,14.0,0.516398,0.001725,0.001770,0.006679,0.008431,-0.000405,-0.000546,399.418182,44734.835938,729.220276,0.998365,0.011334,4.433914,496.598403,1.558667,0.000691,0.000713,0.002588,0.003539,0.002729,0.002849,120.838669,66.311111,0.999995,0.004342,3.355169,1.187193,0.008868
428930,126-9988,9988,126,1.000654,0.000458,0.001941,1.000710,0.000644,0.002103,0.001509,-0.000182,-0.000002,0.002249,0.002590,0.000258,0.000002,0.003870,0.000986,0.000242,0.001090,0.001343,0.000397,0.001652,82.372727,164.408300,817,40.790909,133.050249,701,124.989098,999.912781,163.734283,1.001124,0.000356,3.875000,31.0,4.580627,0.000406,0.000440,0.001777,0.002450,0.000180,0.000131,240.407089,26925.593750,270.857758,1.000677,0.000949,3.866301,433.025687,1.603631,0.000793,0.000821,0.003229,0.004550,0.002868,0.002951,124.057465,68.361176,1.000002,0.003331,3.408063,1.172630,1.000620,0.000500,0.001754,1.000609,0.000495,0.001879,0.001826,0.000173,0.000001,0.002584,0.001915,0.000101,7.992957e-07,0.003581,0.000710,0.000166,0.000793,0.000988,0.000198,0.001024,38.714286,116.794168,642,18.833333,126.357699,715,95.740288,1819.065430,128.847351,1.000756,0.000500,3.894737,74.0,4.605489,0.000396,0.000435,0.001671,0.002313,0.000114,0.000102,226.638443,25383.505859,242.282501,1.000716,0.001445,3.648706,408.655086,1.472526,0.000691,0.000713,0.002588,0.003539,0.002729,0.002849,120.838669,66.311111,0.999995,0.004342,3.355169,1.187193,0.004089


In [None]:
model_pipeline = Pipeline(
    steps=[
        ('scaler',   StandardScaler()),
        ('xgboost',   XGBRegressor())
        ]
    )

In [None]:
xgb = XGBRegressor()
xgb.fit(X_train, y_train)
rmspe(y_test, xgb.predict(X_test))

0.23809433938933214

In [None]:
xg_grid = {'learning_rate':Real(0.01, 0.4, prior='uniform'),
        'gamma': Real(0, 10, prior='uniform'),
        'reg_alpha':  Real(0, 10, prior='uniform'),
        'colsample_bytree': Real(0.6, 1, prior='uniform'),
        'max_depth':  Integer(5,18), 
        'n_estimators':  Integer(5,250)
        }

cv_xg = BayesSearchCV(
    XGBRegressor(),
    xg_grid,
    n_iter=100,
    random_state=123, 
    verbose=0, 
    cv=3, 
    refit=True, 
    return_train_score=True
    )

cv_xg.fit(X_train, y_train)
joblib.dump(cv_xg, OUT_DIR/'models/xgboost.pkl')

KeyboardInterrupt: 

In [None]:
rmspe(y_test, cv_xg.best_estimator_.predict(X_test))

In [None]:
rmspe(y_test, xgb.predict(X_test))

In [None]:
cv_xg = joblib.load(OUT_DIR/'models/xgboost.pkl')

In [None]:
cv_xg.best_params_

In [None]:
xg_tuned = XGBRegressor(
    gamma = 0, 
    learning_rate = 0.148727, 
    max_depth = 12,
    n_estimators = 191,
    reg_alpha = 3.53674,
    subsample = 1
    )

#xg_tuned = XGBRegressor(**cv_xg.best_params_)
xg_tuned.fit(X_train, y_train)

In [None]:
rmspe(y_train, xg_tuned.predict(X_train))

In [None]:
rmspe(y_test, xg_tuned.predict(X_test))

In [None]:
submission = final_test_data[['id']].rename(columns = {'id': 'row_id'})
submission['target'] = cv_xg.best_estimator_.predict(final_test_data[model_col])

In [None]:
submission.info()

In [None]:
submission = final_test_data[['id']].rename(columns = {'id': 'row_id'})
submission['target'] = xg_tuned.predict(final_test_data[model_col])
submission.to_csv('submission.csv',index = False)

In [None]:
submission_sample = pd.read_csv(DATA_DIR/'sample_submission.csv')
submission_sample.info()