# Probing tsGT code functionality
This notebook contains some fundamental tests for checking the functionality of the tsGT model code. 

In [None]:
import sys
sys.path.append('code')

import datasets as ds
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

## First, let's get an idea of how the Dataset is built

In [None]:
import pickle
with open('data/synthetic_datasets_segments.pkl', 'rb') as file:
    samples = pickle.load(file)
samples["dgp_dataset"]["samples"].shape

In [None]:
def data_to_df(samples, start_date='2018-01-01 00:00:00', freq='1h'):
    " Utility function to map a dataset to a dataframe. Covariates here don't matter and are dropped during training."
    n_signals, length = samples.shape
    df = pd.DataFrame(samples.T)
    df.index = pd.date_range(start_date, periods=length, freq=freq)
    df.index.name = 'date'
    return df

## datasets.py
The relevant parameters here are:
- series_length: this defines the length of the slices of the dataset, for example let the dataset consist of [n_series, length]. Then during training we take random samples of [batch_size, start_idx:end_idx] where end_idx - start_idx = 256
- train_window: this denotes the length of the series where we can get training examples from. If eval_window>0 this means that the evaluation window is situated at the end of the series and so train window has shape [n_series, :train_window]. If eval_window=0 then the evaluation window takes up the first series_length points in the series.
- eval_window: this defines the window over which loss is to be calculated. eval_window depends on the length of eval_data which is always equal to series_length. For example, if series_length=256 and eval_window = 128. Then eval_data=256. During validation loss calculation the loss is calculated over eval_data using teacher forcing. If eval_window>0 then the loss is calculated only on the 128 points corresponding to eval_window.

In [None]:
# Dummy dataset where the first 896 points are 1s and the rest 0s, testing the split.
data = np.ones((1024, 1024)) 
data[896:] = 0
df = pd.DataFrame(data, index=pd.date_range('2018-01-01', periods=1024, freq='h'))
dataset = ds.Dataset(
    data_full=ds.DataCollection(data_loader=lambda: df),
    series_length=256,
    start_date='2018-01-01',
    train_window=896,
    eval_window=128)
assert dataset.train_data.shape == (1024, 896)
assert dataset.eval_data.shape == (1024, 256)
assert dataset.eval_horizon == 128
assert dataset.train_data.shape == (1024, 896)
assert dataset.eval_data.shape == (1024, 256)
assert np.all(dataset.train_data == 1)
assert np.all(dataset.eval_data[:, :128] == 1) # first 128 points of eval come from the 1s since they are overlapping with the training window
assert np.all(dataset.eval_data[:, 128:] == 0) # last 128 cols of eval come from the 0s region
assert dataset.eval_horizon == 128
assert dataset.cov.shape == (1152,3) # This is because it has a shape of [training_window + eval_window, 3] and since they overlap at 128 points 1024+128 = 1152
print("All tests passed")

## index_streams.py
This is essentially just a supporting package for getting indexes and slices for the input iterables. Not much to test here other than sanity checks. (Note: only relevant bits of code here are the ones that randomly sample eval and train indexes. Everything else is not used. For example, the code provides methods for extracting the full index streams but this is deterministic sampling.) The two relevant functions here are:
- create_train_index_stream: this creates the index stream for training examples, so it provides us with a generator of (series_index, slice_start, slice_stop). It chooses a random series and a random starting point. The only relevant parameter here is weighted_sampling:Bool which is kept false. This parameter helps mitigate sampling from series/time points that are 0 (for example seasonal data) but our data does not reflect these trends.
- create_eval_index_stream: similar to the above only for the eval index stream, the difference is that we get the whole length and so slice_stop=None. The relevant parameter here is full_eval which is false by default. This makes it so that we evaluate on the whole dataset.

In [None]:
from index_streams import (
    create_train_index_stream,
    create_eval_index_stream,
    create_random_eval_index_stream,
    create_uniform_train_index_stream
)
# Dummy dataset for testing, again...
dataset = np.random.randn(5, 100)
series_length = 20
np.random.seed(0)
stream = create_uniform_train_index_stream(dataset, series_length)
for i in range(20):
    series_idx, slice_start, slice_stop = next(stream)
    assert slice_stop - slice_start == series_length
    assert 0 <= series_idx < dataset.shape[0]
    assert 0 <= slice_start < dataset.shape[1]
    assert slice_stop <= dataset.shape[1]
np.random.seed(0)
stream = create_random_eval_index_stream(dataset, series_length)
for i in range(20):
    series_idx, slice_start, slice_stop = next(stream)
    assert 0 <= series_idx < dataset.shape[0]
    assert slice_start == -series_length
    assert slice_stop is None
    sliced = dataset[series_idx, slice_start:slice_stop]
    assert sliced.shape[0] == series_length
print("All tests passed")

## inputs.py
Using the above index and slice streams, the inputs script now lets us extract inputs to be used for training/eval. The main use of this module is CreateInputs which takes the index streams above and returns actual train and eval streams with train and eval data. A very important variable here is mask. This determines on which parts of the sequence the loss is calculated. For the train sequences the mask is always 1. During eval, the mask is always 1 for eval_horizon. For example, if eval_horizon = 0 and series_length=256 then the mask consists of 256 1s. If eval_horizon 128 and series_length = 256 then the mask has 128 0s followed by 128 1s.

In [None]:
data = np.ones((1024, 1024)) 
data[896:] = 0
df = pd.DataFrame(data, index=pd.date_range('2018-01-01', periods=1024, freq='h'))
dataset = ds.Dataset(
    data_full=ds.DataCollection(data_loader=lambda: df),
    series_length=256,
    start_date='2018-01-01',
    train_window=896,
    eval_window=128)

In [None]:
from index_streams import create_train_index_stream, create_eval_index_stream
from inputs import slice_stream, minibatch_stream, CreateInputs
import numpy as np
import random

train_stream, eval_stream = CreateInputs(dataset=dataset, batch_size=16, series_length=256,
            weighted_sampling=False, full_eval=False, traxify=False)
train_batch = next(train_stream(None))
eval_batch = next(eval_stream(None))
series, inp, target, mask = train_batch
eval_series, eval_inp, eval_target, eval_mask = eval_batch
assert series.shape == (16, 256)
assert inp.shape[0] == 16
assert mask.shape == (16, 256)
assert np.array_equal(series, target) # This is basically the same series anyways for now, shift comes later for teacher forcing.
assert np.all(series == 1)
assert np.all(mask == 1) # All 1s since we calculate loss over the whol` hrg,e train series.
assert np.all(eval_mask[:, :128] == 0) # This is the overlapping section, mask is 0 here.
assert np.all(eval_mask[:, 128:] == 1) # This is where loss is calculated
assert np.all(eval_series[:, :128] == 1)
assert np.all(eval_series[:, 128:] == 0)
assert np.array_equal(eval_series, eval_target)
print("All tests passed")

## normalization.py

This module creates the first necessary step of the data preparation process by normalising the data according to the method laid out in the paper. Here it is necessary to essentially test that everything works correctly with known data. An important factor here is also the mask meaning that any positions where mask=0 are excluded from calculating the mean.

In [None]:
from predictors.normalization import PerTsNormalizer
import numpy as np
normalizer = PerTsNormalizer(regularizer=0.001)
data = np.array([
    [1.0, 1.0, 1.0, 1.0],   # absmean = 1.0, scaling = 1.001
    [2.0, 0.0, 2.0, 0.0],   # absmean = 1.0, scaling = 1.001
])
scaled_data, parameters, mask_mod = normalizer.normalize(data)
recovered = normalizer.denormalize(scaled_data, parameters)
zs = np.array([[0.0, 0.0, 0.0, 0.0]])
scaled_zs, par_zs, a = normalizer.normalize(zs)
neg = np.array([[-3.0, 3.0, -3.0, 3.0]])
scaled_neg, b, c = normalizer.normalize(negatives)
assert np.allclose(parameters.scaling_factor, [[1.001], [1.001]])
assert np.allclose(recovered, data)
assert np.allclose(par_zs.scaling_factor, [[0.001]])
assert np.allclose(scaled_zs, 0.0)
assert np.allclose(scaled_neg, [[-3.0/3.001, 3.0/3.001, -3.0/3.001, 3.0/3.001]])
# this is now basically testing the mask situation
data = np.array([[1.0, 1.0, 10.0, 10.0]])
mask = np.array([[0.0, 0.0, 1.0, 1.0]]) # Here the last two are 1, this would be when eval_horizon = 2 but series_length = 4
scaled_with_mask, params_mask, _ = normalizer.normalize(data, mask=mask)
scaled_no_mask, params_no_mask, _ = normalizer.normalize(data, mask=None)
assert np.allclose(params_mask.scaling_factor, [[10.001]])
assert np.allclose(scaled_with_mask, [[1.0/10.001, 1.0/10.001, 10.0/10.001, 10.0/10.001]])
assert np.allclose(params_no_mask.scaling_factor, [[5.501]])
assert np.allclose(scaled_no_mask, [[1.0/5.501, 1.0/5.501, 10.0/5.501, 10.0/5.501]])
assert np.allclose(normalizer.denormalize(scaled_with_mask, params_mask), data)
assert np.allclose(normalizer.denormalize(scaled_no_mask, params_no_mask), data)
print("All tests passed")

## serializers.py
The serializers script is important as it provides the serialization of floats into discrete tokens and the deserialization of tokens to floats. Gym spaces supports the normalization of the values into a set bound. This is used during preprocessing. There are some important parameters here:
- vocab_size: This is the base we multiply the floats with. In the paper this is set to 10
- precision: This defines how many tokens we represent each float with. For example precision=3 throughout the paper so with vocab_size = 10. We can get 0.143 = 143. These three tokens then represent the float value 0.143

In [None]:
import gym
from serializers import BoxSpaceSerializer
space = gym.spaces.Box(low=-0.3, high=1.5, shape=())
serializer = BoxSpaceSerializer(space, vocab_size=10, precision=3)
assert space.shape == ()
assert np.isclose(float(space.low), -0.3) # accounting for floating point errors here...
assert np.isclose(float(space.high), 1.5)
data = np.array([[0.5], [1.5], [-0.3], [0.0], [1.0]])
preprocessed = serializer._preprocess(data) # This is basically the mapping to [0,1] and clipping.
low, high = -0.3, 1.5
assert np.allclose(preprocessed[0], (0.5 - low) / (high - low))
assert np.allclose(preprocessed[1], 1.0)
assert np.allclose(preprocessed[2], 0.0) 
assert np.allclose(preprocessed[3], (0.0- low) / (high - low))
assert np.allclose(preprocessed[4], (1.0-low) / (high - low))
assert np.all(preprocessed >= 0.0)
assert np.all(preprocessed <= 1.0)

serialized = serializer.serialize(data)
deserialized = serializer.deserialize(serialized)
assert serialized.shape == (5, 3)
assert np.array_equal(serialized[1], [9, 9, 9]) # Upper bound
assert np.array_equal(serialized[2], [0, 0, 0]) # Lower bound here.
assert np.isclose(float(deserialized[1]), 1.5, atol=0.01)
assert np.isclose(float(deserialized[2]), -0.3, atol=0.01)
print("All tests passed")

## predictors/inputs.py
This code is responsible for creating a trax layer that injects the inputs with information from the time covariates (as well as series id) as provided by the code. Here we just check if the covariates are dropped. 

In [None]:
from predictors.inputs import InjectInputs
from trax import shapes
import jax.numpy as jnp
import numpy as np
injection_layer = InjectInputs(input_vocab_sizes=None, d_emb=256)
batch_size, seq_len, d_emb, n_covariates = 2, 10, 128, 4
context_emb = jnp.ones((batch_size, seq_len, d_emb))
covariates = jnp.ones((batch_size, seq_len, n_covariates))
inpts = (
    shapes.ShapeDtype((batch_size, seq_len, d_emb)),
    shapes.ShapeDtype((batch_size, seq_len, n_covariates)),
)
injection_layer.init(inpts)
output = injection_layer((context_emb, covariates))
assert output.shape == (batch_size, seq_len, d_emb)
assert np.allclose(output, context_emb)
print("All tests passed")

## attention.py
Attention here is pretty standard, RoPE from Roformer is used to calculate the positional encodings. It is important to test the functions for clarity to see if the numerical outputs are consistent with the behavior that is expected.

In [None]:
import numpy as np
from trax.fastmath import numpy as jnp #jnp when working with trax layers
from trax import shapes, fastmath
from attention import (
    rotate_every_two,
    calculate_sin_cos_rotary,
    DotProductCausalRotaryAttention,
)
# Rotation function first
x = jnp.array([[[1.0, 2.0, 3.0, 4.0]]])  
r = rotate_every_two(x)
e = jnp.array([[[-2.0, 1.0, -4.0, 3.0]]])
assert jnp.allclose(r, e) # So here we should get the needed rotation for applying ROPE
x = jnp.ones((2, 5, 8))
assert rotate_every_two(x).shape == (2, 5, 8) #Obviously should work
x = jnp.array([[[1.0, 2.0, 3.0, 4.0]]])
dr = rotate_every_two(rotate_every_two(x))
assert jnp.allclose(dr, -x)
# Precomputation of sin cos tables
r, ctx = 8, 32
sin, cos = calculate_sin_cos_rotary(rotary_dim=r, n_ctx=ctx)
assert sin.shape == (ctx, r)
assert cos.shape == (ctx, r)
assert jnp.allclose(sin**2 + cos**2, 1.0, atol=1e-3)
assert jnp.allclose(sin[0], 0.0, atol=1e-6)
assert jnp.allclose(cos[0], 1.0, atol=1e-6)
assert jnp.all(sin >= -1.0) and jnp.all(sin <= 1.0) # Should all be bounded here
assert jnp.all(cos >= -1.0) and jnp.all(cos <= 1.0)

batch_heads, seq_len, d_head, fraction_to_rotate = 2, 6, 16, 0.25
layer = DotProductCausalRotaryAttention(fraction_to_rotate=fraction_to_rotate,dropout=0.0, max_inference_length=64, mode='eval')
sig = (
    shapes.ShapeDtype((batch_heads, seq_len, d_head)),
    shapes.ShapeDtype((batch_heads, seq_len, d_head)),
    shapes.ShapeDtype((batch_heads, seq_len, d_head)),
)
layer.init(sig)
np.random.seed(42)
q = jnp.array(np.random.randn(batch_heads, seq_len, d_head))
k = jnp.array(np.random.randn(batch_heads, seq_len, d_head))
v = jnp.array(np.random.randn(batch_heads, seq_len, d_head))
out = layer((q, k, v))
assert out.shape == (batch_heads, seq_len, d_head)
# We can check the causal mask too
q2 = q.at[:, -1, :].set(99.0)
k2 = k.at[:, -1, :].set(99.0)
v2 = v.at[:, -1, :].set(99.0)
out2 = layer((q2, k2, v2))
assert jnp.allclose(out[:, :-1, :], out2[:, :-1, :], atol=1e-5)
#  Check permutation invariance, this should not be true of course.
perm = jnp.array([1, 0] + list(range(2, seq_len)))
q_perm = q[:, perm, :]
k_perm = k[:, perm, :]
v_perm = v[:, perm, :]
out_perm = layer((q_perm, k_perm, v_perm))
assert not jnp.allclose(out, out_perm, atol=1e-3)
print("All tests passed")

## layers.py
This code module contains a lot of utility layers that are not actually utilised, neither in the original experimental setting(when testing the actual tsGT model) nor in the one proposed in the thesis. Essentially the only layer that is utilised here is CausalConv but with a kernel width of 1 which is effectively a linear projection.

In [None]:
import numpy as np
from trax import shapes
from trax.fastmath import numpy as jnp
from layers import CausalConv
conv = CausalConv(filters=256, kernel_width=1, mode='eval')
conv.init(shapes.ShapeDtype((2, 10, 128)))
x = jnp.array(np.random.randn(2, 10, 128))
out = conv(x)
assert out.shape == (2, 10, 256)
conv = CausalConv(filters=32, kernel_width=1, mode='eval')
conv.init(shapes.ShapeDtype((1, 8, 16)))
x = jnp.array(np.random.randn(1, 8, 16).astype(np.float32))
out = conv(x)
x_mod = x.at[:, 3, :].set(99.0)
out_mod = conv(x_mod) # Modifying one position here shold not change a lot
assert jnp.allclose(out[:, :3, :], out_mod[:, :3, :],atol=1e-3)
assert jnp.allclose(out[:, 4:, :],out_mod[:, 4:, :], atol=1e-3)
assert not jnp.allclose(out[:, 3, :], out_mod[:, 3, :],atol=1e-3)
print("All tests passed")

## models.py
This code contains the main transformer body defined using trax layers that is to be used throughout the experiment. The model consists of decoder blocks, feedforward blocks and a causal convolution layer of kernel width 1, essentially a Dense layer.

In [None]:
import numpy as np
from trax import layers as tl, shapes
from trax.fastmath import numpy as jnp
import models

def return_body(mode='eval'):
    return models.TransformerBody(
        d_model=128, d_ff_mul=2, n_layers=4, n_heads=4,
        max_len=2048, dropout=0.0, ff_activation=tl.FastGelu, precision=3
    )

body = return_body()
batch, seq_len, d_model = 2, 24, 128
sig = shapes.ShapeDtype((batch, seq_len, d_model))
body.init(sig)
x = jnp.array(np.random.randn(batch, seq_len, d_model).astype(np.float32))
out = body(x)
assert out.shape == (batch, seq_len, d_model)

out1 = body(x)
out2 = body(x)
assert jnp.allclose(out1, out2, atol=1e-3)

perm = jnp.array([1, 0] + list(range(2, seq_len)))
x_perm = x[:, perm, :]
out_perm = body(x_perm)
assert not jnp.allclose(out, out_perm, atol=1e-3)

x_alt = jnp.array(np.random.randn(batch, seq_len, d_model).astype(np.float32))
out_alt = body(x_alt)
assert jnp.allclose(out[:, 0, :], out_alt[:, 0, :], atol=1e-3)

# Variable lengths 
for sl in [12, 48, 768]:
    x_var = jnp.array(np.random.randn(1, sl, d_model).astype(np.float32))
    out_var = body(x_var)
    assert out_var.shape == (1, sl, d_model)

# batch independence
x_a = jnp.array(np.random.randn(1, 24, d_model).astype(np.float32))
x_b = jnp.array(np.random.randn(1, 24, d_model).astype(np.float32))
x_batched = jnp.concatenate([x_a, x_b], axis=0)
out_a = body(x_a)
out_b = body(x_b)  
out_batched = body(x_batched)
assert jnp.allclose(out_batched[0], out_a[0], atol=1e-3)
assert jnp.allclose(out_batched[1], out_b[0], atol=1e-3)

x = jnp.array(np.random.randn(1, 24, d_model).astype(np.float32))
out = body(x)
assert not jnp.allclose(out[:, 1:, :], x[:, :-1, :], atol=1e-3)

# Numerical stability
x_large = jnp.array(np.random.randn(2, 24, d_model).astype(np.float32)) * 10.0
out_large = body(x_large)
assert jnp.all(jnp.isfinite(out_large))

x_small = jnp.array(np.random.randn(2, 24, d_model).astype(np.float32)) * 1e-6
out_small = body(x_small)
assert jnp.all(jnp.isfinite(out_small))

print("All tests passed")

## decoding.py
Autoregressive decoding of discretised continuous values. This decoding scheme needs a SerialDecoder model that is available in predictors/serial_predictor. It essentially lets the model handle the discrete digits. Here we just test the functionality of the code with some sanity checks.

In [None]:
import numpy as np
import gym
from trax import layers as tl, shapes
from trax.fastmath import numpy as jnp
import models
from decoding import autoregressive_sample
from predictors.serial_predictor import SerialDecoder
from serializers import BoxSpaceSerializer

vocab_size, d_model, precision = 10, 64, 3
serializer = BoxSpaceSerializer( space=gym.spaces.Box(shape=(), low=0.0, high=10.0),vocab_size=vocab_size,precision=precision)
body = models.TransformerBody( d_model=d_model, d_ff_mul=2, n_layers=2, n_heads=2,
    max_len=2048, dropout=0.0, ff_activation=tl.FastGelu,precision=precision, mode='predict')

model = SerialDecoder(model_body=body, serializer=serializer,d_emb=d_model,input_vocab_sizes=None, mode='predict')
sig = (shapes.ShapeDtype((1, 1), dtype=np.int32),shapes.ShapeDtype((1, 1, 0), dtype=np.int32))
model.init(sig)
init_state = model.state

# Just some placeholder sampling for logits
def greedy_sample(logits):
    return np.argmax(logits, axis=-1)
context_len, horizon = 6, 9
context = np.random.randint(0, vocab_size, (1, context_len))
inputs = np.zeros((1, context_len + horizon, 0))
model.state = init_state
result = autoregressive_sample(
    model=model, sample_fn=greedy_sample,
    context=context, inputs=inputs,
    batch_size=1, horizon_length=horizon,
)
assert result.shape == (1, horizon)
assert np.all(result >= 0) and np.all(result < vocab_size)
# It's actually sensitive to the context
model.state = init_state
result_a = autoregressive_sample(
    model=model, sample_fn=greedy_sample,
    context=np.array([[1, 1, 1, 1, 1, 1]], dtype=np.int32),
    inputs=inputs, batch_size=1, horizon_length=horizon,
)
model.state = init_state
result_b = autoregressive_sample(
    model=model, sample_fn=greedy_sample,
    context=np.array([[9, 9, 9, 9, 9, 9]], dtype=np.int32),
    inputs=inputs, batch_size=1, horizon_length=horizon,
)
assert not np.array_equal(result_a, result_b)

# Deserialise
pred_repr = np.reshape(result, (-1, precision))
values = serializer.deserialize(pred_repr)
assert np.all(np.isfinite(values))
assert values.shape == (horizon // precision,)
print("All tests passed.")

## serial_predictor.py
This is where the whole pipeline is put together and everything from training to prediction becomes possible. Of course the predictor itself does not have a state so as long as every component is instantiated with the same parameters and the model weights are there (From the decoder mainly) then it is possible to for example reconstruct the predictor used during training for evaluation. Here we can test every aspect of the predictor.

In [None]:
import numpy as np
import gym
import gin
from trax import layers as tl, shapes
from trax.fastmath import numpy as jnp
from trax.rl import serialization_utils as srl_utils
import models
from predictors.serial_predictor import SerialPredictor, SerialDecoder, SerialTraining
from predictors.normalization import Normalizer
from serializers import BoxSpaceSerializer
from distributions import Categorical
gin.clear_config()

vocab_size, d_model, precision, significance_decay, low, high, batch_size, context_len, horizon = 10, 64, 3, 0.3, 0.0, 10.0, 2, 24, 8
def make_body(mode='eval', precision=3):
    return models.TransformerBody(
        d_model=d_model, d_ff_mul=2, n_layers=2, n_heads=2,
        max_len=2048, dropout=0.0, ff_activation=tl.FastGelu,
        precision=precision, mode=mode,
    )

predictor = SerialPredictor(
    model_body_fn=make_body,
    d_in=d_model,
    vocab_size=vocab_size,
    precision=precision,
    significance_decay=significance_decay, # this is for weighing the significance of digits to the weight contribution
    low=low,
    high=high,
    accelerate_predict_model=False, # This is for testing
    normalization_regularizer=0.001)
series_len = context_len + horizon
train_model = predictor.make_train_eval_model(mode='train')
sig = (
    shapes.ShapeDtype((batch_size, series_len), dtype=np.float32),
    shapes.ShapeDtype((batch_size, series_len, 0), dtype=np.int32),
    shapes.ShapeDtype((batch_size, series_len), dtype=np.float32),
    shapes.ShapeDtype((batch_size, series_len), dtype=np.float32),
)
train_model.init(sig)
train_weights = train_model.weights
np.random.seed(67)
context = np.random.rand(batch_size, context_len) * 5.0
inputs = np.zeros((batch_size, context_len + horizon, 0))

pred = predictor.predict(
    weights=train_weights,
    context=context,
    inputs=inputs,
    horizon_length=horizon)
assert pred.shape == (batch_size, horizon)
assert np.all(np.isfinite(pred))
original_sample = predictor._categorical.sample
predictor._categorical.sample = lambda logits, **kwargs: np.argmax(logits, axis=-1) # just do greedy here, like we did before
pred1 = predictor.predict(
    weights=train_weights, context=context,
    inputs=inputs, horizon_length=horizon,
)
pred2 = predictor.predict(
    weights=train_weights, context=context,
    inputs=inputs, horizon_length=horizon,
)
assert np.allclose(pred1, pred2, atol=1e-6)
predictor._categorical.sample = original_sample
predictor._categorical.sample = lambda logits, **kwargs: np.argmax(logits, axis=-1)
ctx_low = np.ones((2, context_len)) * 0.1
ctx_high = np.ones((2, context_len)) * 9.0
inputs_single = np.zeros((2, context_len + horizon, 0))
pred_low = predictor.predict(
    weights=train_weights, context=ctx_low,
    inputs=inputs_single, horizon_length=horizon,
)
pred_high = predictor.predict(
    weights=train_weights, context=ctx_high,
    inputs=inputs_single, horizon_length=horizon,
)
assert not np.allclose(pred_low, pred_high, atol=1e-3)
predictor._categorical.sample = original_sample

predictor._categorical.sample = lambda logits, **kwargs: np.argmax(logits, axis=-1) 
for h in [4, 12, 24]:
    inp = np.zeros((2, context_len + h, 0))
    p = predictor.predict(
        weights=train_weights, context=ctx_low,
        inputs=inp, horizon_length=h,
    )
    assert p.shape == (2, h)
predictor._categorical.sample = original_sample
normalizer = predictor._normalizer
series = np.random.rand(2, 50)* 8.0

norm_series, params, _ = normalizer.normalize(series)
denorm_series = normalizer.denormalize(norm_series, params)
assert np.allclose(series, denorm_series, atol=1e-3)
serializer = predictor._serializer
norm_data, _, _ = normalizer.normalize(context)
serialized = serializer.serialize(norm_data)
assert serialized.shape == (batch_size, context_len * precision)
assert np.all(serialized >= 0) and np.all(serialized < vocab_size)

# Deserialize back into cont values
reshaped = np.reshape(serialized, (-1, precision))
deserialized = serializer.deserialize(reshaped)
deserialized = np.reshape(deserialized, (batch_size, context_len))
assert np.allclose(norm_data, deserialized, atol=0.01)

train_model = predictor.make_train_eval_model(mode='train')
series_len = context_len + horizon
series = np.random.rand(batch_size, series_len) * 5.0
inputs_train = np.zeros((batch_size, series_len, 0))
target = series.copy()
mask = np.zeros((batch_size, series_len))
mask[:, -horizon:] = 1.0
sig = (
    shapes.ShapeDtype(series.shape, dtype=np.float32),
    shapes.ShapeDtype(inputs_train.shape, dtype=np.int32),
    shapes.ShapeDtype(target.shape, dtype=np.float32),
    shapes.ShapeDtype(mask.shape, dtype=np.float32),
)
train_model.init(sig)
output = train_model((series, inputs_train, target, mask))
logits, target_repr, weights = output

assert logits.shape[-1] == vocab_size
assert target_repr.dtype in (np.int32, np.int64, jnp.int32, jnp.int64) # should just be some integer here
assert np.all(weights >= 0)


sig_map = serializer.significance_map 
expected_weights_pattern = significance_decay ** sig_map
assert np.isclose(significance_decay ** 0, 1.0)
assert np.isclose(significance_decay ** 1, 0.3)
assert np.isclose(significance_decay ** 2, 0.09)
assert expected_weights_pattern[0] > expected_weights_pattern[1] > expected_weights_pattern[2]

# we also test WeightedCrossEntropyLoss
loss_layer = predictor.make_loss()
loss_layer.init(
    (
        shapes.ShapeDtype(logits.shape, dtype=np.float32),
        shapes.ShapeDtype(target_repr.shape, dtype=target_repr.dtype),
        shapes.ShapeDtype(weights.shape, dtype=np.float32),
    )
)
loss = loss_layer((logits, target_repr, weights))
assert np.isscalar(loss) or loss.shape == ()
assert np.isfinite(float(loss))
assert float(loss) > 0

# We can also check the eval model
eval_model = predictor.make_train_eval_model(mode='eval')
eval_model.init(sig)
output_eval = eval_model((series, inputs_train, target, mask))
logits_eval, target_repr_eval, weights_eval = output_eval
assert logits_eval.shape[-1] == vocab_size
assert np.all(np.isfinite(logits_eval))
assert not np.allclose(logits, logits_eval, atol=1e-3)

print("All tests passed")

## Training tests
These tests are meant to probe the functionality of trax methods for training, it is important to test the compatibility of the library with the code as, for example, trainer.py essentially just uses the trax training loop method for running training. This makes the training a bit more monolithic as everything is handled in trax internally.

In [None]:
import numpy as np
import tempfile
import gin
from trax import layers as tl, shapes, optimizers
from trax.fastmath import numpy as jnp
from trax.supervised import training, lr_schedules
import models
from predictors.serial_predictor import SerialPredictor
import os
gin.clear_config()
vocab_size, d_model, precision, batch_size, series_len, horizon = 10, 64, 3, 2, 32, 8
def make_body(mode='eval', precision=3):
    return models.TransformerBody(
        d_model=d_model, d_ff_mul=2, n_layers=2, n_heads=2,
        max_len=2048, dropout=0.0, ff_activation=tl.FastGelu,
        precision=precision, mode=mode,
    )
# Accelerate should be working, training is slow but this isn't a problem caused by accelerate...
predictor = SerialPredictor(
    model_body_fn=make_body,
    d_in=d_model,
    vocab_size=vocab_size,
    precision=precision,
    significance_decay=0.3,
    low=0.0, high=10.0,
    accelerate_predict_model=False,
    normalization_regularizer=0.001)
train_model = predictor.make_train_eval_model(mode='train')
acc_model = tl.Accelerate(train_model)
series = np.random.rand(batch_size, series_len) * 5.0
inputs_train = np.zeros((batch_size, series_len, 0))
target = series.copy()
mask = np.ones((batch_size, series_len))
sig = (
    shapes.ShapeDtype(series.shape, dtype=np.float32),
    shapes.ShapeDtype(inputs_train.shape, dtype=np.int32),
    shapes.ShapeDtype(target.shape, dtype=np.float32),
    shapes.ShapeDtype(mask.shape, dtype=np.float32),
)
acc_model.init(sig)
out = acc_model((series, inputs_train, target, mask))
logits, target_repr, weights = out
assert logits.shape[-1] == vocab_size
assert np.all(np.isfinite(logits))

optimizer = optimizers.Adam() # Trax optimizer should be loading
assert optimizer is not None

# This is the LR schedule in the paper, shouldnt be a problem
schedule = lr_schedules.multifactor(
    constant=0.03,
    factors='constant * linear_warmup * rsqrt_decay',
    warmup_steps=1000,
)
assert schedule(0) >= 0
assert schedule(500) > schedule(0)
assert schedule(1000) >= schedule(500)
assert schedule(5000)< schedule(1000)
assert all(np.isfinite([schedule(0), schedule(500), schedule(1000), schedule(5000)]))

def example_stream():
    while True:
        s = np.random.rand(batch_size, series_len) * 5.0
        i = np.zeros((batch_size, series_len, 0))
        t = s.copy()
        m = np.ones((batch_size, series_len))
        yield (s, i, t, m)
stream = example_stream()
batch = next(stream)

train_task = training.TrainTask(
    example_stream(),
    loss_layer=predictor.make_loss(),
    optimizer=optimizers.Adam(),
    lr_schedule=lr_schedules.multifactor(
        constant=0.03,
        factors='constant * linear_warmup * rsqrt_decay',
        warmup_steps=100,
    ),
    n_steps_per_checkpoint=50,
)
assert train_task is not None # make sure it loads correctly.

# this is basically just copied from the training method
with tempfile.TemporaryDirectory() as tmpdir:
    train_model = tl.Accelerate(predictor.make_train_eval_model(mode='train'))
    eval_model = tl.Accelerate(predictor.make_train_eval_model(mode='eval'))

    eval_task = training.EvalTask(
        example_stream(),
        metrics=[predictor.make_loss()],
        metric_names=['loss'],
        n_eval_batches=2,
    )

    loop = training.Loop(
        model=train_model,
        tasks=[training.TrainTask(
            example_stream(),
            loss_layer=predictor.make_loss(),
            optimizer=optimizers.Adam(),
            lr_schedule=lr_schedules.multifactor(
                constant=0.03,
                factors='constant * linear_warmup * rsqrt_decay',
                warmup_steps=100,
            ),
            n_steps_per_checkpoint=10,
        )],
        eval_model=eval_model,
        eval_tasks=[eval_task],
        output_dir=tmpdir,
        n_devices=1,
        checkpoint_at=lambda step: False,
        permanent_checkpoint_at=lambda step: False,
    )
    loop.run(5) # make sure it works
    assert loop.step == 5
print("All tests passed")