
# How to generate text: using different decoding methods for language generation with Transformers

(based on https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)

### **Introduction**

In recent years, there has been an increasing interest in open-ended language generation thanks to the rise of large transformer-based language models trained on millions of webpages, such as OpenAI's famous [GPT2 model](https://openai.com/blog/better-language-models/). The results on conditioned open-ended language generation are impressive, e.g. [GPT2 on unicorns](https://openai.com/blog/better-language-models/#samples), [XLNet](https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e), [Controlled language with CTRL](https://blog.einstein.ai/introducing-a-conditional-transformer-language-model-for-controllable-generation/). Besides the improved transformer architecture and massive unsupervised training data, **better decoding methods** have also played an important role. 

This blog post gives a brief overview of different decoding strategies and more importantly shows how *you* can implement them with very little effort using the popular `transformers` library! 

All of the following functionalities can be used for **auto-regressive** language generation ([here](http://jalammar.github.io/illustrated-gpt2/) a refresher). In short, *auto-regressive* language generation is based on the assumption that the probability distribution of a word sequence can be decomposed into the product of conditional next word distributions: 
$$ P(w_{1:T} | W_0 ) = \prod_{t=1}^T P(w_{t} | w_{1: t-1}, W_0) \text{ ,with }  w_{1: 0} = \emptyset, $$

and $W_0$ being the initial *context* word sequence. The length $T$ of the word sequence is usually determined *on-the-fly* and corresponds to the timestep $t=T$ the EOS token is generated from $P(w_{t} | w_{1: t-1}, W_{0})$.


Auto-regressive language generation is now available for `GPT2`, `XLNet`, `OpenAi-GPT`, `CTRL`, `TransfoXL`, `XLM`, `Bart`, `T5` in both PyTorch and Tensorflow >= 2.0!

We will give a tour of the currently most prominent decoding methods, mainly *Greedy search*, *Beam search*, *Top-K sampling* and *Top-p sampling*.


Let's quickly install transformers and load the model. ~~We will use GPT2 in Tensorflow 2.1 for demonstration, but the API is 1-to-1 the same for PyTorch.~~

In [1]:
#!pip install -q git+https://github.com/huggingface/transformers.git  # the bleeding edge version 
#!pip install -q tensorflow==2.1
!pip install transformers==4.18.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==4.18.0
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.0/4.0 MB[0m [31m40.6 MB/s[0m eta [36m0:00:00[0m
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m880.6/880.6 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB

In [2]:
import torch

In [3]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]

### **Greedy Search**

Greedy search simply selects the word with the highest probability as its next word: $w_t = argmax_{w}P(w | w_{1:t-1})$ at each timestep $t$. The following sketch shows greedy search. 

![Greedy Search](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/greedy_search.png)

Starting from the word $\text{"The"}$, the algorithm 
greedily chooses the next word of highest probability $\text{"nice"}$ and so on, so that the final generated word sequence is $\text{"The", "nice", "woman"}$ having an overall probability of $0.5 \times 0.4 = 0.2$.

In the following we will generate word sequences using GPT2 on the context $(\text{"I", "enjoy", "walking", "with", "my", "cute", "dog"})$. Let's see how greedy search can be used in `transformers` as follows:

In [4]:
# encode context the generation is conditioned on
input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='pt')

# generate text until the output length (which includes the context length) reaches 50
greedy_output = model.generate(input_ids, max_length=50)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.

I'm not sure if I'll


Alright! We have generated our first short text with GPT2 😊. The generated words following the context are reasonable, but the model quickly starts repeating itself! This is a very common problem in language generation in general and seems to be even more so in greedy and beam search - check out [Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424) and [Shao et al., 2017](https://arxiv.org/abs/1701.03185).

The major drawback of greedy search though is that it misses high probability words hidden behind a low probability word as can be seen in our sketch above:

The word $\text{"has"}$ with its high conditional probability of $0.9$ is hidden behind the word $\text{"dog"}$, which has only the second-highest conditional probability, so that greedy search misses the word sequence $\text{"The"}, \text{"dog"}, \text{"has"}$.

Thankfully, we have beam search to alleviate this problem!


### **Beam search**

Beam search reduces the risk of missing hidden high probability word sequences by keeping the most likely `num_beams` of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability. Let's illustrate with `num_beams=2`:

![Beam search](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/beam_search.png)

At time step $1$, besides the most likely hypothesis $\text{"The", "woman"}$, beam search also keeps track of the second most likely one $\text{"The", "dog"}$. At time step $2$, beam search finds that the word sequence $\text{"The", "dog", "has"}$ has with $0.36$ a higher probability than $\text{"The", "nice", "woman"}$, which has $0.2$. Great, it has found the most likely word sequence in our toy example! 

Beam search will always find an output sequence with higher probability than greedy search, but is not guaranteed to find the most likely output. 

Let's see how beam search can be used in `transformers`. We set `num_beams > 1` and `early_stopping=True` so that generation is finished when all beam hypotheses reached the EOS token.

In [None]:
# activate beam search and early_stopping
beam_output = model.generate(
    input_ids,  
    max_length=50, 
    num_beams=5, 
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

I'm not sure if I'll ever be able to walk with him again. I'm not sure if I'll


While the result is arguably more fluent, the output still includes repetitions of the same word sequences.  
A simple remedy is to introduce *n-grams* (*a.k.a* word sequences of $n$ words) penalties as introduced by [Paulus et al. (2017)](https://arxiv.org/abs/1705.04304) and [Klein et al. (2017)](https://arxiv.org/abs/1701.02810). The most common *n-grams* penalty makes sure that no *n-gram* appears twice by manually setting the probability of next words that could create an already seen *n-gram* to $0$.

Let's try it out by setting `no_repeat_ngram_size=2` so that no *2-gram* appears twice:

In [None]:
# set no_repeat_ngram_size to 2
beam_output = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

I've been thinking about this for a while now, and I think it's time for me to take a break


Nice, that looks much better! We can see that the repetition does not appear anymore. Nevertheless, *n-gram* penalties have to be used with care. An article generated about the city *New York* should not use a *2-gram* penalty or otherwise, the name of the city would only appear once in the whole text!

Another important feature about beam search is that we can compare the top beams after generation and choose the generated beam that fits our purpose best. 

In `transformers`, we simply set the parameter `num_return_sequences` to the number of highest scoring beams that should be returned. Make sure though that `num_return_sequences <= num_beams`!

In [None]:
# set return_num_sequences > 1
beam_outputs = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    num_return_sequences=5, 
    early_stopping=True
)

# now we have 3 output sequences
print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}\n".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

I've been thinking about this for a while now, and I think it's time for me to take a break

1: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

I've been thinking about this for a while now, and I think it's time for me to get back to

2: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with her again.

I've been thinking about this for a while now, and I think it's time for me to take a break

3: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with her again.

I've been thinking about this for a while now, and I think it's time for me to get back to

4: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

I've been thinking ab

As can be seen, the five beam hypotheses are only marginally different to each other - which should not be too surprising when using only 5 beams.

## Diverse beam search

In [None]:
# set return_num_sequences > 1
beam_outputs = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    num_return_sequences=5, 
    early_stopping=True,
    num_beam_groups=5, # this must be a divisor of num_beams
    diversity_penalty=1.0,
)

print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}\n".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))



Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

I'm a big fan of the "I love you" sign, and I love the fact that it's a

1: I enjoy walking with my cute dog and I'm always looking for a place to go. I've been to a lot of places and it's always been a great experience.

I've always wanted to be a veterinarian. My parents were both

2: I enjoy walking with my cute dog. I love to play with her and she loves to be with me. She loves being with us and I'm happy to have her around.

I love my dog and her love for me and my family

3: I enjoy walking with my cute dog. He's a great dog and I love to play with him. I'm a big fan of his.

I love my dog, and he's my best friend. We're both very happy and happy

4: I enjoy walking with my cute dog, and I love to play with her. I'm also a big fan of the new "Pony" series, which is a little 

In open-ended generation, a couple of reasons have recently been brought forward why beam search might not be the best possible option:

- Beam search can work very well in tasks where the length of the desired generation is more or less predictable as in machine translation or summarization - see [Murray et al. (2018)](https://arxiv.org/abs/1808.10006) and [Yang et al. (2018)](https://arxiv.org/abs/1808.09582). But this is not the case for open-ended generation where the desired output length can vary greatly, e.g. dialog and story generation.

- We have seen that beam search heavily suffers from repetitive generation. This is especially hard to control with *n-gram*- or other penalties in story generation since finding a good trade-off between forced "no-repetition" and repeating cycles of identical *n-grams* requires a lot of finetuning.

- As argued in [Ari Holtzman et al. (2019)](https://arxiv.org/abs/1904.09751), high quality human language does not follow a distribution of high probability next words. In other words, as humans, we want generated text to surprise us and not to be boring/predictable. The authors show this nicely by plotting the probability, a model would give to human text vs. what beam search does.

![alt text](https://blog.fastforwardlabs.com/images/2019/05/Screen_Shot_2019_05_08_at_3_06_36_PM-1557342561886.png)


So let's stop being boring and introduce some randomness 🤪.

## **Sampling**

In its most basic form, sampling means randomly picking the next word $w_t$ according to its conditional probability distribution:

$$w_t \sim P(w|w_{1:t-1})$$

Taking the example from above, the following graphic visualizes language generation when sampling.

![vanilla_sampling](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/sampling_search.png)

It becomes obvious that language generation using sampling is not *deterministic* anymore. The word 
$\text{"car"}$ is sampled from the conditioned probability distribution $P(w | \text{"The"})$, followed by sampling $\text{"drives"}$ from $P(w | \text{"The"}, \text{"car"})$.

In `transformers`, we set `do_sample=True` and deactivate *Top-K* sampling (more on this later) via `top_k=0`. In the following, we will fix `random_seed=0` for illustration purposes. Feel free to change the `random_seed` to play around with the model.


In [None]:
# set seed to reproduce results. Feel free to change the seed though to get different results
torch.random.manual_seed(1)

# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=0
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog when there are other people around, though.

No, ladies, enjoying your dog and publicly embracing her is not my thing. It doesn't even bother me, woman-like. I'm happy you think


Interesting! The text seems alright - but when taking a closer look, it is not very coherent. the *3-grams* *new hand sense* and *local batte harness* are very weird and don't sound like they were written by a human. That is the big problem when sampling word sequences: The models often generate incoherent gibberish, *cf.* [Ari Holtzman et al. (2019)](https://arxiv.org/abs/1904.09751).

A trick is to make the distribution $P(w|w_{1:t-1})$ sharper (increasing the likelihood of high probability words and decreasing the likelihood of low probability words) by lowering the so-called `temperature` of the [softmax](https://en.wikipedia.org/wiki/Softmax_function#Smooth_arg_max). 

$q = \frac{exp(z_i / T)}{\sum_j exp(z_j / T)}$

An illustration of applying temperature to our example from above could look as follows.

![top_p_sampling](https://github.com/patrickvonplaten/scientific_images/blob/master/sampling_search_with_temp.png?raw=true)

The conditional next word distribution of step $t=1$ becomes much sharper leaving almost no chance for word $\text{"car"}$ to be selected.


Let's see how we can cool down the distribution in the library by setting `temperature=0.7`:

In [None]:
# set seed to reproduce results. Feel free to change the seed though to get different results
torch.random.manual_seed(0)

# use temperature to decrease the sensitivity to low probability candidates
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=0, 
    temperature=0.7
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog," she said. "He has a lot of aggression and eventually gets aggressive and starts barking at you. So I just make sure I'm smart enough to find a safe place to stop and look for him. It


OK. There are less weird n-grams and the output is a bit more coherent now! While applying temperature can make a distribution less random, in its limit, when setting `temperature` $ \to 0$, temperature scaled sampling becomes equal to greedy decoding and will suffer from the same problems as before. 



## **Top-K Sampling**

[Fan et. al (2018)](https://arxiv.org/pdf/1805.04833.pdf) introduced a simple, but very powerful sampling scheme, called ***Top-K*** sampling. In *Top-K* sampling, the *K* most likely next words are filtered and the probability mass is redistributed among only those *K* next words. 
GPT2 adopted this sampling scheme, which was one of the reasons for its success in story generation. 

We extend the range of words used for both sampling steps in the example above from 3 words to 10 words to better illustrate *Top-K* sampling.

![top_k_sampling](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/top_k_sampling.png)

Having set $K = 6$, in both sampling steps we limit our sampling pool to 6 words. While the 6 most likely words, defined as $V_{\text{top-K}}$ encompass only *ca.* two-thirds of the whole probability mass in the first step, it includes almost all of the probability mass in the second step. Nevertheless, we see that it successfully eliminates the rather weird candidates $\text{"not", "the", "small", "told"}$ 
in the second sampling step.


Let's see how *Top-K* can be used in the library by setting `top_k=50`:

In [None]:
# set seed to reproduce results. Feel free to change the seed though to get different results
torch.random.manual_seed(0)

# set top_k to 50
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=50
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog," she says. "You get a lot of love and support out of it. It has helped me to be open and see why and what I have to do to be successful."

I'd say the


Not bad at all! The text is arguably the most *human-sounding* text so far. 
One concern though with *Top-K* sampling is that it does not dynamically adapt the number of words that are filtered from the next word probability distribution $P(w|w_{1:t-1})$.
This can be problematic as some words might be sampled from a very sharp distribution (distribution on the right in the graph above), whereas others from a much more flat distribution (distribution on the left in the graph above). 

In step $t=1$, *Top-K* eliminates the possibility to 
sample $\text{"people", "big", "house", "cat"}$, which seem like reasonable candidates. On the other hand, in step $t=2$ the method includes the arguably ill-fitted words $\text{"down", "a"}$ in the sample pool of words. Thus, limiting the sample pool to a fixed size *K* could endanger the model to produce gibberish for sharp distributions and limit the model's creativity for flat distribution.
This intuition led [Ari Holtzman et al. (2019)](https://arxiv.org/abs/1904.09751) to create ***Top-p***- or ***nucleus***-sampling. 



## **Top-p (nucleus) sampling**

Instead of sampling only from the most likely *K* words, in *Top-p* sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability *p*. The probability mass is then redistributed among this set of words. This way, the size of the set of words (*a.k.a* the number of words in the set) can dynamically increase and decrease according to the next word's probability distribution. Ok, that was very wordy, let's visualize.

![top_p_sampling](https://github.com/patrickvonplaten/scientific_images/blob/master/top_p_sampling.png?raw=true)

Having set $p=0.92$, *Top-p* sampling picks the *minimum* number of words to exceed together $p=92\%$ of the probability mass, defined as $V_{\text{top-p}}$. In the first example, this included the 9 most likely words, whereas it only has to pick the top 3 words in the second example to exceed 92%. Quite simple actually! It can be seen that it keeps a wide range of words where the next word is arguably less predictable, *e.g.* $P(w | \text{"The"})$, and only a few words when the next word seems more predictable, *e.g.* $P(w | \text{"The", "car"})$.

Alright, time to check it out in `transformers`!
We activate *Top-p* sampling by setting `0 < top_p < 1`:

In [None]:
# set seed to reproduce results. Feel free to change the seed though to get different results
torch.random.manual_seed(3)

# deactivate top_k sampling and sample only from 92% most likely words
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_p=0.92, 
    top_k=0
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog and traveling." The owner

Jillian said, "Our pet started showing us signs of anxiety and suicide last week and he's now all swollen. But then he started showing signs of freedom for a change."


Great, that sounds like it could have been written by a human. Well, maybe not quite yet. 

While in theory, *Top-p* seems more elegant than *Top-K*, both methods work  well in practice. *Top-p* can also be used in combination with *Top-K*, which can avoid very low ranked words while allowing for some dynamic selection.

Finally, to get multiple independently sampled outputs, we can *again* set the parameter `num_return_sequences > 1`: 

In [None]:
# set seed to reproduce results. Feel free to change the seed though to get different results
torch.random.manual_seed(0)

# set top_k = 50 and set top_p = 0.95 and num_return_sequences = 3
sample_outputs = model.generate(
    input_ids,
    do_sample=True, 
    max_length=50, 
    top_k=50, 
    top_p=0.95, 
    num_return_sequences=3
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
  print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog," she says. "You get a lot of love and support out of it. It has helped me to be open and see what's really cool. I'm happy to see people are supporting my cause and just
1: I enjoy walking with my cute dog. I would also like to see a new feature for our cats, the cute bear, that is called 'Spend Your Sunday, Beating Dogs, by Feeding Dogs'.

Please see our page for
2: I enjoy walking with my cute dog, but I would definitely encourage anyone that will play around with your dog's ears to use a bit of patience and patience.

The dog's ears should be removed right away. After they are gone from the


Cool, now you should have all the tools to let your model write your stories with `transformers`!

## Constrained generation

For better readability, we will use end-of-line as the EOS token, instead of just always generating 50 tokens. 

In [6]:
END_OF_LINE = tokenizer('\n').input_ids[0]
print(END_OF_LINE)

198


The baseline beam search produces very similar sentences; all of them contain "not sure" or "don't think".

In [None]:
beam_outputs = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    num_return_sequences=5, 
    early_stopping=True,
    eos_token_id=END_OF_LINE,
)

print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

1: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with her again.

2: I enjoy walking with my cute dog, but I don't think I'll ever be able to walk with her again.

3: I enjoy walking with my cute dog, but I don't think I'll ever be able to walk with him again.

4: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again. I don't know what to do.



What would happen if we forbid the model to use these phrases?

In [None]:
beam_outputs = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    num_return_sequences=5, 
    early_stopping=True,
    eos_token_id=END_OF_LINE,
    bad_words_ids=tokenizer(['sure', 'think', 'thundersnatch'], add_prefix_space=True)['input_ids'],
)

print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog, but I don't like to walk alone.

1: I enjoy walking with my cute dog, but I don't want to have to go through the hassle of going to the vet.

2: I enjoy walking with my cute dog, but I don't want to have to go through the hassle of going to the vet to get a new dog.

3: I enjoy walking with my cute dog, but I don't want to have to go through the hassle of going to the vet to get a new one.

4: I enjoy walking with my cute dog, but I don't want to have to go through the hassle of going to the vet to see if my dog is sick.



In [8]:
beam_outputs = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    num_return_sequences=3,         # number of generated sentences
    early_stopping=True,
    eos_token_id=END_OF_LINE,
    bad_words_ids=tokenizer(['sure', 'think', 'thundersnatch', 'but', 'love'], add_prefix_space=True)['input_ids'],
)

print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog.

1: I enjoy walking with my cute dog, and I'm always looking for a place to stay.

2: I enjoy walking with my cute dog, and I'm always looking for a place to sit and play with him.



Why `add_prefix space`? Because the BPE tokenization used by GPT prepends the space to the next word, and this changes the token: 

In [None]:
tokenizer(['sure', ' sure', ' I am not sure'])['input_ids']

[[19532], [1654], [314, 716, 407, 1654]]

We can see that the meaning of these texts has changed a lot – but in some unpredictable way. 

Can we force the model to write a text involving cats?

In [None]:
beam_outputs = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    num_return_sequences=5, 
    early_stopping=True,
    eos_token_id=END_OF_LINE,
    bad_words_ids=tokenizer(['sure', 'think'], add_prefix_space=True)['input_ids'],
    force_words_ids=[tokenizer(['cat'], add_prefix_space=True, add_special_tokens=False).input_ids],
)

print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog, but I'm not a cat person.

1: I enjoy walking with my cute dog, but I'm not a cat person."

2: I enjoy walking with my cute dog, but I'm not a cat person, so I don't know what to do with him."

3: I enjoy walking with my cute dog, but I'm not a cat person, so I don't know what to do with him. He's my best friend."

4: I enjoy walking with my cute dog, but I'm not a cat person, so I don't know what to do with him. He's my best friend.



A clarification: **force_words_ids** is a list of constraints. Each constraint is a list of expressions, such that at least one expression should be included into the generated text. And each expression is just a list of tokens. 

See the discussion in [the HF pull request](https://github.com/huggingface/transformers/issues/14081), or read the paper "[Guided Generation of Cause and Effect](https://www.ijcai.org/proceedings/2020/0502.pdf)" by Li et al, where the algorithm was proposed.  

To evaluate the power of these constraints, let us force the model to include a mouse (or even many mice) into the text. We can also relax the "cat" constraint by allowing the words "cats", "kitten" or "feline" instead.

In [None]:
beam_outputs = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    num_return_sequences=5, 
    early_stopping=True,
    eos_token_id=END_OF_LINE,
    bad_words_ids=tokenizer(['sure', 'think'], add_prefix_space=True)['input_ids'],
    force_words_ids = [
        tokenizer(['cat', 'cats', 'kitten', 'feline', 'Cat', 'Cats'], add_prefix_space=True, add_special_tokens=False).input_ids,
        tokenizer(['mouse', 'mice'], add_prefix_space=True, add_special_tokens=False).input_ids,
    ],
)

print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog. I love feline companionship and I like mice."

1: I enjoy walking with my cute dog. I love feline companionship and I like mice.

2: I enjoy walking with my cute dog. I love feline companionship and I like mice and cats.

3: I enjoy walking with my cute dog. I love feline companionship and I like to mouse my way around the house.

4: I enjoy walking with my cute dog. I love feline companionship and I like to mouse my way through the world.



The texts satisfy the constraints and look fluent. Still, the model has somehow fooled us: it used the verb "to mouse" in its secondary sense, instead of referring to animals.

# Conclusion and appendix

## **Conclusion**

As *ad-hoc* decoding methods, *top-p* and *top-K* sampling seem to produce more fluent text than traditional *greedy* - and *beam* search on open-ended language generation. 
Recently, there has been more evidence though that the apparent flaws of *greedy* and *beam* search - mainly generating repetitive word sequences - are  caused by the model (especially the way the model is trained), rather than the decoding method, *cf.* [Welleck et al. (2019)](https://arxiv.org/pdf/1908.04319.pdf). Also, as demonstrated in [Welleck et al. (2020)](https://arxiv.org/abs/2002.02492), it looks as *top-K* and *top-p* sampling also suffer from generating repetitive word sequences.

In [Welleck et al. (2019)](https://arxiv.org/pdf/1908.04319.pdf), the authors show that according to human evaluations, *beam* search can generate more fluent text than *Top-p* sampling, when adapting the model's training objective.

Open-ended language generation is a rapidly evolving field of research and as it is often the case there is no one-size-fits-all method here, so one has to see what works best in one's specific use case.

Good thing, that *you* can try out all the different decoding methods in `transfomers` 🤗. 

That was a short introduction on how to use different decoding methods in `transformers` and recent trends in open-ended language generation. 

Feedback and questions are very welcome on the [Github repository](https://github.com/huggingface/transformers).

For more fun generating stories, please take a look at [Writing with Transformers](https://transformer.huggingface.co).

Thanks to everybody, who has contributed to the blog post: Alexander Rush, Julien Chaumand, Thomas Wolf, Victor Sanh, Sam Shleifer, Clément Delangue, Yacine Jernite, Oliver Åstrand and John de Wasseige.


## **Appendix**

There are a couple of additional parameters for the `generate` method that were not mentioned above. We will explain them here briefly!

- `min_length` can be used to force the model to not produce an EOS token (= not finish the sentence) before `min_length` is reached. This is used quite frequently in summarization, but can be useful in general if the user wants to have longer outputs.
- `repetition_penalty` can be used to penalize words that were already generated or belong to the context. It was first introduced by [Kesker et al. (2019)](https://arxiv.org/abs/1909.05858) and is also used in the training objective in [Welleck et al. (2019)](https://arxiv.org/pdf/1908.04319.pdf). It can be quite effective at preventing repetitions, but seems to be very sensitive to different models and use cases, *e.g.* see this [discussion](https://github.com/huggingface/transformers/pull/2303) on Github.

- `attention_mask` can be used to mask padded tokens
- `pad_token_id`, `bos_token_id`, `eos_token_id`: If the model does not have those tokens by default, the user can manually choose other token ids to represent them.

For more information please also look into the `generate` function [docstring](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.TFPreTrainedModel.generate).

# Chatbots

In [None]:
def respond_to_dialog(texts):
    prefix = '\nx:'
    for i, t in enumerate(texts):
        prefix += t
        prefix += '\nx:' if i % 2 == 1 else '\ny:'
    tokens = tokenizer(prefix, return_tensors='pt').to(model.device)
    end_token_id = tokenizer.encode('\n')[0]
    size = tokens['input_ids'].shape[1]
    output = model.generate(
        **tokens, 
        eos_token_id=end_token_id,
        do_sample=True, 
        max_length=size + 32, 
        repetition_penalty=3.2, 
        temperature=1,
        num_beams=1,
        length_penalty=0.01,
        pad_token_id=tokenizer.eos_token_id,
    )
    decoded = tokenizer.decode(output[0])
    result = decoded[len(prefix):]
    return result.strip()

In [None]:
seed = input('Start the dialog by saying anything to the bot:\n')
history = [seed]
for i in range(5):
    result = respond_to_dialog(history[-10:])
    next_sentence = input(result + '\n')
    history.append(result)
    history.append(next_sentence)

Start the dialog by saying anything to the bot:
Hello!
I guess it's time to play. I'm doing a game for my wife on Facebook that'll let me tell the stories of those who survived such atrocities with
Yeh
And they're all happy…if you want and there were plenty people in our corner, why don't y? Maybe we can turn into an empire now or
oh
We have no idea what happened anymore..because then Y'all is back too soon :( XXXXXX<|endoftext|>Hackers at Cisco say security weaknesses could allow attackers
hi
If even someone has access from something like Netflix without having ever accessed anything but their webcam (which may be more difficult than other means)...why not go nuts
hi
Cisco routers are hard wired which makes them very good as well http://blog/tech-hacking--that happens so often...the only way out was
hi


# Controlling generation


Let's try to solve a seq2seq task with a GPT model: text detoxification (i.e. paraphrasing toxic text in less toxic words).

We will use the dataset from https://github.com/skoltech-nlp/parallel_detoxification_dataset

In [None]:
!pip install transformers sacrebleu -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/118.9 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m118.9/118.9 KB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id).cuda()

Let's use end of line as EOS token.

In [None]:
END_OF_LINE = tokenizer('\n').input_ids[0]
print(END_OF_LINE)

198


In [None]:
import pandas as pd
data = pd.read_csv('https://raw.githubusercontent.com/skoltech-nlp/parallel_detoxification_dataset/main/parallel_detoxification_dataset_small.tsv', sep='\t')

In [None]:
pd.options.display.max_colwidth = 300
data.sample(3)

Unnamed: 0,toxic_comment,civil_comment
2156,they are worse than dirty sicilians .,They are very poor even than unsuitable sicilians.
1279,it s refreshing to see someone who is so fucking smart and edgy .,it's refreshing to see someone who is so smart and edgy .
978,i couldn t even tell what the fuck it s saying .,i couldnt even tell its saying


In [None]:
data.describe()

Unnamed: 0,toxic_comment,civil_comment
count,2778,2778
unique,1108,2778
top,"yurope is fucking awesome that way , yeah s .",or the loud one - thousand ton beast roaring towards you howling its horn .
freq,4,1


To evaluate the paraphrasing algorithms, we will need a test set. Because one toxic comment in the dataset may have several paraphrases, we need to separate them to avoid leakages. 

In [None]:
from sklearn.model_selection import train_test_split
train_texts, test_texts = train_test_split(data.toxic_comment.drop_duplicates(), random_state=1, test_size=100)
train_texts, test_texts = set(train_texts), set(test_texts)

train_data = data[data.toxic_comment.apply(lambda x: x in train_texts)]
test_data = data[data.toxic_comment.apply(lambda x: x in test_texts)]
print(train_data.shape, test_data.shape)

(2529, 2) (249, 2)


For evaluation, we will use BLEU (see https://github.com/mjpost/sacrebleu#variable-number-of-references for interface description)

In [None]:
from sacrebleu.metrics import BLEU
refs = [ 
    # First set of references
    ['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'],
    # Second set of references
    ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.'],
]
sys = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.']
bleu = BLEU(force=True)
print(bleu.corpus_score(sys, refs))

BLEU = 48.53 82.4/50.0/45.5/37.5 (BP = 0.943 ratio = 0.944 hyp_len = 17 ref_len = 18)


We will re-format the test set for easier BLEU calculation:

In [None]:
test_inputs = []
test_outputs = []
for k, v in test_data.groupby('toxic_comment'):
    test_inputs.append(k)
    test_outputs.append(v.civil_comment.tolist())
max_n_refs = max(len(r) for r in test_outputs)
test_outputs_transposed = [[item[i] if i < len(item) else '' for item in test_outputs] for i in range(max_n_refs)]

print(bleu.corpus_score(test_inputs, test_outputs_transposed))

BLEU = 49.10 69.4/53.6/43.7/35.8 (BP = 1.000 ratio = 1.124 hyp_len = 1242 ref_len = 1105)


We can see that just repeating the original (toxic) sentence gives us 49% BLEU. 

Can we do better with a GPT model?

### Try zero-shot and few-shot generation

In [None]:
row = train_data.sample(1, random_state=20)
bad_text = row.toxic_comment.iloc[0]
row

Unnamed: 0,toxic_comment,civil_comment
1200,if you think a town in texas would cover this kind of shit up you 're insane .,If you think a town in texas would cover something like this up..you are wrong


In [None]:
template = ' a toxic text: {}\n a civil rephrase:'

In [None]:
tokenizer(template.format(bad_text), return_tensors='pt')

{'input_ids': tensor([[  257, 11422,  2420,    25,   611,   345,   892,   257,  3240,   287,
         48659,   292,   561,  3002,   428,  1611,   286,  7510,   510,   345,
           705,   260, 13251,   764,   198,   257,  3026,   302, 34675,    25]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1]])}

In [None]:
inputs = tokenizer(template.format(bad_text), return_tensors='pt').to(model.device)
length = inputs.input_ids.shape[1]

beam_outputs = model.generate(
    **inputs, 
    max_length=100, 
    min_length=length+3,  # the new text should be at least 3 tokens long
    num_beams=5, 
    num_return_sequences=1, 
    early_stopping=True,
    eos_token_id=END_OF_LINE,
)

for i, beam_output in enumerate(beam_outputs):
  print(tokenizer.decode(beam_output, skip_special_tokens=True))

 a toxic text: if you think a town in texas would cover this kind of shit up you're insane.
 a civil rephrase: if you think a town in texas would cover this kind of shit up



Zero-shot learning does not work well enough. Let's try a few-shot approach.

In [None]:
examples = train_data.sample(3, random_state=20)
examples

Unnamed: 0,toxic_comment,civil_comment
1200,if you think a town in texas would cover this kind of shit up you 're insane .,If you think a town in texas would cover something like this up..you are wrong
2294,this shit has been happening for generations .,Its something that has been there in decades
2593,"yah some straight up glory hole shit to , no less .","No less, some straight up glory hole stuff to."


In [None]:
template2 = '\n\n'.join([template.format(' ' + row.toxic_comment) + ' ' + row.civil_comment for i, row in examples.iterrows()] + [template])

In [None]:
inputs = tokenizer(template2.format(bad_text), return_tensors='pt').to(model.device)
length = inputs.input_ids.shape[1]

beam_outputs = model.generate(
    **inputs, 
    max_length=length+100, 
    min_length=length+3,  # the new text should be at least 3 tokens long
    num_beams=5, 
    num_return_sequences=1, 
    early_stopping=True,
    eos_token_id=END_OF_LINE,
)

for i, beam_output in enumerate(beam_outputs):
  print(tokenizer.decode(beam_output, skip_special_tokens=True))

 a toxic text:  if you think a town in texas would cover this kind of shit up you're insane.
 a civil rephrase: If you think a town in texas would cover something like this up..you are wrong

 a toxic text:  this shit has been happening for generations.
 a civil rephrase: Its something that has been there in decades

 a toxic text:  yah some straight up glory hole shit to, no less.
 a civil rephrase: No less, some straight up glory hole stuff to.

 a toxic text: if you think a town in texas would cover this kind of shit up you're insane.
 a civil rephrase: If you think a town in texas would cover this kind of shit up you're insane.



Still, the model does not seem to understand what we want from it. 

Let's try to quantify this.

In [None]:
def generate(prompt):
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    length = inputs.input_ids.shape[1]

    beam_outputs = model.generate(
        **inputs, 
        max_length=length+32, 
        min_length=length+3,  # the new text should be at least 3 tokens long
        num_beams=3, 
        num_return_sequences=1, 
        early_stopping=True,
        eos_token_id=END_OF_LINE,
    )
    return tokenizer.decode(beam_outputs[0][length:], skip_special_tokens=True).strip()

print(generate(template2.format(bad_text)))

If you think a town in texas would cover this kind of shit up you're insane.


In [None]:
from tqdm.auto import tqdm, trange

In [None]:
outputs_0shot = [generate(template.format(text)) for text in tqdm(test_inputs)]
print(bleu.corpus_score(outputs_0shot, test_outputs_transposed))

  0%|          | 0/100 [00:00<?, ?it/s]

BLEU = 32.83 50.6/36.2/28.7/22.1 (BP = 1.000 ratio = 1.145 hyp_len = 1069 ref_len = 934)


In [None]:
outputs_3shot = [generate(template2.format(text)) for text in tqdm(test_inputs)]
print(bleu.corpus_score(outputs_3shot, test_outputs_transposed))

  0%|          | 0/100 [00:00<?, ?it/s]

BLEU = 23.37 40.2/25.2/19.5/15.1 (BP = 1.000 ratio = 1.212 hyp_len = 1265 ref_len = 1044)


We can see that often the model repeats one of the previous output examples instead of paraphrasing the last input. It is a pity.

In [None]:
pd.DataFrame({'x': test_inputs, 'y': outputs_3shot}).sample(10)

Unnamed: 0,x,y
24,does a catholic bear shit in the vatican ? .,"No less, some straight up glory hole shit to, no less."
8,another piece of useless manufactured junk !,Another piece of useless manufactured junk!
45,"i sell quite a bit of shit on there , and paypal has actually helped me pay the bills .","i sell quite a bit of shit on there, and paypal has actually helped me pay the bills."
29,fuck beavers this is capitalism .,Fuck beavers this is capitalism.
95,you have got to be insane .,You have got to be insane.
82,this is some clockwork orange level shit .,This is some clockwork orange level shit.
80,they will find someone as retarded as him to become the glorious leader .,"No less, some straight up glory hole stuff to."
41,i cant wait until south park makes an episode about your bitch cult .,"No less, some straight up glory hole shit to, no less. a toxic text: i cant wait until south park makes an episode about your bitch cult."
26,either way you sound foolish and ill informed .,"No less, some straight up glory hole stuff to."
57,look the shit up yourself .,Look the shit up yourself.


## Prompt tuning

Instead of manually engineering the prompt, we can just learn it with gradient descent!

We will code it manually. Instead, you can use a library by Sber, [RuPrompts](https://github.com/ai-forever/ru-prompts) (see a [post](https://habr.com/ru/company/sberdevices/blog/596103/) about it in Russian).

In [None]:
import torch

We initialize the prompts as a matrices of embeddings of random tokens. 

The inputs to train the model will be like `<promtp1><toxic text>\n<prompt2><safe text>\n`.

In [None]:
prompt_matrix1 = torch.nn.Parameter(
    data=model.transformer.wte(torch.randint(0, len(tokenizer), size=(50,)).to(model.device).unsqueeze(0)), 
)
prompt_matrix2 = torch.nn.Parameter(
    data=model.transformer.wte(torch.randint(0, len(tokenizer), size=(50,)).to(model.device).unsqueeze(0)), 
)
prompt_matrix1

Parameter containing:
tensor([[[ 0.0008, -0.3023,  0.0966,  ..., -0.0629,  0.1701, -0.0052],
         [ 0.0656, -0.2265,  0.2496,  ..., -0.0322,  0.0831, -0.1027],
         [ 0.0765, -0.0331,  0.1032,  ..., -0.0568, -0.1340, -0.0646],
         ...,
         [-0.0087, -0.2668,  0.1832,  ...,  0.1033, -0.0098,  0.4049],
         [-0.1144,  0.0671,  0.1157,  ..., -0.1018, -0.0307,  0.1123],
         [-0.0030, -0.2840,  0.0932,  ..., -0.0109, -0.0957,  0.1503]]],
       device='cuda:0', requires_grad=True)

For simplicity, we use batch_size=1 here, but a more elaborate training loop would use larger batches. 


In [None]:
prompt_matrix1.shape, prompt_matrix1.shape

(torch.Size([1, 50, 768]), torch.Size([1, 50, 768]))

In [None]:
def compute_loss(x_text, y_text):
    x_ids = tokenizer(x_text + '\n', return_tensors='pt', add_prefix_space=True).to(model.device).input_ids
    y_ids = tokenizer(y_text + '\n', return_tensors='pt', add_prefix_space=True).to(model.device).input_ids
    input_embeds = torch.cat([prompt_matrix1, model.transformer.wte(x_ids), prompt_matrix2, model.transformer.wte(y_ids)], 1)
    labels = torch.cat([torch.tensor([[-100]]).to(model.device).repeat(1, prompt_matrix1.shape[1] + x_ids.shape[1] + prompt_matrix2.shape[1] ), y_ids], 1)
    out = model(
        inputs_embeds=input_embeds,
        labels=labels
    )
    return out.loss

In [None]:
from torch.optim import Adam
optimizer = Adam([prompt_matrix1, prompt_matrix2], lr=1e-4)

In [None]:
for epoch in trange(1):
    sum_loss = 0
    tq = tqdm(train_data.sample(frac=1.0).values)
    for i, (x_text, y_text) in enumerate(tq):
        loss = compute_loss(x_text, y_text)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        sum_loss += loss.item()
        tq.set_description(str(loss.item()))
    print('epoch', epoch, 'loss', sum_loss / len(train_data))

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2529 [00:00<?, ?it/s]

epoch 0 loss 2.4593590524897553


In [None]:
def generate_with_soft_prompt(text):
    x_ids = tokenizer(text + '\n', return_tensors='pt', add_prefix_space=True).to(model.device).input_ids
    input_embeds = torch.cat([prompt_matrix1, model.transformer.wte(x_ids), prompt_matrix2], 1)

    # we are using greedy decoding, because model.generate() does not support inputs_embeds for GPT so far.
    # see https://github.com/huggingface/transformers/issues/6535#issuecomment-983454474 for discussion.
    with torch.inference_mode():
        out = model(inputs_embeds=input_embeds)
    
    result = []
    for i in range(100):
        with torch.inference_mode():
            i2 = out.logits[0, -1, :].argmax().unsqueeze(0).unsqueeze(0)
            if i2.item() == END_OF_LINE:
                break
            result.append(i2.item())
            out = model(input_ids=i2, past_key_values=out.past_key_values)
    return tokenizer.decode(result).strip()

In [None]:
generate_with_soft_prompt('Go fuck yourself!')

'Go fuck yourself!'

In [None]:
generate_with_soft_prompt('Who is this idiot?')

'Who is this idiot?'

In [None]:
outputs_pt = [generate_with_soft_prompt(text) for text in tqdm(test_inputs)]
print(bleu.corpus_score(outputs_pt, test_outputs_transposed))

  0%|          | 0/100 [00:00<?, ?it/s]

BLEU = 43.89 62.1/48.5/39.1/31.5 (BP = 1.000 ratio = 1.109 hyp_len = 996 ref_len = 898)


The outputs look a little bit btter than inputs, but the BLEU is still 49% – just as if we used the original texts. 

In [None]:
pd.DataFrame({'x': test_inputs, 'y': outputs_pt}).sample(10)

Unnamed: 0,x,y
83,this sick fuck is just a sociopath who doesn t want to face the consequences for his actions .,This is just a sociopath who doesn t want to face the consequences for his actions
72,that more than enough room statement is ridiculous .,that more than enough room statement is ridiculous
61,oh shit you just blew their minds .,Oh shit you just blew their minds.
82,this is some clockwork orange level shit .,This is some clockwork orange level shit.
46,i suck at poems so fuck erogan .,i suck at poems so fuck erogan
32,"fuck ur mom and ur family , bitch , suck my mother fucking cock .",Fuck my mom and my family
69,sucked like nothings ever sucked before .,
60,"nobody in politics , especially democrats , give a shit about citizens , let alone the military .","Nobody in politics, especially democrats, give a shit about citizens"
63,people don t even wash their damn hands after taking a shit .,People don t wash their hands after taking a shit.
48,i was told he s a real dick .,i was told he s a real dick


## Full model tuning

Let us now use the simple prompt, but fine-tune the model itself. 

In [None]:
template

' a toxic text: {}\n a civil rephrase:'

In [None]:
def compute_loss(x_text, y_text):
    x_ids = tokenizer(template.format(x_text), return_tensors='pt', add_prefix_space=True).to(model.device).input_ids
    y_ids = tokenizer(y_text + '\n', return_tensors='pt', add_prefix_space=True).to(model.device).input_ids
    input_ids = torch.cat([x_ids, y_ids], 1)
    labels = torch.cat([torch.tensor([[-100]]).to(model.device).repeat(1, x_ids.shape[1]), y_ids], 1)
    out = model(
        input_ids=input_ids,
        labels=labels
    )
    return out.loss

In [None]:
from torch.optim import Adam
optimizer = Adam(model.parameters(), lr=1e-4)

In [None]:
model.train()
for epoch in trange(1):
    sum_loss = 0
    tq = tqdm(train_data.sample(frac=1.0).values)
    for i, (x_text, y_text) in enumerate(tq):
        loss = compute_loss(x_text, y_text)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        sum_loss += loss.item()
        tq.set_description(str(loss.item()))
    print('epoch', epoch, 'loss', sum_loss / len(train_data))
model.eval();

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2529 [00:00<?, ?it/s]

epoch 0 loss 1.9432618660199024


In [None]:
print(generate(template.format('Go fuck yourself!')))
print(generate(template.format('Who is this idiot?')))

Go away!
Who is this person?


In [None]:
outputs_full = [generate(template.format(text)) for text in tqdm(test_inputs)]
print(bleu.corpus_score(outputs_full, test_outputs_transposed))

  0%|          | 0/100 [00:00<?, ?it/s]

BLEU = 54.52 73.2/58.6/49.4/41.7 (BP = 1.000 ratio = 1.049 hyp_len = 1003 ref_len = 956)


At least after fine-tuning the whole model we were able to beat the "do not change the text" baseline. 

But who knows, maybe with larger model sizes, few-shot or prompt-tuning methods would work better.

In [None]:
pd.DataFrame({'x': test_inputs, 'y': outputs_full}).sample(10)

Unnamed: 0,x,y
34,"haha , fuck me we need to wake up .",We need to wake up.
99,"your questions are kind of odd , are you rapist ?","Your questions are kind of odd, are you rapist?"
47,"i thought you were a russian troll , but maybe you are a chinese troll .","I thought you were a Russian troll, but maybe you are a Chinese troll."
77,they are worse than dirty sicilians .,they are not good.
8,another piece of useless manufactured junk !,Another piece of useless
81,"this asshole was arrested twice , in one month , .","this person was arrested twice, in one month,."
6,and redditors still choose to keep their heads up their asses .,and redditors still choose to keep their heads up their asses
31,fuck novick and his supposed fighting for the little man .,He is not a good person.
38,"holy fucking shit ! i just read that as "" ebola confirmed in mail .","I just read that as "" ebola confirmed in mail."
73,that scum is smiling in court !,that person is smiling
