In [1]:
import findspark
findspark.init()

import pyspark # only run after findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql import functions as func
from pyspark.sql.window import Window
from pyspark.sql.functions import array, col, explode, lit, struct, split, mean, stddev, lead, lag, concat, count, year, month, dayofmonth
from pyspark.sql import DataFrame
from typing import Iterable 
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import broadcast

from datetime import datetime

spark = SparkSession.builder.config("spark.driver.memory", "16g")\
                            .config("spark.driver.maxResultSize", "1g")\
                            .getOrCreate()

from util import *
import pandas as pd
pd.set_option('max_columns', 999)

In [6]:
"""
Takes orderbook features (from Intermediate layer) and aggregates them to different time buckets.
"""

'\nTakes orderbook features (from Intermediate layer) and aggregates them to different time buckets.\n'

In [2]:
exchange = 'FTX'
symbol = 'BTC-PERP'

resample_buckets = [10, 30, 60]

In [3]:
df = spark.read.load(f'data/02_intermediate/lob/exchange={exchange}/symbol={symbol}/year=2019/*/*')


In [4]:
for sec_bucket in resample_buckets:
    df = df.withColumn(f'dt_resampled_{sec_bucket}s', resample(df.timestamp, agg_interval=sec_bucket))
    

In [5]:
for sec_bucket in resample_buckets:
    df_agg = df.groupby(f'dt_resampled_{sec_bucket}s') \
        .agg(
            count(f'dt_resampled_{sec_bucket}s'),
            mean('spread'),
            stddev('spread'),
            mean('midprice'),
            stddev('midprice'),
            mean('bbo_imbalance'),
            stddev('bbo_imbalance'),
            mean('book_imbalance'),
            stddev('book_imbalance'),
    ).sort(f'dt_resampled_{sec_bucket}s')

    df_agg = df_agg.withColumnRenamed(f"dt_resampled_{sec_bucket}s", "timestamp") \
                    .withColumnRenamed(f"count(dt_resampled_{sec_bucket}s)", "count_events") \
                    .withColumnRenamed("avg(spread)", "spread_mean") \
                    .withColumnRenamed("stddev_samp(spread)", "spread_std") \
                    .withColumnRenamed("avg(midprice)", "midprice_mean") \
                    .withColumnRenamed("stddev_samp(midprice)", "midprice_std") \
                    .withColumnRenamed("avg(bbo_imbalance)", "bbo_imbalance_mean") \
                    .withColumnRenamed("stddev_samp(bbo_imbalance)", "bbo_imbalance_std") \
                    .withColumnRenamed("avg(book_imbalance)", "book_imbalance_mean") \
                    .withColumnRenamed("stddev_samp(book_imbalance)", "book_imbalance_std") 

    df_agg = df_agg.withColumn('exchange', lit(exchange)) \
                    .withColumn('symbol', lit(symbol)) \
                    .withColumn("year", year(df_agg.timestamp)) \
                    .withColumn("month", month(df_agg.timestamp)) \
                    .withColumn("day", dayofmonth(df_agg.timestamp)) \
                    .withColumn("bucket_size", lit(f"{sec_bucket}_second_bucketing"))

    spark.conf.set("spark.sql.sources.partitionOverwriteMode","dynamic")
    df_agg.write.mode('overwrite').partitionBy('exchange', 'symbol', 'bucket_size', 'year', 'month', 'day').parquet("data/03_feature/lob")
