In [1]:
import os
import glob
from joblib import Parallel, delayed
import pandas as pd
import numpy as np
import scipy as sc
from sklearn.model_selection import KFold
import lightgbm as lgb
import warnings
warnings.filterwarnings('ignore')
pd.set_option('max_columns', 300)

In [2]:
# data directory
data_dir = '../../../data/'

# Function to calculate first WAP
def calc_wap1(df):
    wap = (df['bid_price1'] * df['ask_size1'] + df['ask_price1'] * df['bid_size1']) / (df['bid_size1'] + df['ask_size1'])
    return wap

# Function to calculate second WAP
def calc_wap2(df):
    wap = (df['bid_price2'] * df['ask_size2'] + df['ask_price2'] * df['bid_size2']) / (df['bid_size2'] + df['ask_size2'])
    return wap

# Function to calculate third WAP
def calc_wap3(df):
    wap = (df['bid_price1'] * df['ask_size1'] + df['ask_price1']*df['bid_size1'] + df['bid_price2']*df['ask_size2'] + df['ask_price2']*df['bid_size2']) / (
                            df['ask_size1'] + df['bid_size1'] + df['ask_size2'] + df['bid_size2'])
    return wap

# Function to calculate the log of the return
# Remember that logb(x / y) = logb(x) - logb(y)
def log_return(series):
    return np.log(series).diff()

# Calculate the realized volatility
def realized_volatility(series):
    return np.sqrt(np.sum(series**2))
# Function to calculate the log of the return
# Remember that logb(x / y) = logb(x) - logb(y)
def log_return(series):
    return np.log(series).diff()

# Calculate the realized volatility
def realized_volatility(series):
    return np.sqrt(np.sum(series**2))

# Function to count unique elements of a series
def count_unique(series):
    return len(np.unique(series))

# Function to read our base train and test set
def read_train_test():
    train = pd.read_csv(data_dir + 'train.csv')
    test = pd.read_csv(data_dir + 'test.csv')
    # Create a key to merge with book and trade data
    train['row_id'] = train['stock_id'].astype(str) + '-' + train['time_id'].astype(str)
    test['row_id'] = test['stock_id'].astype(str) + '-' + test['time_id'].astype(str)
    print(f'Our training set has {train.shape[0]} rows')
    return train, test

def book_preprocessor(file_path):
    df = pd.read_parquet(file_path)
    # Calculate Wap
    df['wap1'] = calc_wap1(df)
    df['wap2'] = calc_wap2(df)
    df['wap3'] = calc_wap3(df)
    # Calculate log returns
    df['log_return1'] = df.groupby(['time_id'])['wap1'].apply(log_return)
    df['log_return2'] = df.groupby(['time_id'])['wap2'].apply(log_return)
    df['log_return3'] = df.groupby(['time_id'])['wap3'].apply(log_return)
    # Calculate wap balance
    df['wap_balance12'] = abs(df['wap1'] - df['wap2'])
    df['wap_balance13'] = abs(df['wap1'] - df['wap3'])
    df['wap_balance23'] = abs(df['wap2'] - df['wap3'])
    # Calculate spread
    df['price_spread'] = (df['ask_price1'] - df['bid_price1']) / ((df['ask_price1'] + df['bid_price1']) / 2)
    df['bid_spread'] = df['bid_price1'] - df['bid_price2']
    df['ask_spread'] = df['ask_price1'] - df['ask_price2']
    df['total_volume'] = (df['ask_size1'] + df['ask_size2']) + (df['bid_size1'] + df['bid_size2'])
    df['volume_imbalance'] = abs((df['ask_size1'] + df['ask_size2']) - (df['bid_size1'] + df['bid_size2']))
    # Dict for aggregations
    create_feature_dict = {
        'wap1': [np.sum, np.mean, np.std],
        'wap2': [np.sum, np.mean, np.std],
        'wap3': [np.sum, np.mean, np.std],
        'log_return1': [np.sum, realized_volatility, np.mean, np.std],
        'log_return2': [np.sum, realized_volatility, np.mean, np.std],
        'log_return3': [np.sum, realized_volatility, np.mean, np.std],
        'wap_balance12': [np.sum, np.mean, np.std],
        'wap_balance13': [np.sum, np.mean, np.std],
        'wap_balance23': [np.sum, np.mean, np.std],
        'price_spread':[np.sum, np.mean, np.std],
        'bid_spread':[np.sum, np.mean, np.std],
        'ask_spread':[np.sum, np.mean, np.std],
        'total_volume':[np.sum, np.mean, np.std],
        'volume_imbalance':[np.sum, np.mean, np.std]
    }
    
    # Function to get group stats for different windows (seconds in bucket)
    def get_stats_window(seconds_in_bucket, add_suffix = False):
        # Group by the window
        df_feature = df[df['seconds_in_bucket'] >= seconds_in_bucket].groupby(['time_id']).agg(create_feature_dict).reset_index()
        # Rename columns joining suffix
        df_feature.columns = ['_'.join(col) for col in df_feature.columns]
        # Add a suffix to differentiate windows
        if add_suffix:
            df_feature = df_feature.add_suffix('_' + str(seconds_in_bucket))
        return df_feature


    # Get the stats for different windows
    df_feature = get_stats_window(seconds_in_bucket = 0, add_suffix = False)
    df_feature_450 = get_stats_window(seconds_in_bucket = 450, add_suffix = True)
    df_feature_300 = get_stats_window(seconds_in_bucket = 300, add_suffix = True)
    df_feature_150 = get_stats_window(seconds_in_bucket = 150, add_suffix = True)
    
    # Merge all
    df_feature = df_feature.merge(df_feature_450, how = 'left', left_on = 'time_id_', right_on = 'time_id__450')
    df_feature = df_feature.merge(df_feature_300, how = 'left', left_on = 'time_id_', right_on = 'time_id__300')
    df_feature = df_feature.merge(df_feature_150, how = 'left', left_on = 'time_id_', right_on = 'time_id__150')
    # Drop unnecesary time_ids
    df_feature.drop(['time_id__450', 'time_id__300', 'time_id__150'], axis = 1, inplace = True)# Create row_id so we can merge
    stock_id = file_path.split('=')[1]
    df_feature['row_id'] = df_feature['time_id_'].apply(lambda x: f'{stock_id}-{x}')
    df_feature.drop(['time_id_'], axis = 1, inplace = True)
    return df_feature

In [3]:
stock_id = 0
file_path_book = data_dir + "book_train.parquet/stock_id=" + str(stock_id)
df = book_preprocessor(file_path_book)
df.head()

Unnamed: 0,wap1_sum,wap1_mean,wap1_std,wap2_sum,wap2_mean,wap2_std,wap3_sum,wap3_mean,wap3_std,log_return1_sum,log_return1_realized_volatility,log_return1_mean,log_return1_std,log_return2_sum,log_return2_realized_volatility,log_return2_mean,log_return2_std,log_return3_sum,log_return3_realized_volatility,log_return3_mean,log_return3_std,wap_balance12_sum,wap_balance12_mean,wap_balance12_std,wap_balance13_sum,wap_balance13_mean,wap_balance13_std,wap_balance23_sum,wap_balance23_mean,wap_balance23_std,price_spread_sum,price_spread_mean,price_spread_std,bid_spread_sum,bid_spread_mean,bid_spread_std,ask_spread_sum,ask_spread_mean,ask_spread_std,total_volume_sum,total_volume_mean,total_volume_std,volume_imbalance_sum,volume_imbalance_mean,volume_imbalance_std,wap1_sum_450,wap1_mean_450,wap1_std_450,wap2_sum_450,wap2_mean_450,wap2_std_450,wap3_sum_450,wap3_mean_450,wap3_std_450,log_return1_sum_450,log_return1_realized_volatility_450,log_return1_mean_450,log_return1_std_450,log_return2_sum_450,log_return2_realized_volatility_450,log_return2_mean_450,log_return2_std_450,log_return3_sum_450,log_return3_realized_volatility_450,log_return3_mean_450,log_return3_std_450,wap_balance12_sum_450,wap_balance12_mean_450,wap_balance12_std_450,wap_balance13_sum_450,wap_balance13_mean_450,wap_balance13_std_450,wap_balance23_sum_450,wap_balance23_mean_450,wap_balance23_std_450,price_spread_sum_450,price_spread_mean_450,price_spread_std_450,bid_spread_sum_450,bid_spread_mean_450,bid_spread_std_450,ask_spread_sum_450,ask_spread_mean_450,ask_spread_std_450,total_volume_sum_450,total_volume_mean_450,total_volume_std_450,volume_imbalance_sum_450,volume_imbalance_mean_450,volume_imbalance_std_450,wap1_sum_300,wap1_mean_300,wap1_std_300,wap2_sum_300,wap2_mean_300,wap2_std_300,wap3_sum_300,wap3_mean_300,wap3_std_300,log_return1_sum_300,log_return1_realized_volatility_300,log_return1_mean_300,log_return1_std_300,log_return2_sum_300,log_return2_realized_volatility_300,log_return2_mean_300,log_return2_std_300,log_return3_sum_300,log_return3_realized_volatility_300,log_return3_mean_300,log_return3_std_300,wap_balance12_sum_300,wap_balance12_mean_300,wap_balance12_std_300,wap_balance13_sum_300,wap_balance13_mean_300,wap_balance13_std_300,wap_balance23_sum_300,wap_balance23_mean_300,wap_balance23_std_300,price_spread_sum_300,price_spread_mean_300,price_spread_std_300,bid_spread_sum_300,bid_spread_mean_300,bid_spread_std_300,ask_spread_sum_300,ask_spread_mean_300,ask_spread_std_300,total_volume_sum_300,total_volume_mean_300,total_volume_std_300,volume_imbalance_sum_300,volume_imbalance_mean_300,volume_imbalance_std_300,wap1_sum_150,wap1_mean_150,wap1_std_150,wap2_sum_150,wap2_mean_150,wap2_std_150,wap3_sum_150,wap3_mean_150,wap3_std_150,log_return1_sum_150,log_return1_realized_volatility_150,log_return1_mean_150,log_return1_std_150,log_return2_sum_150,log_return2_realized_volatility_150,log_return2_mean_150,log_return2_std_150,log_return3_sum_150,log_return3_realized_volatility_150,log_return3_mean_150,log_return3_std_150,wap_balance12_sum_150,wap_balance12_mean_150,wap_balance12_std_150,wap_balance13_sum_150,wap_balance13_mean_150,wap_balance13_std_150,wap_balance23_sum_150,wap_balance23_mean_150,wap_balance23_std_150,price_spread_sum_150,price_spread_mean_150,price_spread_std_150,bid_spread_sum_150,bid_spread_mean_150,bid_spread_std_150,ask_spread_sum_150,ask_spread_mean_150,ask_spread_std_150,total_volume_sum_150,total_volume_mean_150,total_volume_std_150,volume_imbalance_sum_150,volume_imbalance_mean_150,volume_imbalance_std_150,row_id
0,303.125061,1.003725,0.000693,303.105539,1.003661,0.000781,303.125613,1.003727,0.000714,0.002292,0.004499,7.613599e-06,0.00026,0.002325,0.006999,8e-06,0.000404,0.002303,0.004106,7.650981e-06,0.000237,0.117051,0.000388,0.000295,0.059363,0.000197,0.000184,0.057688,0.000191,0.0002,0.257255,0.000852,0.000211,0.053006,0.000176,0.000162,-0.045557,-0.000151,0.000126,97696,323.496689,138.101214,40738,134.89404,107.260583,68.236749,1.003482,0.000514,68.231672,1.003407,0.00064,68.235321,1.003461,0.000548,-0.000361,0.001721,-5e-06,0.00021,6.8e-05,0.004114,1e-06,0.000503,7e-05,0.002296,1e-06,0.000281,0.024868,0.000366,0.000277,0.014048,0.000207,0.000187,0.010819,0.000159,0.000189,0.053236,0.000783,0.000181,0.01779,0.000262,0.000178,-0.011274,-0.000166,0.000126,17948,263.941176,116.940077,9620,141.470588,84.467864,139.521722,1.003753,0.000487,139.509756,1.003667,0.000585,139.520718,1.003746,0.000509,0.000157,0.002953,1.131529e-06,0.000251,0.000274,0.004863,2e-06,0.000414,0.000213,0.002849,2e-06,0.000243,0.051757,0.000372,0.000273,0.027256,0.000196,0.0002,0.024502,0.000176,0.000194,0.114272,0.000822,0.000237,0.030976,0.000223,0.000173,-0.022548,-0.000162,0.000131,40995,294.928058,136.527199,19065,137.158273,97.898813,232.888919,1.003832,0.000445,232.870736,1.003753,0.000519,232.8878,1.003827,0.000431,0.000276,0.003796,1e-06,0.00025,3e-06,0.006087,1.295471e-08,0.0004,8.4e-05,0.003401,3.616409e-07,0.000224,0.091997,0.000397,0.000281,0.046882,0.000202,0.000185,0.045115,0.000194,0.000192,0.199058,0.000858,0.000221,0.043697,0.000188,0.000165,-0.034024,-0.000147,0.00012,75964,327.431034,142.761068,28672,123.586207,103.533216,0-5
1,200.047768,1.000239,0.000262,200.041171,1.000206,0.000272,200.047245,1.000236,0.000227,0.00036,0.001204,1.810239e-06,8.6e-05,0.000801,0.002476,4e-06,0.000176,0.00063,0.001507,3.166541e-06,0.000107,0.042312,0.000212,0.000155,0.021278,0.000106,0.0001,0.021034,0.000105,8.9e-05,0.078836,0.000394,0.000157,0.028358,0.000142,0.000148,-0.027001,-0.000135,6.5e-05,82290,411.45,172.263581,28410,142.05,102.139758,54.027991,1.000518,0.000235,54.021532,1.000399,0.000287,54.024997,1.000463,0.000213,-5.9e-05,0.000918,-1e-06,0.000126,0.000488,0.001883,9e-06,0.000258,0.000235,0.00112,4e-06,0.000154,0.014524,0.000269,0.000175,0.008022,0.000149,0.000126,0.006502,0.00012,7.5e-05,0.018812,0.000348,0.000144,0.012598,0.000233,0.000239,-0.007729,-0.000143,6.6e-05,24191,447.981481,177.264272,5275,97.685185,88.144569,115.045656,1.000397,0.000207,115.039774,1.000346,0.000241,115.043208,1.000376,0.000176,9.6e-05,0.000981,8.383753e-07,9.2e-05,0.000413,0.002009,4e-06,0.000188,0.00036,0.00121,3e-06,0.000113,0.027445,0.000239,0.000158,0.014315,0.000124,0.000106,0.013129,0.000114,8.2e-05,0.040589,0.000353,0.000121,0.018873,0.000164,0.00018,-0.014153,-0.000123,5.9e-05,55720,484.521739,168.586713,15584,135.513043,110.256349,173.052001,1.000301,0.000221,173.042301,1.000245,0.000266,173.049229,1.000285,0.000202,0.000298,0.001058,2e-06,8.1e-05,0.000873,0.002262,5.044579e-06,0.000172,0.000764,0.00138,4.418278e-06,0.000105,0.035454,0.000205,0.000158,0.017757,0.000103,9.8e-05,0.017697,0.000102,9e-05,0.061017,0.000353,0.000112,0.024394,0.000141,0.000154,-0.022032,-0.000127,5.8e-05,72535,419.277457,178.652395,26221,151.566474,104.576846,0-11
2,187.913849,0.999542,0.000864,187.939824,0.99968,0.000862,187.925251,0.999602,0.000848,-0.002074,0.002369,-1.109201e-05,0.000173,-0.001493,0.004801,-8e-06,0.000352,-0.001827,0.002469,-9.770073e-06,0.000181,0.062228,0.000331,0.000246,0.025361,0.000135,0.000115,0.036867,0.000196,0.000178,0.13633,0.000725,0.000164,0.036955,0.000197,0.00017,-0.037243,-0.000198,0.000171,78274,416.351064,138.433034,26586,141.414894,108.891243,43.922425,0.998237,0.000541,43.933158,0.998481,0.000766,43.926436,0.998328,0.000622,-0.001469,0.001158,-3.3e-05,0.000173,-0.001831,0.002972,-4.2e-05,0.000451,-0.001632,0.001386,-3.7e-05,0.000208,0.016055,0.000365,0.000282,0.006117,0.000139,0.0001,0.009938,0.000226,0.00021,0.026608,0.000605,0.000105,0.008186,0.000186,0.000217,-0.009143,-0.000208,0.000168,20201,459.113636,116.212559,6869,156.113636,102.02467,67.910601,0.998685,0.000779,67.92755,0.998935,0.000891,67.918931,0.998808,0.000838,-0.002591,0.001295,-3.81056e-05,0.000153,-0.001549,0.003196,-2.3e-05,0.00039,-0.002421,0.001519,-3.6e-05,0.000182,0.029308,0.000431,0.000294,0.011383,0.000167,0.000134,0.017925,0.000264,0.000218,0.046866,0.000689,0.000162,0.009622,0.000141,0.000185,-0.016945,-0.000249,0.00019,30956,455.235294,120.920736,9802,144.147059,101.873534,118.896016,0.999126,0.000829,118.918175,0.999312,0.000853,118.905969,0.99921,0.000826,-0.002854,0.002138,-2.4e-05,0.000195,-0.002986,0.004019,-2.50905e-05,0.000369,-0.002899,0.00197,-2.436141e-05,0.00018,0.044347,0.000373,0.000276,0.016837,0.000141,0.000122,0.02751,0.000231,0.000204,0.080811,0.000679,0.000163,0.0191,0.000161,0.000155,-0.028626,-0.000241,0.000195,50996,428.537815,135.376048,15718,132.084034,114.924631,0-16
3,119.859781,0.998832,0.000757,119.835941,0.998633,0.000656,119.849183,0.998743,0.000697,-0.002828,0.002574,-2.376661e-05,0.000236,-0.002053,0.003637,-1.7e-05,0.000334,-0.002562,0.002709,-2.152617e-05,0.000248,0.045611,0.00038,0.000248,0.021332,0.000178,0.000142,0.024279,0.000202,0.000138,0.103252,0.00086,0.00028,0.022764,0.00019,0.000199,-0.013001,-0.000108,9.1e-05,52232,435.266667,156.120334,17546,146.216667,121.533215,17.965415,0.998079,0.00043,17.97163,0.998424,0.000544,17.96865,0.998258,0.000495,-0.000526,0.000993,-2.9e-05,0.000239,-0.000882,0.001424,-4.9e-05,0.000342,-0.000711,0.001145,-3.9e-05,0.000275,0.006441,0.000358,0.000253,0.003339,0.000185,0.000178,0.003102,0.000172,9e-05,0.019047,0.001058,7.4e-05,0.002082,0.000116,4.6e-05,-0.000879,-4.9e-05,1.1e-05,9720,540.0,153.413704,2628,146.0,106.693624,52.91711,0.998436,0.000504,52.918125,0.998455,0.000513,52.91738,0.998441,0.000464,-0.001179,0.001776,-2.224226e-05,0.000245,-0.00044,0.002713,-8e-06,0.000376,-0.000836,0.002019,-1.6e-05,0.000279,0.017525,0.000331,0.000228,0.00883,0.000167,0.00014,0.008695,0.000164,0.000114,0.044159,0.000833,0.000278,0.008375,0.000158,0.000165,-0.005043,-9.5e-05,7.6e-05,22163,418.169811,146.485459,7669,144.698113,101.135778,80.875601,0.998464,0.000432,80.866621,0.998353,0.000477,80.871208,0.99841,0.000401,-0.00129,0.002196,-1.6e-05,0.000245,-0.001112,0.003273,-1.372564e-05,0.000366,-0.001242,0.002393,-1.532987e-05,0.000267,0.029323,0.000362,0.000247,0.01415,0.000175,0.000145,0.015173,0.000187,0.000131,0.074552,0.00092,0.000296,0.013789,0.00017,0.000191,-0.008745,-0.000108,8.5e-05,34363,424.234568,156.628404,12293,151.765432,124.293028,0-31
4,175.932865,0.999619,0.000258,175.934256,0.999626,0.000317,175.939558,0.999657,0.00025,-2e-06,0.001894,-1.057099e-08,0.000144,-0.000281,0.003257,-2e-06,0.000247,2e-06,0.001932,9.52927e-09,0.000146,0.044783,0.000254,0.000188,0.019143,0.000109,0.000118,0.02564,0.000146,0.000139,0.069901,0.000397,0.00013,0.033565,0.000191,8.3e-05,-0.019206,-0.000109,7.6e-05,60407,343.221591,158.054066,21797,123.846591,102.407501,35.982653,0.999518,0.000257,35.991844,0.999773,0.000212,35.988354,0.999677,0.00018,0.000397,0.001378,1.1e-05,0.000233,-0.000298,0.000966,-8e-06,0.000163,0.000264,0.000899,7e-06,0.000152,0.013087,0.000364,0.000203,0.007179,0.000199,0.000139,0.005909,0.000164,0.000148,0.0187,0.000519,0.000138,0.00704,0.000196,2.4e-05,-0.004895,-0.000136,6.6e-05,14110,391.944444,123.180227,4212,117.0,99.328028,88.954468,0.999488,0.000205,88.965742,0.999615,0.000272,88.96172,0.99957,0.000201,0.000645,0.00152,7.24993e-06,0.000162,-0.000201,0.002188,-2e-06,0.000233,0.000492,0.00133,6e-06,0.000142,0.022397,0.000252,0.000188,0.01017,0.000114,0.000129,0.012227,0.000137,0.000126,0.03782,0.000425,0.00014,0.017016,0.000191,7.3e-05,-0.010722,-0.00012,7.6e-05,36275,407.58427,165.851509,8851,99.449438,93.029811,134.948413,0.999618,0.000259,134.955499,0.99967,0.000293,134.95542,0.99967,0.000243,0.000491,0.001609,4e-06,0.000139,0.000299,0.002927,2.213193e-06,0.000253,0.000478,0.00164,3.53874e-06,0.000142,0.032718,0.000242,0.000193,0.014195,0.000105,0.000119,0.018523,0.000137,0.000143,0.053347,0.000395,0.000137,0.02522,0.000187,8.8e-05,-0.015757,-0.000117,8e-05,50121,371.266667,162.610706,17749,131.474074,109.275622,0-62


In [9]:
vol_cols = []
time_variation = [1, 150, 300, 450]
# feature names using wap1, wap2, wap3
for i in range(1, 4):
    for t in time_variation:
        if t == 1:
            vol_cols += [f'log_return{i}_realized_volatility']
        else:
            vol_cols += [f'log_return{i}_realized_volatility_{t}']

# trade_features
for t in time_variation:
    if t == 1: 
        vol_cols += [f'trade_log_return_realized_volatility']
    else:
        vol_cols += [f'trade_log_return_realized_volatility_{t}']
    
for i in range(1, 4):
    print([col for col in vol_cols if f'log_return{i}' in col])
    
print([col for col in vol_cols if f'trade' in col])

['log_return1_realized_volatility', 'log_return1_realized_volatility_150', 'log_return1_realized_volatility_300', 'log_return1_realized_volatility_450']
['log_return2_realized_volatility', 'log_return2_realized_volatility_150', 'log_return2_realized_volatility_300', 'log_return2_realized_volatility_450']
['log_return3_realized_volatility', 'log_return3_realized_volatility_150', 'log_return3_realized_volatility_300', 'log_return3_realized_volatility_450']
['trade_log_return_realized_volatility', 'trade_log_return_realized_volatility_150', 'trade_log_return_realized_volatility_300', 'trade_log_return_realized_volatility_450']
