In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".25"
import torch
torch.multiprocessing.set_start_method('spawn')

In [3]:
from functools import partial
from argparse import Namespace
from glob import glob
import numpy as onp

import jax
import jax.numpy as np
from jax import random
from jax.scipy.linalg import block_diag
from flax.training import checkpoints
import orbax

from lob.lob_seq_model import BatchLobPredModel
from lob.train_helpers import create_train_state, eval_step, prep_batch, cross_entropy_loss, compute_accuracy
from s5.ssm import init_S5SSM
from s5.ssm_init import make_DPLR_HiPPO
from s5.dataloading import make_data_loader
from lob.lobster_dataloader import LOBSTER_Dataset, LOBSTER

from encoding import Vocab, Message_Tokenizer

#from validation_helpers import *
import validation_helpers as valh

In [4]:
# necessary for checkpoints to be loaded

import nest_asyncio
nest_asyncio.apply()

In [5]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

raw_restored = checkpoints.restore_checkpoint(
    '../checkpoints_causal_seq/',
    None,
    #step=11,
    orbax_checkpointer=orbax_checkpointer
)# , )

In [6]:
args = Namespace(**raw_restored['config'])

In [7]:
args

Namespace(C_init='trunc_standard_normal', USE_WANDB=True, activation_fn='half_glu1', batchnorm=True, bidirectional=False, blocks=8, bn_momentum=0.95, bsz=16, clip_eigs=False, conj_sym=True, cosine_anneal=True, d_model=32, dataset='lobster-prediction', dir_name='./data', discretization='zoh', dt_global=False, dt_max=0.1, dt_min=0.001, early_stop_patience=1000, epochs=100, jax_seed=1919, lr_factor=1, lr_min=0, lr_patience=1000000, masking='causal', mode='pool', n_layers=6, opt_config='standard', p_dropout=0.0, prenorm=True, reduce_factor=1.0, ssm_lr_base=0.001, ssm_size_base=32, wandb_entity='peer-nagy', wandb_project='LOBS5', warmup_end=1, weight_decay=0.05)

In [8]:
args.bsz

16

In [9]:
args.d_model

32

In [10]:
v = Vocab()
n_classes = len(v)
print('n_classes', n_classes)

seq_len = 10000

n_classes 11111


In [11]:
v = Vocab()
n_classes = len(v)
in_dim = n_classes

ssm_size = args.ssm_size_base
ssm_lr = args.ssm_lr_base

# Set global learning rate lr (e.g. encoders, etc.) as function of ssm_lr
lr = args.lr_factor * ssm_lr

# determine the size of initial blocks
block_size = int(ssm_size / args.blocks)

key = random.PRNGKey(args.jax_seed)
init_rng, train_rng = random.split(key, num=2)

# Initialize state matrix A using approximation to HiPPO-LegS matrix
Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size)

if args.conj_sym:
    block_size = block_size // 2
    ssm_size = ssm_size // 2

Lambda = Lambda[:block_size]
V = V[:, :block_size]
Vc = V.conj().T

# If initializing state matrix A as block-diagonal, put HiPPO approximation
# on each block
Lambda = (Lambda * np.ones((args.blocks, block_size))).ravel()
V = block_diag(*([V] * args.blocks))
Vinv = block_diag(*([Vc] * args.blocks))

print("Lambda.shape={}".format(Lambda.shape))
print("V.shape={}".format(V.shape))
print("Vinv.shape={}".format(Vinv.shape))

padded = False

ssm_init_fn = init_S5SSM(
    H=args.d_model,
    P=ssm_size,
    Lambda_re_init=Lambda.real,
    Lambda_im_init=Lambda.imag,
    V=V,
    Vinv=Vinv,
    C_init=args.C_init,
    discretization=args.discretization,
    dt_min=args.dt_min,
    dt_max=args.dt_max,
    conj_sym=args.conj_sym,
    clip_eigs=args.clip_eigs,
    bidirectional=args.bidirectional
)

model_cls = partial(
    BatchLobPredModel,
    ssm=ssm_init_fn,
    d_output=n_classes,
    d_model=args.d_model,
    n_layers=args.n_layers,
    padded=padded,
    activation=args.activation_fn,
    dropout=args.p_dropout,
    mode=args.mode,
    prenorm=args.prenorm,
    batchnorm=args.batchnorm,
    bn_momentum=args.bn_momentum,
)

state = create_train_state(
    model_cls,
    init_rng,
    padded,  # padded
    False,  # retrieval
    in_dim=in_dim,
    bsz=args.bsz,
    seq_len=seq_len,
    weight_decay=args.weight_decay,
    batchnorm=args.batchnorm,
    opt_config=args.opt_config,
    ssm_lr=ssm_lr,
    lr=lr,
    dt_global=args.dt_global
)

Lambda.shape=(16,)
V.shape=(32, 16)
Vinv.shape=(16, 32)


2023-04-05 11:58:48.988201: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.62GiB (rounded to 7111040000)requested by op 
2023-04-05 11:58:48.988326: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:497] *___________________________________________________________________________________________________
2023-04-05 11:58:48.988381: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2389] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 7111040000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         4B
              constant allocation:         0B
        maybe_live_out allocation:    6.62GiB
     preallocated temp allocation:         0B
                 total allocation:    6.62GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 7111040000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:         4B
              constant allocation:         0B
        maybe_live_out allocation:    6.62GiB
     preallocated temp allocation:         0B
                 total allocation:    6.62GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 6.62GiB
		Operator: op_name="jit(broadcast_in_dim)/jit(main)/broadcast_in_dim[shape=(16, 10000, 11111) broadcast_dimensions=()]" source_file="/nfs/home/peern/LOBS5/lob/train_helpers.py" source_line=128
		XLA Label: broadcast
		Shape: f32[16,10000,11111]
		==========================

	Buffer 2:
		Size: 4B
		Entry Parameter Subshape: f32[]
		==========================



In [None]:
ckpt = {
    'model': state,
    'config': raw_restored['config'],
    'metrics': {
        'loss_train': np.nan,
        'loss_val': np.nan,
        'loss_test': np.nan,
        'acc_val': np.nan,
        'acc_test': np.nan,
    }
}

In [None]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

restored = checkpoints.restore_checkpoint(
    '../checkpoints_causal_seq/',
    ckpt,
    #step=11,
    orbax_checkpointer=orbax_checkpointer
)

In [None]:
restored

{'config': {'C_init': 'trunc_standard_normal',
  'USE_WANDB': True,
  'activation_fn': 'half_glu1',
  'batchnorm': True,
  'bidirectional': False,
  'blocks': 8,
  'bn_momentum': 0.95,
  'bsz': 16,
  'clip_eigs': False,
  'conj_sym': True,
  'cosine_anneal': True,
  'd_model': 32,
  'dataset': 'lobster-prediction',
  'dir_name': './data',
  'discretization': 'zoh',
  'dt_global': False,
  'dt_max': 0.1,
  'dt_min': 0.001,
  'early_stop_patience': 1000,
  'epochs': 100,
  'jax_seed': 1919,
  'lr_factor': 1,
  'lr_min': 0,
  'lr_patience': 1000000,
  'masking': 'causal',
  'mode': 'pool',
  'n_layers': 6,
  'opt_config': 'standard',
  'p_dropout': 0.0,
  'prenorm': True,
  'reduce_factor': 1.0,
  'ssm_lr_base': 0.001,
  'ssm_size_base': 32,
  'wandb_entity': 'peer-nagy',
  'wandb_project': 'LOBS5',
  'warmup_end': 1,
  'weight_decay': 0.05},
 'metrics': {'acc_test': array(0.65965, dtype=float32),
  'acc_val': array(0.6576391, dtype=float32),
  'loss_test': array(1.6603074, dtype=float32)

In [None]:
state = restored['model']

In [None]:
#data_dir = '/nfs/home/peern/LOBS5/data/'
#messages = sorted(glob(data_dir + '*message*.npy'))

#d = LOBSTER_Dataset(
#    message_files=messages,
#    n_messages=500,
#    n_buffer_files=2,
#    mask_fn=LOBSTER_Dataset.random_mask,
#    randomize_offset=True,
#    seed=42
#)

In [11]:
dataset_obj = LOBSTER(
    'lobster',
    data_dir='/nfs/home/peern/LOBS5/data/',
    mask_fn=LOBSTER_Dataset.random_mask
)
dataset_obj.setup()

In [12]:
val_loader = make_data_loader(
    dataset_obj.dataset_val,
    dataset_obj,
    seed=args.jax_seed,
    batch_size=args.bsz,
    drop_last=True,
    shuffle=False,
    num_workers=0
)

In [14]:
model = model_cls(training=False, step_rescale=1.0)
#model.apply = restored['model'].apply_fn

In [13]:
dataset_obj.dataset_train[0]

(array([    3,     3,   691, ..., 11108, 11007, 11110]), array([11109]))

In [14]:
type(dataset_obj.dataset_train[0][0])

numpy.ndarray

In [15]:
counter = 0
for a in val_loader:
    X = a[0]
    y = a[1]
    # print(type(a[0]))
    # print(len(a))
    # print(a[0]) # X
    # print(a[1]) # y
    counter += 1
    if counter > 1:
        break

In [16]:
X[0][0:20]

Array([    3,    28,   989,   729,   168,  1003,  1020, 11108, 11008,
       11110,     2,     2,     2,     2,     2,     2,     2,     2,
           2,     2], dtype=int32)

In [144]:
msg_str = tok.decode_to_str(
    X[0][0:20],
    v
)#.flatten()
msg_str

array([['000', '025', '986', '726', '165', '1', '001', '-', '01', '1',
        'NAN', 'NAN', 'NAN', 'NAN', 'NAN', 'NAN', 'NAN', 'NAN', 'NAN',
        'NAN']], dtype='<U3')

In [52]:
# modification message from end of sequence (take orig message from beginning)
t_msg = X[0][-40:-20][:10].copy()
t_seq = X[0].copy()

In [51]:
X[0][-40:-20][10:]

Array([    3,    30,   215,    81,   132,  1005,  1124, 11107, 11009,
       11109], dtype=int32)

In [49]:
t_msg

Array([    3,    30,   215,    81,   132,  1005,  1124, 11107, 11009,
       11109], dtype=int32)

In [44]:
t_seq[483 * l: 484*l]

Array([    3,    30,    84,   427,   450,  1003,  1124, 11107, 11010,
       11109,     2,     2,     2,     2,     2,     2,     2,     2,
           2,     2], dtype=int32)

In [None]:
t_seq[498 * l: 499*l]

Array([    3,    30,    84,   427,   450,  1003,  1124, 11107, 11010,
       11109,     3,    30,   215,    81,   132,  1005,  1124, 11107,
       11009, 11109], dtype=int32)

In [61]:
valh.find_orig_msg(t_msg, t_seq)

# "orig msg" might also just be a modification 

Array(483, dtype=int32)

In [25]:
tok.decode(
    X[0][-40:-20],
    v
)

array([[2.70814244e+10, 1.00000000e+00, 1.10000000e+01, 3.00000000e+00,
        0.00000000e+00, 2.72120781e+10, 3.00000000e+00, 1.10000000e+01,
        2.00000000e+00, 0.00000000e+00]])

In [107]:
# form test msg to check possible errors
m_i = 50
l = 20
test = X[100, m_i*l:(m_i+1)*l]#.at[5].set(1004)

In [121]:
y[:,0].shape

(16,)

In [119]:
X[X == v.MASK_TOK].shape

(16,)

In [18]:
tok = Message_Tokenizer()

In [22]:
test_seq[0, 19*l:(19+1)*l]

Array([    3,    28,   990,   391,   114,  1006,  1038, 11108, 11007,
       11110,     2,     2,     2,     2,     2,     2,     2,     2,
           2,     2], dtype=int32)

In [35]:
test_seq = X.at[X == v.MASK_TOK].set(y[:,0])

l = 20
batch_i = 0
for m_i in range(0, 500):
    test_msg = test_seq[batch_i, m_i*l:(m_i+1)*l]#.at[5].set(1004)
    #assert test_msg[5] == 1003, m_i
    valh.validate_msg(test_msg, tok, v)

In [108]:
tok.decode_to_str(test, v)

array([['000', '051', '583', '924', '373', '1', '010', '-', '03', '1',
        '000', '052', '896', '268', '647', '3', '010', '-', '04', '1']],
      dtype='<U3')

In [109]:
test

Array([    3,    54,   586,   927,   376,  1003,  1107, 11108, 11010,
       11110,     3,    55,   899,   271,   650,  1005,  1107, 11108,
       11011, 11110], dtype=int32)

In [110]:
valh.validate_msg(test, tok, v)

event_type [1003]


True

In [41]:
# TODO: pay attention to the field (location) in addition to the label
#       because generic labels can be used in different fields

for pred_tok, label in zip(pred.argmax(axis=-1).tolist(), labels.tolist()):
    field, label_dec = v.DECODING_GLOBAL[label]
    print(field)
    print(v.ENCODING[field])

generic
{'MSK': 0, 'HID': 1, 'NAN': 2}
time
{'000': 3, '001': 4, '002': 5, '003': 6, '004': 7, '005': 8, '006': 9, '007': 10, '008': 11, '009': 12, '010': 13, '011': 14, '012': 15, '013': 16, '014': 17, '015': 18, '016': 19, '017': 20, '018': 21, '019': 22, '020': 23, '021': 24, '022': 25, '023': 26, '024': 27, '025': 28, '026': 29, '027': 30, '028': 31, '029': 32, '030': 33, '031': 34, '032': 35, '033': 36, '034': 37, '035': 38, '036': 39, '037': 40, '038': 41, '039': 42, '040': 43, '041': 44, '042': 45, '043': 46, '044': 47, '045': 48, '046': 49, '047': 50, '048': 51, '049': 52, '050': 53, '051': 54, '052': 55, '053': 56, '054': 57, '055': 58, '056': 59, '057': 60, '058': 61, '059': 62, '060': 63, '061': 64, '062': 65, '063': 66, '064': 67, '065': 68, '066': 69, '067': 70, '068': 71, '069': 72, '070': 73, '071': 74, '072': 75, '073': 76, '074': 77, '075': 78, '076': 79, '077': 80, '078': 81, '079': 82, '080': 83, '081': 84, '082': 85, '083': 86, '084': 87, '085': 88, '086': 89, '087'

In [19]:
rng = jax.random.PRNGKey(42)

In [20]:
rng

Array([ 0, 42], dtype=uint32)

In [None]:
sample_pred(pred, 5, jax.random.split(rng, pred.shape[0]))

Array([1616, 1039, 1616, 1616, 3101, 6865, 1039, 1616, 1039,  986, 5077,
       6524, 5606, 1616, 5606, 5606], dtype=int32)

In [277]:
rng_ = jax.random.split(rng, batch[0].shape[0])
result = fill_predicted_toks(batch[0], pred, top_n=5, rng=rng_)
tok.invalid_toks_per_seq(result, v)
rng = rng_[-1]

array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [278]:
result2 = fill_predicted_toks(batch[0], pred)
tok.invalid_toks_per_seq(result2, v)

array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [132]:
tok.invalid_toks_per_msg(result, v).sum(axis=-1)

array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [123]:
pred.shape

(16, 11111)

In [95]:
batch[0][1][batch[0][1] == v.MASK_TOK]

Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)

In [96]:
batch[0].shape

(8, 10000)

In [19]:
def is_tok_valid(tok, field, vocab):
    tok = tok.tolist()
    if isinstance(field, str):
        return tok in vocab.DECODING[Message_Tokenizer.FIELD_ENC_TYPES[field]]
    else:
        return [t in vocab.DECODING[Message_Tokenizer.FIELD_ENC_TYPES[f]] 
                for t, f in zip(tok, field)]

In [491]:
is_tok_valid(
    pred.argmax(axis=-1),
    get_masked_fields(batch[0]),
    v
)

[False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False]

In [None]:
import jax.numpy as jnp

def syntax_validation_matrix():
    """ Create a matrix of shape (MSG_LEN, VOCAB_SIZE) where a
        True value indicates that the token is valid for the location
        in the message.
    """
    v = Vocab()

    idx = []
    for i in range(Message_Tokenizer.MSG_LEN):
        field = get_field_from_idx(i)
        decoder_key = Message_Tokenizer.FIELD_ENC_TYPES[field[0]]
        for tok, val in v.DECODING[decoder_key].items():
            idx.append([i, tok])
    idx = tuple(jnp.array(idx).T)
    mask = jnp.zeros((Message_Tokenizer.MSG_LEN, len(v)), dtype=bool)
    mask = mask.at[idx].set(True)

    # adjustments for special tokens (no MSK, NAN, HID) allowed
    mask = mask.at[:, v.MASK_TOK].set(False)
    mask = mask.at[:, v.HIDDEN_TOK].set(False)

    # adjustment for positions only allowing subset of field
    # e.g. +/- at start of price
    enc_type = 'price'
    allowed_toks = jnp.array([v.ENCODING[enc_type]['+'], v.ENCODING[enc_type]['-']])
    adj_col = np.zeros((mask.shape[1],), dtype=bool).at[allowed_toks].set(True)
    # TODO: remove hardcoding and make this more general
    mask = mask.at[(7, 17), :].set(adj_col)
    return mask

In [23]:
Message_Tokenizer()

NameError: name 'Message_Tokenizer' is not defined

In [13]:
rng = jax.random.PRNGKey(42)
tok = Message_Tokenizer()

all_pred_toks = []
all_labels = []

losses = []
accuracy = []
ranks = []
valid_mass = []
valid_mass_n5 = []
valid_pred = []
losses_baseline = []

VALID_MATRIX = Message_Tokenizer.syntax_validation_matrix()

for batch_idx, batch in enumerate(val_loader):
    
    # PREPARE BATCH
    inputs, labels, integration_timesteps = prep_batch(batch, seq_len, in_dim)
    # INFERENCE STEP
    loss, acc, pred = eval_step(
        inputs, labels, integration_timesteps, state, model, args.batchnorm)
    
    # STORE RESULTS
    pred_toks = pred.argmax(axis=-1)
    all_labels += labels.tolist()
    all_pred_toks += pred_toks.tolist()
    
    # STATS
    losses.append(cross_entropy_loss(pred, labels))
    accuracy.append(compute_accuracy(pred, labels))
    
    # where does the correct label rank in the predicted distribution?
    ranks.append(valh.pred_rank(pred, labels))
    # how much of the predicted distribution is valid?
    masked_fields = valh.get_masked_fields(batch[0])
    valid_mass.append(valh.valid_prediction_mass(pred, masked_fields))
    valid_mass_n5.append(valh.valid_prediction_mass(pred, masked_fields, top_n=5))

    # check if argmax prediction is valid token for masked fields
    valid_pred.append(valh.is_tok_valid(pred_toks, masked_fields, v))

    # benchmark: uniform prediction over syntactically valid tokens
    pos = valh.get_masked_idx(batch[0])[..., -1]
    baseline_distr = VALID_MATRIX[pos] / VALID_MATRIX[pos].sum(axis=-1, keepdims=True)
    losses_baseline.append(cross_entropy_loss(np.log(
            np.where(baseline_distr==0, 1e-10, baseline_distr)
        ), labels)
    )

all_labels = np.array(all_labels)
all_pred_toks = np.array(all_pred_toks)
losses = np.array(losses)
accuracy = np.array(accuracy)
ranks = np.array(ranks)
valid_mass = np.array(valid_mass)
valid_mass_n5 = np.array(valid_mass_n5)
valid_pred = np.array(valid_pred)
losses_baseline = np.array(losses_baseline)

NameError: name 'val_loader' is not defined

In [30]:
print('mean loss', losses.mean())
print('mean accuracy', accuracy.mean())
print('mean rank', ranks.mean())
print('median rank', np.median(ranks))
print('mean valid mass', valid_mass.mean())
print('mean valid mass (top 5)', valid_mass_n5.mean())
print('mean valid prediction', valid_pred.mean())
print('mean baseline loss (uniform over valid syntax)', losses_baseline.mean())

mean loss 26.86415
mean accuracy 0.27653304
mean rank 756.5858
median rank 122.5
mean valid mass 0.56203675
mean valid mass (top 5) 0.5690942
mean valid prediction 0.6020047
mean baseline loss (uniform over valid syntax) 9.793726


In [64]:
import sklearn
import numpy as onp
import pandas as pd


precision, recall, fscore, support = sklearn.metrics.precision_recall_fscore_support(
    all_labels.astype(int),
    all_pred_toks,
    labels=range(len(v)),
    zero_division=0,
    average=None
)

'''
print('precision: {}'.format(precision))
print('recall: {}'.format(recall))
print('fscore: {}'.format(fscore))
print('support: {}'.format(support))
'''

"\nprint('precision: {}'.format(precision))\nprint('recall: {}'.format(recall))\nprint('fscore: {}'.format(fscore))\nprint('support: {}'.format(support))\n"

In [93]:
field_dec = onp.array([(field, dec) for tok, (field, dec) in sorted(v.DECODING_GLOBAL.items())])

scores_df = pd.DataFrame({
    'field': field_dec[:, 0],
    'decoded': field_dec[:, 1],
    'precision': precision,
    'recall': recall,
    'fscore': fscore,
    'support': support,
})
#scores_df

In [97]:
scores_df.loc[scores_df.support > 0].groupby('field').agg(
    precision=('precision', 'mean'),
    recall=('recall', 'mean'),
    fscore=('fscore', 'mean'),
    support=('support', 'sum'),
)

Unnamed: 0_level_0,precision,recall,fscore,support
field,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
direction,0.983327,0.986676,0.984959,231
event_type,0.744489,0.620784,0.661801,282
generic,0.992991,1.0,0.996483,850
price,0.378611,0.428705,0.380395,493
size,0.075944,0.093591,0.079177,257
time,0.021673,0.035929,0.025314,1279


In [None]:
# TODO: longer validation by masking every index seperately
#       and using different sequence offsets
#       this way we can predict every token in the data set over time

In [31]:
seq = valh.fill_predicted_toks(batch[0][0], pred[0])

In [115]:
valh.append_hid_msg(seq)[-25:]

Array([ 1005,  1107, 11108, 11007, 11110,     1,     1,     1,     1,
           1,     1,     1,     1,     1,     1,     1,     1,     1,
           1,     1,     1,     1,     1,     1,     1], dtype=int32)

In [138]:
# TODO: loop through and increase i to AR. predict next token
valh.mask_last_msg_in_seq(seq, 1)

(Array([   25,   677,   810, ..., 11108, 11007, 11110], dtype=int32),
 Array(684, dtype=int32))

In [142]:
np.expand_dims(seq, axis=0)

Array([[   25,   677,   810, ..., 11108, 11007, 11110]], dtype=int32)

In [32]:
start_seq = seq

In [100]:
# TODO: try this for causal model, check (syntactic) validity of predictions

pred_n_messages = 10
valid_mask_array = tok.syntax_validation_matrix()
inf_seq = valh.pred_msg(
    start_seq,
    pred_n_messages,
    state,
    model,
    args.batchnorm,
    rng,
    valid_mask_array
)

In [102]:
l = Message_Tokenizer.MSG_LEN
tok_output = onp.array(inf_seq[-pred_n_messages*l:].reshape((-1,l)))
tok_output

array([[  617,    10,    10,    10,    10,  1005,  3207, 11108, 11046,
        11109,   337,    10,    10,    10,    10,  1004,  3990, 11107,
        11071, 11110],
       [  702,    10,    10,    10,    10,  1005,  4436, 11108, 11046,
        11109,   846,    10,    10,    10,    10,  1004,  3990, 11107,
        11071, 11110],
       [  702,    10,    10,    10,    10,  1005,  6239, 11108, 11081,
        11109,   846,    10,    10,    10,    10,  1004,  3474, 11108,
        11089, 11109],
       [  900,   398,    10,    10,    10,  1005,  6239, 11108, 11081,
        11109,   846,    10,    10,    10,    10,  1005,  3474, 11108,
        11071, 11109],
       [  900,   475,    10,    10,    10,  1005,  8171, 11108, 11074,
        11109,   846,    10,    10,    10,    10,  1005,  3474, 11108,
        11071, 11109],
       [  772,   602,   398,    10,    10,  1005,  6239, 11108, 11067,
        11110,   296,    10,    10,    10,    10,  1005,  4692, 11108,
        11071, 11109],
       [  

In [None]:
whole_seq = np.concatenate([start_seq, inf_seq], axis=0)
for m in inf_seq:
    valh.validate_msg(m, tok, v)
    valh.find_orig_msg(m, whole_seq[])

In [98]:
start_seq.reshape((-1, Message_Tokenizer.MSG_LEN))

Array([[   25,   677,   810, ...,     2,     2,     2],
       [   25,   677,   810, ...,     2,     2,     2],
       [   25,   677,   810, ..., 11107, 11008, 11109],
       ...,
       [   25,   684,    13, ..., 11108, 11007, 11110],
       [   25,   684,    13, ..., 11108, 11007, 11110],
       [   25,   684,    13, ..., 11108, 11007, 11110]], dtype=int32)

In [104]:
for m in tok_output: # onp.array(start_seq.reshape((-1, Message_Tokenizer.MSG_LEN))): # :
    for i, t in enumerate(m):
        # print('i', i)
        f = valh.get_field_from_idx(i)[0]
        # print('field', f)
        print(
            v.DECODING_GLOBAL[t],
            valh.is_tok_valid(onp.array(t), f, v))
    print()

('time', '614') True
('time', '007') True
('time', '007') True
('time', '007') True
('time', '007') True
('event_type', '3') True
('size', '2200') True
('price', '-') True
('price', '39') True
('direction', '0') True
('time', '334') True
('time', '007') True
('time', '007') True
('time', '007') True
('time', '007') True
('event_type', '2') True
('size', '2983') True
('price', '+') True
('price', '64') True
('direction', '1') True

('time', '699') True
('time', '007') True
('time', '007') True
('time', '007') True
('time', '007') True
('event_type', '3') True
('size', '3429') True
('price', '-') True
('price', '39') True
('direction', '0') True
('time', '843') True
('time', '007') True
('time', '007') True
('time', '007') True
('time', '007') True
('event_type', '2') True
('size', '2983') True
('price', '+') True
('price', '64') True
('direction', '1') True

('time', '699') True
('time', '007') True
('time', '007') True
('time', '007') True
('time', '007') True
('event_type', '3') True


In [48]:
tok.decode_to_str(onp.array(inf_seq[-10*l:].reshape((-1,l))), v)

array([[['015', '015', '616', '142', '616', '', '005', '01', '03', '',
         '', '022', '020', '020', '020', '', '005', '', '03', '']],

       [['015', '015', '020', '020', '785', '', '005', '01', '02', '',
         '', '022', '020', '020', '', '', '020', '', '01', '']],

       [['015', '015', '020', '785', '', '', '005', '01', '02', '', '',
         '020', '020', '020', '', '3', '020', '', '01', '']],

       [['015', '015', '015', '785', '', '', '005', '', '02', '', '',
         '022', '022', '020', '', '3', '020', '', '01', '']],

       [['015', '015', '020', '785', '', '', '005', '', '03', '', '',
         '022', '020', '020', '', '3', '020', '', '01', '']],

       [['015', '015', '020', '785', '', '', '020', '01', '01', '', '',
         '020', '020', '', '', '', '020', '02', '02', '']],

       [['015', '020', '020', '', '', '', '010', '03', '01', '', '015',
         '020', '785', '', '', '', '020', '02', '02', '']],

       [['015', '015', '015', '', '', '', '', '02', '02'

In [46]:
tok.decode(onp.array(inf_seq[-10*l:].reshape((-1,l))), v)

TypeError: sequence item 0: expected str instance, numpy.ndarray found

In [99]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(scores_df.loc[scores_df.support > 0])

Unnamed: 0,field,decoded,precision,recall,fscore,support
2,generic,NAN,0.992991,1.0,0.996483,850
3,time,000,0.709677,1.0,0.830189,22
4,time,001,0.73913,0.85,0.790698,20
5,time,002,0.681818,0.9375,0.789474,16
6,time,003,0.75,0.818182,0.782609,11
7,time,004,0.785714,1.0,0.88,11
8,time,005,0.923077,0.8,0.857143,15
9,time,006,0.4,1.0,0.571429,6
10,time,007,0.866667,0.928571,0.896552,14
11,time,008,0.5,1.0,0.666667,6


In [66]:
scores_df.to_csv('scores.csv')

In [88]:
import flax.linen as nn


def __pred_rank(pred, labels):
    """ Get the rank of the correct label in the predicted distribution.
        Lower is better (0 is correct prediction).
    """
    correct_mask = nn.one_hot(labels.astype(int), pred.shape[-1]).astype(bool)
    # ::-1 sorts in descending order (0 is highest rank)
    a = pred.argsort(axis=-1)
    ranks = a[..., ::-1].argsort(axis=-1)
    return ranks[correct_mask]

def __pred_rank2(pred, labels):
    """ Get the rank of the correct label in the predicted distribution.
        Lower is better (0 is correct prediction).
    """
    # ::-1 sorts in descending order (0 is highest rank)
    a = pred[..., ::-1].argsort(axis=-1)
    return np.argwhere(a == np.expand_dims(labels, -1))

In [275]:
# calculate loss of uniform prediction as reference
cross_entropy_loss(np.log(np.ones_like(pred) / pred.shape[-1]), labels).mean()

Array(9.315691, dtype=float32)

In [391]:
import numpy as onp

mask_fn = LOBSTER_Dataset.get_masking_fn(
    random_msg_idxs=None,
    random_fields=['time'],
    randomize_message=False)  # True
    
gen = onp.random.default_rng(42)

In [407]:
# test masking

# first, fill back in the correct label
seq = onp.array(batch[0][0].copy()).reshape((-1, Message_Tokenizer.MSG_LEN))
seq[seq == Vocab.MASK_TOK] = labels[0]

# LOBSTER_Dataset.causal_mask(
masked_toks, target = mask_fn( # LOBSTER_Dataset.random_mask(
    seq,
    gen)

print(masked_toks)
print('target', target)
print('location', onp.argwhere(masked_toks == Vocab.MASK_TOK))

[[   25   677   810 ...     2     2     2]
 [   25   677   810 ...     2     2     2]
 [   25   677   810 ... 11107 11008 11109]
 ...
 [   25   684    13 ... 11108 11007 11110]
 [   25   684    13 ... 11108 11007 11110]
 [    1     1     1 ... 11108 11007     0]]
target 11110
location [[499  19]]


In [302]:
onp.argwhere(masked_toks == Vocab.MASK_TOK)

array([[272,  17],
       [367,   4]])

In [125]:
def get_valid_mask(inp, shape):
    mask = np.zeros(shape)
    valid_toks = get_valid_toks_for_input(inp)
    tuple_index = (
        tuple(i for i in range(len(valid_toks)) for _ in range(len(valid_toks[i]))),
        tuple(t for tok_tuple in valid_toks for t in tok_tuple))
    return mask.at[tuple_index].set(1.)

In [127]:
get_valid_mask(batch[0], pred.shape)

Array([1.0030e+03, 1.0030e+03, 1.0003e+04, 5.0000e+00, 1.0030e+03,
       1.0030e+03, 1.0500e+02, 7.0000e+00, 1.0030e+03, 1.0030e+03,
       1.0030e+03, 1.0500e+02, 1.0500e+02, 1.0030e+03, 1.0030e+03,
       1.0030e+03], dtype=float32)

In [None]:
# TODO: loss of unconditional (training data) prediction per field

In [None]:
# TODO: accuracy per token (for each correct label, how often is it predicted?)
#       recall, precision, per token

In [22]:
tok = Message_Tokenizer()

In [113]:
tok.col_idx_by_encoder

{'time': [0, 1, 2, 3, 4, 10, 11, 12, 13, 14],
 'event_type': [5, 15],
 'size': [6, 16],
 'price': [7, 8, 17, 18],
 'direction': [9, 19]}

Array([[-9.314257 , -9.318846 , -9.314268 , ..., -9.317285 , -9.313043 ,
        -9.319928 ],
       [-9.314671 , -9.319134 , -9.314537 , ..., -9.31646  , -9.313398 ,
        -9.321061 ],
       [-9.313926 , -9.3199835, -9.313149 , ..., -9.316932 , -9.312886 ,
        -9.320968 ],
       ...,
       [-9.314899 , -9.318279 , -9.316247 , ..., -9.3185425, -9.314482 ,
        -9.319631 ],
       [-9.314888 , -9.318323 , -9.3155575, ..., -9.317863 , -9.313861 ,
        -9.319694 ],
       [-9.315161 , -9.318303 , -9.316559 , ..., -9.317473 , -9.314913 ,
        -9.318973 ]], dtype=float32)

In [118]:
tok.invalid_toks_per_msg(batch[0], v).mean()

0.0

In [122]:
pred.argmax(axis=-1)

Array([6332, 6332, 6332, 3101, 6865, 1616, 6332, 3101, 3101, 3101, 5606,
       5606, 5606, 5606, 5606, 5606], dtype=int32)

In [419]:
pred.shape

(16, 11111)

In [418]:
tok.invalid_toks_per_msg(, v).mean()

ValueError: cannot reshape array of size 177776 into shape (16,newaxis,20)

In [417]:
tok._validate_syntax(batch[0][0], v)

(True,
 array([['022', '674', '807', ..., 'NAN', 'NAN', 'NAN'],
        ['022', '674', '807', ..., 'NAN', 'NAN', 'NAN'],
        ['022', '674', '807', ..., '+', '01', '0'],
        ...,
        ['022', '681', '010', ..., '-', '00', '1'],
        ['022', '681', '010', ..., '-', '00', '1'],
        ['022', '681', '010', ..., '-', '00', '1']], dtype='<U3'))

In [228]:
#valid_prediction_mass(pred[0:2], get_masked_fields(batch[0][0:2]))
valid_prediction_mass(pred, get_masked_fields(batch[0]))

Array([0.09027135, 0.09027305, 0.09027165, 0.00945231, 0.00063061,
       0.00945135, 0.00045093, 0.09027136], dtype=float32)

In [281]:
# probability of syntactically valid tokens in the top_n predictions
print('1:')
print(valid_prediction_mass(pred, get_masked_fields(batch[0]), top_n=1))

print('5:')
print(valid_prediction_mass(pred, get_masked_fields(batch[0]), top_n=5))

print('10:')
print(valid_prediction_mass(pred, get_masked_fields(batch[0]), top_n=10))

1:
[1. 0. 1. 0. 0. 0. 0. 1.]
5:
[0.40015867 0.2000537  0.200092   0.         0.         0.
 0.1999747  0.6000089 ]
10:
[0.30012807 0.20000866 0.30003056 0.         0.         0.
 0.1000105  0.30007356]


In [33]:
# for a batch, print predicted tokens (and vocab) and actually correct labels

for pred_tok, label in zip(pred.argmax(axis=-1).tolist(), labels.tolist()):
    print('pred:')
    print(pred_tok)
    print(v.DECODING_GLOBAL[pred_tok])
    print('label', label)
    print(v.DECODING_GLOBAL[label])
    print()

pred:
260
('time', '257')
label 2.0
('generic', 'NAN')

pred:
5392
('size', '4385')
label 54.0
('time', '051')

pred:
260
('time', '257')
label 2.0
('generic', 'NAN')

pred:
7353
('size', '6346')
label 2.0
('generic', 'NAN')

pred:
1401
('size', '0394')
label 1005.0
('event_type', '3')

pred:
633
('time', '630')
label 2.0
('generic', 'NAN')

pred:
260
('time', '257')
label 11109.0
('direction', '0')

pred:
260
('time', '257')
label 2.0
('generic', 'NAN')

