# Trax Quick Intro

We train **Trax Transformer** (or Reformer) on a simple copy problem and run inference.
* See how to create your inputs from python.
* Learn how to run the trainer.
* Run fast inference with Transformer.




## General Setup
Execute the following few cells (once) before running any of the code samples in this notebook.

In [0]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np




In [0]:
#@title
# Import Trax

! pip install -q -U trax
! pip install -q tensorflow

import trax

# Transformer

In [3]:
# Construct inputs, see one batch
def copy_task(batch_size, vocab_size, length):
  """This task is to copy a random string w, so the input is 0w0w."""
  while True:
    assert length % 2 == 0
    w_length = (length // 2) - 1
    w = np.random.randint(low=1, high=vocab_size-1,
                          size=(batch_size, w_length))
    zero = np.zeros([batch_size, 1], np.int32)
    loss_weights = np.concatenate([np.zeros((batch_size, w_length+2)),
                                   np.ones((batch_size, w_length))], axis=1)
    x = np.concatenate([zero, w, zero, w], axis=1)
    yield (x, x, loss_weights)  # Here inputs and targets are the same.
copy_inputs = trax.supervised.Inputs(lambda _: copy_task(16, 32, 10))

# Peek into the inputs.
data_stream = copy_inputs.train_stream(1)
inputs, targets, mask = next(data_stream)
print("Inputs[0]:  %s" % str(inputs[0]))
print("Targets[0]: %s" % str(targets[0]))
print("Mask[0]:    %s" % str(mask[0]))

Inputs[0]:  [ 0  6 13 29 22  0  6 13 29 22]
Targets[0]: [ 0  6 13 29 22  0  6 13 29 22]
Mask[0]:    [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]


In [4]:
# Transformer LM
def tiny_transformer_lm(mode):
  return trax.models.TransformerLM(   # You can try trax_models.ReformerLM too.
    d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)

# Train tiny model with Trainer.
output_dir = os.path.expanduser('~/train_dir/')
!rm -f ~/train_dir/model.pkl  # Remove old model.
trainer = trax.supervised.Trainer(
    model=tiny_transformer_lm,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adafactor,  # Change optimizer params here.
    lr_schedule=trax.lr.MultifactorSchedule,  # Change lr schedule here.
    inputs=copy_inputs,
    output_dir=output_dir)

# Train for 3 epochs each consisting of 500 train batches, eval on 2 batches.
n_epochs  = 3
train_steps = 500
eval_steps = 2
for _ in range(n_epochs):
  trainer.train_epoch(train_steps, eval_steps)


Step    500: Ran 500 train steps in 16.51 secs
Step    500: Evaluation
Step    500: train                   accuracy |  0.53125000
Step    500: train                       loss |  1.83887446
Step    500: train         neg_log_perplexity | -1.83887446
Step    500: train weights_per_batch_per_core |  80.00000000
Step    500: eval                    accuracy |  0.52500004
Step    500: eval                        loss |  1.92791247
Step    500: eval          neg_log_perplexity | -1.92791247
Step    500: eval  weights_per_batch_per_core |  80.00000000
Step    500: Finished evaluation

Step   1000: Ran 500 train steps in 2.54 secs
Step   1000: Evaluation
Step   1000: train                   accuracy |  1.00000000
Step   1000: train                       loss |  0.00707983
Step   1000: train         neg_log_perplexity | -0.00707983
Step   1000: train weights_per_batch_per_core |  80.00000000
Step   1000: eval                    accuracy |  1.00000000
Step   1000: eval                        

In [9]:
# Initialize model for inference.
predict_model = tiny_transformer_lm(mode='predict')
predict_signature = trax.shapes.ShapeDtype((1,1), dtype=np.int32)
predict_model.init(predict_signature)
predict_model.init_from_file(os.path.join(output_dir, "model.pkl"),
                             weights_only=True)
# You can also do: predict_model.weights = trainer.model_weights

# Run inference
prefix = [0, 1, 2, 3, 4, 0]   # Change non-0 digits to see if it's copying
cur_input = np.array([[0]])
result = []
for i in range(10):
  logits = predict_model(cur_input)
  next_input = np.argmax(logits[0, 0, :], axis=-1)
  if i < len(prefix) - 1:
    next_input = prefix[i]
  cur_input = np.array([[next_input]])
  result.append(int(next_input))  # Append to the result
print(result)

[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
