In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".25"
import torch
torch.multiprocessing.set_start_method('spawn')

In [283]:
from functools import partial
from argparse import Namespace
from glob import glob

import jax
import jax.numpy as np
from jax import random
from jax.scipy.linalg import block_diag
from flax.training import checkpoints
import orbax

from lob.lob_seq_model import BatchLobPredModel
from lob.train_helpers import create_train_state, eval_step, prep_batch, cross_entropy_loss, compute_accuracy
from s5.ssm import init_S5SSM
from s5.ssm_init import make_DPLR_HiPPO
from s5.dataloading import make_data_loader
from lob.lobster_dataloader import LOBSTER_Dataset, LOBSTER

from encoding import Vocab

from validation_helpers import *

In [4]:
# necessary for checkpoints to be loaded

import nest_asyncio
nest_asyncio.apply()

In [5]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

raw_restored = checkpoints.restore_checkpoint(
    '../checkpoints/',
    None,
    #step=11,
    orbax_checkpointer=orbax_checkpointer
)# , )

In [6]:
args = Namespace(**raw_restored['config'])

In [7]:
args

Namespace(C_init='trunc_standard_normal', USE_WANDB=True, activation_fn='half_glu1', batchnorm=True, bidirectional=False, blocks=8, bn_momentum=0.95, bsz=16, clip_eigs=False, conj_sym=True, cosine_anneal=True, d_model=32, dataset='lobster-prediction', dir_name='./data', discretization='zoh', dt_global=False, dt_max=0.1, dt_min=0.001, early_stop_patience=1000, epochs=100, jax_seed=1919, lr_factor=1, lr_min=0, lr_patience=1000000, masking='random', mode='pool', n_layers=3, opt_config='standard', p_dropout=0.0, prenorm=True, reduce_factor=1.0, ssm_lr_base=0.001, ssm_size_base=32, wandb_entity='peer-nagy', wandb_project='LOBS5', warmup_end=1, weight_decay=0.05)

In [8]:
args.bsz

16

In [9]:
args.d_model

32

In [10]:
v = Vocab()
n_classes = len(v)
print('n_classes', n_classes)

seq_len = 10000

n_classes 11111


In [11]:
v = Vocab()

In [12]:
v = Vocab()
n_classes = len(v)
in_dim = n_classes

ssm_size = args.ssm_size_base
ssm_lr = args.ssm_lr_base

# Set global learning rate lr (e.g. encoders, etc.) as function of ssm_lr
lr = args.lr_factor * ssm_lr

# determine the size of initial blocks
block_size = int(ssm_size / args.blocks)

key = random.PRNGKey(args.jax_seed)
init_rng, train_rng = random.split(key, num=2)

# Initialize state matrix A using approximation to HiPPO-LegS matrix
Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size)

if args.conj_sym:
    block_size = block_size // 2
    ssm_size = ssm_size // 2

Lambda = Lambda[:block_size]
V = V[:, :block_size]
Vc = V.conj().T

# If initializing state matrix A as block-diagonal, put HiPPO approximation
# on each block
Lambda = (Lambda * np.ones((args.blocks, block_size))).ravel()
V = block_diag(*([V] * args.blocks))
Vinv = block_diag(*([Vc] * args.blocks))

print("Lambda.shape={}".format(Lambda.shape))
print("V.shape={}".format(V.shape))
print("Vinv.shape={}".format(Vinv.shape))

padded = False

ssm_init_fn = init_S5SSM(
    H=args.d_model,
    P=ssm_size,
    Lambda_re_init=Lambda.real,
    Lambda_im_init=Lambda.imag,
    V=V,
    Vinv=Vinv,
    C_init=args.C_init,
    discretization=args.discretization,
    dt_min=args.dt_min,
    dt_max=args.dt_max,
    conj_sym=args.conj_sym,
    clip_eigs=args.clip_eigs,
    bidirectional=args.bidirectional
)

model_cls = partial(
    BatchLobPredModel,
    ssm=ssm_init_fn,
    d_output=n_classes,
    d_model=args.d_model,
    n_layers=args.n_layers,
    padded=padded,
    activation=args.activation_fn,
    dropout=args.p_dropout,
    mode=args.mode,
    prenorm=args.prenorm,
    batchnorm=args.batchnorm,
    bn_momentum=args.bn_momentum,
)

state = create_train_state(
    model_cls,
    init_rng,
    padded,  # padded
    False,  # retrieval
    in_dim=in_dim,
    bsz=args.bsz,
    seq_len=seq_len,
    weight_decay=args.weight_decay,
    batchnorm=args.batchnorm,
    opt_config=args.opt_config,
    ssm_lr=ssm_lr,
    lr=lr,
    dt_global=args.dt_global
)

Lambda.shape=(16,)
V.shape=(32, 16)
Vinv.shape=(16, 32)
configuring standard optimization setup
[*] Trainable Parameters: 731991


In [496]:
ckpt = {
    'model': state,
    'config': raw_restored['config'],
    'metrics': {
        'loss_train': np.nan,
        'loss_val': np.nan,
        'loss_test': np.nan,
        'acc_val': np.nan,
        'acc_test': np.nan,
    }
}

In [497]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

restored = checkpoints.restore_checkpoint(
    '../checkpoints/',
    ckpt,
    #step=11,
    orbax_checkpointer=orbax_checkpointer
)

In [498]:
restored

{'config': {'C_init': 'trunc_standard_normal',
  'USE_WANDB': True,
  'activation_fn': 'half_glu1',
  'batchnorm': True,
  'bidirectional': False,
  'blocks': 8,
  'bn_momentum': 0.95,
  'bsz': 16,
  'clip_eigs': False,
  'conj_sym': True,
  'cosine_anneal': True,
  'd_model': 32,
  'dataset': 'lobster-prediction',
  'dir_name': './data',
  'discretization': 'zoh',
  'dt_global': False,
  'dt_max': 0.1,
  'dt_min': 0.001,
  'early_stop_patience': 1000,
  'epochs': 100,
  'jax_seed': 1919,
  'lr_factor': 1,
  'lr_min': 0,
  'lr_patience': 1000000,
  'masking': 'random',
  'mode': 'pool',
  'n_layers': 6,
  'opt_config': 'standard',
  'p_dropout': 0.0,
  'prenorm': True,
  'reduce_factor': 1.0,
  'ssm_lr_base': 0.001,
  'ssm_size_base': 32,
  'wandb_entity': 'peer-nagy',
  'wandb_project': 'LOBS5',
  'warmup_end': 1,
  'weight_decay': 0.05},
 'metrics': {'acc_test': array(0.6473482, dtype=float32),
  'acc_val': array(0.6329114, dtype=float32),
  'loss_test': array(2.245343, dtype=float32

In [16]:
#data_dir = '/nfs/home/peern/LOBS5/data/'
#messages = sorted(glob(data_dir + '*message*.npy'))

#d = LOBSTER_Dataset(
#    message_files=messages,
#    n_messages=500,
#    n_buffer_files=2,
#    mask_fn=LOBSTER_Dataset.random_mask,
#    randomize_offset=True,
#    seed=42
#)

In [499]:
dataset_obj = LOBSTER(
    'lobster',
    data_dir='/nfs/home/peern/LOBS5/data/',
    mask_fn=LOBSTER_Dataset.random_mask
)
dataset_obj.setup()

In [500]:
val_loader = make_data_loader(
    dataset_obj.dataset_val,
    dataset_obj,
    seed=args.jax_seed,
    batch_size=args.bsz,
    drop_last=True,
    shuffle=False,
    num_workers=0
)

In [501]:
model = model_cls(training=False, step_rescale=False)

In [33]:
for batch_idx, batch in enumerate(val_loader):
    #print(batch)
    
    inputs, labels, integration_timesteps = prep_batch(batch, seq_len, in_dim)
    #loss, acc, pred = eval_step(
    #    inputs, labels, integration_timesteps, state, model, args.batchnorm)
    
    print(len(batch[0][batch[0] == 0]))

    break

104


In [21]:
for batch_idx, batch in enumerate(val_loader):
    print(batch)
    
    inputs, labels, integration_timesteps = prep_batch(batch, seq_len, in_dim)
    loss, acc, pred = eval_step(
        inputs, labels, integration_timesteps, state, model, args.batchnorm)
    
    break

(Array([[  3,  37, 490, ...,   2,   2,   2],
       [  3,  52, 866, ...,   2,   2,   2],
       [  3,  73, 181, ...,   2,   2,   2],
       ...,
       [  3, 145, 680, ...,   2,   2,   2],
       [  3, 158,  36, ...,   2,   2,   2],
       [  3, 195, 533, ...,   2,   2,   2]], dtype=int32), Array([[    2],
       [   54],
       [    2],
       [    2],
       [ 1005],
       [    2],
       [11109],
       [    2]], dtype=int32), {})


In [22]:
labels.shape

(8,)

In [23]:
pred.shape

(8, 11111)

In [24]:
cross_entropy_loss(pred, labels)

Array([9.319307, 9.314024, 9.31922 , 9.318898, 9.315664, 9.318539,
       9.310003, 9.31841 ], dtype=float32)

In [25]:
compute_accuracy(pred, labels)

Array([False, False, False, False, False, False, False, False], dtype=bool)

In [26]:
loss

Array([9.319307, 9.314024, 9.31922 , 9.318898, 9.315664, 9.318539,
       9.310003, 9.31841 ], dtype=float32)

In [27]:
pred.shape

(8, 11111)

In [28]:
pred

Array([[-9.307722 , -9.315114 , -9.319307 , ..., -9.315296 , -9.309821 ,
        -9.314997 ],
       [-9.309109 , -9.315402 , -9.318858 , ..., -9.31534  , -9.310311 ,
        -9.315259 ],
       [-9.308548 , -9.314764 , -9.31922  , ..., -9.315253 , -9.31021  ,
        -9.3143635],
       ...,
       [-9.309304 , -9.315951 , -9.318539 , ..., -9.3156185, -9.310347 ,
        -9.315128 ],
       [-9.308162 , -9.315214 , -9.319619 , ..., -9.315283 , -9.310003 ,
        -9.315269 ],
       [-9.3088455, -9.314688 , -9.31841  , ..., -9.315585 , -9.310222 ,
        -9.314322 ]], dtype=float32)

In [29]:
# TODO: map to vocab 
#       get distribution (of n biggest) over vocab

In [30]:
pred.argmax(axis=-1)

Array([ 260, 5392,  260, 7353, 1401,  633,  260,  260], dtype=int32)

In [41]:
# TODO: pay attention to the field (location) in addition to the label
#       because generic labels can be used in different fields

for pred_tok, label in zip(pred.argmax(axis=-1).tolist(), labels.tolist()):
    field, label_dec = v.DECODING_GLOBAL[label]
    print(field)
    print(v.ENCODING[field])

generic
{'MSK': 0, 'HID': 1, 'NAN': 2}
time
{'000': 3, '001': 4, '002': 5, '003': 6, '004': 7, '005': 8, '006': 9, '007': 10, '008': 11, '009': 12, '010': 13, '011': 14, '012': 15, '013': 16, '014': 17, '015': 18, '016': 19, '017': 20, '018': 21, '019': 22, '020': 23, '021': 24, '022': 25, '023': 26, '024': 27, '025': 28, '026': 29, '027': 30, '028': 31, '029': 32, '030': 33, '031': 34, '032': 35, '033': 36, '034': 37, '035': 38, '036': 39, '037': 40, '038': 41, '039': 42, '040': 43, '041': 44, '042': 45, '043': 46, '044': 47, '045': 48, '046': 49, '047': 50, '048': 51, '049': 52, '050': 53, '051': 54, '052': 55, '053': 56, '054': 57, '055': 58, '056': 59, '057': 60, '058': 61, '059': 62, '060': 63, '061': 64, '062': 65, '063': 66, '064': 67, '065': 68, '066': 69, '067': 70, '068': 71, '069': 72, '070': 73, '071': 74, '072': 75, '073': 76, '074': 77, '075': 78, '076': 79, '077': 80, '078': 81, '079': 82, '080': 83, '081': 84, '082': 85, '083': 86, '084': 87, '085': 88, '086': 89, '087'

In [258]:
rng = jax.random.PRNGKey(42)

In [285]:
rng

Array([1158584513, 1255931741], dtype=uint32)

In [None]:
sample_pred(pred, 5, jax.random.split(rng, pred.shape[0]))

Array([1616, 1039, 1616, 1616, 3101, 6865, 1039, 1616, 1039,  986, 5077,
       6524, 5606, 1616, 5606, 5606], dtype=int32)

In [277]:
rng_ = jax.random.split(rng, batch[0].shape[0])
result = fill_predicted_toks(batch[0], pred, top_n=5, rng=rng_)
tok.invalid_toks_per_seq(result, v)
rng = rng_[-1]

array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [278]:
result2 = fill_predicted_toks(batch[0], pred)
tok.invalid_toks_per_seq(result2, v)

array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [132]:
tok.invalid_toks_per_msg(result, v).sum(axis=-1)

array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [123]:
pred.shape

(16, 11111)

In [95]:
batch[0][1][batch[0][1] == v.MASK_TOK]

Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)

In [96]:
batch[0].shape

(8, 10000)

In [97]:
# check error: are there multiple mask tokens?

len(batch[0][batch[0] == 0])

104

In [328]:
get_masked_idx(batch[0])

Array([[  0, 342,  17],
       [  0, 475,  14],
       [  1, 124,   1],
       [  1, 320,   1],
       [  2, 140,  10],
       [  2, 460,  19],
       [  3, 375,  17],
       [  3, 480,   6],
       [  4,   2,  15],
       [  4, 202,   9],
       [  5,  85,  19],
       [  5, 183,  18],
       [  6,  82,   9],
       [  6, 112,   4],
       [  7,  15,   8],
       [  7, 289,  12]], dtype=int32)

In [321]:
get_masked_fields(batch[0])

['price_new',
 'time_new',
 'time',
 'time',
 'time_new',
 'direction_new',
 'price_new',
 'size',
 'event_type_new',
 'direction',
 'direction_new',
 'price_new',
 'direction',
 'time',
 'price',
 'time_new']

In [327]:
get_masked_fields(batch[0])

['price_new',
 'time_new',
 'time',
 'time',
 'time_new',
 'direction_new',
 'price_new',
 'size',
 'event_type_new',
 'direction',
 'direction_new',
 'price_new',
 'direction',
 'time',
 'price',
 'time_new']

In [324]:
batch[1].shape

(8, 1)

In [492]:
def is_tok_valid(tok, field, vocab):
    tok = tok.tolist()
    if isinstance(field, str):
        return tok in vocab.DECODING[Message_Tokenizer.FIELD_ENC_TYPES[field]]
    else:
        return [t in vocab.DECODING[Message_Tokenizer.FIELD_ENC_TYPES[f]] 
                for t, f in zip(tok, field)]

In [491]:
is_tok_valid(
    pred.argmax(axis=-1),
    get_masked_fields(batch[0]),
    v
)

[False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False]

In [503]:
rng = jax.random.PRNGKey(42)
tok = Message_Tokenizer()

ranks = []
valid_mass = []
valid_mass_n5 = []
valid_pred = []

for batch_idx, batch in enumerate(val_loader):
    #print(batch)
    
    inputs, labels, integration_timesteps = prep_batch(batch, seq_len, in_dim)
    loss, acc, pred = eval_step(
        inputs, labels, integration_timesteps, state, model, args.batchnorm)
    
    # where does the correct label rank in the predicted distribution?
    ranks.append(pred_rank(pred, labels))
    # how much of the predicted distribution is valid?
    masked_fields = get_masked_fields(batch[0])
    valid_mass.append(valid_prediction_mass(pred, masked_fields))
    valid_mass_n5.append(valid_prediction_mass(pred, masked_fields, top_n=5))

    # check if argmax prediction is valid token for masked fields
    valid_pred.append(is_tok_valid(pred.argmax(axis=-1), masked_fields, v))

In [105]:
np.array(ranks).mean()

Array(6302.8657, dtype=float32)

In [106]:
np.array(valid_mass).mean()

Array(0.13728166, dtype=float32)

In [None]:
np.array(valid_mass_n5).mean()

Array(0.28699377, dtype=float32)

In [504]:
np.array(ranks).mean()

Array(5913.6465, dtype=float32)

In [508]:
np.median(np.array(ranks))

Array(5504., dtype=float32)

In [505]:
np.array(valid_mass).mean()

Array(0.13728018, dtype=float32)

In [506]:
np.array(valid_mass_n5).mean()

Array(0.16054794, dtype=float32)

In [133]:
tok = Message_Tokenizer()

In [113]:
tok.col_idx_by_encoder

{'time': [0, 1, 2, 3, 4, 10, 11, 12, 13, 14],
 'event_type': [5, 15],
 'size': [6, 16],
 'price': [7, 8, 17, 18],
 'direction': [9, 19]}

Array([[-9.314257 , -9.318846 , -9.314268 , ..., -9.317285 , -9.313043 ,
        -9.319928 ],
       [-9.314671 , -9.319134 , -9.314537 , ..., -9.31646  , -9.313398 ,
        -9.321061 ],
       [-9.313926 , -9.3199835, -9.313149 , ..., -9.316932 , -9.312886 ,
        -9.320968 ],
       ...,
       [-9.314899 , -9.318279 , -9.316247 , ..., -9.3185425, -9.314482 ,
        -9.319631 ],
       [-9.314888 , -9.318323 , -9.3155575, ..., -9.317863 , -9.313861 ,
        -9.319694 ],
       [-9.315161 , -9.318303 , -9.316559 , ..., -9.317473 , -9.314913 ,
        -9.318973 ]], dtype=float32)

In [118]:
tok.invalid_toks_per_msg(batch[0], v).mean()

0.0

In [122]:
pred.argmax(axis=-1)

Array([6332, 6332, 6332, 3101, 6865, 1616, 6332, 3101, 3101, 3101, 5606,
       5606, 5606, 5606, 5606, 5606], dtype=int32)

In [419]:
pred.shape

(16, 11111)

In [418]:
tok.invalid_toks_per_msg(, v).mean()

ValueError: cannot reshape array of size 177776 into shape (16,newaxis,20)

In [417]:
tok._validate_syntax(batch[0][0], v)

(True,
 array([['022', '674', '807', ..., 'NAN', 'NAN', 'NAN'],
        ['022', '674', '807', ..., 'NAN', 'NAN', 'NAN'],
        ['022', '674', '807', ..., '+', '01', '0'],
        ...,
        ['022', '681', '010', ..., '-', '00', '1'],
        ['022', '681', '010', ..., '-', '00', '1'],
        ['022', '681', '010', ..., '-', '00', '1']], dtype='<U3'))

In [228]:
#valid_prediction_mass(pred[0:2], get_masked_fields(batch[0][0:2]))
valid_prediction_mass(pred, get_masked_fields(batch[0]))

Array([0.09027135, 0.09027305, 0.09027165, 0.00945231, 0.00063061,
       0.00945135, 0.00045093, 0.09027136], dtype=float32)

In [281]:
# probability of syntactically valid tokens in the top_n predictions
print('1:')
print(valid_prediction_mass(pred, get_masked_fields(batch[0]), top_n=1))

print('5:')
print(valid_prediction_mass(pred, get_masked_fields(batch[0]), top_n=5))

print('10:')
print(valid_prediction_mass(pred, get_masked_fields(batch[0]), top_n=10))

1:
[1. 0. 1. 0. 0. 0. 0. 1.]
5:
[0.40015867 0.2000537  0.200092   0.         0.         0.
 0.1999747  0.6000089 ]
10:
[0.30012807 0.20000866 0.30003056 0.         0.         0.
 0.1000105  0.30007356]


In [33]:
# for a batch, print predicted tokens (and vocab) and actually correct labels

for pred_tok, label in zip(pred.argmax(axis=-1).tolist(), labels.tolist()):
    print('pred:')
    print(pred_tok)
    print(v.DECODING_GLOBAL[pred_tok])
    print('label', label)
    print(v.DECODING_GLOBAL[label])
    print()

pred:
260
('time', '257')
label 2.0
('generic', 'NAN')

pred:
5392
('size', '4385')
label 54.0
('time', '051')

pred:
260
('time', '257')
label 2.0
('generic', 'NAN')

pred:
7353
('size', '6346')
label 2.0
('generic', 'NAN')

pred:
1401
('size', '0394')
label 1005.0
('event_type', '3')

pred:
633
('time', '630')
label 2.0
('generic', 'NAN')

pred:
260
('time', '257')
label 11109.0
('direction', '0')

pred:
260
('time', '257')
label 2.0
('generic', 'NAN')

