##### Copyright 2019 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Trax: Train Models in JAX

[JAX](https://github.com/google/jax) allows you to write [numpy](https://www.numpy.org/) and run it fast on accelerators.

This makes ML research more *fun* and *clear* so we made
* [Trax](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/trax): a library of models in JAX.

In this demo we show how to:
* Train a Trax model on a toy copy problem.
* Decode from a pre-trained [Transformer](https://arxiv.org/abs/1706.03762) language model.
* Define [Transformer](https://arxiv.org/abs/1706.03762) from scratch in Trax.
* Do research in Trax: play with hard attention to see how it impacts training and results.

We would like your feedback!
* What are the parts you like or dislike in JAX and Trax?
* Will you start doing your research in Trax? If not, why? What would change your mind?
* What should we focus on? Speed, cleanliness, memory use?
* If you cannot tell us in person, please add your feedback on [this github issue](https://github.com/tensorflow/tensor2tensor/issues/1478).


## Installs

We install jax and trax and download a pretrained model and vocab file.

In [2]:
# Install JAX for GPU and Tensor2Tensor.
!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.14-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax==0.1.27
!pip install --upgrade -q tensor2tensor==1.13.4
# Grab language-model checkpoint and vocab file.
!rm -f model.pkl
!wget https://storage.googleapis.com/traxdemo/model.pkl
!wget https://storage.googleapis.com/traxdemo/vocab.lm1b.en.32768
# Show GPU type.
!nvidia-smi -L

[K     |████████████████████████████████| 44.6MB 1.2MB/s 
[K     |████████████████████████████████| 174kB 3.5MB/s 
[K     |████████████████████████████████| 61kB 24.4MB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Building wheel for opt-einsum (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 1.4MB 3.4MB/s 
[K     |████████████████████████████████| 686kB 45.8MB/s 
[K     |████████████████████████████████| 143kB 40.2MB/s 
[K     |████████████████████████████████| 296kB 32.6MB/s 
[?25h  Building wheel for pypng (setup.py) ... [?25l[?25hdone
--2019-05-14 22:57:21--  https://storage.googleapis.com/traxdemo/model.pkl
Resolving storage.googleapis.com (storage.googleapis.com)... 209.85.234.128, 2607:f8b0:4001:c12::80
Connecting to storage.googleapis.com (storage.googleapis.com)|209.85.234.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 211170062 (201M) [application/octet-stream]
Saving to: ‘model.pkl’


201

## Imports

In [3]:
from six.moves import cPickle
import os
import datetime
import random

import numpy as onp
from matplotlib import pyplot as plt

from jax.ops import index, index_update

from tensor2tensor.trax import trax
from tensor2tensor.trax import layers as tl
from tensor2tensor.trax import inputs as trax_input
from tensor2tensor.trax import models as trax_models
from tensor2tensor.trax import optimizers as trax_optimizers
from tensor2tensor.trax import backend
from tensor2tensor.trax.backend import numpy as np
from tensor2tensor.trax.backend import random as trax_random


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



# Toy Copy Problem

Here we define batched random integer inputs for a trivial sequence-copy learning task.

In [0]:
VOCAB_SIZE = 128
def toy_problem_inputs(num_devices, batch_size=64,
                       train_lengths=[10, 20], eval_lengths=[20]):
  """Make Inputs for the toy problem of the language 0w0w for w in [1..127]*.

  Args:
    num_devices: how many devices to build the inputs for (assert 1 for colab).
    batch_size: how large are the batches.
    train_lengths: lengths of w for training.
    eval_lengths: lengths of w for eval.

  Returns:
    trax.inputs.Inputs
  """
  assert num_devices == 1
  def random_minibatches(length_list):
    """Generate a stream of random mini-batches."""
    while True:
      length = random.choice(length_list)
      w = onp.random.randint(low=1, high=VOCAB_SIZE-1,
                            size=(batch_size, length // 2))
      zero = onp.zeros([batch_size, 1], onp.int32)
      x = onp.concatenate([zero, w, zero, w], axis=1)
      yield (x, x)  # In a language model input and output are the same.

  return trax_input.Inputs(
      train_stream=lambda: random_minibatches(train_lengths),
      train_eval_stream=lambda: random_minibatches(train_lengths),
      eval_stream=lambda: random_minibatches(eval_lengths),
      input_shape=(None,))

In [5]:
inputs = toy_problem_inputs(1)
print(next(inputs.train_stream())[0][0])

[  0  68  91  99 107 115 113 111  17 102  48   0  68  91  99 107 115 113
 111  17 102  48]


## Baseline Transformer on Toy Problem

In [5]:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M")
output_dir = os.path.expanduser("~/trax_lm_%s" % timestamp)
def model(mode):
  return trax_models.TransformerLM(
      VOCAB_SIZE, feature_depth=128,
      feedforward_depth=256, num_layers=3,
      num_heads=4, mode=mode)
_ = trax.train(model=model,
               inputs=toy_problem_inputs,
               output_dir=output_dir,
               train_steps=3000,
               eval_steps=10,
               eval_frequency=1000)

Step      0: Starting training using 1 devices

Step      1: Ran 1 train steps in 36.77 secs
Step      1: Total trainable parameters size: 692736
Step      1: Evaluation
Step      1: train           accuracy |  0.00616714
Step      1: train neg_log_perplexity | -5.06836748
Step      1: train               loss |  5.06836748
Step      1: eval            accuracy |  0.00610795
Step      1: eval  neg_log_perplexity | -5.20451212
Step      1: eval                loss |  5.20451212
Step      1: Finished evaluation

Step   1000: Ran 999 train steps in 89.13 secs
Step   1000: Evaluation
Step   1000: train           accuracy |  0.45719695
Step   1000: train neg_log_perplexity | -2.71764731
Step   1000: train               loss |  2.71764731
Step   1000: eval            accuracy |  0.41278410
Step   1000: eval  neg_log_perplexity | -2.94052887
Step   1000: eval                loss |  2.94052887
Step   1000: Finished evaluation

Step   2000: Ran 1000 train steps in 15.61 secs
Step   2000: Evalua

# Decoding from a Pre-Trained Transformer Language Model

In [6]:
# load model checkpoint
with open("model.pkl", "rb") as f:
   (params, step, history) = cPickle.load(f, encoding="latin1")

# lm1b subword vocab
def clean(x):
  return x[1:-2]
with open("vocab.lm1b.en.32768", "r") as fp:
  vocab = list(map(clean, fp.readlines()))
vocab_map = {v:idx for idx,v in enumerate(vocab)}

list(enumerate(vocab))[:10]

[(0, '<pad>_'),
 (1, '<EOS>_'),
 (2, 'the_'),
 (3, ' , _'),
 (4, ' ._'),
 (5, 'to_'),
 (6, 'of_'),
 (7, 'a_'),
 (8, 'and_'),
 (9, 'in_')]

In [0]:
tlm = trax_models.TransformerLM(
  dropout=0.1, 
  feature_depth=512, 
  feedforward_depth=2048, 
  max_len=2048, 
  mode='eval', 
  num_heads=8, 
  num_layers=6, 
  vocab_size=32000)

In [0]:
def gumbel_sample(v, temperature=0.8):
  u = onp.random.uniform(low=1e-9, high=1.0, size=v.shape)
  g = -onp.log(-onp.log(u))
  return np.argmax(v + g * temperature)

In [10]:
prompt = "Please_"
num_samples = 5
max_length = 20
for _ in range(num_samples):
  enc = [vocab_map[w] for w in str.split(prompt)]
  pos = len(enc)
  rng = trax_random.get_prng(0)
  data = np.zeros((1, 50), dtype=np.int32)
  data = index_update(data, index[1, 0:pos], enc)

  while pos < max_length:
    tmp = tlm(data, params=params, rng=rng)
    next_sym = gumbel_sample(tmp[0, pos])
    data = index_update(data, index[1, pos], next_sym)
    pos += 1
    if int(next_sym) == 1:
      break

  print("".join([vocab[idx] for idx in onp.array(data)[0, 0:pos]]))

Please_write_to_him_to_tell_him_about_the_Wallace_and_Gromit_films_. _and_to_give_him_this_
Please_do_not_turn_to_making_sure_your_children_are_already_in_school_or_that_you_have_school_ ._<EOS>_
Please_read_the_full_prospectus_to_see_if_the_proposed_transaction_may_be_accurate_ ._<EOS>_
Please_note_that_the_new_policy_has_been_strengthened_by_the_fact_that_Britney_Spears_ ' _mother_ , _Janet_Jackson_
Please_ , _please_aim_at_your_brother_ , _if_you_want_to_ ._<EOS>_


# Transformer from Scratch

Here we re-implement multiheaded self-attention and a transformer language model from scratch using only a few simple linear primitives from trax.

Note in particular the commented modifications in the core  __DotProductAttention__ function as an example of how easy it is to modify layers and models for research using Trax.

In [0]:
def DotProductAttention(query, key, value, mask, dropout, mode, rng, hard_k=4):
  """Core dot product self-attention.
  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use
  Returns:
    Self attention for q, k, v arrays.
  """
  depth = np.shape(query)[-1]
  dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
  if mask is not None:
    dots = np.where(mask, dots, -1e9)
  # Softmax.
  dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
  # ----------------------------------------------------------------------
  # As an example of a simple research modification, we modify the typical 
  # dot-product attention mechanism with top-k "hard attention":
  # ----------------------------------------------------------------------
  if hard_k > 0:
    top_k = np.sort(dots)[..., -hard_k]  # Get the top-kth weight.
    dots -= top_k[..., np.newaxis]  # Subtract (be 0 for lower ones).
    dots = np.maximum(dots, 0)
    dots /= np.sum(dots, axis=-1, keepdims=True)  # Re-normalize.
  # ----------------------------------------------------------------------
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = np.where(keep, dots / (1.0 - dropout), 0)
  out = np.matmul(dots, value)
  # Uncomment to see an example TRAX stack trace to this point:
  # ----------------------------------------------------------------------
  # raise ValueError("err")
  # ----------------------------------------------------------------------
  return out


def _multihead_attention_output_shape(  # pylint: disable=invalid-name
    input_shapes, **unused_kwargs):
  """Helper: calculate multihead attention output shape."""
  q_shape = input_shapes[0][0]  # Inputs are ((q, k, v), mask).
  mask_shape = input_shapes[1]
  return q_shape, mask_shape


@tl.layer(output_shape=_multihead_attention_output_shape)
def PureMultiHeadedAttention(x, params, num_heads=8, dropout=0.0,
                             mode='train', **kwargs):
  """Pure transformer-style multi-headed attention.
  Args:
    x: inputs ((q, k, v), mask)
    params: parameters (none)
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'
    **kwargs: other arguments including the rng
  Returns:
    Pure Multi-headed attention result, and the mask.
  """
  del params
  rng = kwargs.get('rng', None)
  (q, k, v), mask = x
  feature_depth = q.shape[-1]
  assert feature_depth % num_heads == 0
  head_depth = feature_depth // num_heads
  nbatch = np.shape(q)[0]
  # nbatch, seqlen, feature_depth --> nbatch, num_heads, seqlen, head_depth
  def SplitHeads(x):
    return np.transpose(
        np.reshape(x, (nbatch, -1, num_heads, head_depth)), (0, 2, 1, 3))
  # nbatch, num_heads, seqlen, head_depth --> nbatch, seqlen, feature_depth
  def JoinHeads(x):  # pylint: disable=invalid-name
    return np.reshape(
        np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, num_heads*head_depth))
  # Split heads, dot-product attention, rejoin heads.
  res = JoinHeads(
      DotProductAttention(
          SplitHeads(q), SplitHeads(k), SplitHeads(v), mask,
          dropout=dropout, mode=mode, rng=rng))
  return res, mask  # Keep the mask.


def MultiHeadedAttentionQKV(
    feature_depth, num_heads=8, dropout=0.0, mode='train'):
  """Transformer-style multi-headed attention.
  Accepts inputs of the form (q, k, v), mask.
  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'
  Returns:
    Multi-headed self-attention result and the mask.
  """
  return tl.Serial(
      tl.Parallel(
          tl.Parallel(
              tl.Dense(feature_depth),
              tl.Dense(feature_depth),
              tl.Dense(feature_depth),
          ),
          tl.Copy()
      ),
      PureMultiHeadedAttention(  # pylint: disable=no-value-for-parameter
          feature_depth=feature_depth, num_heads=num_heads,
          dropout=dropout, mode=mode),
      tl.Parallel(tl.Dense(feature_depth), tl.Copy())
  )


def MultiHeadedAttention(
    feature_depth, num_heads=8, dropout=0.0, mode='train'):
  """Transformer-style multi-headed attention.
  Accepts inputs of the form (x, mask) and constructs (q, k, v) from x.
  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'
  Returns:
    Multi-headed self-attention layer.
  """
  return tl.Serial(
      tl.Parallel(
          # q = k = v = first input
          tl.Branch(
              tl.Copy(), tl.Copy(), tl.Copy()),
          tl.Copy()  # pass the mask
      ),
      MultiHeadedAttentionQKV(  # pylint: disable=no-value-for-parameter
          feature_depth, num_heads=num_heads, dropout=dropout, mode=mode),
  )

In [0]:
def ResidualFeedForward(feature_depth,
                        feedforward_depth,
                        dropout,
                        mode):
  """Residual feed-forward layer with normalization at start."""
  return tl.Residual(
      tl.LayerNorm(),
      tl.Dense(feedforward_depth),
      tl.Relu(),
      tl.Dropout(rate=dropout, mode=mode),
      tl.Dense(feature_depth),
      tl.Dropout(rate=dropout, mode=mode)
  )


def DecoderLayer(feature_depth,
                 feedforward_depth,
                 num_heads,
                 dropout,
                 mode):
  """Transformer decoder layer.
  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'
  Returns:
    the layer.
  """
  return tl.Serial(
      tl.Residual(  # Self-attention block.
          tl.LayerNorm(),
          tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)),  # Create mask.
          # We replace the "stock" self-attention layer with the one defined
          # above:
          # tl.MultiHeadedAttention(feature_depth, num_heads=num_heads,
          #                         dropout=dropout, mode=mode),
          MultiHeadedAttention(feature_depth, num_heads=num_heads,
                                  dropout=dropout, mode=mode),
          tl.Select(0),  # Drop the mask.
          tl.Dropout(rate=dropout, mode=mode)
      ),
      ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)
  )


def TransformerLM(vocab_size,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_layers=6,
                  num_heads=8,
                  dropout=0.1,
                  max_len=2048,
                  mode='train'):
  """Transformer language model (only uses the decoder part of Transformer).
  Args:
    vocab_size: int: vocab size
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
  Returns:
    the layer.
  """
  return tl.Serial(
      tl.ShiftRight(),
      tl.Embedding(feature_depth, vocab_size),
      tl.Dropout(rate=dropout, mode=mode),
      tl.PositionalEncoding(max_len=max_len),
      tl.Serial(*[DecoderLayer(feature_depth, feedforward_depth, num_heads,
                               dropout, mode)
                  for _ in range(num_layers)]),
      tl.LayerNorm(),
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )

In [22]:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M")
output_dir = os.path.expanduser("~/trax_lm_%s" % timestamp)
def new_model(mode):
  return TransformerLM(
      VOCAB_SIZE, feature_depth=128,
      feedforward_depth=256, num_layers=3,
      num_heads=4, mode=mode)
_ = trax.train(model=new_model,
           inputs=toy_problem_inputs,
           output_dir=output_dir,
           train_steps=3000,
           eval_steps=10,
           eval_frequency=1000)

Step      0: Starting training using 1 devices

Step      1: Ran 1 train steps in 42.29 secs
Step      1: Total trainable parameters size: 692736
Step      1: Evaluation
Step      1: train           accuracy |  0.00686553
Step      1: train neg_log_perplexity | -5.42891455
Step      1: train               loss |  5.42891455
Step      1: eval            accuracy |  0.00809659
Step      1: eval  neg_log_perplexity | -5.39403439
Step      1: eval                loss |  5.39403439
Step      1: Finished evaluation

Step   1000: Ran 999 train steps in 109.64 secs
Step   1000: Evaluation
Step   1000: train           accuracy |  0.12875238
Step   1000: train neg_log_perplexity | -4.29979420
Step   1000: train               loss |  4.29979420
Step   1000: eval            accuracy |  0.09928977
Step   1000: eval  neg_log_perplexity | -4.45948172
Step   1000: eval                loss |  4.45948172
Step   1000: Finished evaluation

Step   2000: Ran 1000 train steps in 16.89 secs
Step   2000: Evalu