In [None]:
! pip install trax

In [None]:
import os

import trax
from trax import data
from trax import layers as tl
from trax.supervised import training
from trax.fastmath import numpy as np

## The `trax` deep learning framework

Reference:
- https://github.com/google/trax 

`trax` is coming out of the Google Brain team and is the latest iteration after almost a decade of work on TensorFlow, machine translation, and Tensor2Tensor. Being a new-comer in a somewhat crowded space (`keras`, `pytorch`, `thinc`), it has been able to learn from the mistakes or the best practices of those APIs.
In particular:
- it is very concise
- it runs on `TensorFlow` backend
- it uses `Jax` to speed up tensor-based computation (instead of `numpy`)

## The AG News dataset

A great dataset to look into text classification. https://www.tensorflow.org/datasets/catalog/ag_news_subset

- 0 is "World News"
- 1 is "Sports News"
- 2 is "Business News"
- 3 is "Science-Technology News"

# Dataset
I'm not using the Kaggle dataset here, but rather the TensorFlow dataset as it is more convenient with `trax`.
`trax` needs generators of data. Each element is a tuple (input, target) or (input, target, weight) (usually weight is =1 because all examples have the same importance).

In [None]:
# you need to run this cell twice in Kaggle
# Reference: https://www.tensorflow.org/datasets/catalog/ag_news_subset
train_stream = data.TFDS('ag_news_subset', keys=('description', 'label'), train=True)()
eval_stream = data.TFDS('ag_news_subset', keys=('description', 'label'), train=False)()

In [None]:
print(next(train_stream))

In [None]:
data_pipeline = data.Serial(
    data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    data.Shuffle(),
    data.FilterByLength(max_length=2048, length_keys=[0]),
    data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[512, 128,  32,    8, 1],
                             length_keys=[0]),
    data.AddLossWeights()
)
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')

# Model
`trax` is really concise, you can use the library of layers available.

In [None]:
model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=50),
    tl.Mean(axis=1),
    tl.Dense(4),
    tl.LogSoftmax()
)
model

# Training
For training, there is the concep of a "task" which wraps the data, the optimiser, the metrics etc...

In [None]:
# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output-dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Run 2000 steps (batches).
training_loop.run(2000)

# Look at predictions

In [None]:
inputs, targets, weights = next(eval_batches_stream)

In [None]:
example_input = inputs[0]
expected_class = targets[0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')
print(f'Expected class: {expected_class}')