Useful articles:

- [The short-term predictability of returns in order book markets: A deep learning perspective](https://pdf.sciencedirectassets.com/271676/1-s2.0-S0169207024X00047/1-s2.0-S0169207024000062/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjELD//////////wEaCXVzLWVhc3QtMSJIMEYCIQColRlLpV+nTPDAZMPx3XCiAZsSXspseHKIZOEEWpVGWAIhAKkUipMTCXHCJfJCjU8EM7pmx4Ybi3VhcbIM4YSswpI2KrIFCEkQBRoMMDU5MDAzNTQ2ODY1IgzHaUBW4qumlsAA57gqjwV6ZDJwSLFW1etGpf4ilzHo2DKQCKABKp87sR9MiflLHAIPClTwC8OQgLVH3O0EVdjk/djOFe2lJXL2fWdKId+JdPeb34mtmFPBNoa2e234zcc42CaZZ+2nuYLykKu7FgFZgMsgL7uy8WiJV7JP0fD5PjUiRBnINM/IhVDoMz/2HsM3PDAt6U8swxyOE6GS3fopMz8QtyaxOd1b4nnje4gMfSFMdhAOQ0F1jJEhEKGScfVH7UN/7Cwy/Ak6lRIy3dbAZFVRYg/2MiInqKpLNCCmmOS6mS/SWlLkKKPMUotK5lGLb3hIfsefslmaZNHbcgtiIZO/ItglZHBjZ9I7TIgkOvftjTiTlv4t1xb/mG3hb4oU8tI9hd1097vcz3Xskt7CVkPCaCbgUfnbHNQiMRTY/NEaWiH9EHikLLP/9vgaGxUrtI+MedNcdf+ii0E/kS/3EkEf8BH2R2cHPIMVdftjIDDYjff3mHZCQwkHnvU47DmSSyn8dF46i/FZsqrgp98h4N2nuPTe0C515bxlu3eadad+hV6mFPMYNT3kuDdCCzRUdBMKL7fZ0fNMcb0/TJ66xLWuPs+2fMmH3dHiGcW4j87GjgDPUCiTzMHw9mzffxsPTmDMS1Mdbwyoy6yA6bfQynB8ZqXXVyFNXu5Lz4Sp58ckgBws6juFfC6yuR4WjDgs85WNhBykXQrqZaf9YGRobG2J/1bLHdAhnSR7J2fJdNhdPjEk2Xp4t+FG5AoOhY8Elo31BEU8SVTTlRr+mOHlE8k9KFh5Np0fubRd9Vj91yz5idVlmngyx4mu1P1jRHtNvsU9EVjmV76/xgmRhM8e055LYnnLrLfsoN0ery0OfLaEhuy2ky9P8ZUzBx/VMJab6LkGOrAB1vrP6dvfwKdo27s1ZH0+hStVJ5/j5/jV4nsd3wJRec0dLaVe+TXhRtWQZpjOH0/W9qyuApgwhafaVNR3sG8uikmoXpNl/qcydrjQ735T5jXFPv521VXYdFkeOpZWeefELBc0sGd7Gu/9rSWBjZw1TAMT4fnftmHPJ5WKg3iLgypbS59hEdLw45pAZgkfsFaogTHvaiH+fcTfw60sI+gkpR+beaf75bR8f+bJJ44bf1g%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20241117T161753Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY2CNS7ALU/20241117/us-east-1/s3/aws4_request&X-Amz-Signature=4d308461505864efa28d74d5d860247b74c92d1dfa129589c0978d272be2bd2a&hash=f6b345b3c06112ac6875d095f3fedc7fb20f4d7afff105ec766e34f19d40ef0d&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0169207024000062&tid=spdf-a4179bf3-fe28-48dc-a3a1-78e08aec5a61&sid=161508629856824cf868caf8f49d69b18751gxrqb&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=1d045d0a57565f565f5c51&rr=8e4108934a9e779f&cc=gb)

In [1]:
import plotnine as pn
import polars as pl
import polars.selectors as cs


In [2]:
# Make data more easily accessible
# Read L2 data
# df = pl.read_csv("data/data.csv.gz")
df = (
    pl.read_parquet("data/data.parquet")
    # Assuming that order reflects time
    .with_row_index(name="time").with_columns(pl.col("time").cast(pl.Int64))
    # # Fill nulls with 0
    # .with_columns((cs.contains("Rate") | cs.contains("Size")).fill_null(0))
)

cols = ["time", "y"]

# Reorder cols
order_book_feat = []
for side in ["ask", "bid"]:
    for i in range(15):
        order_book_feat += [f"{side}Rate{i}", f"{side}Size{i}"]

df = df.select(cols + order_book_feat)

# check no nan (there are nulls though, as expected)
assert df.with_columns(pl.all().is_nan()).sum_horizontal().sum() == 0

df.shape

(3497666, 62)

In [3]:
# compute tick size
exps = [
    pl.col(f"{side}Rate{i}") - pl.col(f"{side}Rate{i-1}")
    for i in range(1, 15)
    for side in ["ask", "bid"]
]

tick_size = (
    # select prices
    df.select(cs.contains("Rate"))
    # run differences wrt previous level
    .with_columns(*exps)
    # get absolute values
    .with_columns(pl.all().abs())
    # get the minimum per row and then overall
    .min_horizontal().min()
)
print(f"{tick_size=}")

tick_size=0.5


In [4]:
df

time,y,askRate0,askSize0,askRate1,askSize1,askRate2,askSize2,askRate3,askSize3,askRate4,askSize4,askRate5,askSize5,askRate6,askSize6,askRate7,askSize7,askRate8,askSize8,askRate9,askSize9,askRate10,askSize10,askRate11,askSize11,askRate12,askSize12,askRate13,askSize13,askRate14,askSize14,bidRate0,bidSize0,bidRate1,bidSize1,bidRate2,bidSize2,bidRate3,bidSize3,bidRate4,bidSize4,bidRate5,bidSize5,bidRate6,bidSize6,bidRate7,bidSize7,bidRate8,bidSize8,bidRate9,bidSize9,bidRate10,bidSize10,bidRate11,bidSize11,bidRate12,bidSize12,bidRate13,bidSize13,bidRate14,bidSize14
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0,-0.5,1619.5,1.0,1620.0,10.0,1621.0,24.0,,,,,,,,,,,,,,,,,,,,,,,,,1615.0,7.0,1614.0,10.0,1613.0,1.0,1612.0,10.0,1611.0,20.0,1610.0,3.0,1607.0,20.0,1606.0,27.0,1605.0,11.0,1604.0,14.0,1603.0,35.0,1602.0,10.0,1601.5,1.0,1601.0,10.0,1600.0,13.0
1,-0.5,1619.5,1.0,1620.0,10.0,1621.0,24.0,1621.5,5.0,,,,,,,,,,,,,,,,,,,,,,,1615.0,7.0,1614.0,10.0,1613.0,1.0,1612.0,10.0,1611.0,20.0,1610.0,3.0,1607.0,20.0,1606.0,27.0,1605.0,11.0,1604.0,14.0,1603.0,35.0,1602.0,10.0,1601.5,1.0,1601.0,10.0,1600.0,13.0
2,-0.5,1619.5,1.0,1620.0,10.0,1621.0,24.0,1621.5,5.0,1622.0,2.0,,,,,,,,,,,,,,,,,,,,,1615.0,7.0,1614.0,10.0,1613.0,1.0,1612.0,10.0,1611.0,20.0,1610.0,3.0,1607.0,20.0,1606.0,27.0,1605.0,11.0,1604.0,14.0,1603.0,35.0,1602.0,10.0,1601.5,1.0,1601.0,10.0,1600.0,13.0
3,-0.5,1619.5,1.0,1620.0,10.0,1621.0,24.0,1621.5,5.0,1622.0,22.0,,,,,,,,,,,,,,,,,,,,,1615.0,7.0,1614.0,10.0,1613.0,1.0,1612.0,10.0,1611.0,20.0,1610.0,3.0,1607.0,20.0,1606.0,27.0,1605.0,11.0,1604.0,14.0,1603.0,35.0,1602.0,10.0,1601.5,1.0,1601.0,10.0,1600.0,13.0
4,-0.5,1619.5,1.0,1620.0,10.0,1621.0,24.0,1621.5,5.0,1622.0,32.0,,,,,,,,,,,,,,,,,,,,,1615.0,7.0,1614.0,10.0,1613.0,1.0,1612.0,10.0,1611.0,20.0,1610.0,3.0,1607.0,20.0,1606.0,27.0,1605.0,11.0,1604.0,14.0,1603.0,35.0,1602.0,10.0,1601.5,1.0,1601.0,10.0,1600.0,13.0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
3497661,0.0,1576.0,3.0,1576.5,1.0,1577.0,10.0,1577.5,1.0,1578.0,3.0,1579.0,1.0,1579.5,8.0,1581.5,10.0,1582.0,10.0,1582.5,41.0,1583.0,25.0,1584.0,20.0,1585.0,14.0,1585.5,1.0,1587.5,2.0,1575.0,10.0,1574.0,4.0,1573.0,31.0,1571.0,5.0,1570.5,2.0,1570.0,104.0,1569.5,1.0,1567.0,25.0,1566.0,61.0,1565.5,26.0,1565.0,7.0,1564.0,202.0,1563.0,2.0,1562.5,2.0,1562.0,172.0
3497662,0.0,1576.0,2.0,1576.5,1.0,1577.0,10.0,1577.5,1.0,1578.0,3.0,1579.0,1.0,1579.5,8.0,1581.5,10.0,1582.0,10.0,1582.5,41.0,1583.0,25.0,1584.0,20.0,1585.0,14.0,1585.5,1.0,1587.5,2.0,1575.0,10.0,1574.0,4.0,1573.0,31.0,1571.0,5.0,1570.5,2.0,1570.0,104.0,1569.5,1.0,1567.0,25.0,1566.0,61.0,1565.5,26.0,1565.0,7.0,1564.0,202.0,1563.0,2.0,1562.5,2.0,1562.0,172.0
3497663,0.0,1576.0,3.0,1576.5,1.0,1577.0,10.0,1577.5,1.0,1578.0,3.0,1579.0,1.0,1579.5,8.0,1581.5,10.0,1582.0,10.0,1582.5,41.0,1583.0,25.0,1584.0,20.0,1585.0,14.0,1585.5,1.0,1587.5,2.0,1575.0,10.0,1574.0,4.0,1573.0,31.0,1571.0,5.0,1570.5,2.0,1570.0,104.0,1569.5,1.0,1567.0,25.0,1566.0,61.0,1565.5,26.0,1565.0,7.0,1564.0,202.0,1563.0,2.0,1562.5,2.0,1562.0,172.0
3497664,0.0,1576.0,3.0,1576.5,1.0,1577.0,10.0,1578.0,3.0,1579.0,1.0,1579.5,8.0,1581.5,10.0,1582.0,10.0,1582.5,41.0,1583.0,25.0,1584.0,20.0,1585.0,14.0,1585.5,1.0,1587.5,2.0,1588.0,20.0,1575.0,10.0,1574.0,4.0,1573.0,31.0,1571.0,5.0,1570.5,2.0,1570.0,104.0,1569.5,1.0,1567.0,25.0,1566.0,61.0,1565.5,26.0,1565.0,7.0,1564.0,202.0,1563.0,2.0,1562.5,2.0,1562.0,172.0


In [3]:
df.with_columns((cs.contains("Rate0") - pl.col("midprice")) / 0.5).select(cs.contains("Rate0"))["askRate0"].describe()

statistic,value
str,f64
"""count""",3497666.0
"""null_count""",0.0
"""mean""",0.687398
"""std""",0.334971
"""min""",0.5
"""25%""",0.5
"""50%""",0.5
"""75%""",1.0
"""max""",9.0


In [5]:
# Compute order flow
def order_flow(price_col: str, vol_col: str, side: str) -> pl.Expr:
    price_diff = pl.col(price_col).diff(1)
    vol_lag = pl.col(vol_col).shift(1)
    
    ge_case = pl.col(vol_col).shift(1).neg() if side == "ask" else pl.col(vol_col)
    le_case = pl.col(vol_col).shift(1).neg() if side == "bid" else pl.col(vol_col)
    eq_case = pl.col(vol_col).diff(1)
    
    return (
        pl
        .when(price_diff > 0).then(ge_case)
        .when(price_diff < 0).then(le_case)
        .otherwise(eq_case)
    )

# I checked that the order flow is correct, so dropping prices and volumes
df_of = (
    df
    .with_columns((cs.contains("Rate") | cs.contains("Size")).fill_null(0))
    .with_columns(
        order_flow(f"{side}Rate{l}", f"{side}Size{l}", side).alias(f"{side}OrderFlow{l}") 
        for l in range(15)
        for side in ["bid", "ask"]
    )
    .select(cs.by_name(cols) | cs.contains("OrderFlow"))
)

In [None]:
(
    pn.ggplot(df_of, pn.aes(x="time", y="askOrderFlow0"))
    + pn.geom_line()
)

In [None]:
num_ticks = 10
tick_size = 0.5

(
    df
    .with_columns(bid_vol=pl.concat_list(cs.contains("bidSize")).list.)
    .select("bid_vol")
)


In [None]:
# compute cumulative sum of volume
(
    df
    .with_columns((cs.contains("Rate") | cs.contains("Size")).fill_null(0))
    .with_columns(pl.cum_sum_horizontal(f"bidSize{i}" for i in range(15)))
    .select(["time", "cum_sum"])
    .with_columns(cum_sum=pl.concat_list(pl.col("cum_sum").struct.unnest()))
)



In [45]:
# Standardise order flow
window_size = 100
min_periods = 5
def standard_scaler(col, **kwargs) -> pl.Expr:
    return (col - col.rolling_mean(**kwargs)) / (col.rolling_std(**kwargs) + 1e-8)

def min_max_scaler(col, feature_range=(0, 1), **kwargs) -> pl.Expr:
    range_min, range_max = feature_range
    rolling_min = col.rolling_min(**kwargs)
    rolling_max = col.rolling_max(**kwargs)
    scaled = (col - rolling_min) / (rolling_max - rolling_min + 1e-8)
    # Scale to the desired feature range
    return scaled * (range_max - range_min) + range_min

# df_of = df_of.with_columns(standard_scaler(cs.contains("OrderFlow"), window_size=window_size, min_periods=min_periods))
df_of = df_of.with_columns(min_max_scaler(cs.contains("OrderFlow"), window_size=window_size, min_periods=min_periods))

In [None]:
df_of["askOrderFlow0"].describe()

In [None]:
df_of["askOrderFlow0"].describe()

In [None]:
(
    pn.ggplot(df_of[-2000:], pn.aes(x="time", y="askOrderFlow0"))
    + pn.geom_line()
    + pn.geom_line(pn.aes(y="bidOrderFlow0"), color="red")
)

In [None]:
(
    pn.ggplot(df, pn.aes(x="time", y="midprice"))
    + pn.geom_line()
)

In [None]:
(
    pn.ggplot(df, pn.aes(x="time", y="y"))
    + pn.geom_line()
)

In [4]:
window_size = 100

def normalise_exp(col_name: str, window_size: int) -> pl.Expr:
    return (pl.col(col_name) - pl.col(col_name).rolling_mean(window_size)) / pl.col(col_name).rolling_std(window_size)

norm_df = (
    df
    # Normalise volumes within each timestep wrt to total depth
    .with_columns(cs.contains("Size") / pl.col("total_depth"))
    # Normalise prices within each timestep wrt to midprice and spread
    .with_columns((cs.contains("Rate") - pl.col("midprice")) / pl.col("spread"))
    # Normalise derived features across timesteps
    .with_columns(normalise_exp(c, window_size) for c in derived_feat)
)

assert len(norm_df) - len(norm_df.drop_nulls(subset=derived_feat)) == window_size - 1
norm_df = norm_df.drop_nulls(subset=derived_feat)

In [None]:
norm_df

In [37]:
window_size = 100

def normalise_exp(col_name: str, window_size: int) -> pl.Expr:
    return (pl.col(col_name) - pl.col(col_name).rolling_mean(window_size)) / pl.col(col_name).rolling_std(window_size)

norm_df = df.with_columns(
    # Normalise prices within each timestep wrt to midprice and spread
    (cs.contains("Rate") - pl.col("midprice")) / pl.col("spread"),
    # Normalise volumes within each timestep wrt to total depth
    cs.contains("Size") / pl.col("total_depth"),
    # Normalise derived features across timesteps
    *[normalise_exp(c, window_size) for c in derived_feat],
)


In [92]:
p = (
    norm_df.with_columns(p=(pl.col("midprice") != pl.col("midprice").shift(1)).cum_sum()).select(["time", "p"])
    # .group_by("p").len().describe()
)

In [None]:
p.group_by("p").agg(pl.col("time")).filter(pl.col("time").list.len() > 500).with_columns(n_periods=pl.col("time").list.len())

In [None]:
df.filter((pl.col("time") > 178225) & (pl.col("time") < 178864))

In [None]:
norm_df["spread"].describe()

In [73]:
from statsmodels.tsa.stattools import acf

In [None]:
from statsmodels.tsa.seasonal import STL
res = STL(df["midprice"].to_numpy(), period=2).fit()
res.plot()


In [None]:
acf_values = acf(df["midprice"], nlags=)
lags = list(range(len(acf_values)))

(
    pn.ggplot(pl.DataFrame({'lags': lags, 'acf_values': acf_values}), pn.aes(x='lags', y='acf_values'))
    + pn.geom_bar(stat='identity')
    + pn.ggtitle('Autocorrelation Function')
    + pn.xlab('Lags')
    + pn.ylab('ACF Values')
)

In [None]:
df.with_columns(
    # Normalise prices within each timestep wrt to midprice and spread
    (cs.contains("Rate") - pl.col("midprice")) / pl.col("spread"),
    # Normalise volumes within each timestep wrt to total depth
    cs.contains("Size") / pl.col("total_depth"),
    # Normalise derived features across timesteps
    *[normalise_exp(c, window_size) for c in derived_feat],
)

In [None]:
df

In [None]:
norm_df.filter(pl.col("midprice").is_nan()).head()

In [None]:
norm_df["midprice"].describe()

In [None]:
(
    pn.ggplot(norm_df, pn.aes("time", "midprice"))
    + pn.geom_line()
)

In [85]:
window_size = 100

def normalise_exp(col_name: str, window_size: int) -> pl.Expr:
    return (pl.col(col_name) - pl.col(col_name).rolling_mean(window_size)) / pl.col(col_name).rolling_std(window_size)

norm_df = df.with_columns([normalise_exp(c, window_size) for c in df.columns if c not in ["time", "y"]])
assert len(norm_df) - len(norm_df.drop_nulls()) == window_size - 1

norm_df = norm_df.drop_nulls()

In [None]:
norm_df

In [None]:
norm_df["midprice"].drop_nulls().describe()

In [None]:
# check target value is midprice(t+87) - midprice(t) clipped between -5 and 5
assert len(
    df.select(["midprice", "y"])
    .with_columns(y_rec=(pl.col("midprice").shift(-87) - pl.col("midprice")).clip(-5, 5))
    .with_columns(diff=pl.col("y") - pl.col("y_rec"))
    .drop_nulls()
    .filter(pl.col("diff") != 0)
) == 0

In [None]:
# compute tick size
exps = [
    pl.col(f"{side}Rate{i}") - pl.col(f"{side}Rate{i-1}")
    for i in range(1, 15)
    for side in ["ask", "bid"]
]

tick_size = (
    # select prices
    df.select(cs.contains("Rate"))
    # run differences wrt previous level
    .with_columns(*exps)
    # get absolute values
    .with_columns(pl.all().abs())
    # get the minimum per row and then overall
    .min_horizontal().min()
)
print(f"{tick_size=}")

In [68]:
exps = [
    ((pl.col("midprice") - pl.col(f"{side}Rate{i}")) / tick_size).alias(f"{side}Rate{i}_from_mid")
    for i in range(0, 15)
    for side in ["ask", "bid"]
]
rel_df = (
    df
    # compute how many ticks away from midprice
    .with_columns(*exps)
    # convert spread as number of ticks
    .with_columns(pl.col("spread") / tick_size)
    # select prices
    .select(cs.by_name(default_cols) | cs.contains("_from_mid"))
)

In [None]:
rel_df["spread"].describe()

In [None]:
df.filter(pl.col("askRate1").is_null())

In [None]:
(
    pn.ggplot(df.with_columns(diff=pl.col("askRate0") - pl.col("askRate1")), pn.aes("time", "diff"))
    + pn.geom_line()
)

In [None]:
(
    pn.ggplot(rel_df, pn.aes(x="time", y="midprice"))
    + pn.geom_line()
)

In [None]:
(
    pn.ggplot(rel_df, pn.aes(x="time", y="y"))
    + pn.geom_line()
)

In [None]:
(
    pn.ggplot(rel_df, pn.aes(x="spread"))
    + pn.geom_histogram(bins=20)
)

In [None]:
df

In [5]:
# # Basic features
# df = df.with_columns(
#     midprice=(pl.col("askRate0") + pl.col("bidRate0")) / 2,
#     spread=pl.col("askRate0") - pl.col("bidRate0"),
#     # skew=pl.col("askSize0").log() - pl.col("bidSize0").log(),
#     # total_ask_size=pl.sum_horizontal(cs.contains("askSize")),
#     # total_bid_size=pl.sum_horizontal(cs.contains("bidSize")),
# )

# # # Volume-Weighted Average Price (VWAP)
# # df = df.with_columns(
# #     ask_vmap=pl.col('total_ask_size') / pl.sum_horizontal(pl.col(f'askRate{i}') * pl.col(f'askSize{i}') for i in range(15)),
# #     bid_vmap=pl.col('total_bid_size') / pl.sum_horizontal(pl.col(f'bidRate{i}') * pl.col(f'bidSize{i}') for i in range(15)),
# # )

In [None]:
pdata = (
    df[:200_000]
    .select(cs.contains("Rate") | cs.contains("Size") | cs.contains("time"))
    .unpivot(index="time")
    .with_columns(
        level=pl.col("variable").str.extract("(\d+)").cast(pl.Int16),
        side=pl.col("variable").str.extract("([a-z]+)"),
        is_volume=pl.col("variable").str.contains("Size"),
    )
    .drop("variable")
    .pivot(index=["time", "level", "side"], on="is_volume", values="value")
    .rename({"false": "price", "true": "volume"})
    .filter(pl.col("level") < 10)
    # .filter(pl.col("volume") > 0)
    # .pivot(index=["time", "level"], on=["side", "is_volume"], values="value")
    # .rename(
    #     {
    #         "{\"ask\",false}": "ask_price",
    #         "{\"ask\",true}": "ask_volume",
    #         "{\"bid\",false}": "bid_price",
    #         "{\"bid\",true}": "bid_volume",
    #     }
    # )
)
pdata

In [None]:
pdata.group_by(["level", "side"]).agg(pl.col("volume").mean()).pivot(on="side", index="level").sort("level")

In [36]:
# (
#     pn.ggplot(pdata.filter(pl.col("level") < 5), pn.aes("time", "price", alpha="volume", colour="side"))
#     + pn.geom_point()
#     + pn.scale_alpha_continuous(range=(0.01, 1), guide=None, limits=(80, None))
#     + pn.scale_colour_manual(cmap)
#     + pn.theme_bw()
#     + pn.theme(legend_position="none")
# )

In [None]:
import lightgbm as lgb
from sklearn.linear_model import LinearRegression

models = [
    lgb.LGBMRegressor(random_state=0, verbosity=-1),
    LinearRegression(),
]

In [None]:
from mlforecast import MLForecast
from mlforecast.lag_transforms import ExpandingMean, RollingMean
from mlforecast.target_transforms import Differences

fcst = MLForecast(
    models=models,
    freq=1,
    # lags=[7, 14],
    # lag_transforms={
    #     1: [ExpandingMean()],
    #     7: [RollingMean(window_size=28)]
    # },
    # target_transforms=[],
)


In [None]:
regdata = (
    df.select(cs.starts_with("ask") | cs.starts_with("ask") | cs.by_name(["y", "time", "uid"]))
    .drop("ask_vmap")
    .to_pandas()
)

In [None]:
fcst.fit(regdata, id_col="uid", time_col="time")


In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_squared_error

# Step 1: Define indices for time-based splits
n = len(regdata)
train_end = int(n * 0.7)  # 70% for training
val_start = int(n * 0.8)  # Skip the middle 10%
val_end = n  # Last 20% for validation

y = regdata["y"].values
X = regdata.drop(columns=["y", "uid", "time"])

# Step 2: Create the time-based splits
X_train, y_train = X[:train_end], y[:train_end]
X_val, y_val = X[val_start:val_end], y[val_start:val_end]

# Step 3: Re-train the linear regression model on the new splits
model = RandomForestRegressor()
model.fit(X_train, y_train)

# Step 4: Make predictions on the validation set
y_pred = model.predict(X_val)

# Step 5: Evaluate the model with updated splits
r2_time_split = r2_score(y_val, y_pred)
mse_time_split = mean_squared_error(y_val, y_pred)

r2_time_split, mse_time_split


In [None]:

# Drop unnecessary columns for simplicity in the baseline model
features = ['spread', 'total_ask_size', 'total_bid_size', 'ask_vwap', 'bid_vwap']
X = data.select(features)
y = data['y']

X.head(), y.head()


In [None]:
# # Remove trading halts (order book does not change)
# prev_len = len(df)
# df = df.unique().sort("time")
# prev_len - len(df)

In [None]:
# Let's look at a smaller sample (this is enough to have similar statistics to the original size)
sdf = df.head(1_000_000 // 2)

In [None]:
(
    sdf
    .rename({"askRate0": "ask_price", "bidRate0": "bid_price", "askSize0": "ask_vol", "bidSize0": "bid_vol"})
    # .with_columns(
    #     ask_prices=pl.concat_list(cs.contains("askRate")),
    #     ask_prices=pl.concat_list(cs.contains("askSize")),
    #     bid_prices=pl.concat_list(cs.contains("bidRate")),
    # )
    
)

In [None]:
df["y"].describe(), df["y"].unique().to_numpy()

In [None]:
(
    pn.ggplot(df, pn.aes("time"))
    + pn.geom_line(pn.aes(y="askRate0"), colour="red")
    + pn.geom_line(pn.aes(y="bidRate0"), colour="green")
)

In [None]:
(
    pn.ggplot(df, pn.aes("time", "skew"))
    + pn.geom_line()
)

In [None]:
df["askRate0", "bidRate0", "y"].corr()

In [None]:
pdata = (
    df.select(cs.contains("Size") | cs.contains("time"))
    .unpivot(index="time")
    .with_columns(level=pl.col("variable").str.extract("(\d+)").cast(pl.Int16))
    .with_columns(level=pl.when(pl.col("variable").str.starts_with("ask")).then(pl.col("level")).otherwise(pl.col("level").neg()))
)

In [None]:
(
    pn.ggplot(pdata.filter(pl.col("time") == 49), pn.aes("factor(level)", "value"))
    + pn.geom_bar(stat="identity")
)

In [None]:
(
    pn.ggplot(pdata.filter(pl.col("time") == 50), pn.aes("factor(level)", "value"))
    + pn.geom_bar(stat="identity")
)

In [None]:
(
    pn.ggplot(pdata.filter(pl.col("time") == 51), pn.aes("factor(level)", "value"))
    + pn.geom_bar(stat="identity")
)