
# How to generate text: using different decoding methods for language generation with Transformers

### **Introduction**

In recent years, there has been an increasing interest in open-ended language generation thanks to the rise of large transformer-based language models trained on millions of webpages, such as OpenAI's famous [GPT2 model](https://openai.com/blog/better-language-models/). The results on conditioned open-ended language generation are impressive, e.g. [GPT2 on unicorns](https://openai.com/blog/better-language-models/#samples), [XLNet](https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e), [Controlled language with CTRL](https://blog.einstein.ai/introducing-a-conditional-transformer-language-model-for-controllable-generation/). Besides the improved transformer architecture and massive unsupervised training data, **better decoding methods** have also played an important role. 

This blog post gives a brief overview of different decoding strategies and more importantly shows how *you* can implement them with very little effort using the popular `transformers` library! 

All of the following functionalities can be used for **auto-regressive** language generation ([here](http://jalammar.github.io/illustrated-gpt2/) a refresher). In short, *auto-regressive* language generation is based on the assumption that the probability distribution of a word sequence can be decomposed into the product of conditional next word distributions: 
$$ P(w_{1:T} | W_0 ) = \prod_{t=1}^T P(w_{t} | w_{1: t-1}, W_0) \text{ ,with }  w_{1: 0} = \emptyset, $$

and $W_0$ being the initial *context* word sequence. The length $T$ of the word sequence is usually determined *on-the-fly* and corresponds to the timestep $t=T$ the EOS token is generated from $P(w_{t} | w_{1: t-1}, W_{0})$.


Auto-regressive language generation is now available for `GPT2`, `XLNet`, `OpenAi-GPT`, `CTRL`, `TransfoXL`, `XLM`, `Bart`, `T5` in both PyTorch and Tensorflow >= 2.0!

We will give a tour of the currently most prominent decoding methods, mainly *Greedy search*, *Beam search*, *Top-K sampling* and *Top-p sampling*.


Let's quickly install transformers and load the model. We will use GPT2 in Tensorflow 2.1 for demonstration, but the API is 1-to-1 the same for PyTorch.

In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
  Building wheel for transformers (PEP 517) ... [?25l[?25hdone


In [2]:
!nvidia-smi

Thu Jan 21 04:27:38 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
model_path = './drive/My Drive/models/mtg_card_gen/checkpoint-200/'
# model_path = 'gpt2-large

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import json
tokenizer = AutoTokenizer.from_pretrained(model_path)

# add the EOS token as PAD token to avoid warnings
model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id).eval()

In [6]:
device = torch.device('cuda')

In [7]:
model = model.to(device)

In [8]:
results = {}  

In [9]:
def print_and_add_to_results(output, key, add=True):
  if not key in results:
    results[key] = []
  str_out = tokenizer.decode(output, skip_special_tokens=True)
  print(str_out)
  if add:
    results[key].append(str_out)

### **Greedy Search**
* Greedy search doesn't work very well, outputs aren't great
* Decent amount of repitition
* Wrong Structure


In [10]:
max_length = 256

In [11]:
# encode context the generation is conditioned on
prompt = "Flamebreathing Pidgeon"
input_ids = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False).to(device)

In [12]:
# generate text until the output length (which includes the context length) reaches 50
greedy_output = model.generate(input_ids, max_length=135)

print_and_add_to_results(greedy_output[0], 'greedy')

Flamebreathing Pidgeon| Creature — Bird| {2}{G}| common| Flying line_break When Flamebreathing Pidgeon dies, create a 1/1 green Bird creature token with flying.| 2| 2| None| "The birds are the only thing that can stop the dragon." line_break —Gerrard of the Kessig end_of_card Gerrard of the Kessig| Creature — Bird| {2}{G}| common| Flying line_break When Gerrard of the Kessig dies, create a 1/1 green Bird creature token with flying.| 2| 2


In [13]:
# activate beam search and early_stopping
beam_outputs = model.generate(
    input_ids,  
    max_length=84, 
    num_beams=5, 
    early_stopping=True,
    num_return_sequences=5, 
)

for i, beam_output in enumerate(beam_outputs):
  print_and_add_to_results(beam_output, 'beam', add=True)

Flamebreathing Pidgeon| Creature — Bird| {2}{G}| common| Flying line_break When Flamebreathing Pidgeon enters the battlefield, create a 1/1 green Bird creature token with flying.| 2| 2| None| None end_of_card Flamebreathing Pidgeon| Creature — Bird| {2}{G}| common|
Flamebreathing Pidgeon| Creature — Bird| {2}{G}| common| Flying line_break When Flamebreathing Pidgeon enters the battlefield, create a 1/1 green Bird creature token with flying.| 2| 2| None| None end_of_card Flamebreathing Salamander| Creature — Salamander| {2}{G}| common|
Flamebreathing Pidgeon| Creature — Bird| {2}{G}| common| Flying line_break When Flamebreathing Pidgeon enters the battlefield, create a 1/1 green Bird creature token with flying.| 2| 2| None| None end_of_card Flamebreathing Salamander| Creature — Salamander| {3}{G}| common|
Flamebreathing Pidgeon| Creature — Bird| {2}{G}| common| Flying line_break When Flamebreathing Pidgeon enters the battlefield, create a 1/1 green Bird creature token with flying.| 2| 2

### **Sampling**

In its most basic form, sampling means randomly picking the next word $w_t$ according to its conditional probability distribution:

$$w_t \sim P(w|w_{1:t-1})$$

Taking the example from above, the following graphic visualizes language generation when sampling.

![vanilla_sampling](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/sampling_search.png)

It becomes obvious that language generation using sampling is not *deterministic* anymore. The word 
$\text{"car"}$ is sampled from the conditioned probability distribution $P(w | \text{"The"})$, followed by sampling $\text{"drives"}$ from $P(w | \text{"The"}, \text{"car"})$.

In `transformers`, we set `do_sample=True` and deactivate *Top-K* sampling (more on this later) via `top_k=0`. In the following, we will fix `random_seed=0` for illustration purposes. Feel free to change the `random_seed` to play around with the model.


In [14]:
# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_outputs = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=128, 
    top_k=0,
    num_return_sequences=5,
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
  print_and_add_to_results(sample_output, 'sample_base', True)

Output:
----------------------------------------------------------------------------------------------------
Flamebreathing Pidgeon| Creature — Bird Soldier| {4}{G}| common| Flying line_break Flash (You may cast this spell any time you could cast an instant.)| 3| 1| None| Every sparrow that rises above Petalwood yawns and withers within a thousand years. end_of_card Melsheim Madness| Enchantment — Aura| {3}{U}| common| Enchant creature line_break Enchanted creature gets +2/+1 and has flying.| None| None| None| "All metals, vermin fed on endogenously it
Flamebreathing Pidgeon| Creature — Snake| {4}{W}| common| Choose one — line_break • Tap and other cards you control: Suncloaks gain that many life. line_break • Tap and other cards you control: Suncloaks gain that many life.| 4| 4| None| None end_of_card Slobbe Mutilator| Creature — Horror| {2}{W}| common| Change Slobbe Mutilator's power and toughness as you cast this spell. (Damage and effects that say "destroy" don't destroy it.)
Flame

A trick is to make the distribution $P(w|w_{1:t-1})$ sharper (increasing the likelihood of high probability words and decreasing the likelihood of low probability words) by lowering the so-called `temperature` of the [softmax](https://en.wikipedia.org/wiki/Softmax_function#Smooth_arg_max). 

![top_p_sampling](https://github.com/patrickvonplaten/scientific_images/blob/master/sampling_search_with_temp.png?raw=true)

Let's see how we can cool down the distribution in the library by setting `temperature=0.7`:

In [15]:
# use temperature to decrease the sensitivity to low probability candidates
sample_outputs = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=128, 
    top_k=0, 
    temperature=0.7,
    num_return_sequences=5
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
  print_and_add_to_results(sample_output, 'temp_.7', True)

Output:
----------------------------------------------------------------------------------------------------
Flamebreathing Pidgeon| Creature — Fungus| {4}{W}{W}| common| Sacrifice a number of creature cards from an opponent's graveyard: Ragebreathing Pidgeon gets +2/+2 until end of turn.| 4| 3| None| "This is the time to draw my sword." end_of_card Fungus Whiptail| Creature — Wurm| {5}{W}| common| {T}: Add {W} or {U}.| 5| 4| None| "Those who dwell in it will be free to adapt to
Flamebreathing Pidgeon| Creature — Peacemaker| {B}| uncommon| Cumulative upkeep {2} (At the beginning of your upkeep, put an age counter on this permanent, then sacrifice it unless you pay its upkeep cost for each age counter on it.) line_break When Cumulative upkeep is paid, Cumulative upkeep deals X damage to you, where X is Cumulative upkeep's power.| 1| 1| None| "The sun is shining with all of us." end_of_card Eternal Destiny| Enchantment| {2}{W}{U}| rare| At
Flamebreathing Pidgeon| Creature — Elemental| {3

### **Top-K Sampling**

[Fan et. al (2018)](https://arxiv.org/pdf/1805.04833.pdf) introduced a simple, but very powerful sampling scheme, called ***Top-K*** sampling. In *Top-K* sampling, the *K* most likely next words are filtered and the probability mass is redistributed among only those *K* next words. 
GPT2 adopted this sampling scheme, which was one of the reasons for its success in story generation. 

We extend the range of words used for both sampling steps in the example above from 3 words to 10 words to better illustrate *Top-K* sampling.

![top_k_sampling](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/top_k_sampling.png)

Having set $K = 6$, in both sampling steps we limit our sampling pool to 6 words. While the 6 most likely words, defined as $V_{\text{top-K}}$ encompass only *ca.* two-thirds of the whole probability mass in the first step, it includes almost all of the probability mass in the second step. Nevertheless, we see that it successfully eliminates the rather weird candidates $\text{"not", "the", "small", "told"}$ 
in the second sampling step.


Let's see how *Top-K* can be used in the library by setting `top_k=50`:

In [16]:
# set top_k to 50
torch.manual_seed(0)
sample_outputs = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=max_length, 
    top_k=50,
    num_return_sequences=5
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
  print_and_add_to_results(sample_output, 'top_k_50', True)

Output:
----------------------------------------------------------------------------------------------------
Flamebreathing Pidgeon| Creature — Bird| {6}{U}| common| Protection from red| 3| 3| None| "I will send that brood to be consumed with Fire for eternity, and I will be the one to bring it home. It will do the rest." end_of_card Flamebreathing Pidgeon| Instant| {1}{U}| common| Flamebreathing Pidgeon deals 3 damage to any target.| None| None| None| "Fire is a weapon whose power cannot be measured. If you dare to wield it against me, you will be incinerated!" line_break —Alena, Selesnya initiate end_of_card Herald of Justice| Creature — Construct| {6}| common| Hexproof line_break {4}, Sacrifice Herald of Justice: Draw a card.| 6| 5| None| "If you love your work, you must earn it. Your deeds shall triumph, not dishonor." line_break —Thalia, Selesnya initiate end_of_card Kalas, God of the Damned| Legendary Creature — God| {4}{U}{U}| mythic| When
Flamebreathing Pidgeon| Enchantment| {5

One concern though with *Top-K* sampling is that it does not dynamically adapt the number of words that are filtered from the next word probability distribution $P(w|w_{1:t-1})$.



### **Top-p (nucleus) sampling**

Instead of sampling only from the most likely *K* words, in *Top-p* sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability *p*. The probability mass is then redistributed among this set of words. 
![top_p_sampling](https://github.com/patrickvonplaten/scientific_images/blob/master/top_p_sampling.png?raw=true)

In [17]:
torch.manual_seed(0)

# deactivate top_k sampling and sample only from 92% most likely words
sample_outputs = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=max_length, 
    top_p=0.92, 
    top_k=0,
    num_return_sequences=5
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
  print_and_add_to_results(sample_output, 'top_p_.92', True)

Output:
----------------------------------------------------------------------------------------------------
Flamebreathing Pidgeon| Creature — Bird| {6}{U}| common| Protection from red| 3| 3| None| "I will fly that way. If you speak of it, you'll agree that it's a blessing." line_break —Craska —Flamebreathing Pidgeon end_of_card Gill Strider| Creature — Insect| {4}{U}| common| Haste line_break Each creature you control with hexproof gets +1/+1 and has menace.| 4| 3| None| "May you protect the Gardens of Plenty." end_of_card Ice Climb| Enchantment| {3}{U}{U}| rare| Ice Climb enters the battlefield with two -1/-1 counters on it for each card in your hand. line_break Threshold — Ice Climb enters the battlefield with three -1/-1 counters on it for each card in your hand.| None| None| None| She taunted its inhabitants, saying, "Thumb your fingernails and lift your strength to the world of armaments." end_of_card Implant| Instant| {2}{U}| common| Target player loses 1
Flamebreathing Pidgeon

Great, that sounds like it could have been written by a human. Well, maybe not quite yet. 

While in theory, *Top-p* seems more elegant than *Top-K*, both methods work  well in practice. *Top-p* can also be used in combination with *Top-K*, which can avoid very low ranked words while allowing for some dynamic selection.

Finally, to get multiple independently sampled outputs, we can *again* set the parameter `num_return_sequences > 1`: 

In [18]:
torch.manual_seed(0)
# set top_k = 50 and set top_p = 0.95 and num_return_sequences = 3
sample_outputs = model.generate(
    input_ids,
    do_sample=True, 
    max_length=max_length, 
    top_k=50, 
    top_p=0.95, 
    num_return_sequences=5
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
  print_and_add_to_results(sample_output, 'top_p_.92_k_50', True)

Output:
----------------------------------------------------------------------------------------------------
Flamebreathing Pidgeon| Creature — Bird| {6}{U}| common| Protection from red| 3| 3| None| "I will send that brood to be consumed with the venom. Its heart will rise before the blood-grinding poison. It will do the rest." line_break —Aurelia, elder end_of_card Sporeclutch| Instant| {1}{U}| common| Destroy target blue or green creature. It gains indestructible until end of turn.| None| None| None| "In his place is the fire of passion." line_break —Leonin Tarlov end_of_card Sporecrasher| Creature — Bird| {3}{U}| uncommon| Flying line_break {2}{U}: Target creature gets -1/-1 until end of turn.| 2| 4| None| "The wings of sporecrasher sting as hard as steel and are as deadly as their blood." line_break —Leonin Tarlov end_of_card Sporecrasher's Blade| Creature — Bird| {3}{U}| uncommon| Flying line_break {4}{U}: Target creature gets -3
Flamebreathing Pidgeon| Creature — Imp| {3}{G}{W}| 

In [19]:
def save_json(path, to_save):
  with open(path, 'w', encoding='utf-8') as f:
    json.dump(to_save, f)

In [20]:
save_path = './drive/My Drive/models/mtg_card_gen/outputs/'
specific_save_path = save_path + prompt + '.json'

In [21]:
save_json(specific_save_path, results)