# Text generation with transformers

In this notebook we train a decoder-only model that can run in generative mode, just like modern LLMs. We will however train it on a rather specific type of text -- the IMDb reviews we have been classifiying in the past. Now we will not be classifying anything, but rather generate new reviews.

The text generation is an _autoregressive_ process, and there are different strategies one can inplement in order to obtain natural-looking text. We will try to implement several ones, and see how they compare.

In [1]:
import tensorflow as tf
import keras
import tensorflow_datasets
from string import punctuation

Exception ignored in: <function _xla_gc_callback at 0x781ade688860>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lib/__init__.py", line 96, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 


## Load data

In [2]:
dataset, info = tensorflow_datasets.load(
    'imdb_reviews',
    with_info=True,
    as_supervised=True,
    split=['train', 'test']
)

train_ds, test_ds = dataset[0], dataset[1]



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/incomplete.QQBFXY_1.0.0/imdb_reviews-train.tfrecor…

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/incomplete.QQBFXY_1.0.0/imdb_reviews-test.tfrecord…

Generating unsupervised examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/incomplete.QQBFXY_1.0.0/imdb_reviews-unsupervised.…

Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.


Quick check:

In [3]:
for example, label in train_ds.take(1):
  print('text: ', example)
  print('label: ', label.numpy())

text:  tf.Tensor(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", shape=(), dtype=string)
label:  0


## Configuration

We need to make som choices on hyperparameters and sequence lengths. You can change these if you like.

In [4]:
vocab_size = 20000  # Only consider the top 20k words
sequence_length = 80  # Max sequence size
embed_dim = 256  # Embedding size for each token
num_heads = 2  # Number of attention heads
feed_forward_dim = 256  # Hidden layer size in feed forward network inside transformer


## Text vectorisation

The usual process:

In [5]:
def custom_standardization(input_string):
    """Remove html line-break tags and handle punctuation"""
    lowercased = tf.strings.lower(input_string)
    stripped_html = tf.strings.regex_replace(lowercased, "<br />", " ")
    return tf.strings.regex_replace(stripped_html, f"([{punctuation}])", r" \1")

In [6]:
text_vectorization = keras.layers.TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size - 1,
    output_mode="int",
    output_sequence_length=sequence_length + 1,
)

text_only_ds = train_ds.map(lambda x, y: x)

text_vectorization.adapt(text_only_ds)
vocabulary = text_vectorization.get_vocabulary()

## Prepare the dataset

We want our decoder to predict the next token of the input sentence -- hence our labels will be the next true token in the sentence.

Create a dataset where the labels are shifted by one position.

In [7]:
def prepare_lm_inputs_labels(text, labels):
    """
    Shift word sequences by 1 position so that the target for position (i) is
    word at position (i+1). The model will use all words up till position (i)
    to predict the next word.

    Discard the original labels, which we don't need.
    """
    #text = tf.expand_dims(text, -1)
    tokenized_sentences = text_vectorization(text)
    x = tokenized_sentences[:-1]
    y = tokenized_sentences[1:]
    return x, y

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.map(prepare_lm_inputs_labels, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(prepare_lm_inputs_labels, num_parallel_calls=AUTOTUNE)

Batching and prefetching

In [8]:
batchsize = 64

train_ds = train_ds.batch(batchsize).prefetch(AUTOTUNE)
test_ds = test_ds.batch(batchsize).prefetch(AUTOTUNE)

Verify that the targets are in fact the original sequence, but shifted one position to the right:

In [9]:
for example, label in train_ds.take(1):
    print('text.shape:', example.shape)
    print('text: ', example[0].numpy())
    print('label: ', label[0].numpy())

text.shape: (64, 80)
text:  [   13    16    40   436   398    20     3    99    26    33 11022    11
    39  1466  3257    50   524 11297     3   213    30    94   164     4
    21    13   220   339    33    72   256   222    11   483     3    66
    72    94   123   106    28  5808    13    20    15   638   770     3
    13    20     9    40   417  9254   187  2523   430     3     2    97
  1227   147    76   157    60     2     1  7793    76   264    72  2967
    18     1     3  3040     1     1  1488  5083]
label:  [   16    40   436   398    20     3    99    26    33 11022    11    39
  1466  3257    50   524 11297     3   213    30    94   164     4    21
    13   220   339    33    72   256   222    11   483     3    66    72
    94   123   106    28  5808    13    20    15   638   770     3    13
    20     9    40   417  9254   187  2523   430     3     2    97  1227
   147    76   157    60     2     1  7793    76   264    72  2967    18
     1     3  3040     1     1  1488  

## Model components

We need positional embeddings, and we need a transformer decoder.

In [10]:
class PositionalEmbedding(keras.layers.Layer):
    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = keras.layers.Embedding(
            input_dim=input_dim, output_dim=output_dim)
        self.position_embeddings = keras.layers.Embedding(
            input_dim=sequence_length, output_dim=output_dim)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.not_equal = keras.layers.Lambda(lambda x: tf.math.not_equal(x, 0))

    def call(self, inputs):
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(self.positions)
        return embedded_tokens + embedded_positions

    def build(self, input_shape):
        length = input_shape[-1]
        self.positions = tf.range(start=0, limit=length, delta=1)

    def compute_mask(self, inputs, mask=None):
        return self.not_equal(inputs)

    def get_config(self):
        config = super(PositionalEmbedding, self).get_config()
        config.update({
            "output_dim": self.output_dim,
            "sequence_length": self.sequence_length,
            "input_dim": self.input_dim,
        })
        return config


class TransformerDecoder(keras.layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention_1 = keras.layers.MultiHeadAttention(
          num_heads=num_heads, key_dim=embed_dim)
        self.attention_2 = keras.layers.MultiHeadAttention(
          num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = keras.Sequential(
            [keras.layers.Dense(dense_dim, activation="relu"),
             keras.layers.Dense(embed_dim),]
        )
        self.layernorm_1 = keras.layers.LayerNormalization()
        self.layernorm_2 = keras.layers.LayerNormalization()
        self.layernorm_3 = keras.layers.LayerNormalization()
        self.supports_masking = True

    def get_config(self):
        config = super(TransformerDecoder, self).get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "dense_dim": self.dense_dim,
        })
        return config

    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = tf.range(sequence_length)[:, tf.newaxis]
        j = tf.range(sequence_length)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1),
             tf.constant([1, 1], dtype=tf.int32)], axis=0)
        return tf.tile(mask, mult)

    def call(self, inputs, encoder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs)
        if mask is not None:
            padding_mask = tf.cast(
                mask[:, tf.newaxis, :], dtype="int32")
            padding_mask = tf.minimum(padding_mask, causal_mask)
        else:
            padding_mask = mask
        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=causal_mask)
        attention_output_1 = self.layernorm_1(inputs + attention_output_1)
        attention_output_2 = self.attention_2(
            query=attention_output_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
        )
        attention_output_2 = self.layernorm_2(
            attention_output_1 + attention_output_2)
        proj_output = self.dense_proj(attention_output_2)
        return self.layernorm_3(attention_output_2 + proj_output)

## Define the model

We set up our model to output the logits, and not the score after softmax, so that we can add temperature scaling to the softmax later.

In this case we need to match out

In [11]:

embed_dim = 256
latent_dim = 2048
num_heads = 2

inputs = keras.Input(shape=(sequence_length,), dtype="int64")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, x)
outputs = keras.layers.Dense(vocab_size, activation=None)(x)    # no softmax, apply it later
model = keras.Model(inputs, outputs)

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer="rmsprop"
)

Train the model:

In [12]:
model.fit(train_ds, epochs=15, validation_data=test_ds)

Epoch 1/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 167ms/step - loss: 6.3468 - val_loss: 5.4153
Epoch 2/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m57s[0m 134ms/step - loss: 5.3335 - val_loss: 5.2011
Epoch 3/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 153ms/step - loss: 5.1201 - val_loss: 5.1191
Epoch 4/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 157ms/step - loss: 4.9727 - val_loss: 5.0637
Epoch 5/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 133ms/step - loss: 4.8482 - val_loss: 5.0381
Epoch 6/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 157ms/step - loss: 4.7410 - val_loss: 5.0150
Epoch 7/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 134ms/step - loss: 4.6387 - val_loss: 5.0039
Epoch 8/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 134ms/step - loss: 4.5393 - val_loss: 4.9879
Epoch 9/15
[1m3

<keras.src.callbacks.history.History at 0x781aab5182d0>

## Generate text

Approach 1: Select most probable token.

In [13]:
import numpy as np

def most_probable(predictions):
    """
    Return index of the most probable token
    """
    return np.argmax(predictions)

Get token indices from vocabulary

In [14]:
tokens_index = dict(enumerate(text_vectorization.get_vocabulary()))

In [25]:
prompt = "This movie"
generate_length = 50

sentence = prompt
for i in range(generate_length):
    tokenized_sentence = text_vectorization([sentence])[:, :sequence_length]
    predictions = model(tokenized_sentence)
    next_token = most_probable(
        predictions[0, i, :]
    )
    sampled_token = tokens_index[next_token]
    sentence += " " + sampled_token
print(sentence)

This movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie movie


Approach 2:

### <span style="color: red;">Exercise:<span>

Implement top-K sampling.

For the sampling itself (after selecting the top token scores), you can use

```
samples = np.random.multinomial(1, predictions, 1)
return np.argmax(samples)
```

([NumPy docs](https://numpy.org/doc/2.2/reference/random/generated/numpy.random.multinomial.html))

In [34]:
def find_top_k_sample(preds, k=5):
  top_k_indices = np.argsort(preds)[-k:]
  top_k_probs = preds[top_k_indices]
  top_k_probs = top_k_probs / np.sum(top_k_probs)
  return np.random.choice(top_k_indices, p=top_k_probs)


Approach 3:

### <span style="color: red;">Exercise:<span>:

Compute token scores using softmax with temperature.

The equation is

$$
y = \frac{\exp(a_i / T)}{\sum_j \exp(a_j /T)} \,,
$$

where $T$ is the temperature, $a_i$ is the logit of the token in question.


In [37]:
def apply_temperature(predictions, temperature=1):
  predictions = np.exp(predictions / temperature)
  return predictions / np.sum(predictions)

prompt = "This movie is such a good movie"
generate_length = 50

sentence = prompt
for i in range(generate_length):
    tokenized_sentence = text_vectorization([sentence])[:, :sequence_length]
    predictions = model(tokenized_sentence)
    # Convert predictions to a numpy array if not already; adjust indexing as needed.
    preds = predictions[0, i, :].numpy()
    # Sample the next token index using top-k sampling with your desired k (e.g., 5)
    next_token = find_top_k_sample(preds, k=5)
    sampled_token = tokens_index[next_token]
    sentence += " " + sampled_token
print(sentence)



This movie is such a good movie it was .this as great film it was such as favorite film , were different as saturday film i are different although late film you are that because saturday 1970 anyone 're that although saturday lucky anyone peter that although saturday lucky ones peter whatever since saturday california 50s .


Approach 4:

### <span style="color: red;">Exercise:<span> (more difficult)

Implement beam search. For this you will need to manage several (let's say 3 to 5) parallel branches of outputs up to a certain length, and then compute the probabilities of each branch, before selecting the most likely one:

![](https://d2l.ai/_images/beam-search.svg)

For more information about beam search, have a look at https://d2l.ai/chapter_recurrent-modern/beam-search.html, or other sources.