In [1]:
# TODO:

# load test data (messages & book)
# select (random) test sequence
# encode msg and book sequence for model

# get raw book data (L2) at the start of the sequence
# init simulator with initial book
# replay sequence in simulator (actual)

# load trained model
# predict next message
# map message to one the simulator understands & is valid
# apply message to simulator (predicted)
# get L2 representation and encode it for model

In [2]:
%load_ext autoreload
%autoreload 2

In [30]:
import numpy as onp
import jax.numpy as jnp
from jax.nn import one_hot
import pandas as pd
from glob import glob
from functools import partial
from typing import Union
import time
from typing import Optional, Union

from lob_seq_model import LobPredModel
from encoding import Vocab, Message_Tokenizer
from lobster_dataloader import LOBSTER_Dataset, LOBSTER_Subset, LOBSTER_Sampler, LOBSTER
import preproc

In [31]:
data_dir = '/nfs/home/peern/LOBS5/data/raw/'
#save_dir = '/nfs/home/peern/LOBS5/data/'

In [32]:
message_files = sorted(glob(data_dir + '*message*.csv'))
book_files = sorted(glob(data_dir + '*orderbook*.csv'))

In [33]:
# load test data (last day)

m = pd.read_csv(
    message_files[-1],
    names=['time', 'event_type', 'order_id', 'size', 'price', 'direction'],
    index_col=False)

b = pd.read_csv(
    book_files[-1],
    index_col=False,
    header=None
)

# remove diallowed order types
m = m.loc[m.event_type.isin([1, 2, 3, 4])]
b = b.loc[m.index]

  exec(code_obj, self.user_global_ns, self.user_ns)
  return func(*args, **kwargs)


## Encoding

In [34]:
# Book encoding
price_levels = 40  # how many ticks to represent

In [35]:
# Message encoding
v = Vocab()
tok = Message_Tokenizer()

In [36]:
print('<< pre processing >>')
m_proc = tok.preproc(m, b)
print('<< encoding >>')
m_enc = tok.encode(m_proc, v)

<< pre processing >>
truncating 0.0000% of prices > 9900
truncating 0.0000% of prices < -9900
<< encoding >>


In [37]:
# remove first message from raw data as well
m = m.iloc[1:]

In [38]:
m_enc

array([[    3,     3,     8, ...,     2,     2,     2],
       [    3,     3,     8, ...,     2,     2,     2],
       [    3,     3,     9, ...,     2,     2,     2],
       ...,
       [   26,   402,   997, ..., 11108, 11007, 11110],
       [   26,   402,   985, ..., 11108, 11010, 11110],
       [   26,   402,  1002, ...,     2,     2,     2]])

In [39]:
m.shape

(1829106, 6)

In [40]:
m_proc.shape

(1829106, 10)

In [41]:
m_enc.shape

(1829106, 20)

In [42]:
# TODO: load from saved file instead
b_enc = preproc.process_book(b, price_levels=price_levels)

In [43]:
b.shape

(1829107, 40)

In [44]:
b_enc.shape

(1829107, 41)

## Data Initialisation

In [25]:
n_messages = 500

In [19]:
# when to start the prediction
# convert time into seconds after midnight
start_time = (pd.to_datetime('11:00') - pd.to_datetime('00:00')).total_seconds()
# get seq end index
end_i = len(m.loc[m.time < start_time])

In [20]:
m_seq = m_enc[end_i - n_messages : end_i].reshape(-1)  # (n_messages [500] * levels [20], )
b_seq = b_enc[end_i - n_messages : end_i]              # (n_messages [500], price_levels + 1 [41])

## Simulator

In [21]:
import os
import sys

# add git submodule to path to allow imports to work
submodule_name = 'AlphaTrade'
(parent_folder_path, current_dir) = os.path.split(os.path.abspath(''))
sys.path.append(os.path.join(parent_folder_path, submodule_name))

In [22]:
from gymnax_exchange.jaxob.jorderbook import OrderBook
import gymnax_exchange.jaxob.JaxOrderbook as job

In [23]:
# TODO: integrate this into simualtor: OrderBook

def init_msgs_from_l2(book: Union[pd.Series, onp.ndarray]) -> jnp.ndarray:
    orderbookLevels = len(book) // 4  # price/quantity for bid/ask
    data = jnp.array(book).reshape(int(orderbookLevels*2),2)
    newarr = jnp.zeros((int(orderbookLevels*2),8))
    initOB = newarr \
        .at[:,3].set(data[:,0]) \
        .at[:,2].set(data[:,1]) \
        .at[:,0].set(1) \
        .at[0:orderbookLevels*4:2,1].set(-1) \
        .at[1:orderbookLevels*4:2,1].set(1) \
        .at[:,4].set(0) \
        .at[:,5].set(job.INITID) \
        .at[:,6].set(34200) \
        .at[:,7].set(0).astype('int32')
    return initOB

def msgs_to_jnp(m_df: pd.DataFrame) -> jnp.ndarray:
    m_df = m_df.copy()
    cols = ['Time', 'Type', 'OrderID', 'Quantity', 'Price', 'Side']
    if m_df.shape[1] == 7:
        cols += ["TradeID"]
    m_df.columns = cols
    m_df['TradeID'] = 0  #  TODO: should be TraderID for multi-agent support
    col_order=['Type','Side','Quantity','Price','TradeID','OrderID','Time']
    m_df = m_df[col_order]
    m_df = m_df[(m_df['Type'] != 6) & (m_df['Type'] != 7) & (m_df['Type'] != 5)]
    time = m_df["Time"].astype('string').str.split('.',expand=True)
    m_df[["TimeWhole","TimeDec"]] = time.astype('int32')
    m_df = m_df.drop("Time", axis=1)
    mJNP = jnp.array(m_df)
    return mJNP

def reset_orderbook(
        b: OrderBook,
        l2_book: Optional[Union[pd.Series, onp.ndarray]] = None,
    ) -> None:
    b.orderbook_array = b.orderbook_array.at[:].set(-1)
    if l2_book is not None:
        msgs = init_msgs_from_l2(l2_book)
        b.process_orders_array(msgs)

In [24]:
sim = OrderBook(price_levels=10, orderQueueLen=20)
sim

<gymnax_exchange.jaxob.jorderbook.OrderBook at 0x7f216409d700>

In [25]:
# init simulator at the start of the sequence
reset_orderbook(sim, b.iloc[end_i - n_messages])

In [26]:
# replay sequence in simulator (actual)
# so that sim is at the same state as the model
replay = msgs_to_jnp(m.iloc[end_i - n_messages : end_i])
trades = sim.process_orders_array(replay)

In [27]:
sim.get_L2_state()

Array([988000,    100, 987900,    802, 988100,    182, 987800,    782,
       988200,   1056, 987700,    600, 988300,    706, 987600,    500,
       988400,   1100, 987500,   1250, 988500,   1012, 987400,    850,
       988600,    468, 987300,    490, 988700,   1615, 987200,     50,
       988800,    431, 987100,    750, 989000,     50, 987000,     50],      dtype=int32)

## Model

In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".25"
import torch
torch.multiprocessing.set_start_method('spawn')

In [56]:
from argparse import Namespace
from glob import glob

import jax
from jax import random
from jax.scipy.linalg import block_diag
from flax.training import checkpoints
import orbax

#from lob.lob_seq_model import BatchLobPredModel
from lob.train_helpers import create_train_state, eval_step, prep_batch, cross_entropy_loss, compute_accuracy
from s5.ssm import init_S5SSM
from s5.ssm_init import make_DPLR_HiPPO
from s5.dataloading import make_data_loader
from lob.lobster_dataloader import LOBSTER_Dataset, LOBSTER

import validation_helpers as valh
from lob.init_train import init_train_state, load_checkpoint, load_args_from_checkpoint

In [4]:
# necessary for checkpoints to be loaded in jupyter notebook

import nest_asyncio
nest_asyncio.apply()

In [57]:
args = load_args_from_checkpoint('../checkpoints_book_causal/')

In [58]:
args

Namespace(C_init='trunc_standard_normal', USE_WANDB=True, activation_fn='half_glu1', batchnorm=True, bidirectional=False, blocks=8, bn_momentum=0.95, bsz=16, clip_eigs=False, conj_sym=True, cosine_anneal=True, d_model=32, dataset='lobster-prediction', dir_name='./data', discretization='zoh', dt_global=False, dt_max=0.1, dt_min=0.001, early_stop_patience=1000, epochs=100, jax_seed=1919, lr_factor=1.0, lr_min=0, lr_patience=1000000, masking='causal', mode='pool', n_layers=6, opt_config='standard', p_dropout=0.0, prenorm=True, reduce_factor=1.0, ssm_lr_base=0.0005, ssm_size_base=32, use_book_data=True, wandb_entity='peer-nagy', wandb_project='LOBS5', warmup_end=1, weight_decay=0.05)

In [45]:
v = Vocab()
n_classes = len(v)
seq_len = n_messages * Message_Tokenizer.MSG_LEN
book_dim = b_enc.shape[1]
book_seq_len = n_messages

In [63]:
new_train_state = init_train_state(
    args,
    n_classes=n_classes,
    seq_len=seq_len,
    book_dim=book_dim,
    book_seq_len=book_seq_len,
)

configuring standard optimization setup
[*] Trainable Parameters: 1094460


In [64]:
ckpt = load_checkpoint(
    new_train_state,
    '../checkpoints_book_causal/',
    args.__dict__)
state = ckpt['model']

In [None]:
# return model_cls from init fn?
model = model_cls(training=False, step_rescale=1.0)

In [None]:
# TODO from above:
# x load trained model
#   predict next message
#   map message to one the simulator understands & is valid
#   apply message to simulator (predicted)
#   get L2 representation and encode it for model

In [None]:
# TODO: refactor slightly and work in simulation step
#       and simulator matching orders

pred_n_messages = 1
valid_mask_array = valh.syntax_validation_matrix()
inf_seq = valh.pred_msg(
    start_seq,
    pred_n_messages,
    state,
    model,
    args.batchnorm,
    rng,
    valid_mask_array
)