In [1]:
# CONTENTS

# load test data (messages & book)
# select (random) test sequence
# encode msg and book sequence for model

# get raw book data (L2) at the start of the sequence
# init simulator with initial book
# replay sequence in simulator (actual)

# load trained model
# predict next message
# map message to one the simulator understands & is valid
# apply message to simulator (predicted)
# get L2 representation and encode it for model

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".25"
import torch
torch.multiprocessing.set_start_method('spawn')

In [4]:
from argparse import Namespace
from glob import glob
import numpy as onp
import pandas as pd
from functools import partial
from typing import Union, Optional

import jax
import jax.numpy as jnp
from jax.nn import one_hot
from jax import random
from jax.scipy.linalg import block_diag
from flax.training import checkpoints
import orbax

#from lob.lob_seq_model import BatchLobPredModel
from lob.train_helpers import create_train_state, eval_step, prep_batch, cross_entropy_loss, compute_accuracy
from s5.ssm import init_S5SSM
from s5.ssm_init import make_DPLR_HiPPO
from s5.dataloading import make_data_loader
from lob_seq_model import LobPredModel
from encoding import Vocab, Message_Tokenizer
from lobster_dataloader import LOBSTER_Dataset, LOBSTER_Subset, LOBSTER_Sampler, LOBSTER

import preproc
import validation_helpers as valh
from lob.init_train import init_train_state, load_checkpoint, load_args_from_checkpoint

In [5]:
data_dir = '/nfs/home/peern/LOBS5/data/raw/'
save_dir = '/nfs/home/peern/LOBS5/data/'

In [6]:
message_files = sorted(glob(data_dir + '*message*.csv'))
book_files = sorted(glob(data_dir + '*orderbook*.csv'))

In [7]:
# load test data (last day)

m = pd.read_csv(
    message_files[-1],
    names=['time', 'event_type', 'order_id', 'size', 'price', 'direction'],
    index_col=False)

b = pd.read_csv(
    book_files[-1],
    index_col=False,
    header=None
)

# remove diallowed order types
m = m.loc[m.event_type.isin([1, 2, 3, 4])]
b = b.loc[m.index]

  exec(code_obj, self.user_global_ns, self.user_ns)
  return func(*args, **kwargs)


## Encoding

In [8]:
# Book encoding
price_levels = 40  # how many ticks to represent

In [9]:
# Message encoding
v = Vocab()
tok = Message_Tokenizer()

In [10]:
# encode from raw data

# print('<< pre processing >>')
# m_proc = tok.preproc(m, b)
# print('<< encoding >>')
# m_enc = tok.encode(m_proc, v)

# instead load from file:
msg_enc_file = sorted(glob(save_dir + '*message*.npy'))[-1]
m_enc = onp.load(msg_enc_file)

In [11]:
# remove first message from raw data as well
m = m.iloc[1:]

In [12]:
m_enc

array([[    3,     3,     8, ...,     2,     2,     2],
       [    3,     3,     8, ...,     2,     2,     2],
       [    3,     3,     9, ...,     2,     2,     2],
       ...,
       [   26,   402,   997, ..., 11108, 11007, 11110],
       [   26,   402,   985, ..., 11108, 11010, 11110],
       [   26,   402,  1002, ...,     2,     2,     2]])

In [13]:
m.shape

(1829106, 6)

In [14]:
# m_proc.shape

In [15]:
m_enc.shape

(1829106, 20)

In [16]:
# encode from raw data:
# b_enc = preproc.process_book(b, price_levels=price_levels)

# instead load from file:
book_enc_file = sorted(glob(save_dir + '*book*.npy'))[-1]
b_enc = onp.load(book_enc_file)

In [17]:
#b_enc_ = preproc.process_book(b, price_levels=price_levels)

In [18]:
b.shape

(1829107, 40)

In [19]:
b_enc.shape

(1829107, 41)

In [20]:
#b_enc_.shape

## Data Initialisation

In [21]:
n_messages = 500

In [22]:
# when to start the prediction
# convert time into seconds after midnight
start_time = (pd.to_datetime('11:00') - pd.to_datetime('00:00')).total_seconds()
# get seq end index
end_i = len(m.loc[m.time < start_time])

In [23]:
m_seq = m_enc[end_i - n_messages : end_i].reshape(-1)  # (n_messages [500] * levels [20], )
# book state: we already include the book state after the last message
# (different to training where we only have the book state before the first message
# and mask part of the last message)
# for message seq, we first need to append an empty message
b_seq = b_enc[end_i - n_messages + 1 : end_i + 1]      # (n_messages [500], price_levels + 1 [41])

## Simulator

In [24]:
import os
import sys

# add git submodule to path to allow imports to work
submodule_name = 'AlphaTrade'
(parent_folder_path, current_dir) = os.path.split(os.path.abspath(''))
sys.path.append(os.path.join(parent_folder_path, submodule_name))

In [25]:
from gymnax_exchange.jaxob.jorderbook import OrderBook
import gymnax_exchange.jaxob.JaxOrderbook as job

In [26]:
# TODO: integrate this into simualtor: OrderBook

def init_msgs_from_l2(book: Union[pd.Series, onp.ndarray]) -> jnp.ndarray:
    orderbookLevels = len(book) // 4  # price/quantity for bid/ask
    data = jnp.array(book).reshape(int(orderbookLevels*2),2)
    newarr = jnp.zeros((int(orderbookLevels*2),8))
    initOB = newarr \
        .at[:,3].set(data[:,0]) \
        .at[:,2].set(data[:,1]) \
        .at[:,0].set(1) \
        .at[0:orderbookLevels*4:2,1].set(-1) \
        .at[1:orderbookLevels*4:2,1].set(1) \
        .at[:,4].set(0) \
        .at[:,5].set(job.INITID) \
        .at[:,6].set(34200) \
        .at[:,7].set(0).astype('int32')
    return initOB

def msgs_to_jnp(m_df: pd.DataFrame) -> jnp.ndarray:
    m_df = m_df.copy()
    cols = ['Time', 'Type', 'OrderID', 'Quantity', 'Price', 'Side']
    if m_df.shape[1] == 7:
        cols += ["TradeID"]
    m_df.columns = cols
    m_df['TradeID'] = 0  #  TODO: should be TraderID for multi-agent support
    col_order=['Type','Side','Quantity','Price','TradeID','OrderID','Time']
    m_df = m_df[col_order]
    m_df = m_df[(m_df['Type'] != 6) & (m_df['Type'] != 7) & (m_df['Type'] != 5)]
    time = m_df["Time"].astype('string').str.split('.',expand=True)
    m_df[["TimeWhole","TimeDec"]] = time.astype('int32')
    m_df = m_df.drop("Time", axis=1)
    mJNP = jnp.array(m_df)
    return mJNP

def reset_orderbook(
        b: OrderBook,
        l2_book: Optional[Union[pd.Series, onp.ndarray]] = None,
    ) -> None:
    b.orderbook_array = b.orderbook_array.at[:].set(-1)
    if l2_book is not None:
        msgs = init_msgs_from_l2(l2_book)
        b.process_orders_array(msgs)

In [27]:
sim = OrderBook(price_levels=10, orderQueueLen=20)
sim

<gymnax_exchange.jaxob.jorderbook.OrderBook at 0x7fbe21de0100>

In [28]:
# init simulator at the start of the sequence
reset_orderbook(sim, b.iloc[end_i - n_messages])

In [29]:
# replay sequence in simulator (actual)
# so that sim is at the same state as the model
replay = msgs_to_jnp(m.iloc[end_i - n_messages : end_i])
trades = sim.process_orders_array(replay)

In [30]:
sim.get_L2_state()

Array([988000,    100, 987900,    802, 988100,    182, 987800,    782,
       988200,   1056, 987700,    600, 988300,    706, 987600,    500,
       988400,   1100, 987500,   1250, 988500,   1012, 987400,    850,
       988600,    468, 987300,    490, 988700,   1615, 987200,     50,
       988800,    431, 987100,    750, 989000,     50, 987000,     50],      dtype=int32)

## Model

In [31]:
# necessary for checkpoints to be loaded in jupyter notebook

import nest_asyncio
nest_asyncio.apply()

In [32]:
args = load_args_from_checkpoint('../checkpoints_book_causal_2/')

In [33]:
args

Namespace(C_init='trunc_standard_normal', USE_WANDB=True, activation_fn='half_glu1', batchnorm=True, bidirectional=False, blocks=8, bn_momentum=0.95, bsz=16, clip_eigs=False, conj_sym=True, cosine_anneal=True, d_model=32, dataset='lobster-prediction', dir_name='./data', discretization='zoh', dt_global=False, dt_max=0.1, dt_min=0.001, early_stop_patience=1000, epochs=100, jax_seed=42, lr_factor=1.0, lr_min=0, lr_patience=1000000, masking='causal', mode='pool', n_layers=6, opt_config='standard', p_dropout=0.0, prenorm=True, reduce_factor=1.0, ssm_lr_base=0.0005, ssm_size_base=32, use_book_data=True, wandb_entity='peer-nagy', wandb_project='LOBS5', warmup_end=1, weight_decay=0.05)

In [34]:
v = Vocab()
n_classes = len(v)
seq_len = n_messages * Message_Tokenizer.MSG_LEN
book_dim = b_enc.shape[1]
book_seq_len = n_messages

In [35]:
new_train_state, model_cls = init_train_state(
    args,
    n_classes=n_classes,
    seq_len=seq_len,
    book_dim=book_dim,
    book_seq_len=book_seq_len,
)

configuring standard optimization setup
[*] Trainable Parameters: 1094460


In [36]:
ckpt = load_checkpoint(
    new_train_state,
    '../checkpoints_book_causal_2/',
    args.__dict__)
state = ckpt['model']

In [37]:
model = model_cls(training=False, step_rescale=1.0)

In [38]:
# TODO from above:
# x load trained model
#   predict next message
#   map message to one the simulator understands & is valid
#   apply message to simulator (predicted)
#   get L2 representation and encode it for model

In [39]:
# append new HID message (and next LOB state if not already in seq)
# loop: predict next token until full message is generated
# map message to one the simulator understands & is valid
# feed message to simulator (predicted) --> next book state
# encode next book state for model and append to book sequence

In [40]:
m_seq

array([    8,   402,    38, ..., 11107, 11009, 11109])

In [234]:
vocab_len = len(v)
batchnorm = args.batchnorm
sample_top_n = 1
rng = jax.random.PRNGKey(42)
rng, rng_ = jax.random.split(rng)

### Model Validation (optional)

In [42]:
dataset_obj = LOBSTER(
    'lobster',
    data_dir='/nfs/home/peern/LOBS5/data/',
    mask_fn=LOBSTER_Dataset.causal_mask,
    use_book_data=True,
)
dataset_obj.setup()

In [43]:
test_loader = make_data_loader(
    dataset_obj.dataset_test,
    dataset_obj,
    seed=args.jax_seed,
    batch_size=args.bsz,
    drop_last=True,
    shuffle=False,
    num_workers=0
)

In [44]:
rng = jax.random.PRNGKey(42)
tok = Message_Tokenizer()

all_pred_toks = []
all_labels = []

losses = []
accuracy = []
ranks = []
valid_mass = []
valid_mass_n5 = []
valid_pred = []
losses_baseline = []

VALID_MATRIX = valh.syntax_validation_matrix()

for batch_idx, batch in enumerate(test_loader):
    
    # PREPARE BATCH
    inputs, labels, integration_timesteps = prep_batch(batch, seq_len, n_classes)
    # INFERENCE STEP
    loss, acc, pred = eval_step(
        inputs, labels, integration_timesteps, state, model, args.batchnorm)
    
    # STORE RESULTS
    pred_toks = pred.argmax(axis=-1)
    all_labels += labels.tolist()
    all_pred_toks += pred_toks.tolist()
    
    # STATS
    losses.append(cross_entropy_loss(pred, labels))
    accuracy.append(compute_accuracy(pred, labels))
    
    # where does the correct label rank in the predicted distribution?
    ranks.append(valh.pred_rank(pred, labels))
    # how much of the predicted distribution is valid?
    masked_fields = valh.get_masked_fields(batch[0])
    valid_mass.append(valh.valid_prediction_mass(pred, masked_fields))
    valid_mass_n5.append(valh.valid_prediction_mass(pred, masked_fields, top_n=5))

    # check if argmax prediction is valid token for masked fields
    valid_pred.append(valh.is_tok_valid(pred_toks, masked_fields, v))

    # benchmark: uniform prediction over syntactically valid tokens
    pos = valh.get_masked_idx(batch[0])[..., -1]
    baseline_distr = VALID_MATRIX[pos] / VALID_MATRIX[pos].sum(axis=-1, keepdims=True)
    losses_baseline.append(cross_entropy_loss(jnp.log(
            jnp.where(baseline_distr==0, 1e-10, baseline_distr)
        ), labels)
    )

all_labels = jnp.array(all_labels)
all_pred_toks = jnp.array(all_pred_toks)
losses = jnp.array(losses)
accuracy = jnp.array(accuracy)
ranks = jnp.array(ranks)
valid_mass = jnp.array(valid_mass)
valid_mass_n5 = jnp.array(valid_mass_n5)
valid_pred = jnp.array(valid_pred)
losses_baseline = jnp.array(losses_baseline)

In [245]:
print('mean loss', losses.mean())
print('mean accuracy', accuracy.mean())
print('mean rank', ranks.mean())
print('median rank', np.median(ranks))
print('mean valid mass', valid_mass.mean())
print('mean valid mass (top 5)', valid_mass_n5.mean())
print('mean valid prediction', valid_pred.mean())
print('mean baseline loss (uniform over valid syntax)', losses_baseline.mean())

AttributeError: 'list' object has no attribute 'mean'

In [247]:
from sklearn.metrics import precision_recall_fscore_support
import numpy as onp
import pandas as pd


precision, recall, fscore, support = precision_recall_fscore_support(
    all_labels.astype(int),
    all_pred_toks,
    labels=range(len(v)),
    zero_division=0,
    average=None
)

'''
print('precision: {}'.format(precision))
print('recall: {}'.format(recall))
print('fscore: {}'.format(fscore))
print('support: {}'.format(support))
'''

"\nprint('precision: {}'.format(precision))\nprint('recall: {}'.format(recall))\nprint('fscore: {}'.format(fscore))\nprint('support: {}'.format(support))\n"

In [248]:
field_dec = onp.array([(field, dec) for tok, (field, dec) in sorted(v.DECODING_GLOBAL.items())])

scores_df = pd.DataFrame({
    'field': field_dec[:, 0],
    'decoded': field_dec[:, 1],
    'precision': precision,
    'recall': recall,
    'fscore': fscore,
    'support': support,
})
#scores_df

In [249]:
scores_df.loc[scores_df.support > 0].groupby('field').agg(
    precision=('precision', 'mean'),
    recall=('recall', 'mean'),
    fscore=('fscore', 'mean'),
    support=('support', 'sum'),
)

Unnamed: 0_level_0,precision,recall,fscore,support
field,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
direction,0.756392,0.509702,0.607549,531
event_type,0.615834,0.59723,0.605116,520
generic,0.713011,0.996881,0.831383,962
price,0.401648,0.383298,0.377822,553
size,0.122625,0.066499,0.078992,536
time,0.050324,0.061767,0.052441,546


### Forecasting

In [65]:
#m_seq_start = m_seq.copy()

In [235]:
m_seq = m_seq_start.copy()

In [236]:
m_seq[-40:]

array([    8,   402,   871,   741,   414,  1003,  1032, 11107, 11010,
       11109,     2,     2,     2,     2,     2,     2,     2,     2,
           2,     2,     8,   402,   426,   552,   519,  1003,  1107,
       11107, 11009, 11109,     8,   402,   938,   659,   860,  1005,
        1107, 11107, 11009, 11109])

In [237]:
losses = []
accs = []

valid_mask_array = valh.syntax_validation_matrix(v)

m_seq = valh.append_hid_msg(m_seq)

#idx = range(Message_Tokenizer.MSG_LEN)
reversed_idx = [i \
   for field_i in reversed(list(range(len(Message_Tokenizer.FIELDS)))) \
   for i in range(*LOBSTER_Dataset._get_tok_slice_i(field_i))]

for mask_i in reversed_idx:
    # syntactically valid tokens for current message position
    valid_mask = valid_mask_array[mask_i]

    m_seq, _ = valh.mask_last_msg_in_seq(m_seq, mask_i)
    # inference
    input = (
        one_hot(
            jnp.expand_dims(m_seq, axis=0), vocab_len
        ).astype(float),
        jnp.expand_dims(b_seq, axis=0)
    )
    integration_timesteps = (
        jnp.ones((1, len(m_seq))), 
        jnp.ones((1, len(b_seq)))
    )
    logits = valh.predict(
        input,
        integration_timesteps, state, model, batchnorm)
    if valid_mask is not None:
        logits = valh.filter_valid_pred(logits, valid_mask)
    # TODO: remove - just for debugging
    label = m_enc[end_i][mask_i]
    losses.append(cross_entropy_loss(logits, label))
    accs.append(compute_accuracy(logits, label))

    #print(m_seq[-20:])
    # update sequence
    # note: rng arg expects one element per batch element
    rng, rng_ = jax.random.split(rng)
    m_seq = valh.fill_predicted_toks(m_seq, logits, sample_top_n, jnp.array([rng_]))

In [238]:
m_seq[-20:]

Array([    8,    22,    22,     2,     2,  1003,  1107, 11107, 11008,
       11109,     2,     2,     2,     2,     2,     2,     2,     2,
           2,     2], dtype=int32)

In [239]:
accs

[Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([ True], dtype=bool),
 Array([ True], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([ True], dtype=bool),
 Array([ True], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool),
 Array([False], dtype=bool)]

In [240]:
losses

[Array([1.1188538], dtype=float32),
 Array([25.409428], dtype=float32),
 Array([31.802822], dtype=float32),
 Array([15.232959], dtype=float32),
 Array([22.315346], dtype=float32),
 Array([23.502268], dtype=float32),
 Array([28.578278], dtype=float32),
 Array([19.202412], dtype=float32),
 Array([19.152643], dtype=float32),
 Array([24.333796], dtype=float32),
 Array([0.4000675], dtype=float32),
 Array([8.457518e-05], dtype=float32),
 Array([3.1385012], dtype=float32),
 Array([2.9693208], dtype=float32),
 Array([0.05330859], dtype=float32),
 Array([0.00211861], dtype=float32),
 Array([6.8914833], dtype=float32),
 Array([7.080786], dtype=float32),
 Array([6.7107654], dtype=float32),
 Array([6.88228], dtype=float32)]

In [241]:
m_seq[-20:]

Array([    8,    22,    22,     2,     2,  1003,  1107, 11107, 11008,
       11109,     2,     2,     2,     2,     2,     2,     2,     2,
           2,     2], dtype=int32)

In [None]:
# try multiple rolls to get valid message

In [None]:
# TODO:

# sim_msg = get_sim_msg(
#     m_enc[end_i],
#     m_seq,
#     sim,
#     tok,
#     v,
#     new_order_id=42, tick_size=100
# )
# sim_msg

# sim.process_order(sim_msg)

In [242]:
tok.decode_to_str(m_seq[-20:], v)

array([['005', '019', '019', 'NAN', 'NAN', '1', '010', '+', '01', '0',
        'NAN', 'NAN', 'NAN', 'NAN', 'NAN', 'NAN', 'NAN', 'NAN', 'NAN',
        'NAN']], dtype='<U3')

In [243]:
# decode predicted message
pred_msg = tok.decode(m_seq[-20:], v).flatten()
pred_msg

array([nan,  1., 10.,  1.,  0., nan, nan, nan, nan, nan])

In [232]:
v.DECODING_GLOBAL[22]

('time', '019')

In [None]:
valh.validate_msg(m_seq[-20:], tok, v)

True

In [None]:
sim.get_best_bid()

Array(987900, dtype=int32)

In [None]:
# TODO: get new order ID from simulator
new_order_id = 42
tick_size = 100

def get_sim_msg(pred_msg_enc, m_seq, sim, tok, v, new_order_id, tick_size):
    # decoded predicted message
    pred_msg = tok.decode(pred_msg_enc, v).flatten()
    
    orig_part = pred_msg[: len(pred_msg) // 2]
    modif_part = pred_msg[len(pred_msg) // 2:]

    # new order: no modification values present (all NA)
    # should be new limit order (1) or execution (4)
    if onp.isnan(modif_part).all():
        # convert relative to absolute price
        price = sim.get_best_bid() + int(orig_part[3]) * tick_size
        order_dict = {
            'timestamp': str(orig_part[0] * 1e-9 + 9.5 * 3600),
            'type': int(orig_part[1]),
            'order_id': new_order_id, 
            'quantity': int(orig_part[2]),
            'price': price,
            'side': 'ask' if orig_part[4] == 0 else 'bid',  # TODO: should be 'buy' or 'sell'
            'trade_id': 0  # should be trader_id in future
        }
    # modification of existing order
    else:
        # original part is only needed to match to an order ID
        # find original msg index location in the sequence (if it exists)
        orig_enc = pred_msg_enc[: len(pred_msg_enc) // 2]
        orig_i = valh.find_orig_msg(orig_enc, m_seq)
        if orig_i is not None:
            # get order ID from raw data for simulator
            order_id = int(m.iloc[orig_i].order_id)
        else:
            # TODO: fuzzy match??
            #       or just assume ID unknown and match to price level only...
            #       ... in which case - which order in the queue?
            #       perhaps last one that matches size? (>=)
            order_id = -1  # TODO: use simulator initial ID?

        # convert relative to absolute price
        price = sim.get_best_bid() + int(modif_part[3]) * tick_size
        order_dict = {
            'timestamp': str(modif_part[0] * 1e-9 + 9.5 * 3600),
            'type': int(modif_part[1]),
            'order_id': order_id, 
            'quantity': int(modif_part[2]),
            'price': price,
            'side': 'ask' if modif_part[4] == 0 else 'bid',  # TODO: should be 'buy' or 'sell'
            'trade_id': 0  # should be trader_id in future
        }

    return order_dict

In [None]:
def msg_to_raw(msg, bid_price, tick_size):
    """Convert message to raw data format."""
    assert len(msg) == 5
    # time
    msg[0] = msg[0] * 1e-9 + 9.5 * 3600
    # price
    msg[3] = bid_price + int(msg[3]) * tick_size
    # direction
    msg[4] = msg[4] * 2 - 1
    return msg

In [None]:
# actual next message (not predicted and not part of seq)
m_enc[end_i]

array([    8,   402,   328,   183,   770,  1003,  1207, 11107, 11011,
       11109,     8,   403,     9,   666,   752,  1005,  1107, 11107,
       11010, 11109])

In [None]:
#raw_seq = m.iloc[end_i - n_messages: end_i].copy()
#raw_seq.drop('order_id', inplace=True, axis=1)

In [None]:
sim_msg = get_sim_msg(
    m_enc[end_i],
    m_seq,
    sim,
    tok,
    v,
    new_order_id=42, tick_size=100
)
sim_msg

{'timestamp': '39600.006663749',
 'type': 3,
 'order_id': 32429970,
 'quantity': 10,
 'price': Array(988200, dtype=int32),
 'side': 'ask',
 'trade_id': 0}

In [None]:
sim.process_order(sim_msg)

[       5       -1       10   988200        0 32429970    39600  6663749]
[[[[      100    988000         0 276499246     39599 322500767]
   [       -1        -1        -1        -1        -1        -1]
   [       -1        -1        -1        -1        -1        -1]
   ...
   [       -1        -1        -1        -1        -1        -1]
   [       -1        -1        -1        -1        -1        -1]
   [       -1        -1        -1        -1        -1        -1]]

  [[      100    988100         0 276492538     39599 264867549]
   [       82    988100         0 276511822     39599 685781583]
   [       -1        -1        -1        -1        -1        -1]
   ...
   [       -1        -1        -1        -1        -1        -1]
   [       -1        -1        -1        -1        -1        -1]
   [       -1        -1        -1        -1        -1        -1]]

  [[       81    988200         0 276491886     39599  26281485]
   [      100    988200         0 276500518     39599  33606933

(Array([[-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1]], dtype=int32),
 Array([       5,       -1,       10,   988200,        0, 32429970,
           39600,  6663749], dtype=int32))

In [None]:
input[0].shape

(1, 10000, 11111)

In [None]:
input[1].shape

(1, 500, 41)

In [None]:
# TODO: refactor slightly and work in simulation step
#       and simulator matching orders

pred_n_messages = 1
valid_mask_array = valh.syntax_validation_matrix()
inf_seq = valh.pred_msg(
    start_seq,
    pred_n_messages,
    state,
    model,
    args.batchnorm,
    rng,
    valid_mask_array
)