In [None]:
from functools import partial
from tqdm import trange
import torch
import torch.nn.functional as F
import numpy as np
import pytorch_pretrained_bert
from pytorch_pretrained_bert import TransfoXLTokenizer, TransfoXLModel, TransfoXLLMHeadModel
for mod in (np, torch, pytorch_pretrained_bert):
    print(f'{mod.__name__}: {mod.__version__}')

# Build model Transformer XL

In [None]:
seed = 0
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name_or_path = 'transfo-xl-wt103'
tokenizer = TransfoXLTokenizer.from_pretrained(model_name_or_path)
model = TransfoXLLMHeadModel.from_pretrained(model_name_or_path)
model.to(device)
model.eval()

Dummy prediction, to check vocab size:

In [None]:
line = "Dummy"
line_tokenized = tokenizer.tokenize(line)
line_indexed = tokenizer.convert_tokens_to_ids(line_tokenized)
tokens_tensor = torch.tensor([line_indexed])
predictions, _ = model(tokens_tensor)
vocab_size = predictions.shape[-1]
assert vocab_size == 267735  # WikiText-103 vocab size

# Minimal example

## Online text generation

In [None]:
line = "Cars were invented in"
max_predictions = 16
top_k = 2

line_tokenized = tokenizer.tokenize(line)
line_indexed = tokenizer.convert_tokens_to_ids(line_tokenized)
tokens_tensor = torch.tensor([line_indexed])
tokens_tensor = tokens_tensor.to(device)
mems = None

for i in range(max_predictions):
    predictions, mems = model(tokens_tensor, mems=mems)
    context_size = tokens_tensor.shape[1]
    assert predictions.shape == (1, context_size, vocab_size)
    topk = torch.topk(predictions[0, -1, :], 10)
    predicted_index = topk.indices[top_k-1].item()
    predicted_index = torch.tensor([[predicted_index]]).to(device)
    tokens_tensor = torch.cat((tokens_tensor, predicted_index), dim=1)
    
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
    print(predicted_token, end=' ', flush=True)
    print('\n', tokenizer.convert_ids_to_tokens(topk.indices), '\n')

> **NOTE**: this text is generated choosing at each step the top_k most probable token.
> This is **online text generation**, since at each step, the model only knows the past.

## Off-line text generation

In [None]:
def print_text(input_tokens, predicted_tensor, top_k=5):
    print(f'\n[top {top_k} token] PROMPT:', line)
    for i in range(len(line_indexed) - 1, context_size):
        topk = torch.topk(predicted_tensor[0, i, :], top_k)
        top_k_predictions = tokenizer.convert_ids_to_tokens(topk.indices)
        print(top_k_predictions[top_k - 1], end=' ')
    print()
        
input_text = tokenizer.convert_ids_to_tokens(tokens_tensor.tolist()[0])
for i in range(1, 5):
    print_text(input_text, predictions, top_k=i)      

> **NOTE**: this text is generated choosing at each step the top_k most probable token.
> This is **offline text generation** using the final `prediction` tensor 
> that has information about the whole sequence (so for each word the prediction has been influenced by the future).
>
> The text seems worst, probably because the model is trained to optimize the online prediction
> like in the previous example

In [None]:
def print_input_output(input_tokens, predicted_tensor, top_k=10):
    print(f'  MODEL INPUTS    MODEL OUTPUT (top {top_k} tokens)')
    print(f'  ------------    -----------------------------')
    for i in range(context_size):
        topk = torch.topk(predicted_tensor[0, i, :], top_k)
        p = '* ' if i < len(line_indexed) else '  '
        print(f'{p}{input_tokens[i]:14s}:', end=' ')
        top_k_predictions = tokenizer.convert_ids_to_tokens(topk.indices)
        print(' '.join(top_k_predictions))
        #print('', np.round(topk.values.tolist(), 2))
        
input_text = tokenizer.convert_ids_to_tokens(tokens_tensor.tolist()[0])
print_input_output(input_text, predictions)      

> **NOTE**: lines starting with `*` are inputs in the initial prompt.

> **NOTE 2**: the top tokens are imprecise, because the prediction was done online,
> while here we use the final `prediction` tensor to score the tokens (offline prediction)

## Online text generation with sampling

In [None]:
seed = 0
prompt = "Cars were invented in"
max_predictions = 25
top_k = 40

np.random.seed(seed)
torch.random.manual_seed(seed)
line_tokenized = tokenizer.tokenize(prompt)
line_indexed = tokenizer.convert_tokens_to_ids(line_tokenized)
tokens_tensor = torch.tensor([line_indexed])
tokens_tensor = tokens_tensor.to(device)
mems = None

print(f'PROMPT: {prompt}')
print('MODEL:  ', end='')
for i in range(max_predictions):
    predictions, mems = model(tokens_tensor, mems=mems)
    context_size = tokens_tensor.shape[1]
    assert predictions.shape == (1, context_size, vocab_size)
    
    # sample next token from the most probable top-k
    last_prediction = predictions[0, -1, :]
    topk = torch.topk(last_prediction, top_k)
    log_probs = F.softmax(topk.values, dim=-1)  # softmax among the top-k
    rand_idx_in_topk = torch.multinomial(log_probs, num_samples=1)
    predicted_index = topk.indices[rand_idx_in_topk]
    
    # test
    last_pred_trunk = top_k_logits(last_prediction.reshape(1, -1), top_k)
    sorted_valid_values = last_pred_trunk[last_pred_trunk > -1e10].sort(descending=True).values
    assert all(sorted_valid_values == topk.values)
    
    # update model state
    predicted_index = torch.tensor([[predicted_index]]).to(device)
    tokens_tensor = torch.cat((tokens_tensor, predicted_index), dim=1)
    
    # print current token
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
    print(predicted_token, end=' ', flush=True)

In [166]:
def gen_text_sample(
        prompt = "Cars were invented in",
        seed = 0,
        length = 5,
        top_k = 40,
        top_p = None,
    ):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    line_tokenized = tokenizer.tokenize(prompt)
    line_indexed = tokenizer.convert_tokens_to_ids(line_tokenized)
    tokens_tensor = torch.tensor([line_indexed])
    tokens_tensor = tokens_tensor.to(device)
    if top_p is not None:
        assert 0 < top_p <= 1, '`top_p` must be in (0..1]'
        top_k = round(tokens_tensor.shape[1] * top_p)

    print(f'PROMPT: {prompt}')
    print('MODEL:  ', end='')
    mems = None
    for i in range(length):
        predictions, mems = model(tokens_tensor, mems=mems)
        context_size = tokens_tensor.shape[1]
        assert predictions.shape == (1, context_size, vocab_size)

        # sample next token from the most probable top-k
        last_prediction = predictions[0, -1, :]
        topk = torch.topk(last_prediction, top_k)
        log_probs = F.softmax(topk.values, dim=-1)  # softmax among the top-k
        rand_idx_in_topk = torch.multinomial(log_probs, num_samples=1)
        predicted_index = topk.indices[rand_idx_in_topk]

        # test
        last_pred_trunk = top_k_logits(last_prediction.reshape(1, -1), top_k)
        sorted_valid_values = last_pred_trunk[last_pred_trunk > -1e10].sort(descending=True).values
        assert all(sorted_valid_values == topk.values)

        # update model state
        predicted_index = torch.tensor([[predicted_index]]).to(device)
        tokens_tensor = torch.cat((tokens_tensor, predicted_index), dim=1)

        # print current token
        predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
        print(predicted_token, end=' ', flush=True)
        
        if predicted_token == '=':
            break
            # avoid wasting time with bad predictions

In [None]:
prompt = "Cars were invented in"
gen_text_sample(top_k=40, length=10)

In [None]:
prompt = "Cars were invented in"
gen_text_sample(top_p=0.5, length=10)

# Test generation

In [169]:
prompt = 'What do you know about Machine Learning and Natural Language Processing?'
length = 60
for seed in range(5):
    gen_text_sample(prompt, top_p=0.9, length=length, seed=seed)
    print('\n')

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  for a " real <unk> . " <eos> What do you know about Machine Learning and Natural Language <unk> for a " <unk> <unk> . " <eos> What do you know about Machine Learning and natural Language <unk> for a " <unk> <unk> . " <eos> What do I know about Machine Learning , Natural Language <unk> for a " 

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  , <eos> <eos> = 

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  is <unk> , a book about the <unk> , which is about the <unk> of <unk> ( <unk> , <unk> ) ) and is <unk> , a <unk> , and <unk> , a <unk> <unk> , <unk> on the <unk> of <unk> and is <unk> . <unk> , an <unk> <unk> , is <unk> , a <unk> <unk> , <unk> 

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  , <eos> the only way to know about Machine Learning and natural Language <unk> and t

In [None]:
prompt = 'What do you know about Machine Learning and Natural Language Processing?<eos> Natural Language Processing is a branche of machine learning'
length = 60
for seed in range(5):
    gen_text_sample(prompt, top_k=40, length=length, seed=seed)
    print('\n')

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  has and is part of . <eos> The most important part of the book is about the development of a new theory of Machine Learning and how it could be integrated into the new theory that was being implemented . <eos> <eos> = 

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  , <eos> <eos> = 

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  @-@ <unk> , ? â ? <eos> The questions about the ways a computer could be of use with the work that I have about the job of developing the language <unk> @-@ <unk> , ? ? â ? <eos> What do you know about a computer ? â ? â ? â ? â ? â ? â ? 

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  , <eos> a system for teaching the lessons of Natural Language <unk> , <eos> which is a system used 

In [None]:
prompt = ('What do you know about Machine Learning and Natural Language Processing?<eos>
          'Natural Language Processing is a branche of machine learning ')
length = 60
for seed in range(5):
    gen_text_sample(prompt, top_k=40, length=length, seed=seed)
    print('\n')

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  has and is part of . <eos> The most important part of the book is about the development of a new theory of Machine Learning and how it could be integrated into the new theory that was being implemented . <eos> <eos> = 

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  , <eos> <eos> = 

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  @-@ <unk> , ? â ? <eos> The questions about the ways a computer could be of use with the work that I have about the job of developing the language <unk> @-@ <unk> , ? ? â ? <eos> What do you know about a computer ? â ? â ? â ? â ? â ? â ? 

PROMPT: What do you know about Machine Learning and Natural Language Processing?
MODEL:  , <eos> a system for teaching the lessons of Natural Language <unk> , <eos> 

# Model exploration

The file `textgen.py` provides an API for text generation for both *Transformer XL* and other models (*GPT2*, etc..).

It requires:

- mode signature: `model(prev, past=tensor)` 
- function `decoder(ids)` returning tokens
- from `generate_text_<model>` function, use partial to assign model specific args and create a
  function `gen_text` with will have the same signature for all models

In [None]:
%run -i textgen.py

In [None]:
def model_comp(prev, past):
    return model(prev, mems=past)

In [None]:
decoder = partial(decoder_transformer_xl, tokenizer=tokenizer)
#decoder?

In [None]:
gen_text = partial(generate_text_transformer_xl, model_comp, tokenizer, decoder)

In [None]:
gen_text(line, 
         length=10, sample=False, top_k=2)

In [None]:
# logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
#                     datefmt = '%m/%d/%Y %H:%M:%S',
#                     level = logging.INFO)
# logger = logging.getLogger(__name__)

In [None]:
def top_k_logits(logits, k):
    """
    Masks everything but the k top entries as -infinity (1e10).
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
    if k == 0:
        return logits
    else:
        values = torch.topk(logits, k)[0]
        batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
        return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)

    
def top_p_logits(logits, p):
    """
    Masks everything but the top-p entries as -infinity (1e10).
    
    Differently from `top_k_logits`, here we we don't take a fixed number
    k of elements in `logits`, but a fraction `p`
    of elements. These are the elements higher that the `p` percentile.
    
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
    if p == 1:
        return logits
    else:
        k = round(logits.shape[1] * p)
        print(f'top_p = {top_p:.1g}, k = {k}', flush=True)
        return top_k_logits(logits, k)

    
def sample_sequence(model, length, context, batch_size=None, 
                    temperature=1, top_k=0, top_p=None, device='cuda', sample=True):
    context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
    prev = context
    output = context
    past = None
    with torch.no_grad():
        for i in trange(length):
            logits, past = model(prev, past=past)
            logits = logits[:, -1, :] / temperature
            if top_p is None:
                logits = top_k_logits(logits, k=top_k)
            else:
                logits = top_p_logits(logits, p=top_p)
            log_probs = F.softmax(logits, dim=-1)
            if sample:
                prev = torch.multinomial(log_probs, num_samples=1)
            else:
                _, prev = torch.topk(logits, k=1, dim=-1)
            output = torch.cat((output, prev), dim=1)
    return output

In [None]:
def encode_transformer_xl(text, encoder, device):
    text_tokenized = encoder.tokenize(text)
    text_indexed = encoder.convert_tokens_to_ids(text_tokenized)
    text_indexed_tensor = torch.tensor([text_indexed])
    text_indexed_tensor = text_indexed_tensor.to(device)
    return text_indexed_tensor

def run_model(
        prompt = None,
        batch_size = 1,
        nsamples = 1,    
        length = -1,
        temperature = 1,
        top_k = 0,
        top_p=None,
        sample = True,
        seed = 0,
        EOT = '<|endoftext|>',
    ):
    # Arguments checks
    assert nsamples % batch_size == 0
    assert prompt is not None and len(prompt) > 0
    
#     if length == -1:
#         length = model.config.n_ctx // 2
#     elif length > model.config.n_ctx:
#         raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)

    # Seed the random-number generators
    if seed is not None:
        np.random.seed(seed)
        torch.random.manual_seed(seed)
        torch.cuda.manual_seed(seed)
    
    # Encode prompt (str -> tokens -> tensor(vocabulary))
    context_tokens = encode_transformer_xl(prompt, tokenizer, device)

    # Generate an output text (multiple times if (nsamples / batch_size) > 1)
    generated = 0
    for _ in range(nsamples // batch_size):
        out = sample_sequence(
            model=model, length=length,
            context=context_tokens,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, device=device, sample=sample,
        )
        print(f'PROMPT: {prompt}')
        out = out[:, len(context_tokens):].tolist()
        for i in range(batch_size):
            generated += 1
            text = enc.decode(out[i])
            print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
            end = text.find(EOT)
            end = len(text) if end == -1 else end+len(EOT)
            print(text[:end])
    print("=" * 80)
    


In [None]:
seed = 0
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name_or_path = 'transfo-xl-wt103'
tokenizer = TransfoXLTokenizer.from_pretrained(model_name_or_path)
model = TransfoXLLMHeadModel.from_pretrained(model_name_or_path)
model.to(device)
model.eval()

In [None]:
model.config

In [None]:
seed = 0
np.random.seed(seed)
torch.random.manual_seed(seed)
run_model('What do you know about Machine Learning and Natural Language Processing?', length=128)

In [None]:
seed = 0
for seed in range(10):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    run_model('What do you know about Machine Learning and Natural Language Processing?', length=128)