# TransformerLM Quick Start and Guide 
Language models are machine learning models that power some of the most impressive applications involving text and language (e.g. machine translation, sentiment analysis, chatbots, summarization). At the time of this writing, some of the largest ML models in existence are language models. They are also based on the [transformer](https://arxiv.org/abs/1706.03762) architecture. The transformer language model (TransformerLM) is a simpler [variation](https://arxiv.org/pdf/1801.10198.pdf) of the original transformer architecture and is useful for plenty of tasks.

<img style="max-width:350px;" src="https://storage.googleapis.com/ml-intro/t/transformerLM-1.png" />

The [Trax](https://trax-ml.readthedocs.io/en/latest/) implementation of TransformerLM focuses on clear code and speed.  It runs without any changes on CPUs, GPUs and TPUs.

In this notebook, we will:

1. Use a pre-trained TransformerLM
2. Train a TransformerLM model
3. Looking inside the Trax TransformerLM


In [1]:
import os
import numpy as np
! pip install -q -U trax
import trax

[K     |████████████████████████████████| 368kB 3.5MB/s 
[K     |████████████████████████████████| 163kB 9.1MB/s 
[K     |████████████████████████████████| 2.6MB 11.0MB/s 
[K     |████████████████████████████████| 1.5MB 31.6MB/s 
[K     |████████████████████████████████| 1.1MB 47.4MB/s 
[K     |████████████████████████████████| 307kB 57.6MB/s 
[K     |████████████████████████████████| 51kB 5.7MB/s 
[K     |████████████████████████████████| 3.5MB 20.4MB/s 
[K     |████████████████████████████████| 778kB 60.0MB/s 
[K     |████████████████████████████████| 184kB 55.8MB/s 
[K     |████████████████████████████████| 358kB 58.9MB/s 
[K     |████████████████████████████████| 368kB 57.6MB/s 
[K     |████████████████████████████████| 655kB 58.8MB/s 
[K     |████████████████████████████████| 81kB 7.5MB/s 
[K     |████████████████████████████████| 5.3MB 17.7MB/s 
[K     |████████████████████████████████| 983kB 55.2MB/s 
[K     |████████████████████████████████| 3.0MB 55.2MB/s 
[K

## Using a pre-trained TransformerLM

The following cell loads a pre-trained TransformerLM that sorts a list of four integers.

In [2]:
# Create a Transformer model.
# Have to use the same configuration of the pre-trained model we'll load next
model = trax.models.TransformerLM(  
    d_model=32, d_ff=128, n_layers=2, 
    vocab_size=32, mode='predict')

# Initialize using pre-trained weights.
model.init_from_file('gs://ml-intro/models/sort-transformer.pkl.gz',
                     weights_only=True, 
                     input_signature=trax.shapes.ShapeDtype((1,1), dtype=np.int32))

# Input sequence
# The 0s indicate the beginning and end of the input sequence
input = [0, 3, 14, 15, 9, 0]


# Run the model
output = trax.supervised.decoding.autoregressive_sample(
    model, np.array([input]), temperature=0.0, max_length=4)

# Show us the output
output



array([[ 3,  9, 14, 15]])

This is a trivial example to get you started and put a toy transformer into your hands. Language models get their name from their ability to assign probabilities to sequences of words. This property makes them useful for generating text (and other types of sequences) by choosing the highest probability item as the next item in the sequence -- exactly like the next-word suggestion feature of your smartphone keyboard.

In Trax, TransformerLM is a series of [Layers]() combined using the [Serial]() combinator. A high level view of the TransformerLM we've declared above can look like this:

<img src="https://storage.googleapis.com/ml-intro/t/transformerLM-layers-1.png" />

The model has two decoder layers because we set `n_layers` to 2. TransformerLM makes predictions by being fed one token at a time, with output tokens typically fedback as inputs (that's the `autoregressive` part of the `autoregressive_sample` method we used to generate the output from the model). 

If we're to think of a simple model that takes the sequence `1, 2` and returns `3, 4`, this is how that process would look like:
<img src="https://storage.googleapis.com/ml-intro/t/transformerLM%20input-output.gif" />

## Train a TransformerLM Model

Let's train a TransformerLM model. We'll train this one to reverse a list of integers. This is another toy task that we can train a small transformer to do. But using the concepts we'll go over, you'll be able to train proper language models on larger dataset.

**Example**: This model is to take a sequence like `[1, 2, 3, 4]` and return `[4, 3, 2, 1]`.

1. Create the Model
1. Prepare the Dataset
1. Train the model using `Trainer`



### Create the Model

In [3]:
# Create a Transformer model.
def tiny_transformer_lm(mode='train'):
  return trax.models.TransformerLM(  
          d_model=32, d_ff=128, n_layers=2, 
          vocab_size=32, mode=mode)

Refer to [TransferLM in the API reference](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.TransformerLM) to understand each of its parameters and their default values. We have chosen to create a small model using these values for `d_model`, `d_ff`, and `n_layers` to be able to train the model more quickly on this simple task.

<img src="https://storage.googleapis.com/ml-intro/t/untrained-transformer.png" />

### Prepare the Dataset

Trax models are trained on streams of data represented as python iterators. [`trax.data`](https://trax-ml.readthedocs.io/en/latest/trax.data.html) gives you the tools to construct your datapipeline. Trax also gives you readily available access to [TensorFlow Datasets](https://www.tensorflow.org/datasets).

For this simple task, we will create a python generator. Every time we invoke it, it returns a batch of training examples.

In [4]:
def reverse_ints_task(batch_size, length=4):
  while True:
    random_ints = m = np.random.randint(1, 31, (batch_size,length))
    source = random_ints

    target = np.flip(source, 1)

    zero = np.zeros([batch_size, 1], np.int32)
    x = np.concatenate([zero, source, zero, target], axis=1)

    loss_weights = np.concatenate([np.zeros((batch_size, length+2)),
                                    np.ones((batch_size, length))], axis=1)
    yield (x, x, loss_weights)  # Here inputs and targets are the same.

reverse_ints_inputs = trax.data.inputs.Inputs(lambda _: reverse_ints_task(16))



This function prepares a dataset and returns one batch at a time. If we ask for a batch size of 8, for example, it returns the following:

In [5]:
a = reverse_ints_task(8)
sequence_batch, _ , masks = next(a)
sequence_batch

array([[ 0,  9, 18,  2, 14,  0, 14,  2, 18,  9],
       [ 0, 28, 18,  5,  3,  0,  3,  5, 18, 28],
       [ 0, 20, 20, 13, 21,  0, 21, 13, 20, 20],
       [ 0, 16, 14, 26, 12,  0, 12, 26, 14, 16],
       [ 0, 16,  5, 25,  5,  0,  5, 25,  5, 16],
       [ 0,  7, 14, 17, 12,  0, 12, 17, 14,  7],
       [ 0, 20, 15, 10, 10,  0, 10, 10, 15, 20],
       [ 0, 29, 18, 19, 28,  0, 28, 19, 18, 29]])

You can see that each example starts with 0, then a list of integers, then another 0, then the reverse of the list of integers. The function will give us as many examples and batches as we request.

In addition to the example, the generator returns a mask vector. During the training process, the model is challeneged to predict the tokens hidden by the mask (which have a value of 1 associated with that position. So for example, if the first element in the batch is the following vector:

<table><tr>
<td><strong>0</strong></td><td>5</td><td>6</td><td>7</td><td>8</td><td><strong>0</strong></td><td>8</td><td>7</td><td>6</td><td>5</td>
</tr></table> 

And the associated mask vector for this example is:
<table><tr>
<td>0</td><td>0</td><td>0</td><td>0</td><td>0</td><td>0</td><td>1</td><td>1</td><td>1</td><td>1</td>
</tr></table> 

Then the model will only be presented with the following prefix items, and it has to predict the rest:
<table><tr>
<td><strong>0</strong></td><td>5</td><td>6</td><td>7</td><td>8</td><td><strong>0</strong></td><td>_</td><td>_</td><td>_ </td><td>_</td>
</tr></table> 

It's important here to note that while `5, 6, 7, 8` constitute the input sequence, the **zeros** serve a different purpose. We are using them as special tokens to delimit where the source sequence begins and ends. 

With this, we now have a method that streams the dataset in addition to the method that creates the model.

<img src="https://storage.googleapis.com/ml-intro/t/untrained-transformer-and-dataset.png" />


### Train the model using `Trainer`

Trax's [Trainer](https://trax-ml.readthedocs.io/en/latest/trax.html#module-trax.trainer) takes care of the training process. We hand it the model, the method streaming the dataset, and other training parameters. We then start the training loop.

In [None]:
output_dir = os.path.expanduser('~/train_dir/')
!rm -f ~/train_dir/model.pkl.gz  # Remove old model.

# Train tiny model with Trainer.
trainer = trax.supervised.Trainer(
    model=tiny_transformer_lm,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adafactor,  # Change optimizer params here.
    lr_schedule=trax.lr.constant(0.001),  # Change lr schedule here.
    inputs=reverse_ints_inputs,
    output_dir=output_dir)

# Train for 3 epochs each consisting of 500 train batches, eval on 2 batches.
n_epochs  = 3
train_steps = 800
eval_steps = 2
for _ in range(n_epochs):
  trainer.train_epoch(train_steps, eval_steps)


Step    800: Ran 800 train steps in 56.55 secs
Step    800: Evaluation
Step    800: train                   accuracy |  0.32812500
Step    800: train                       loss |  2.61360502
Step    800: train         neg_log_perplexity | -2.61360502
Step    800: train          sequence_accuracy |  0.00000000
Step    800: train weights_per_batch_per_core |  64.00000000
Step    800: eval                    accuracy |  0.25000000
Step    800: eval                        loss |  2.71211648
Step    800: eval          neg_log_perplexity | -2.71211648
Step    800: eval           sequence_accuracy |  0.00000000
Step    800: eval  weights_per_batch_per_core |  64.00000000
Step    800: Finished evaluation

Step   1600: Ran 800 train steps in 6.29 secs
Step   1600: Evaluation
Step   1600: train                   accuracy |  0.74218750
Step   1600: train                       loss |  1.02415180
Step   1600: train         neg_log_perplexity | -1.02415180
Step   1600: train          sequence_accur

The Trainer is the third key component in this process that helps us arrive at the trained model.

<img src="https://storage.googleapis.com/ml-intro/t/transformerLM-training.png" />

### Make predictions

Let's take our newly minted model for a ride. To do that, we load it up, and use the handy `autoregressive_sample` method to feed it our input sequence and return the output sequence. These components now look like this:

<img src="https://storage.googleapis.com/ml-intro/t/transformerLM-sampling-prediction.png" />

And this is the code to do just that:

In [None]:

input = np.array([[0, 4, 6, 8, 10, 0]])

# Initialize model for inference.
predict_model = tiny_transformer_lm(mode='predict')
predict_signature = trax.shapes.ShapeDtype((1,1), dtype=np.int32)
predict_model.init_from_file(os.path.join(output_dir, "model.pkl.gz"),
                             weights_only=True, input_signature=predict_signature)

# Run the model
output = trax.supervised.decoding.autoregressive_sample(
    predict_model, input, temperature=0.0, max_length=4)

# Print the contents of output
output

array([[10,  8,  6,  4]])

## TODO: Looking inside the Trax TransformerLM
Visualize tl.Serial, Layers, and the parameters to initialize the model

Work in progress visualization:
<img src="https://storage.googleapis.com/ml-intro/t/trax-layers-serial-draft.png" />

## TODO: Training a proper language model
Train a language model to generate text. Character level is perhaps a more gentle learning slope, but I see the merit in also going directly into a word-based model.