In [38]:
!!pip install -q rouge-score
!!pip install -q git+https://github.com/keras-team/keras-nlp.git --upgrade

[]

In [39]:
import keras_nlp
import pathlib
import random
import tensorflow as tf

from tensorflow import keras
from tensorflow_text.tools.wordpiece_vocab import (
    bert_vocab_from_dataset as bert_vocab,
)

In [2]:
BATCH_SIZE = 64
EPOCHS = 1  # This should be at least 10 for convergence
MAX_SEQUENCE_LENGTH = 40
ENG_VOCAB_SIZE = 15000
SPA_VOCAB_SIZE = 15000

EMBED_DIM = 256
INTERMEDIATE_DIM = 2048
NUM_HEADS = 8

In [3]:
text_file = keras.utils.get_file(
    fname="spa-eng.zip",
    origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
    extract=True,
)
text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"

Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip


In [4]:
with open(text_file) as f:
    lines = f.read().split("\n")[:-1]
text_pairs = []
for line in lines:
    eng, spa = line.split("\t")
    eng = eng.lower()
    spa = spa.lower()
    text_pairs.append((eng, spa))

In [5]:
random.shuffle(text_pairs)
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

118964 total pairs
83276 training pairs
17844 validation pairs
17844 test pairs


In [6]:
def train_word_piece(text_samples, vocab_size, reserved_tokens):
    word_piece_ds = tf.data.Dataset.from_tensor_slices(text_samples)
    vocab = keras_nlp.tokenizers.compute_word_piece_vocabulary(
        word_piece_ds.batch(1000).prefetch(2),
        vocabulary_size=vocab_size,
        reserved_tokens=reserved_tokens,
    )
    return vocab

In [7]:
reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"]

eng_samples = [text_pair[0] for text_pair in train_pairs]
eng_vocab = train_word_piece(eng_samples, ENG_VOCAB_SIZE, reserved_tokens)

spa_samples = [text_pair[1] for text_pair in train_pairs]
spa_vocab = train_word_piece(spa_samples, SPA_VOCAB_SIZE, reserved_tokens)

In [8]:
print("English Tokens: ", eng_vocab[100:110])
print("Spanish Tokens: ", spa_vocab[100:110])

English Tokens:  ['him', 'there', 'they', 'go', 'her', 'has', 're', 'will', 'll', 'time']
Spanish Tokens:  ['te', 'para', 'mary', 'las', 'más', 'al', 'yo', 'estoy', 'muy', 'tu']


In [9]:
eng_tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=eng_vocab, lowercase=False
)
spa_tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=spa_vocab, lowercase=False
)

In [10]:
eng_input_ex = text_pairs[0][0]
eng_tokens_ex = eng_tokenizer.tokenize(eng_input_ex)
print("English sentence: ", eng_input_ex)
print("Tokens: ", eng_tokens_ex)
print(
    "Recovered text after detokenizing: ",
    eng_tokenizer.detokenize(eng_tokens_ex),
)

print()

spa_input_ex = text_pairs[0][1]
spa_tokens_ex = spa_tokenizer.tokenize(spa_input_ex)
print("Spanish sentence: ", spa_input_ex)
print("Tokens: ", spa_tokens_ex)
print(
    "Recovered text after detokenizing: ",
    spa_tokenizer.detokenize(spa_tokens_ex),
)

English sentence:  my hands are tied.
Tokens:  tf.Tensor([  79  612   86 2140   12], shape=(5,), dtype=int32)
Recovered text after detokenizing:  tf.Tensor(b'my hands are tied .', shape=(), dtype=string)

Spanish sentence:  tengo las manos atadas.
Tokens:  tf.Tensor([ 117  103  640   29 2063   91   15], shape=(7,), dtype=int32)
Recovered text after detokenizing:  tf.Tensor(b'tengo las manos atadas .', shape=(), dtype=string)


In [11]:
def preprocess_batch(eng, spa):
    batch_size = tf.shape(spa)[0]

    eng = eng_tokenizer(eng)
    spa = spa_tokenizer(spa)

    # Pad `eng` to `MAX_SEQUENCE_LENGTH`.
    eng_start_end_packer = keras_nlp.layers.StartEndPacker(
        sequence_length=MAX_SEQUENCE_LENGTH,
        pad_value=eng_tokenizer.token_to_id("[PAD]"),
    )
    eng = eng_start_end_packer(eng)

    # Add special tokens (`"[START]"` and `"[END]"`) to `spa` and pad it as well.
    spa_start_end_packer = keras_nlp.layers.StartEndPacker(
        sequence_length=MAX_SEQUENCE_LENGTH + 1,
        start_value=spa_tokenizer.token_to_id("[START]"),
        end_value=spa_tokenizer.token_to_id("[END]"),
        pad_value=spa_tokenizer.token_to_id("[PAD]"),
    )
    spa = spa_start_end_packer(spa)

    return (
        {
            "encoder_inputs": eng,
            "decoder_inputs": spa[:, :-1],
        },
        spa[:, 1:],
    )


def make_dataset(pairs):
    eng_texts, spa_texts = zip(*pairs)
    eng_texts = list(eng_texts)
    spa_texts = list(spa_texts)
    dataset = tf.data.Dataset.from_tensor_slices((eng_texts, spa_texts))
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.map(preprocess_batch, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.shuffle(2048).prefetch(16).cache()


train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)

In [12]:
for inputs, targets in train_ds.take(1):
    print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
    print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
    print(f"targets.shape: {targets.shape}")

inputs["encoder_inputs"].shape: (64, 40)
inputs["decoder_inputs"].shape: (64, 40)
targets.shape: (64, 40)


In [13]:
encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")

x = keras_nlp.layers.TokenAndPositionEmbedding(
    vocabulary_size=ENG_VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
    mask_zero=True,
)(encoder_inputs)

encoder_outputs = keras_nlp.layers.TransformerEncoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(inputs=x)
encoder = keras.Model(encoder_inputs, encoder_outputs)


# Decoder
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, EMBED_DIM), name="decoder_state_inputs")

x = keras_nlp.layers.TokenAndPositionEmbedding(
    vocabulary_size=SPA_VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
    mask_zero=True,
)(decoder_inputs)

x = keras_nlp.layers.TransformerDecoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)
x = keras.layers.Dropout(0.5)(x)
decoder_outputs = keras.layers.Dense(SPA_VOCAB_SIZE, activation="softmax")(x)
decoder = keras.Model(
    [
        decoder_inputs,
        encoded_seq_inputs,
    ],
    decoder_outputs,
)
decoder_outputs = decoder([decoder_inputs, encoder_outputs])

transformer = keras.Model(
    [encoder_inputs, decoder_inputs],
    decoder_outputs,
    name="transformer",
)

In [14]:
transformer.summary()
transformer.compile(
    "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
transformer.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)

Model: "transformer"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 encoder_inputs (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 token_and_position_embedding (  (None, None, 256)   3850240     ['encoder_inputs[0][0]']         
 TokenAndPositionEmbedding)                                                                       
                                                                                                  
 decoder_inputs (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 transformer_encoder (Transform  (None, None, 256)   1315072     ['token_and_position_em

<keras.callbacks.History at 0x7f9ca9dc7a60>

In [43]:
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Text generation utilities."""

import tensorflow as tf
from absl import logging
from tensorflow import keras


def _validate_prompt(prompt):
    """Helper function to validate input to text_generation utils."""
    if not isinstance(prompt, (tf.Tensor, tf.RaggedTensor)):
        prompt = tf.convert_to_tensor(prompt)
    return prompt


def _validate_token_probability_fn(token_probability_fn, prompt):
    """Helper function to validate `token_probability_fn` output."""
    test_pred = token_probability_fn(prompt)
    if len(test_pred.shape) != 2:
        raise ValueError(
            "Output of `token_probability_fn` is not a 2D tensor, "
            "please provide a function with the output shape "
            "[batch_size, vocab_size]."
        )

    return tf.shape(test_pred)[-1], test_pred.dtype


def _get_prompt_shape(prompt):
    """Helper function to get the batch size and prompt length."""
    if isinstance(prompt, tf.Tensor):
        shape = tf.shape(prompt)
        return (shape[0], shape[1])
    elif isinstance(prompt, tf.RaggedTensor):
        batch_size = prompt.nrows()
        length = tf.math.reduce_min(tf.RaggedTensor.row_lengths(prompt))
        return (batch_size, length)


def _pad_prompt(prompt, max_length):
    """Pad prompt to `max_length` and compute a mask for controlled updates.

    This utility will pad the (possibly ragged) prompt to `max_length`, and
    compute a mask where the input was originally set in the prompt, to avoid
    overwriting the original inputs when generating token(s) for the next
    timestep.
    """
    if isinstance(prompt, tf.Tensor):
        shape = tf.shape(prompt)
        extra_space = tf.math.maximum(0, max_length - shape[1])
        pad_shape = [shape[0], extra_space]

        mask = tf.ones(shape, tf.bool)
        mask = tf.concat((mask, tf.zeros(pad_shape, tf.bool)), axis=1)
        prompt = tf.concat((prompt, tf.zeros(pad_shape, prompt.dtype)), axis=1)
    elif isinstance(prompt, tf.RaggedTensor):
        # TODO: `to_tensor()` works with `jit_compile = True` in TF 2.8.x but
        # fails in TF 2.9.x. Fix this. After this issue has been fixed, we can
        # condense the two branches into one by starting off with a ragged tensor.
        mask = tf.ones_like(prompt, dtype=tf.bool)
        mask = mask.to_tensor(shape=(None, max_length))
        prompt = prompt.to_tensor(shape=(None, max_length))
    return prompt, mask


def _mask_tokens_after_end_token(
    prompt, max_length, end_token_id, pad_token_id
):
    """Helper function to mask the tokens after the end token."""
    # Mask out tokens after `end_token_id` is encountered.
    # Find index of first end_token_id.
    end_indices = tf.math.argmax(prompt == end_token_id, -1)
    # Use max_length if no `end_token_id` is found.
    end_indices = tf.where(
        end_indices == 0,
        tf.cast(max_length, dtype=end_indices.dtype),
        end_indices,
    )
    # Build a mask including end_token and replace tokens after end_token
    # with `pad_token_id`.
    valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
    return tf.where(valid_indices, prompt, pad_token_id)


def greedy_search(
    token_probability_fn,
    prompt,
    max_length,
    end_token_id=None,
    pad_token_id=0,
):
    """Text generation utility based on greedy search.

    Greedy search always appends the token having the largest probability to
    existing sequence.

    Args:
        token_probability_fn: a callable, which takes in input_sequence
            and output the probability distribution or the logits of the next
            token.
        prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
            append generated tokens.
        max_length: int. The max length of generated text.
        end_token_id: int, defaults to None. The token marking the end of the
            sequence, once encountered the generation is finished for the exact
            sequence. If None, every sequence is generated up to `max_length`.
            If set, all tokens after encountering `end_token_id` will be
            replaced with `pad_token_id`.
        pad_token_id: int, defaults to 0. The pad token after `end_token_id`
            is received.

    Returns:
        A 1D int Tensor, or 2D int RaggedTensor representing the generated
        sequences.

    Examples:
    ```python
    BATCH_SIZE = 8
    VOCAB_SIZE = 10
    FEATURE_SIZE = 16
    START_ID = 1
    END_ID = 2

    # Create a dummy model to predict the next token.
    model = keras.Sequential(
        [
            keras.Input(shape=[None]),
            keras.layers.Embedding(
                input_dim=VOCAB_SIZE,
                output_dim=FEATURE_SIZE,
            ),
            keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
        ]
    )

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def token_probability_fn(inputs):
        return model(inputs)[:, -1, :]

    prompt = tf.fill((BATCH_SIZE, 1), START_ID)

    # Print the generated sequence (token ids).
    keras_nlp.utils.greedy_search(
        token_probability_fn,
        prompt,
        max_length=10,
        end_token_id=END_ID,
    )
    ```

    """
    prompt = _validate_prompt(prompt)

    input_is_1d = prompt.shape.rank == 1
    if input_is_1d:
        prompt = prompt[tf.newaxis, :]

    batch_size, length = _get_prompt_shape(prompt)
    prompt, mask = _pad_prompt(prompt, max_length)

    _validate_token_probability_fn(token_probability_fn, prompt)

    def one_step(length, prompt):
        pred = token_probability_fn(prompt[:, :length])
        next_token = tf.cast(tf.argmax(pred, axis=-1), dtype=prompt.dtype)
        next_token = tf.where(mask[:, length], prompt[:, length], next_token)

        # Append the next token to current sequence.
        prompt = tf.tensor_scatter_nd_update(
            tensor=prompt,
            indices=tf.stack(
                (
                    tf.cast(tf.range(batch_size), dtype=length.dtype),
                    tf.repeat(length, batch_size),
                ),
                axis=1,
            ),
            updates=next_token,
        )

        length = tf.add(length, 1)
        return (length, prompt)

    # Run a while loop till text of length `max_length` has been generated.
    length, prompt = tf.while_loop(
        cond=lambda length, _: tf.less(length, max_length),
        body=one_step,
        loop_vars=(length, prompt),
    )

    if end_token_id is not None:
        prompt = _mask_tokens_after_end_token(
            prompt, max_length, end_token_id, pad_token_id
        )

    return tf.squeeze(prompt) if input_is_1d else prompt


def beam_search(
    token_probability_fn,
    prompt,
    max_length,
    num_beams,
    from_logits=False,
    end_token_id=None,
    pad_token_id=0,
):
    """Text generation utility based on beam search algorithm.

    At each time-step, beam search keeps the beams (sequences) of the top
    `num_beams` highest accumulated probabilities, and uses each one of the
    beams to predict candidate next tokens.

    Args:
        token_probability_fn: a callable, which takes in input_sequence
            and outputs the probability distribution of the next token. If
            `from_logits` set to True, it should output the logits of the next
            token. The input shape would be `[batch_size * num_beams, length]`
            and the output should be `[batch_size * num_beams, vocab_size]`.
        prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
            append generated tokens. The initial beam for beam search.
        max_length: int. The max length of generated text.
        num_beams: int. The number of beams that should be kept at each
            time-step. `num_beams` should be strictly positive.
        from_logits: bool. Indicates whether `token_probability_fn` outputs
            logits or probabilities.
        end_token_id: int, defaults to None. The token marking the end of the
            sequence, once encountered the generation is finished for the exact
            sequence. If None, every sequence is generated up to `max_length`.
            If set, all tokens after encountering `end_token_id` will be
            replaced with `pad_token_id`.
        pad_token_id: int, defaults to 0. The pad token after `end_token_id`
            is received.

    Returns:
        A 1D int Tensor, or 2D int Tensor representing the generated
        sequences.

    Examples:
    ```python
    BATCH_SIZE = 8
    VOCAB_SIZE = 10
    FEATURE_SIZE = 16
    START_ID = 1
    END_ID = 2

    # Create a dummy model to predict the next token.
    model = keras.Sequential(
        [
            keras.Input(shape=[None]),
            keras.layers.Embedding(
                input_dim=VOCAB_SIZE,
                output_dim=FEATURE_SIZE,
            ),
            keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
        ]
    )

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def token_probability_fn(inputs):
        return model(inputs)[:, -1, :]

    prompt = tf.fill((BATCH_SIZE, 1), START_ID)

    # Print the generated sequence (token ids).
    keras_nlp.utils.beam_search(
        token_probability_fn,
        prompt,
        max_length=10,
        num_beams=5,
        end_token_id=END_ID,
    )
    ```

    """
    if num_beams <= 0:
        raise ValueError(
            f"`num_beams` should be strictly positive. Received: `num_beams={num_beams}`."
        )
    if num_beams == 1:
        return greedy_search(
            token_probability_fn=token_probability_fn,
            prompt=prompt,
            max_length=max_length,
            end_token_id=end_token_id,
            pad_token_id=pad_token_id,
        )

    prompt = _validate_prompt(prompt)

    input_is_1d = prompt.shape.rank == 1
    if input_is_1d:
        prompt = prompt[tf.newaxis, :]

    batch_size, length = _get_prompt_shape(prompt)
    prompt, mask = _pad_prompt(prompt, max_length)

    vocab_size, pred_dtype = _validate_token_probability_fn(
        token_probability_fn, prompt
    )

    if length >= max_length:
        return tf.squeeze(prompt) if input_is_1d else prompt

    # Initialize beam with shape `(batch_size, num_beams, length)`.
    beams = tf.repeat(tf.expand_dims(prompt, axis=1), num_beams, axis=1)

    # Initialize `beams_prob` with shape `(batch_size, num_beams)`.
    beams_prob = tf.zeros([batch_size, 1], dtype=pred_dtype)
    beams_prob = tf.concat(
        [beams_prob, tf.fill((batch_size, num_beams - 1), pred_dtype.min)],
        axis=-1,
    )

    def one_step(beams, beams_prob, length):
        truncated_beams = beams[..., :length]

        flattened_beams = tf.reshape(
            truncated_beams, shape=[batch_size * num_beams, -1]
        )
        preds = token_probability_fn(flattened_beams)
        if from_logits:
            preds = keras.activations.softmax(preds, axis=-1)
        # Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`.
        preds = tf.reshape(preds, shape=[batch_size, -1])

        probs = tf.math.log(preds) + tf.repeat(
            beams_prob, repeats=vocab_size, axis=1
        )

        candidate_prob, candidate_indexes = tf.math.top_k(
            probs, k=num_beams, sorted=False
        )
        candidate_beam_indexes = candidate_indexes // vocab_size
        next_token = candidate_indexes % vocab_size

        beams = tf.gather(beams, candidate_beam_indexes, axis=1, batch_dims=1)

        # Build a new column of updates to scatter into the beam tensor.
        next_token = tf.where(
            condition=mask[..., length, tf.newaxis],
            x=beams[..., length],
            y=next_token,
        )
        next_token = tf.reshape(next_token, shape=[-1])

        # Generate `(batch_index, beam_index)` tuples for each beam.
        beam_indices = tf.where(tf.ones((batch_size, num_beams), tf.bool))
        beam_indices = tf.cast(beam_indices, dtype=length.dtype)
        # Build a tensor of repeated `length` values.
        length_indices = tf.fill((batch_size * num_beams, 1), length)
        # Concatenate to a triplet of `(batch_index, beam_index, length)`.
        indices = tf.concat([beam_indices, length_indices], axis=-1)

        # Update `beams[:, :, length]` with `next_token`.
        beams = tf.tensor_scatter_nd_update(
            tensor=beams,
            indices=indices,
            updates=next_token,
        )

        beams_prob = candidate_prob
        length = tf.add(length, 1)

        return beams, beams_prob, length

    # Run a while loop till text of length `max_length` has been generated.
    beams, beams_prob, length = tf.while_loop(
        cond=lambda beams, beams_prob, length: tf.less(length, max_length),
        body=one_step,
        loop_vars=(beams, beams_prob, length),
    )

    # Get the beam with the maximum probability.
    max_indexes = tf.math.argmax(beams_prob, axis=-1)
    max_beams = tf.gather(
        beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1
    )
    prompt = tf.squeeze(max_beams)

    if end_token_id is not None:
        prompt = _mask_tokens_after_end_token(
            prompt, max_length, end_token_id, pad_token_id
        )
    return tf.squeeze(prompt) if input_is_1d else prompt


def random_search(
    token_probability_fn,
    prompt,
    max_length,
    seed=None,
    from_logits=False,
    end_token_id=None,
    pad_token_id=0,
):
    """Text generation utility based on randomly sampling the entire probability
    distribution.

    Random sampling samples the next token from the probability distribution
    provided by `token_probability_fn` and appends it to the existing sequence.

    Args:
        token_probability_fn: a callable, which takes in input_sequence
            and output the probability distribution of the next token. If
            `from_logits` set to True, it should output the logits of the next
            token.
        prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
            append generated tokens.
        max_length: int. The max length of generated text.
        seed: int, defaults to None. The random seed used for sampling.
        from_logits: bool. Indicates whether `token_probability_fn` outputs
            logits or probabilities.
        end_token_id: int, defaults to None. The token marking the end of the
            sequence, once encountered the generation is finished for the exact
            sequence. If None, every sequence is generated up to `max_length`.
            If set, all tokens after encountering `end_token_id` will be
            replaced with `pad_token_id`.
        pad_token_id: int, defaults to 0. The pad token after `end_token_id`
            is received.

    Returns:
        A 1D int Tensor, or 2D int Tensor representing the generated
        sequences.

    Examples:
    ```python
    BATCH_SIZE = 8
    VOCAB_SIZE = 10
    FEATURE_SIZE = 16
    START_ID = 1
    END_ID = 2

    # Create a dummy model to predict the next token.
    model = keras.Sequential(
        [
            keras.Input(shape=[None]),
            keras.layers.Embedding(
                input_dim=VOCAB_SIZE,
                output_dim=FEATURE_SIZE,
            ),
            keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
        ]
    )

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def token_probability_fn(inputs):
        return model(inputs)[:, -1, :]

    prompt = tf.fill((BATCH_SIZE, 1), START_ID)

    # Print the generated sequence (token ids).
    keras_nlp.utils.random_search(
        token_probability_fn,
        prompt,
        max_length=10,
        end_token_id=END_ID,
    )
    ```

    """
    prompt = _validate_prompt(prompt)

    input_is_1d = prompt.shape.rank == 1
    if input_is_1d:
        prompt = prompt[tf.newaxis, :]

    batch_size, length = _get_prompt_shape(prompt)
    prompt, mask = _pad_prompt(prompt, max_length)

    _validate_token_probability_fn(token_probability_fn, prompt)

    def one_step(length, prompt):
        pred = token_probability_fn(prompt[:, :length])
        if from_logits:
            pred = keras.activations.softmax(pred, axis=-1)
        next_token = tf.squeeze(
            tf.cast(
                tf.random.categorical(tf.math.log(pred), 1, seed=seed),
                dtype=prompt.dtype,
            ),
            axis=1,
        )
        next_token = tf.where(mask[:, length], prompt[:, length], next_token)

        # Append the next token to current sequence.
        prompt = tf.tensor_scatter_nd_update(
            tensor=prompt,
            indices=tf.stack(
                (
                    tf.cast(tf.range(batch_size), dtype=length.dtype),
                    tf.repeat(length, batch_size),
                ),
                axis=1,
            ),
            updates=next_token,
        )

        length = tf.add(length, 1)
        return (length, prompt)

    # Run a while loop till text of length `max_length` has been generated.
    length, prompt = tf.while_loop(
        cond=lambda length, _: tf.less(length, max_length),
        body=one_step,
        loop_vars=(length, prompt),
    )

    if end_token_id is not None:
        prompt = _mask_tokens_after_end_token(
            prompt, max_length, end_token_id, pad_token_id
        )

    return tf.squeeze(prompt) if input_is_1d else prompt


def top_k_search(
    token_probability_fn,
    prompt,
    max_length,
    k,
    seed=None,
    from_logits=False,
    end_token_id=None,
    pad_token_id=0,
):
    """Text generation utility based on top-k sampling.

    Top-k search samples the next token from the top-k tokens in the
    probability distribution provided by `token_probability_fn` and appends it
    to the existing sequence.

    Args:
        token_probability_fn: a callable, which takes in input_sequence
            and output the probability distribution of the next token. If
            `from_logits` set to True, it should output the logits of the next
            token.
        prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
            append generated tokens.
        max_length: int. The max length of generated text.
        k: int. The number of top tokens to sample from. Should be non-negative
            and less than the vocabulary size.
        seed: int, defaults to None. The random seed used for sampling.
        from_logits: bool. Indicates whether `token_probability_fn` outputs
            logits or probabilities.
        end_token_id: int, defaults to None. The token marking the end of the
            sequence, once encountered the generation is finished for the exact
            sequence. If None, every sequence is generated up to `max_length`.
            If set, all tokens after encountering `end_token_id` will be
            replaced with `pad_token_id`.
        pad_token_id: int, defaults to 0. The pad token after `end_token_id`
            is received.

    Returns:
        A 1D int Tensor, or 2D int Tensor representing the generated
        sequences.

    Examples:
    ```python
    BATCH_SIZE = 8
    VOCAB_SIZE = 10
    FEATURE_SIZE = 16
    START_ID = 1
    END_ID = 2

    # Create a dummy model to predict the next token.
    model = keras.Sequential(
        [
            keras.Input(shape=[None]),
            keras.layers.Embedding(
                input_dim=VOCAB_SIZE,
                output_dim=FEATURE_SIZE,
            ),
            keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
        ]
    )

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def token_probability_fn(inputs):
        return model(inputs)[:, -1, :]

    prompt = tf.fill((BATCH_SIZE, 1), START_ID)

    # Print the generated sequence (token ids).
    keras_nlp.utils.top_k_search(
        token_probability_fn,
        prompt,
        max_length=10,
        k=4,
        end_token_id=END_ID,
    )
    ```

    """
    if k <= 0:
        raise ValueError(f"`k` should be strictly positive. Received: `k={k}`.")

    prompt = _validate_prompt(prompt)

    input_is_1d = prompt.shape.rank == 1
    if input_is_1d:
        prompt = prompt[tf.newaxis, :]

    batch_size, length = _get_prompt_shape(prompt)
    prompt, mask = _pad_prompt(prompt, max_length)

    _validate_token_probability_fn(token_probability_fn, prompt)

    # If k is greater than the vocabulary size, use the entire vocabulary.
    pred = token_probability_fn(prompt)
    if k > pred.shape[1]:
        logging.warning(
            f"`k` larger than vocabulary size={pred.shape[1]}."
            f"Setting `k` to vocabulary size. Received: `k={k}`."
        )
        k = pred.shape[1]

    def one_step(length, prompt):
        pred = token_probability_fn(prompt[:, :length])
        if from_logits:
            pred = keras.activations.softmax(pred, axis=-1)

        # Filter out top-k tokens.
        top_k_pred, top_k_indices = tf.math.top_k(pred, k=k, sorted=False)
        # Sample the next token from the probability distribution.
        next_token = tf.random.categorical(
            tf.math.log(top_k_pred), 1, seed=seed
        )

        # Rearrange to get the next token idx from the original order.
        next_token = tf.gather_nd(top_k_indices, next_token, batch_dims=1)
        next_token = tf.cast(next_token, dtype=prompt.dtype)
        next_token = tf.where(mask[:, length], prompt[:, length], next_token)

        # Append the next token to current sequence.
        prompt = tf.tensor_scatter_nd_update(
            tensor=prompt,
            indices=tf.stack(
                (
                    tf.cast(tf.range(batch_size), dtype=length.dtype),
                    tf.repeat(length, batch_size),
                ),
                axis=1,
            ),
            updates=next_token,
        )

        length = tf.add(length, 1)
        return (length, prompt)

    # Run a while loop till text of length `max_length` has been generated.
    length, prompt = tf.while_loop(
        cond=lambda length, _: tf.less(length, max_length),
        body=one_step,
        loop_vars=(length, prompt),
    )

    if end_token_id is not None:
        prompt = _mask_tokens_after_end_token(
            prompt, max_length, end_token_id, pad_token_id
        )

    return tf.squeeze(prompt) if input_is_1d else prompt


def top_p_search(
    token_probability_fn,
    prompt,
    max_length,
    p,
    seed=None,
    from_logits=False,
    end_token_id=None,
    pad_token_id=0,
):
    """Text generation utility based on top-p (nucleus) sampling.

    Top-p search selects tokens from the smallest subset of output probabilities
    that sum to greater than `p`. Put another way, top-p will first order
    token predictions by likelihood, and ignore all tokens after the cumulative
    probability of selected tokens exceeds `p`. The probability of each
    token is provided by `token_probability_fn`.

    Args:
        token_probability_fn: a callable, which takes in input_sequence
            and output the probability distribution of the next token. If
            `from_logits` set to True, it should output the logits of the next
            token.
        prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
            append generated tokens.
        max_length: int. The max length of generated text.
        p: float. The probability that the top tokens sums up to. Should
            follow the constraint of 0 < p < 1.
        seed: int, defaults to None. The random seed used for sampling.
        from_logits: bool. Indicates whether `token_probability_fn` outputs
            logits or probabilities.
        end_token_id: int, defaults to None. The token marking the end of the
            sequence, once encountered the generation is finished for the exact
            sequence. If None, every sequence is generated up to `max_length`.
            If set, all tokens after encountering `end_token_id` will be
            replaced with `pad_token_id`.
        pad_token_id: int, defaults to 0. The pad token after `end_token_id`
            is received.

    Returns:
        A 1D int Tensor, or 2D int Tensor representing the generated
        sequences.

    Examples:
    ```python
    BATCH_SIZE = 8
    VOCAB_SIZE = 10
    FEATURE_SIZE = 16
    START_ID = 1
    END_ID = 2

    # Create a dummy model to predict the next token.
    model = keras.Sequential(
        [
            keras.Input(shape=[None]),
            keras.layers.Embedding(
                input_dim=VOCAB_SIZE,
                output_dim=FEATURE_SIZE,
            ),
            keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
        ]
    )

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def token_probability_fn(inputs):
        return model(inputs)[:, -1, :]

    prompt = tf.fill((BATCH_SIZE, 1), START_ID)

    # Print the generated sequence (token ids).
    keras_nlp.utils.top_p_search(
        token_probability_fn,
        prompt,
        max_length=10,
        p=0.8,
        end_token_id=END_ID,
    )
    ```

    """
    if p <= 0 or p >= 1:
        raise ValueError(
            f"`p` should be in the range (0, 1). Received: `p={p}`."
        )

    prompt = _validate_prompt(prompt)

    input_is_1d = prompt.shape.rank == 1
    if input_is_1d:
        prompt = prompt[tf.newaxis, :]

    batch_size, length = _get_prompt_shape(prompt)
    prompt, mask = _pad_prompt(prompt, max_length)

    _validate_token_probability_fn(token_probability_fn, prompt)

    def one_step(length, prompt):
        pred = token_probability_fn(prompt[:, :length])
        if from_logits:
            pred = keras.activations.softmax(pred, axis=-1)
        # Sort preds in descending order.
        sorted_preds, sorted_indices = tf.math.top_k(
            pred, k=pred.shape[1], sorted=True
        )
        # Calculate cumulative probability distribution.
        cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1)
        # Create a mask for the tokens to keep.
        keep_mask = cumulative_probs <= p
        # Shift to include the last token that exceed p.
        shifted_keep_mask = tf.concat(
            [tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1
        )
        # Filter out unmasked tokens and sample from filtered distribution.
        probs = tf.where(
            shifted_keep_mask,
            sorted_preds,
            tf.zeros(tf.shape(pred), dtype=sorted_preds.dtype),
        )
        sorted_next_token = tf.random.categorical(
            tf.math.log(probs), 1, seed=seed
        )
        next_token = tf.gather_nd(
            sorted_indices, sorted_next_token, batch_dims=1
        )
        next_token = tf.cast(next_token, dtype=prompt.dtype)
        next_token = tf.where(mask[:, length], prompt[:, length], next_token)

        # Append the next token to current sequence.
        prompt = tf.tensor_scatter_nd_update(
            tensor=prompt,
            indices=tf.stack(
                (
                    tf.cast(tf.range(batch_size), dtype=length.dtype),
                    tf.repeat(length, batch_size),
                ),
                axis=1,
            ),
            updates=next_token,
        )

        length = tf.add(length, 1)
        return (length, prompt)

    # Run a while loop till text of length `max_length` has been generated.
    length, prompt = tf.while_loop(
        cond=lambda length, _: tf.less(length, max_length),
        body=one_step,
        loop_vars=(length, prompt),
    )

    if end_token_id is not None:
        prompt = _mask_tokens_after_end_token(
            prompt, max_length, end_token_id, pad_token_id
        )

    return tf.squeeze(prompt) if input_is_1d else prompt

In [44]:
def decode_sequences(input_sentences):
    batch_size = tf.shape(input_sentences)[0]

    # Tokenize the encoder input.
    encoder_input_tokens = eng_tokenizer(input_sentences).to_tensor(
        shape=(None, MAX_SEQUENCE_LENGTH)
    )

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def token_probability_fn(decoder_input_tokens):
        return transformer([encoder_input_tokens, decoder_input_tokens])[:, -1, :]

    # Set the prompt to the "[START]" token.
    prompt = tf.fill((batch_size, 1), spa_tokenizer.token_to_id("[START]"))

    generated_tokens = top_p_search(
        token_probability_fn,
        prompt,
        p=0.1,
        max_length=40,
        end_token_id=spa_tokenizer.token_to_id("[END]"),
    )
    generated_sentences = spa_tokenizer.detokenize(generated_tokens)
    return generated_sentences


test_eng_texts = [pair[0] for pair in test_pairs]
for i in range(10):
    input_sentence = random.choice(test_eng_texts)
    translated = decode_sequences(tf.constant([input_sentence]))
    translated = translated.numpy()[0].decode("utf-8")
    translated = (
        translated.replace("[PAD]", "")
        .replace("[START]", "")
        .replace("[END]", "")
        .strip()
    )
    print(f"** Example {i} **")
    print(input_sentence)
    print(translated)
    print()

** Example 0 **
tom doesn't have a cat.
tom no tiene una .

** Example 1 **
she was trembling with fear.
ella no fue aricumpro .



In [45]:
rouge_1 = keras_nlp.metrics.RougeN(order=1)
rouge_2 = keras_nlp.metrics.RougeN(order=2)

for test_pair in test_pairs[:30]:
    input_sentence = test_pair[0]
    reference_sentence = test_pair[1]

    translated_sentence = decode_sequences(tf.constant([input_sentence]))
    translated_sentence = translated_sentence.numpy()[0].decode("utf-8")
    translated_sentence = (
        translated_sentence.replace("[PAD]", "")
        .replace("[START]", "")
        .replace("[END]", "")
        .strip()
    )

    rouge_1(reference_sentence, translated_sentence)
    rouge_2(reference_sentence, translated_sentence)

print("ROUGE-1 Score: ", rouge_1.result())
print("ROUGE-2 Score: ", rouge_2.result())

ROUGE-1 Score:  {'precision': <tf.Tensor: shape=(), dtype=float32, numpy=0.37296093>, 'recall': <tf.Tensor: shape=(), dtype=float32, numpy=0.3413588>, 'f1_score': <tf.Tensor: shape=(), dtype=float32, numpy=0.35042742>}
ROUGE-2 Score:  {'precision': <tf.Tensor: shape=(), dtype=float32, numpy=0.18063973>, 'recall': <tf.Tensor: shape=(), dtype=float32, numpy=0.16484126>, 'f1_score': <tf.Tensor: shape=(), dtype=float32, numpy=0.16860461>}
