## Transformer XL

The Transformer-XL model was proposed in Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. It’s a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse previously computed hidden-states to attend to longer context (memory). This model also uses adaptive softmax inputs and outputs (tied).

In [1]:
import torch
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103').to(device)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=9143613.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=659.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1140884800.0, style=ProgressStyle(descr…




In [30]:
text = "Jay chou is a"
tokens_tensor = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(device)

- prediction_scores (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)):
> Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

- mems (List[torch.FloatTensor] of length config.n_layers):
> Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see past input) to speed up sequential decoding. The token ids which have their past given to this model should not be passed as input ids as they have already been computed.

In [31]:
mems = None  # recurrence mechanism

predicted_tokens = list()

for i in range(50):  # stop at 50 predicted tokens
    # Generate predictions
    predictions, mems = model(tokens_tensor, mems=mems)

    # Get most probable word index
    predicted_index = torch.topk(predictions[0, -1, :], 1)[1]

    # Extract the word from the index
    predicted_token = tokenizer.decode(predicted_index)

    # break if [EOS] reached
    if predicted_token == tokenizer.eos_token:
        break

    # Store the current token
    predicted_tokens.append(predicted_token)

    # Append new token to the existing sequence
    tokens_tensor = torch.cat((tokens_tensor, predicted_index.unsqueeze(1)), dim=1)

print('Initial sequence: ' + text)
print('Predicted output: ' + " ".join(predicted_tokens))

Initial sequence: Jay chou is a
Predicted output: type of a type of a type of a type of a type of a type of a type of a type of a type of or a type of a or of a or of a or of a or of of a or of of or of the
