# Build a bidirectional text generator with XLNet

by [Rostyslav Neskorozhenyi](https://www.linkedin.com/in/slanj)

Current [Transformers](https://arxiv.org/abs/1706.03762) based models, like GPT-2 or even GPT-3 show incredible achievements  in the task of [text-generation](https://huggingface.co/blog/how-to-generate) (prediction of the next probable word based on the previous sequence of words). These models can create long, creative and cohesive texts, but usually they can generate text only in one direction, from left to right. I was wondering if there is a way to generate text in both directions and having some start phrase (for example "text generation is cool") to see what story will unfold around it. [XLNet](https://huggingface.co/transformers/model_doc/xlnet.html) was the solution: due to its using of all permutations of the input sequence factorization order this model can help to generate text in any direction.

In this article we will not study in detail the internal principles of XLNet (excellent brief explanation you can find [here](https://towardsdatascience.com/xlnet-a-clever-language-modeling-solution-ab41e87798b0)). Instead, we'll start experimenting right away: we will practice a little bit in masked word prediction with XLNet, try to implement top-K bidirectional generation, and then implement a more efficient approach that combines beam search and top-K sampling.

At the end of the article we will get a generator capable of creating such text based on the start phrase (which is highlighted in bold):

> Following up on my initial thoughts: **text generation is cool**! It works great for creating blog header, title etc. You will need Word 2013





## Install needed modules

We will conduct all our experiments in Google Collab Notebook (with GPU environment), which is available by this [link](https://colab.research.google.com/drive/1RhHiKTp0os2_q5z6pKS6vQUz0SM1EXrM), so the only module we will need to install is the excellent [Transformers](https://huggingface.co/transformers/) library.  This library provides a simple interface to XLNet, as well as to many other transformers based models.

In [None]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
[K     |████████████████████████████████| 778kB 2.7MB/s 
[?25hCollecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 14.5MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 25.0MB/s 
Collecting tokenizers==0.8.1.rc1
[?25l  Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB

## Example of masked words prediction with XLNet

One of the advantages of XLNet is that this model can perfectly cope with the prediction of several related masked words while taking into account the previous context. For example, I will mention in the text that I gave you three apples, and then ask the model to tell me who now owns some apples by feeding the model a sentence with masked words: "\<mask> have \<mask> apples in hands". As a result, we will see that the model perfectly understands who has apples and how many.

Before we can start communicating with the model, we need to load it, as well as load a tokenizer that processes the incoming text into a digital form understandable for the model. In the basic form tokenization is splitting of the text into words or subwords, which then are converted to ids. Each model requires text to be tokenized in a specific way. XLNet uses SentencePiece method. You can read more about the tokenization process at the [link](https://huggingface.co/transformers/tokenizer_summary.html).

In [None]:
# Predict mentioned words in a sentence with XLNet

from transformers import XLNetTokenizer, XLNetLMHeadModel
import torch

tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=798011.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=761.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1441285815.0, style=ProgressStyle(descr…




Also we need to add a padding text to help XLNet with short texts as was [proposed](https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e) by Aman Rusia.

In [None]:
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""

Predict top 5 words for each \<mask> token. To make a prediction we need to feed the model with tokenized text, masked words indexes and permutation masks. Permutation masks are needed to disable input tokens to attend to masked tokens. You can read more about model parameters [here](https://huggingface.co/transformers/model_doc/xlnet.html#xlnetlmheadmodel).

In [None]:
torch.manual_seed(0)
# We show how to setup inputs to predict a next token using a bi-directional context.
# We will predict masked tokens
input_ids = torch.tensor(tokenizer.encode(PADDING_TEXT + "I gave you three apples. <mask> have <mask> apples in hands", add_special_tokens=False)).unsqueeze(0)  

targets = [ -6, -4]

perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[0, :, targets] = 1.0  # Previous tokens don't see last token

target_mapping = torch.zeros((1, len(targets), input_ids.shape[1]), dtype=torch.float)  

target_mapping[0, 0, targets[0]] = 1.0  # Our first  prediction 
target_mapping[0, 1, targets[1]] = 1.0  # Our second  prediction 

input_ids_tensor = input_ids.to("cuda")
target_mapping_tensor = target_mapping.to("cuda")
perm_mask_tensor = perm_mask.to("cuda")

model.eval()
if torch.cuda.is_available(): model.to('cuda') #if we have a GPU 

with torch.no_grad():
  outputs = model(input_ids_tensor, perm_mask=perm_mask_tensor, target_mapping=target_mapping_tensor)
next_token_logits = outputs[0]  # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]

for j in range(len(targets)):
  predicted_k_indexes = torch.topk(outputs[0][0][j],k=5)
  predicted_logits_list = predicted_k_indexes[0] 
  predicted_indexes_list = predicted_k_indexes[1] 
    
  print ("predicted word:",tokenizer.decode(input_ids[0][targets[j]].item()), j)
  for i,item  in enumerate(predicted_indexes_list):
      the_index = predicted_indexes_list[i].item()
      print("word and logits",tokenizer.decode(the_index),predicted_logits_list[i].item())

predicted word: <mask> 0
word and logits You -9.070054054260254
word and logits I -10.822368621826172
word and logits We -12.820359230041504
word and logits Now -14.133552551269531
word and logits They -14.863320350646973
predicted word: <mask> 1
word and logits three -23.045528411865234
word and logits the -24.3369083404541
word and logits these -25.59902000427246
word and logits two -25.809444427490234
word and logits your -25.947147369384766


## Top-k bi-directional generation

Now when we know how to predict masked words with XLNet it's time to create a top-k bidirectional text generator. Its work principles are simple. We will create a loop and at each iteration the model will predict top-k tokens for a masked word on the right or on the left side of start phrase. After that we add random token from topK to the start phrase and repeat iteration for n times.


In [None]:
import random
import numpy as np

# Function to select topK tokens from the probability list and 
# then based on the selected K word distribution get sample of random token IDs

def choose_from_top(probs, k=5, sample_size=1):
    ind = np.argpartition(probs, -k)[-k:]
    top_prob = probs[ind]
    # print(tokenizer.decode(ind))
    top_prob = top_prob / np.sum(top_prob) # Normalize
    choice = np.random.choice(k, sample_size, p = top_prob, replace=False)
    token_ids = ind[choice]
    return token_ids

In [None]:
# top-K bidiretional generation

sent = "text generation is cool"
topk = 10
n = 20
# Lower temperatures make the model more confident in its top choices, while temperatures greater than 1 decrease confidence.
temperature = 5
model.eval()
if torch.cuda.is_available(): model.to('cuda') #if we have a GPU 

sent_tokens = tokenizer.encode(sent, add_special_tokens=False)
mask_tokens = tokenizer.encode('<mask>', add_special_tokens=False)
padding_tokens = tokenizer.encode(PADDING_TEXT, add_special_tokens=False)
   
for i in range(n):
  input = mask_tokens + sent_tokens + mask_tokens     
  target_id1 = -len(input)
  target_id2 = -1

  input_ids = torch.tensor(padding_tokens + input).unsqueeze(0)   # We will predict masked tokens

  perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
  perm_mask[0, :, [target_id1, target_id2]] = 1.0  # Previous tokens don't see last token

  target_mapping = torch.zeros((1, 2, input_ids.shape[1]), dtype=torch.float)  
  target_mapping[0, 0, target_id1] = 1.0  # Our first  prediction 
  target_mapping[0, 1, target_id2] = 1.0  # Our second  prediction 

  input_ids_tensor = input_ids.to("cuda")
  target_mapping_tensor = target_mapping.to("cuda")
  perm_mask_tensor = perm_mask.to("cuda")

  with torch.no_grad():
    outputs = model(input_ids_tensor, perm_mask=perm_mask_tensor, target_mapping=target_mapping_tensor)

  predicted_tokens = []
  
  for j in range(2):
    probs = torch.nn.functional.softmax(outputs[0][0][j]/temperature, dim = 0).to('cpu').numpy()
    predicted_tokens.append(choose_from_top(probs, k=topk, sample_size=1))

  if i % 2 == 0:    
    tok = predicted_tokens[0][0]
    sent_tokens = [tok] + sent_tokens 
    print('left: ', tokenizer.decode(sent_tokens))
  else:     
    tok = predicted_tokens[1][0]
    sent_tokens = sent_tokens + [tok]
    print("right: ", tokenizer.decode(sent_tokens)) 

left:  The text generation is cool
right:  The text generation is cool for
left:  ? The text generation is cool for
right:  ? The text generation is cool for me
left:  ? The text generation is cool for me
right:  ? The text generation is cool for me to
left:  :? The text generation is cool for me to
right:  :? The text generation is cool for me to see
left:  says:? The text generation is cool for me to see
right:  says:? The text generation is cool for me to see and
left:  and says:? The text generation is cool for me to see and
right:  and says:? The text generation is cool for me to see and the
left:  reviews and says:? The text generation is cool for me to see and the
right:  reviews and says:? The text generation is cool for me to see and the font
left:  User reviews and says:? The text generation is cool for me to see and the font
right:  User reviews and says:? The text generation is cool for me to see and the font size
left:  1 User reviews and says:? The text generation is cool

Not too impressive. There is a lot of repetitions and whole text looks meaningless. But we will find a better solution.

## Top-k-beam bi-directional text generation

As we can see, it is still quite difficult for the model to generate text right-to-left. We often get a word that does not fit into the context well, which leads to an even less suitable next word. As a result, the generated text becomes incoherent.

We can increase the chances of finding connected word sequences by generating words not by one on each side of the starting phrase, but by creating a certain number of beams of word sequences and choosing one of the most probable beams of a certain length.

Thus, we get some kind of combination of top-k sampling and beam search. The principle of the resulting method is shown in the diagram.

![Generation schema](https://drive.google.com/uc?id=16ZqB6g5T7dlTwcrCHaSMP5hQnAv0BCiI)

Image was created by Rostyslav Neskorozhenyi with [draw.io](https://draw.io/) 

The bidirectional generation process consists of n iterations. I split each iteration into four steps for better understanding:

- In the first step, we get a start phrase and generate right-to-left on its left side a certain number of beams of a certain length (at each stage of beam search, we select next token candidates with top-K sampling).

- In the second step, we take a random beam from the top-K most probable beams and add it to the start phrase.

- The resulting new phrase serves as a start phrase for the third step, in which we generate a certain number of beams on the right side of the new start phrase.

- In the fourth step, we take a random beam from the top-k beams obtained in the third step and add that beam to the new starting phrase. The resulting phrase serves as the starting point for the next iteration.

I hope that the description was clear enough and the diagram will help you figure it out. The main thing is that, based on my experiments, this method in most cases allows you to generate quite coherent text bidirectionally.

Let's implement the method in code. Firstly we will create a function that will take tokenized start sentence, a sequence of token candidates with their probabilities and generate next n probable sequences of token candidates on the right or on the left side. We will use this function iteratively, so generated token sequences from previous iteration will serve as input on the next iteration. 

In [None]:
# create a combination of beam and top-k generation to generate sequences of n tokens from both sides 

import random

padding_tokens = tokenizer.encode(PADDING_TEXT, add_special_tokens=False)
mask_tokens = tokenizer.encode('<mask>', add_special_tokens=False)

model.eval()
if torch.cuda.is_available(): model.to('cuda') #if we have a GPU 

# create a function that will take tokenized sendence and a list of token candidates (with their probabilities) 
# and generate next n probable token sequences on the right or on the left side

def candidates_gen(sent_tokens, candidate=([], 1, []), d='left', n_candidates=5, topk=20, temperature=5):
  branch_candidates = []  
  cand_tokens = candidate[0]
  
  if d == 'right':    
    input = sent_tokens + cand_tokens + mask_tokens     
    
    target_id = -1
    input_ids = torch.tensor(padding_tokens + input).unsqueeze(0)  

    perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
    perm_mask[0, :, target_id] = 1.0  # Previous tokens don't see last token
  else:        
    input = mask_tokens + cand_tokens + sent_tokens    
    
    target_id = -len(input)  
    input_ids = torch.tensor(padding_tokens + input).unsqueeze(0)  

    perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
    perm_mask[0, :, [target_id - i for i in range(100)]] = 1.0  # Mask additional previos tokens to improve left-side generation

  # We will predict masked tokens 
  target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)  
  target_mapping[0, 0, target_id] = 1.0  # Our right  prediction 

  if torch.cuda.is_available():
    input_ids_tensor = input_ids.to("cuda")
    target_mapping_tensor = target_mapping.to("cuda")
    perm_mask_tensor = perm_mask.to("cuda")
  else:
    input_ids_tensor = input_ids
    target_mapping_tensor = target_mapping
    perm_mask_tensor = perm_mask

  with torch.no_grad():
    outputs = model(input_ids_tensor, perm_mask=perm_mask_tensor, target_mapping=target_mapping_tensor)

  probs = torch.nn.functional.softmax(outputs[0][0][0]/temperature, dim = 0)
  selected_indexes = choose_from_top(probs.to('cpu').numpy(), k=topk, sample_size=n_candidates)
  selected_probs = probs[selected_indexes]

  # predicted_k_probs = torch.topk(probs, k = topk)
  # predicted_indexes_list = predicted_k_probs[1]
  # indexes = list(range(predicted_indexes_list.shape[0]))
  # selected = random.sample(indexes, n_candidates)
  # selected_indexes = predicted_indexes_list[selected]
  # selected_probs = predicted_k_probs[0][selected]

  for i,item  in enumerate(selected_indexes):
      the_index = item.item()
      if d == "right":
        new_sent = cand_tokens + [the_index]
      elif d == "left":
        new_sent = [the_index] + cand_tokens
      
      prob = selected_probs[i].item()
      # add word combinations to branch_candidates in format [sentence, cumulative probability, all probs]
      branch_candidates.append((new_sent, candidate[1] * prob, candidate[2] + [prob]))
  
  return branch_candidates

In [None]:
# test our text branch generator
sent = "Text generation is cool"
sent_tokens = tokenizer.encode(sent, add_special_tokens=False)
first_sample_size = 5
beams = candidates_gen(sent_tokens=sent_tokens, d='left', n_candidates=first_sample_size, temperature=5)
for beam in beams:
  print(tokenizer.decode(beam[0]), beam[1])

" 0.00016840094758663327
that 0.00021763828408438712
<eop> 0.00023566206800751388
! 0.00023651127412449569
The 0.00018732658645603806


Now we will create **beam_gen** function that will generate a list of token beams of given length (depth) using token candidates proposed by **candidates_gen**.

**beam_gen** function will return final beams list sorted by probability.

In [None]:
import random
import numpy as np

def beam_gen(sent_tokens, candidates, depth=5, d='right', sample_size=2, topk=10, temperature=5):
  beams = candidates[:]
  new_candidates = candidates[:]
  while depth > 0:
    new_candidates = []
    for candidate in candidates:
      for new_candidate in candidates_gen(sent_tokens, candidate, d, sample_size, topk, temperature):
        beams.append(new_candidate)
        new_candidates.append(new_candidate)   
    print("Number of beams:", len(new_candidates))    
    candidates = new_candidates[:]
    depth -= 1
  # sort candidate beams by a sum of logaryphms of probability of each word in a beam. Which is equivalet to product of probabilities 
  sorted_beams = sorted(new_candidates, key=lambda tup: np.sum(np.log10(tup[2])), reverse=True)
  return beams, sorted_beams

Let's gather all parts together in a **bi_gen** function. 
**bi_gen** will be able to generate text left-to-right (parameter **direction**='right'), right-to-left (parameter **direction**='left'), or in both directions (parameter **direction**='both') 

If **both** directions are selected, generator will work in the following way: 
generate **n_tokens** on the left side, after that - n tokens in the right side, then again n tokens on the left side and so on.
It will repeat number of times, that is saved in **iterations** parameter.

We will separately indicate in **first_sample_size** parameter the number of cadidates in the first stage of beam search. This number can be higher than the number of candidates in the next stages (specified in the variable **sample_size**), since it is important to get enough candidates for the first token, on which all subsequent sequences will be based. According to my observations, this approach increases the likelihood of generating a coherent and reasonably probable sequence of tokens.

We will use high **temperature** parameter to  lower model confidence in its top token choices. This allows to make the generation more varied and not get stuck with the most likely repeating sequences of tokens.

In [None]:
import random
import numpy as np

def bi_generator(sent, direction, first_sample_size, sample_size, n_tokens, topk, iterations, temperature):
  sent_tokens = tokenizer.encode(sent, add_special_tokens=False) 

  for i in range(iterations):
    if (i % 2 == 0 and direction == 'both') or direction == 'left':
      print('>> left side generation')
      candidates = candidates_gen(sent_tokens=sent_tokens, d='left', n_candidates=first_sample_size,  topk=topk, temperature=temperature)
      beams, sorted_beams = beam_gen(sent_tokens, candidates, n_tokens-1, 'left', sample_size, topk, temperature=temperature)
      topn = len(sorted_beams)//5 if len(sorted_beams) > 4 else len(sorted_beams)
      selected_candidate = random.choice(sorted_beams[:topn])
      sent_tokens = selected_candidate[0] + sent_tokens
      print(tokenizer.decode(sent_tokens))
    if (i % 2 != 0 and direction == 'both') or direction == 'right':
      print('>> right side generation')
      candidates = candidates_gen(sent_tokens=sent_tokens, d='right', n_candidates=first_sample_size, topk=topk, temperature=temperature)
      beams, sorted_beams = beam_gen(sent_tokens, candidates, n_tokens-1, 'right', sample_size, topk, temperature=temperature)
      topn = len(sorted_beams)//5 if len(sorted_beams) > 4 else len(sorted_beams)
      selected_candidate = random.choice(sorted_beams[:topn])
      sent_tokens = sent_tokens + selected_candidate[0]
      print(tokenizer.decode(sent_tokens))
    
  return tokenizer.decode(sent_tokens)

And finally we will try our bidirectional text generator.

In [None]:
sent = "James Bond"  
first_sample_size = 4
sample_size = 2
n_tokens = 4
topk = 20
iterations = 6
temperature = 4
direction = "both"

bi_generator(sent, direction, first_sample_size, sample_size, n_tokens, topk, iterations, temperature);

>> left side generation
Number of beams: 8
Number of beams: 16
Number of beams: 32
his starring role as James Bond
>> right side generation
Number of beams: 8
Number of beams: 16
Number of beams: 32
his starring role as James Bond (2006-
>> left side generation
Number of beams: 8
Number of beams: 16
Number of beams: 32
is last seen in his starring role as James Bond (2006-
>> right side generation
Number of beams: 8
Number of beams: 16
Number of beams: 32
is last seen in his starring role as James Bond (2006-2014) and,
>> left side generation
Number of beams: 8
Number of beams: 16
Number of beams: 32
Valenciennes is last seen in his starring role as James Bond (2006-2014) and,
>> right side generation
Number of beams: 8
Number of beams: 16
Number of beams: 32
Valenciennes is last seen in his starring role as James Bond (2006-2014) and, when released from contract


## Conclusion



It's also starting to seem somewhat like we are embarking into a new world largely controlled by **artificial intelligence** based on its ability over a long period to manipulate, manage and adapt our daily lives.

The entire previous paragraph was generated by our new text generator. The text is pretty convincing, isn't it? Therefore, please accept my congratulations. We have created almost the first of its kind **transformers based bidirectional text generator**. And while it still makes a lot of mistakes, it can be used to create a lot of interesting and fun stories that will grow around any phrase that comes to your mind.

**More examples of text from Bidirectional generator:**


> Follow the trend: Graphic design is cool, **text generation** is cool, data manipulation and algorithms are cool, etc.

> Theoretical scientific framework for the field is enriched across various technological disciplines and these topics include genetic programming and **machine learning**


> Most **drink some beer** and vodka everyday and have no knowledge on the importance of drinking 

> Following up on my initial thoughts: **text generation is cool**! it works great for creating blog header, title etc. You will need Word 2013


