#### Copyright 2020 Google LLC.

In [0]:
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

 https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Reformer: Text Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb)

This notebook was designed to run on TPU.

To use TPUs in Colab, click "Runtime" on the main menu bar and select Change runtime type. Set "TPU" as the hardware accelerator.

In [0]:
# Install JAX.
!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .
!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .
!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl
!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl

# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

In [0]:
!pip install --upgrade -q sentencepiece
!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3

from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import jax
import trax
from trax.models.beam_search import Search
from trax.supervised import inputs

import numpy as np
import jax.numpy as jnp

from scipy.special import softmax

from sentencepiece import SentencePieceProcessor

## Setting up data and model

In this notebook, we'll be pushing the limits of just how many tokens we can fit on a single TPU device. The TPUs available in Colab have 8GB of memory per core, and 8 cores. We will set up a Reformer model that can fit a copy of "Crime and Punishment" on *each* of the 8 TPU cores (over 500,000 tokens per 8GB of memory).

In [0]:
# Import a copy of "Crime and Punishment", by Fyodor Dostoevsky
with GFile('gs://trax-ml/reformer/crime-and-punishment-2554.txt') as f:
  text = f.read()

# The file read above includes metadata and licensing information.
# For training our language model, we will only use the actual novel text.
start = text.find('CRIME AND PUNISHMENT')  # skip header
start = text.find('CRIME AND PUNISHMENT', start + 1)  # skip header
start = text.find('CRIME AND PUNISHMENT', start + 1)  # skip translator preface
end = text.rfind('End of Project')  # skip extra text at the end
text = text[start:end].strip()

In [0]:
# Load a BPE vocabulaary with 320 types. This mostly consists of single letters
# and pairs of letters, but it has some common words and word pieces, too.
!gsutil cp gs://trax-ml/reformer/cp.320.* .

TOKENIZER = SentencePieceProcessor()
TOKENIZER.load('cp.320.model')

In [0]:
# Tokenize
IDS = TOKENIZER.EncodeAsIds(text)
IDS = np.asarray(IDS, dtype=np.int32)
PAD_AMOUNT = 512 * 1024 - len(IDS)
print("Number of tokens:", IDS.shape[0])

Number of tokens: 513812


As we see above, "Crime and Punishment" has just over half a million tokens with the BPE vocabulary we have selected.

Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it.

We have 8 TPU cores, so we will separately randomize the amount of padding for each core.

In [0]:
# Set up the data pipeline.
def my_inputs(n_devices):
  while True:
    inputs = []
    mask = []
    pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)
    for i in range(n_devices):
      inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                            mode='constant'))
      mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),
                          (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                          mode='constant'))
    inputs = np.stack(inputs)
    mask = np.stack(mask)
    yield (inputs, inputs, mask)

print("(device count, tokens per device) = ",
      next(my_inputs(trax.fastmath.device_count()))[0].shape)

(device count, tokens per device) =  (8, 524288)


In [0]:
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.supervised.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# Our model will have 6 layers, alternating between the LSH attention proposed
# in the Reformer paper and local attention within a certain context window.
n_layers = 6
attn_type = [
  @SelfAttention,
  @LSHSelfAttention,  
  @SelfAttention,
  @LSHSelfAttention,
  @SelfAttention,
  @LSHSelfAttention,
  ]
share_qk = False  # LSH attention ignores this flag and always shares q & k
n_heads = 2
attn_kv = 64
dropout = 0.05
n_tokens = 524288

# Parameters for MultifactorSchedule:
# ==============================================================================
MultifactorSchedule.constant = 0.01
MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay'
MultifactorSchedule.warmup_steps = 100
MultifactorSchedule.steps_per_cycle = 900

# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.86
Adam.b2 = 0.92
Adam.eps = 1e-9

# Parameters for SelfAttention:
# ==============================================================================
SelfAttention.attention_dropout = 0.05
SelfAttention.chunk_len = 64
SelfAttention.n_chunks_before = 1
SelfAttention.n_parallel_heads = 1

# Parameters for LSHSelfAttention:
# ==============================================================================
LSHSelfAttention.attention_dropout = 0.0
LSHSelfAttention.chunk_len = 64
LSHSelfAttention.n_buckets = [64, 128]
LSHSelfAttention.n_chunks_after = 0
LSHSelfAttention.n_chunks_before = 1
LSHSelfAttention.n_hashes = 1
LSHSelfAttention.n_parallel_heads = 1
LSHSelfAttention.predict_drop_len = 128
LSHSelfAttention.predict_mem_len = 1024

# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 256
ReformerLM.d_ff = 512
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = %n_tokens
ReformerLM.mode = 'train'
ReformerLM.n_heads = %n_heads
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 320
ReformerLM.share_qk = %share_qk
ReformerLM.axial_pos_shape = (512, 1024)
ReformerLM.d_axial_pos_embs= (64, 192)
""")

In [0]:
# Set up a Trainer.
output_dir = os.path.expanduser('~/train_dir/')
!rm -f ~/train_dir/model.pkl  # Remove old model
trainer = trax.supervised.Trainer(
    model=trax.models.ReformerLM,
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.MultifactorSchedule,
    inputs=trax.supervised.inputs.Inputs(my_inputs),
    output_dir=output_dir,
    has_weights=True)

In [9]:
# Run one training step, to make sure the model fits in memory.
# The first time trainer.train_epoch is called, it will JIT the entire network
# architecture, which takes around 2 minutes. The JIT-compiled model is saved
# so subsequent runs will be much faster than the first.
trainer.train_epoch(n_steps=1, n_eval_steps=1)


Step      1: Ran 1 train steps in 161.13 secs
Step      1: Evaluation
Step      1: train                   accuracy |  0.00300037
Step      1: train                       loss |  6.38712692
Step      1: train         neg_log_perplexity |  6.38712692
Step      1: train          sequence_accuracy |  0.00000000
Step      1: train weights_per_batch_per_core |  513812.00000000
Step      1: eval                    accuracy |  0.00302907
Step      1: eval                        loss |  6.38727570
Step      1: eval          neg_log_perplexity |  6.38727570
Step      1: eval           sequence_accuracy |  0.00000000
Step      1: eval  weights_per_batch_per_core |  513812.00000000
Step      1: Finished evaluation


In [0]:
# Train for 600 steps total
# The first ~20 steps are slow to run, but after that it reaches steady-state
# speed. This will take at least 30 minutes to run to completion, but can safely
# be interrupted by selecting "Runtime > Interrupt Execution" from the menu.
# The language model won't be exceptionally good when trained for just a few
# steps and with minimal regularization. However, we can still sample from it to
# see what it learns.
trainer.train_epoch(n_steps=9, n_eval_steps=1)
for _ in range(59):
  trainer.train_epoch(n_steps=10, n_eval_steps=1)

## Sample from the model

In [0]:
# As we report in the Reformer paper, increasing the number of hashing rounds
# helps with quality. We can even increase the number of hashing rounds at
# evaluation time only.
gin.parse_config("""LSHCausalAttention.n_hashes = 4""")

In [0]:
# Construct the decoder instance.
# Unfortunately this code ends up leaking some memory, so we can only set up the
# decoder once before the memory leak prevents us from running the model and we
# have to restart the notebook.
sampling_decoder = Search(
    trax.models.ReformerLM,
    trainer.model_weights,
    temperature=1.0,
    max_decode_len=128,
    )

In [0]:
# Sample from the Reformer language model.
seqs, scores = sampling_decoder.decode(batch_size=1)
sample = seqs[0, -1]
TOKENIZER.DecodeIds(sample.tolist())