# Neural machine translation with attention

## Overview 😊

### Overall Structure of This Notebook 📝✨

1. **Introduction** 🚀
    - Overview of neural machine translation with attention
    - Motivation and goals

2. **Setup** ⚙️
    - Installation of required packages
    - Importing libraries

3. **Data Preparation** 🗂️
    - Download and inspect the Spanish-English dataset
    - Load and split data into context (Spanish) and target (English)
    - Create `tf.data.Dataset` objects for training and validation

4. **Text Preprocessing** 🧹🔤
    - Standardize and clean sentences
    - Tokenize and vectorize text using Keras `TextVectorization`
    - Convert sentences to padded token ID tensors

5. **Model Components** 🧩
    - Encoder: Bidirectional GRU for context processing
    - Attention Layer: Cross-attention mechanism
    - Decoder: Unidirectional GRU for target sequence generation

6. **Training Preparation** 🏋️‍♂️
    - Masking and handling padding tokens
    - Custom loss and accuracy functions

7. **Model Construction** 🏗️
    - Combine encoder, attention, and decoder into a `Translator` model

8. **Training** 🔥
    - Compile and train the model
    - Monitor loss and accuracy

9. **Evaluation and Visualization** 📊👀
    - Plot training/validation loss and accuracy
    - Visualize attention weights

10. **Inference and Translation** 🌍💬
     - Translate new sentences
     - Export and use the trained model

11. **Conclusion** ✅
     - Summary and next steps

This tutorial demonstrates how to train a sequence-to-sequence (seq2seq) model for Spanish-to-English translation roughly based on [Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025v5) (Luong et al., 2015).


This tutorial: An encoder/decoder connected by attention.


While this architecture is somewhat outdated, it is still a very useful project to work through to get a deeper understanding of sequence-to-sequence models and attention mechanisms (before going on to [Transformers](transformer.ipynb)).



This example assumes some knowledge of TensorFlow fundamentals below the level of a Keras layer:
  * [Working with tensors](https://www.tensorflow.org/guide/tensor) directly
  * [Writing custom `keras.Model`s and `keras.layers`](https://www.tensorflow.org/guide/keras/custom_layers_and_models)

After training the model in this notebook, you will be able to input a Spanish sentence, such as "*¿todavia estan en casa?*", and return the English translation: "*are you still at home?*"

The resulting model is exportable as a `tf.saved_model`, so it can be used in other TensorFlow environments.

The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:

<img src="https://tensorflow.org/images/spanish-english.png" alt="spanish-english attention plot">

Note: This example takes approximately 10 minutes to run.

## Setup ⚙️

In [None]:
!pip install tensorflow==2.19.0 tensorflow-text matplotlib einops wrapt==1.15.0


In [None]:
# !pip install typing_extensions==4.5.0

In [None]:
!export WRAPT_DISABLE_EXTENSIONS=true

In [None]:
import numpy as np

import einops
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import tensorflow as tf
import tensorflow_text as tf_text
import wrapt
if wrapt.__version__!="1.15.0" or tf.__version__!="2.19.0":
  raise Exception(f"Please restart your session as you are still using warpt version: {wrapt.__version__}")
gpus = tf.config.list_physical_devices()
if len(gpus)>1:
  raise Exception("Please use CPUs for this notebook only. Go to Runtime at the top, then press change runtime to CPU.")

This tutorial uses a lot of low level API's where it's easy to get shapes wrong. This class is used to check shapes throughout the tutorial. 🧩


In [None]:
#@title
class ShapeChecker():
  def __init__(self):
    # Keep a cache of every axis-name seen
    self.shapes = {}

  def __call__(self, tensor, names, broadcast=False):
    # Only check shapes when running eagerly (not in graph mode)
    if not tf.executing_eagerly():
      return

    # einops is a Python library for flexible and readable tensor operations.
    # It provides functions to parse, reshape, and manipulate tensor shapes using simple string-based notation.
    # Here, einops.parse_shape is used to extract the dimensions of the tensor according to the provided axis names.
    parsed = einops.parse_shape(tensor, names)

    # Check each axis name and its dimension
    for name, new_dim in parsed.items():
      old_dim = self.shapes.get(name, None)

      # If broadcasting is allowed and the new dimension is 1, skip the check
      if (broadcast and new_dim == 1):
        continue

      if old_dim is None:
        # If the axis name is new, add its length to the cache.
        self.shapes[name] = new_dim
        continue

      # If the dimension has changed, raise an error
      if new_dim != old_dim:
        raise ValueError(f"Shape mismatch for dimension: '{name}'\n"
                         f"    found: {new_dim}\n"
                         f"    expected: {old_dim}\n")

## Data Preparation 🗂️

The tutorial uses a language dataset provided by [Anki](http://www.manythings.org/anki/). This dataset contains language translation pairs in the format:

```
May I borrow this book?	¿Puedo tomar prestado este libro?
```

They have a variety of languages available, but this example uses the English-Spanish dataset.

### Download and prepare the dataset 📥🗂️

For convenience, a copy of this dataset is hosted on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps you need to take to prepare the data:

1. Add a *start* and *end* token to each sentence. 🚩🏁
2. Clean the sentences by removing special characters. 🧹
3. Create a word index and reverse word index (dictionaries mapping from word → id and id → word). 🔢🔄
4. Pad each sentence to a maximum length. 📏

In [None]:
# Download the Spanish-English translation dataset as a zip file.
import pathlib

# Use TensorFlow utility to download and extract the dataset.
path_to_zip = tf.keras.utils.get_file(
    'spa-eng.zip',  # Name of the file to download.
    origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',  # URL of the dataset.
    extract=True  # Automatically extract the zip file after downloading.
)

# Construct the path to the extracted text file containing sentence pairs.
path_to_file = pathlib.Path(path_to_zip).parent/'spa-eng_extracted/spa-eng/spa.txt'

To take a look of the data, we will call the cat command and display the first 10 lines.

In [None]:
!cat {path_to_file} | head -n 10

In [None]:
def load_data(path:pathlib.Path):
  """This function purpose is to load the dataset from a text file and split it into target and context, which means having a word and it's corresponding translation."""
  # Read the entire text file as a string using UTF-8 encoding
  text = path.read_text(encoding='utf-8')

  # Split the text into lines
  lines = text.splitlines()
  # Split each line into a pair (target, context) using tab as the separator
  pairs = [line.split('\t') for line in lines]

  # Extract the context sentences (Spanish) from each pair
  context = np.array([context for target, context in pairs])
  # Extract the target sentences (English) from each pair
  target = np.array([target for target, context in pairs])

  # Return the target and context arrays
  return target, context

In [None]:
# Load the data from the specified file path.
# The load_data function reads the file, splits each line into English (target) and Spanish (context) sentence pairs,
# and returns them as numpy arrays.
target_raw, context_raw = load_data(path_to_file)

print(context_raw[-1])

In [None]:
print(target_raw[-1])

### Create a tf.data dataset 📦

From these arrays of strings you can create a `tf.data.Dataset` of strings that shuffles and batches them efficiently: 🎲📦

In [None]:
# Set the buffer size for shuffling to the total number of context sentences
BUFFER_SIZE = len(context_raw)
# Set the batch size for training and validation
BATCH_SIZE = 64

# Randomly assign each example to the training set (80%) or validation set (20%)
is_train = np.random.uniform(size=(len(target_raw),)) < 0.8

# Create the training dataset:
# - Select context and target sentences assigned to training
# - Shuffle the dataset
# - Batch the dataset
train_raw = (
    tf.data.Dataset
    .from_tensor_slices((context_raw[is_train], target_raw[is_train]))
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE))

# Create the validation dataset:
# - Select context and target sentences not assigned to training
# - Shuffle the dataset
# - Batch the dataset
val_raw = (
    tf.data.Dataset
    .from_tensor_slices((context_raw[~is_train], target_raw[~is_train]))
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE))

In [None]:
# Iterate over one batch from the training dataset
for example_context_strings, example_target_strings in train_raw.take(1):
  # Print the first 5 context (Spanish) sentences in the batch
  print(example_context_strings[:5])
  print()
  # Print the first 5 target (English) sentences in the batch
  print(example_target_strings[:5])
  break  # Exit after processing the first batch

### Text preprocessing

One of the goals of this tutorial is to build a model that can be exported as a `tf.saved_model`. To make that exported model useful it should take `tf.string` inputs, and return `tf.string` outputs: All the text processing happens inside the model. Mainly using a `layers.TextVectorization` layer. 🏗️🔤✨

#### Standardization

The model is dealing with multilingual text with a limited vocabulary. So it will be important to standardize the input text. 🌐🔤

The first step is Unicode normalization to split accented characters and replace compatibility characters with their ASCII equivalents. 🧹

The `tensorflow_text` package contains a unicode normalize operation:

In [None]:
# Define a sample Spanish sentence as a TensorFlow constant
example_text = tf.constant('¿Todavía está en casa?')

# Print the raw bytes of the example text
print(example_text.numpy())

# Normalize the example text using Unicode normalization (NFKD form)
# This splits accented characters and replaces compatibility characters with their ASCII equivalents
print(tf_text.normalize_utf8(example_text, 'NFKD').numpy())

Unicode normalization will be the first step in the text standardization function:

In [None]:
def tf_lower_and_split_punct(text):
  # Unicode normalization (NFKD) splits accented characters into base + accent,
  # and replaces compatibility characters with their canonical equivalents.
  # This helps standardize multilingual text for processing.
  text = tf_text.normalize_utf8(text, 'NFKD')

  # Convert all characters to lowercase for consistency.
  text = tf.strings.lower(text)

  # Remove all characters except spaces, lowercase letters, and select punctuation.
  text = tf.strings.regex_replace(text, '[^ a-z.?!,¿]', '')

  # Add spaces around punctuation marks to separate them from words.
  text = tf.strings.regex_replace(text, '[.?!,¿]', r' \0 ')

  # Remove leading and trailing whitespace.
  text = tf.strings.strip(text)

  # Add [START] and [END] tokens to mark the beginning and end of the sentence.
  # These tokens help the model know where a sentence starts and ends,
  # which is important for sequence-to-sequence tasks like translation.
  text = tf.strings.join(['[START]', text, '[END]'], separator=' ')
  return text

In [None]:
# Print the original example text as a string
print(example_text.numpy().decode())

# Print the example text after applying the text preprocessing function
# This function normalizes, lowercases, removes unwanted characters,
# adds spaces around punctuation, strips whitespace, and adds [START] and [END] tokens
print(tf_lower_and_split_punct(example_text).numpy().decode())

#### Text Vectorization

This standardization function will be wrapped up in a `tf.keras.layers.TextVectorization` layer 😊 which will handle the vocabulary extraction and conversion of input text to sequences of tokens.

In [None]:
# Set the maximum vocabulary size for the text vectorization layer
max_vocab_size = 5000

# Create a TextVectorization layer for processing Spanish context sentences.
# - standardize: applies the tf_lower_and_split_punct function to clean and tokenize the text
# - max_tokens: limits the vocabulary size to max_vocab_size
# - ragged: allows variable-length outputs for tokenized sentences
context_text_processor = tf.keras.layers.TextVectorization(
    standardize=tf_lower_and_split_punct,
    max_tokens=max_vocab_size,
    ragged=True)

The `TextVectorization` layer and many other [Keras preprocessing layers](https://www.tensorflow.org/guide/keras/preprocessing_layers) have an `adapt` method. This method reads one epoch of the training data, and works a lot like `Model.fit`. This `adapt` method initializes the layer based on the data. Here it determines the vocabulary: 🧠🔤

In [None]:
# Adapt the context_text_processor to the training data.
# This step analyzes the context (Spanish) sentences in train_raw,
# builds the vocabulary, and prepares the text vectorization layer.
context_text_processor.adapt(train_raw.map(lambda context, target: context))

# Display the first 10 words from the vocabulary learned by the context_text_processor.
context_text_processor.get_vocabulary()[:10]

That's the Spanish `TextVectorization` layer, now build and `.adapt()` the English one:

In [None]:
# Create a TextVectorization layer for processing English target sentences.
# - standardize: applies the tf_lower_and_split_punct function to clean and tokenize the text
# - max_tokens: limits the vocabulary size to max_vocab_size
# - ragged: allows variable-length outputs for tokenized sentences
target_text_processor = tf.keras.layers.TextVectorization(
    standardize=tf_lower_and_split_punct,
    max_tokens=max_vocab_size,
    ragged=True)

# Adapt the target_text_processor to the training data.
# This step analyzes the target (English) sentences in train_raw,
# builds the vocabulary, and prepares the text vectorization layer.
target_text_processor.adapt(train_raw.map(lambda context, target: target))

# Display the first 10 words from the vocabulary learned by the target_text_processor.
target_text_processor.get_vocabulary()[:10]

Now these layers can convert a batch of strings into a batch of token IDs:

In [None]:
# Convert the batch of example context strings (Spanish sentences) into token IDs using the context_text_processor
example_tokens = context_text_processor(example_context_strings)

# Display the token IDs for the first 3 sentences in the batch
example_tokens[:3, :]

The `get_vocabulary` method can be used to convert token IDs back to text:

In [None]:
# Get the vocabulary from the context_text_processor as a numpy array
context_vocab = np.array(context_text_processor.get_vocabulary())

# Convert the first example's token IDs to their corresponding words using the vocabulary
tokens = context_vocab[example_tokens[0].numpy()]

# Join the tokens into a single string for readability
' '.join(tokens)

The returned token IDs are zero-padded. This can easily be turned into a mask:

---

### Why masking is important 🛡️

Token sequences are padded with zeros to ensure all sequences in a batch have the same length. These padding tokens do **not** represent actual data and should be ignored during computations such as loss calculation, accuracy measurement, and attention visualization.

Without masking, the model would treat padding as meaningful input, introducing noise and bias that can negatively affect training and evaluation. The mask allows us to focus only on the real tokens, ensuring that metrics and model updates are based solely on valid data. ✅

**In summary:**  
- Padding tokens = 🚫 not real data  
- Masking = 🕵️‍♂️ focus on valid tokens  
- Better training & evaluation = 🎯

In [None]:
# If running in a Jupyter notebook, ensure plots display inline
%matplotlib inline

# Plot the token IDs for the batch of example sentences.
plt.subplot(1, 2, 1)
plt.pcolormesh(example_tokens.to_tensor())  # Convert ragged tensor to dense and plot token IDs
plt.title('Token IDs')

# Plot the mask showing which positions are non-padding (token ID != 0)
plt.subplot(1, 2, 2)
plt.pcolormesh(example_tokens.to_tensor() != 0)  # True for non-padding tokens
plt.title('Mask')

### Process the dataset 🛠️✨

The `process_text` function below converts the `Datasets` of strings, into  0-padded tensors of token IDs. It also converts from a `(context, target)` pair to an `((context, target_in), target_out)` pair for training with `keras.Model.fit`. Keras expects `(inputs, labels)` pairs, the inputs are the `(context, target_in)` and the labels are `target_out`. The difference between `target_in` and `target_out` is that they are shifted by one step relative to eachother, so that at each location the label is the next token. 🧩🔢

In [None]:
def process_text(context, target):
  # Convert context sentences (Spanish) to token IDs and pad to tensor
  context = context_text_processor(context).to_tensor()
  # Convert target sentences (English) to token IDs (ragged tensor)
  target = target_text_processor(target)
  # Prepare decoder input by removing the last token ([END])
  targ_in = target[:, :-1].to_tensor()
  # Prepare decoder output by removing the first token ([START])
  targ_out = target[:, 1:].to_tensor()
  # Return ((context, decoder_input), decoder_output) for training
  return (context, targ_in), targ_out

# Map the process_text function over the training and validation datasets
train_ds = train_raw.map(process_text, tf.data.AUTOTUNE)
val_ds = val_raw.map(process_text, tf.data.AUTOTUNE)

Here is the first sequence of each, from the first batch:

In [None]:
# Iterate over one batch from the training dataset
for (ex_context_tok, ex_tar_in), ex_tar_out in train_ds.take(1):
  # Print the first 10 token IDs of the first context (Spanish) sentence in the batch
  print(ex_context_tok[0, :10].numpy())
  print()
  # Print the first 10 token IDs of the first target (English) input sentence in the batch
  print(ex_tar_in[0, :10].numpy())
  # Print the first 10 token IDs of the first target (English) output sentence in the batch
  print(ex_tar_out[0, :10].numpy())

## The encoder/decoder 🤖📝

The following diagrams shows an overview of the model. In both the encoder is on the left, the decoder is on the right. At each time-step the decoder's output is combined with the encoder's output, to predict the next word.

The original [left] contains a few extra connections that are intentionally omitted from this tutorial's model [right], as they are generally unnecessary, and difficult to implement. Those missing connections are:

1. Feeding the state from the encoder's RNN to the decoder's RNN
2. Feeding the attention output back to the RNN's input.

<table>
<tr>
  <td>
   <img width=500 src="https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg"/>
  </td>
  <td>
   <img width=380 src="https://www.tensorflow.org/images/tutorials/transformer/RNN+attention.png"/>
  </td>
</tr>
<tr>
  <th colspan=1>The original from <a href=https://arxiv.org/abs/1508.04025v5>Effective Approaches to Attention-based Neural Machine Translation</a></th>
  <th colspan=1>This tutorial's model</th>
<tr>
</table>


Before getting into it define constants for the model:

In [None]:
# Set the number of units (dimensions) for the model's internal representations.
# This value is used for the size of the RNN hidden state and embedding vectors.
UNITS = 256

### The encoder 🤖

The goal of the encoder is to process the context sequence into a sequence of vectors that are useful for the decoder as it attempts to predict the next output for each timestep. Since the context sequence is constant, there is no restriction on how information can flow in the encoder, so use a bidirectional-RNN to do the processing:

<table>
<tr>
  <td>
   <img width=500 src="https://tensorflow.org/images/tutorials/transformer/RNN-bidirectional.png"/>
  </td>
</tr>
<tr>
  <th>A bidirectional RNN</th>
<tr>
</table>

The encoder:

1. Takes a list of token IDs (from `context_text_processor`). 🆔
2. Looks up an embedding vector for each token (Using a `layers.Embedding`). 🧩
3. Processes the embeddings into a new sequence (Using a bidirectional `layers.GRU`). 🔄
4. Returns the processed sequence. This will be passed to the attention head. 🎯

In [None]:
class Encoder(tf.keras.layers.Layer):
  def __init__(self, text_processor, units):
    super(Encoder, self).__init__()
    # Store the text processor and model units
    self.text_processor = text_processor
    self.vocab_size = text_processor.vocabulary_size()
    self.units = units

    # Embedding layer: Converts token IDs to dense vectors
    self.embedding = tf.keras.layers.Embedding(self.vocab_size, units,
                                               mask_zero=True)

    # Bidirectional GRU: Processes the sequence of embeddings
    # Bidirectional GRU: merge_mode='sum' combines the outputs from the forward and backward RNNs by summing them.
    # This keeps the output dimension the same as 'units', making it easier to use in downstream layers.
    self.rnn = tf.keras.layers.Bidirectional(
      merge_mode='sum',
      layer=tf.keras.layers.GRU(units,
                # Return the full sequence and final state
                return_sequences=True,
                recurrent_initializer='glorot_uniform'))

  def call(self, x):
    # Check input shape: (batch, sequence_length)
    shape_checker = ShapeChecker()
    shape_checker(x, 'batch s')

    # Convert token IDs to embeddings: (batch, sequence_length, units)
    x = self.embedding(x)
    shape_checker(x, 'batch s units')

    # Process embeddings with bidirectional GRU: (batch, sequence_length, units)
    x = self.rnn(x)
    shape_checker(x, 'batch s units')

    # Return processed sequence for attention
    return x

  def convert_input(self, texts):
    # Convert input texts to tensor
    texts = tf.convert_to_tensor(texts)
    # If input is a scalar, add batch dimension
    if len(texts.shape) == 0:
      texts = tf.convert_to_tensor(texts)[tf.newaxis]
    # Tokenize and pad input texts
    context = self.text_processor(texts).to_tensor()
    # Encode the tokenized input
    context = self(context)
    return context

Try it out:

In [None]:
# Create an Encoder instance using the Spanish text processor and model units.
encoder = Encoder(context_text_processor, UNITS)

# Pass the batch of context token IDs through the encoder to get the encoded sequence.
ex_context = encoder(ex_context_tok)

# Print the shape of the input context tokens (batch size, sequence length).
print(f'Context tokens, shape (batch, s): {ex_context_tok.shape}')

# Print the shape of the encoder output (batch size, sequence length, units).
print(f'Encoder output, shape (batch, s, units): {ex_context.shape}')

### The attention layer 🧠✨

The attention layer lets the decoder access the information extracted by the encoder. It computes a vector from the entire context sequence, and adds that to the decoder's output.

The simplest way you could calculate a single vector from the entire sequence would be to take the average across the sequence (`layers.GlobalAveragePooling1D`). An attention layer is similar, but calculates a **weighted** average across the context sequence. Where the weights are calculated from the combination of context and "query" vectors.

<table>
<tr>
  <td>
   <img width=500 src="https://www.tensorflow.org/images/tutorials/transformer/CrossAttention-new-full.png"/>
  </td>
</tr>
<tr>
  <th colspan=1>The attention layer</th>
<tr>
</table>

In [None]:
class CrossAttention(tf.keras.layers.Layer):
  def __init__(self, units, **kwargs):
    super().__init__()
    # Multi-head attention layer (single head, key_dim=units)
    self.mha = tf.keras.layers.MultiHeadAttention(key_dim=units, num_heads=1, **kwargs)
    # Layer normalization for output stability
    self.layernorm = tf.keras.layers.LayerNormalization()
    # Add layer for residual connection
    # Residual connections help gradients flow through the network, improve training stability,
    # and allow the model to reuse the original input along with the attention output.
    self.add = tf.keras.layers.Add()

  def call(self, x, context):
    shape_checker = ShapeChecker()

    # x: target sequence embeddings (batch, t, units)
    shape_checker(x, 'batch t units')
    # context: encoder output (batch, s, units)
    shape_checker(context, 'batch s units')

    # Compute attention output and attention scores
    attn_output, attn_scores = self.mha(
        query=x,         # queries: target sequence
        value=context,   # values: context sequence (encoder output)
        return_attention_scores=True)

    # Check shapes after attention
    shape_checker(x, 'batch t units')
    shape_checker(attn_scores, 'batch heads t s')

    # Average over heads to get (batch, t, s) attention weights
    attn_scores = tf.reduce_mean(attn_scores, axis=1)
    shape_checker(attn_scores, 'batch t s')
    # Cache attention weights for later visualization
    self.last_attention_weights = attn_scores

    # Residual connection: add attention output to input
    x = self.add([x, attn_output])
    # Normalize output
    x = self.layernorm(x)

    # Return attended and normalized output
    return x

In [None]:
# Create an instance of the CrossAttention layer with the specified number of units
attention_layer = CrossAttention(UNITS)

# Create an embedding layer for the target (English) tokens
embed = tf.keras.layers.Embedding(
    target_text_processor.vocabulary_size(),  # Vocabulary size for target language
    output_dim=UNITS,                        # Embedding dimension
    mask_zero=True                           # Mask padding tokens
)

# Embed the input target tokens (decoder input)
ex_tar_embed = embed(ex_tar_in)

# Apply the attention layer: attend to the encoded context using the embedded target tokens
result = attention_layer(ex_tar_embed, ex_context)

# Print the shapes of the input and output tensors for verification
print(f'Context sequence, shape (batch, s, units): {ex_context.shape}')           # Encoder output
print(f'Target sequence, shape (batch, t, units): {ex_tar_embed.shape}')          # Embedded target input
print(f'Attention result, shape (batch, t, units): {result.shape}')               # Output of attention layer
print(f'Attention weights, shape (batch, t, s):    {attention_layer.last_attention_weights.shape}')  # Attention scores

The attention weights will sum to `1` over the context sequence, at each location in the target sequence.

In [None]:
# Sum the attention weights over the context sequence for the first example in the batch.
# This checks that the attention weights for each target token sum to 1 (as expected for a probability distribution).
attention_layer.last_attention_weights[0].numpy().sum(axis=-1)



Here are the attention weights across the context sequences at `t=0`:

In [None]:
# Get the attention weights from the attention layer
attention_weights = attention_layer.last_attention_weights

# Create a mask for the context tokens (True for non-padding tokens)
mask = (ex_context_tok != 0).numpy()

# Plot the attention weights for the first target token, masked by valid context positions
plt.subplot(1, 2, 1)
plt.pcolormesh(mask * attention_weights[:, 0, :])
plt.title('Attention weights')

# Plot the mask itself to show valid context positions
plt.subplot(1, 2, 2)
plt.pcolormesh(mask)
plt.title('Mask');


🔍 **Interpreting Initial Attention Weights**

Because of the small-random initialization, the attention weights are initially all close to `1/(sequence_length)`.  
This means that, before training, the model does not "know" which parts of the input are important, so it distributes its attention almost uniformly across all positions in the context sequence.

As training progresses, the model learns to assign higher attention weights to the most relevant tokens in the input sequence for each output token.  
This non-uniform attention allows the model to focus on specific words or phrases that are most helpful for generating accurate translations, improving both the quality of the output and the interpretability of the attention mechanism. 🎯✨

### The decoder 🤖📝

The decoder's job is to generate predictions for the next token at each location in the target sequence.

1. It looks up embeddings for each token in the target sequence. 🧩
2. It uses an RNN to process the target sequence, and keep track of what it has generated so far. 🔄
3. It uses RNN output as the "query" to the attention layer, when attending to the encoder's output. 🎯
4. At each location in the output it predicts the next token. 📝

When training, the model predicts the next word at each location. So it's important that the information only flows in one direction through the model. The decoder uses a unidirectional (not bidirectional) RNN to process the target sequence.

When running inference with this model it produces one word at a time, and those are fed back into the model.

<table>
<tr>
  <td>
   <img width=500 src="https://tensorflow.org/images/tutorials/transformer/RNN.png"/>
  </td>
</tr>
<tr>
  <th>A unidirectional RNN</th>
<tr>
</table>

Here is the `Decoder` class' initializer. The initializer creates all the necessary layers.

In [None]:
class Decoder(tf.keras.layers.Layer):
  @classmethod
  def add_method(cls, fun):
    # Utility to add methods to the class dynamically (will be used after class object creation and dynamically add methods)
    setattr(cls, fun.__name__, fun)
    return fun

  def __init__(self, text_processor, units):
      super(Decoder, self).__init__()
      # Store the text processor and model units
      self.text_processor = text_processor
      self.vocab_size = text_processor.vocabulary_size()

      # Lookup layer: maps words to token IDs
      self.word_to_id = tf.keras.layers.StringLookup(
        vocabulary=text_processor.get_vocabulary(),
        mask_token='', oov_token='[UNK]')

      # Lookup layer: maps token IDs back to words
      self.id_to_word = tf.keras.layers.StringLookup(
        vocabulary=text_processor.get_vocabulary(),
        mask_token='', oov_token='[UNK]',
        invert=True)

      # Special token IDs for start and end of sequence
      self.start_token = self.word_to_id('[START]')
      self.end_token = self.word_to_id('[END]')

      self.units = units

      # 1. Embedding layer: converts token IDs to dense vectors
      self.embedding = tf.keras.layers.Embedding(
        self.vocab_size, units, mask_zero=True)

      # 2. RNN layer: processes the sequence and keeps track of generated tokens
      self.rnn = tf.keras.layers.GRU(
        units,
        return_sequences=True,
        return_state=True,
        recurrent_initializer='glorot_uniform'
      )

      # 3. Attention layer: attends to the encoder output using RNN output as query
      self.attention = CrossAttention(units)

      # 4. Output layer: produces logits for each output token
      self.output_layer = tf.keras.layers.Dense(self.vocab_size)

#### Training

Next, the `call` method, takes 3 arguments: 🤖

* `inputs` -  a `context, x` pair where:
  * `context` - is the context from the encoder's output. 🌐
  * `x` - is the target sequence input. 📝
* `state` - Optional, the previous `state` output from the decoder (the internal state of the decoder's RNN). Pass the state from a previous run to continue generating text where you left off. 🔄
* `return_state` - [Default: False] - Set this to `True` to return the RNN state. 🏁

In [None]:
@Decoder.add_method
def call(self,
  context, x,
  state=None,
  return_state=False):
    shape_checker = ShapeChecker()
    shape_checker(x, 'batch t')  # Check shape of target input tokens
    shape_checker(context, 'batch s units')  # Check shape of encoder output

    # 1. Lookup the embeddings for the target input tokens
    x = self.embedding(x)
    shape_checker(x, 'batch t units')  # Check shape after embedding

    # 2. Process the target sequence with the RNN
    x,state=self.rnn(x, initial_state=state)
    shape_checker(x, 'batch t units')  # Check shape after RNN

    # 3. Use the RNN output as the query for attention over the context
    x = self.attention(x, context)
    self.last_attention_weights = self.attention.last_attention_weights  # Cache attention weights for visualization
    shape_checker(x, 'batch t units')  # Check shape after attention
    shape_checker(self.last_attention_weights, 'batch t s')  # Check shape of attention weights

    # 4. Generate logit predictions for the next token
    logits = self.output_layer(x)
    shape_checker(logits, 'batch t target_vocab_size')  # Check shape of output logits

    # Optionally return the RNN state for inference
    if return_state:
     return logits, state
    else:
     return logits

That will be sufficient for training. Create an instance of the decoder to test out:

In [None]:
# Create an instance of the Decoder class using the target_text_processor and UNITS.
# The Decoder will be used to generate English translations from encoded Spanish context.
decoder = Decoder(target_text_processor, UNITS)

In training you'll use the decoder like this:

Given the context and target tokens, for each target token it predicts the next target token.

In [None]:
# Print the shapes of the encoder output, input target tokens, and logits for inspection
logits = decoder(ex_context, ex_tar_in)

# The encoder output shape: (batch, sequence_length, units)
print(f'encoder output shape: (batch, s, units) {ex_context.shape}')

# The input target tokens shape: (batch, target_sequence_length)
print(f'input target tokens shape: (batch, t) {ex_tar_in.shape}')

# The logits shape: (batch, target_sequence_length, target_vocabulary_size)
print(f'logits shape shape: (batch, target_vocabulary_size) {logits.shape}')

#### Inference

To use it for inference you'll need a couple more methods:

In [None]:
@Decoder.add_method
def get_initial_state(self, context):
  # This method is needed for inference to initialize the decoder's generation loop.
  # It provides the initial start tokens, a mask indicating which sequences are finished,
  # and the initial RNN state for the decoder. This setup allows the model to begin
  # generating output tokens one step at a time, starting from the [START] token.

  batch_size = tf.shape(context)[0]
  start_tokens = tf.fill([batch_size, 1], self.start_token)
  done = tf.zeros([batch_size, 1], dtype=tf.bool)
  embedded = self.embedding(start_tokens)
  batch_size = tf.shape(embedded)[0]
  return start_tokens, done, self.rnn.get_initial_state(batch_size)[0]


In [None]:
@Decoder.add_method
def tokens_to_text(self, tokens):
  # During inference, the model generates sequences of token IDs.
  # This method converts those token IDs back to human-readable text.
  # It's needed in inference to interpret the model's output as actual sentences.

  # Convert token IDs to words using the id_to_word lookup layer
  words = self.id_to_word(tokens)
  # Join the words into sentences (strings), separating by spaces
  result = tf.strings.reduce_join(words, axis=-1, separator=' ')
  # Remove the [START] token from the beginning of each sentence
  result = tf.strings.regex_replace(result, '^ *\[START\] *', '')
  # Remove the [END] token from the end of each sentence
  result = tf.strings.regex_replace(result, ' *\[END\] *$', '')
  # Return the cleaned sentences as strings
  return result

In [None]:
@Decoder.add_method
def get_next_token(self, context, next_token, done, state, temperature = 0.0):
  # This method is needed for inference because, during generation,
  # we produce one token at a time, feeding the previous output back in.
  # It runs the decoder for a single step, predicts the next token,
  # and updates the RNN state and done mask for each sequence in the batch.

  # Run the decoder for one step to get logits and new RNN state
  logits, state = self(
    context, next_token,
    state = state,
    return_state=True)

  # If temperature is 0, use greedy decoding (argmax)
  if temperature == 0.0:
    next_token = tf.argmax(logits, axis=-1)
  else:
    # Otherwise, sample from the probability distribution (softmax with temperature)
    logits = logits[:, -1, :]/temperature
    next_token = tf.random.categorical(logits, num_samples=1)

  # Mark sequences as done if they produce the end token
  done = done | (next_token == self.end_token)
  # For finished sequences, pad with zeros
  next_token = tf.where(done, tf.constant(0, dtype=tf.int64), next_token)

  # Return the next token, done mask, and new state
  return next_token, done, state

With those extra functions, you can write a generation loop:

In [None]:
# Setup the loop variables for inference.
# Get initial decoder state, start tokens, and done mask for the batch.
next_token, done, state = decoder.get_initial_state(ex_context)
tokens = []

# Generate up to 10 tokens for each sequence in the batch.
for n in range(10):
  # Run one decoding step: get next token, update done mask and RNN state.
  next_token, done, state = decoder.get_next_token(
      ex_context, next_token, done, state, temperature=1.0)
  # Collect the generated token for this step.
  tokens.append(next_token)

# Concatenate all generated tokens along the sequence axis to form output sequences.
tokens = tf.concat(tokens, axis=-1) # (batch, t)

# Convert the token IDs back to text strings.
result = decoder.tokens_to_text(tokens)
# Display the first 3 generated sequences.
result[:3].numpy()

Since the model's untrained, it outputs items from the vocabulary almost uniformly at random.

## The model 🤖✨

Now that you have all the model components, combine them to build the model for training:

In [None]:
class Translator(tf.keras.Model):
  @classmethod
  def add_method(cls, fun):
    # Utility to add methods to the class dynamically
    setattr(cls, fun.__name__, fun)
    return fun

  def __init__(self, units,
               context_text_processor,
               target_text_processor):
    super().__init__()
    # Build the encoder and decoder using the provided text processors and units
    encoder = Encoder(context_text_processor, units)
    decoder = Decoder(target_text_processor, units)

    self.encoder = encoder
    self.decoder = decoder

  def call(self, inputs):
    # Unpack the inputs: context (source language) and x (target language input)
    context, x = inputs
    # Encode the context sequence
    context = self.encoder(context)
    # Decode the target input sequence using the encoded context
    logits = self.decoder(context, x)

    try:
      # Delete the keras mask, so keras doesn't scale the loss+accuracy.
      del logits._keras_mask
    except AttributeError:
      pass

    # Return the output logits (predictions for next token)
    return logits

During training the model will be used like this:

In [None]:
# Instantiate the Translator model with the specified number of units and text processors
model = Translator(UNITS, context_text_processor, target_text_processor)

# Pass a batch of context and target input tokens through the model to get output logits
logits = model((ex_context_tok, ex_tar_in))

# Print the shapes of the input and output tensors for inspection
print(f'Context tokens, shape: (batch, s, units) {ex_context_tok.shape}')  # Shape of context input tokens
print(f'Target tokens, shape: (batch, t) {ex_tar_in.shape}')              # Shape of target input tokens
print(f'logits, shape: (batch, t, target_vocabulary_size) {logits.shape}') # Shape of model output logits

### Training 🏋️‍♂️✨

For training, you'll want to implement your own masked loss and accuracy functions:

In [None]:
def masked_loss(y_true, y_pred):
    # Create a loss function for sparse categorical crossentropy.
    # from_logits=True means y_pred are raw logits, not probabilities.
    # reduction='none' means the loss is computed per element, not averaged.
    # SparseCategoricalCrossentropy is needed because the targets (y_true) are integer class labels (token IDs),
    # not one-hot encoded vectors. It efficiently computes the cross-entropy loss for each token prediction.
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')

    # Compute the loss for each item in the batch.
    loss = loss_fn(y_true, y_pred)

    # Create a mask to ignore losses for padding tokens (where y_true == 0).
    mask = tf.cast(y_true != 0, loss.dtype)
    # Apply the mask to the loss.
    loss *= mask

    # Return the mean loss over all non-padding tokens.
    return tf.reduce_sum(loss) / tf.reduce_sum(mask)

In [None]:
def masked_acc(y_true, y_pred):
    # Compute the predicted token IDs by taking the argmax over the last axis (vocabulary dimension)
    y_pred = tf.argmax(y_pred, axis=-1)
    # Cast predictions to the same dtype as y_true for comparison
    y_pred = tf.cast(y_pred, y_true.dtype)

    # Compare predictions to true labels; 1.0 for match, 0.0 otherwise
    match = tf.cast(y_true == y_pred, tf.float32)
    # Create a mask to ignore padding tokens (where y_true == 0)
    mask = tf.cast(y_true != 0, tf.float32)

    # Compute the mean accuracy over all non-padding tokens
    return tf.reduce_sum(match) / tf.reduce_sum(mask)

Configure the model for training:

In [None]:
# Compile the model for training.
# - optimizer='adam': Use the Adam optimizer for training.
# - loss=masked_loss: Use the custom masked loss function to ignore padding tokens.
# - metrics=[masked_acc, masked_loss]: Track masked accuracy and masked loss during training.
model.compile(optimizer='adam',
              loss=masked_loss,
              metrics=[masked_acc, masked_loss])

The model is randomly initialized, and should give roughly uniform output probabilities. So it's easy to predict what the initial values of the metrics should be:

In [None]:
# Get the vocabulary size of the target language as a float
vocab_size = 1.0 * target_text_processor.vocabulary_size()

# Calculate the expected loss and accuracy for a randomly initialized model:
# - expected_loss: log(vocab_size), since the output probabilities are uniform
# - expected_acc: 1/vocab_size, since the chance of guessing the correct token is 1 out of vocab_size
{
    "expected_loss": tf.math.log(vocab_size).numpy(),
    "expected_acc": 1/vocab_size
}

That should roughly match the values returned by running a few steps of evaluation:

In [None]:
# Evaluate the model on the validation dataset.
# - val_ds: validation dataset, already processed and batched
# - steps=20: run evaluation for 20 batches
# - return_dict=True: return the results as a dictionary of metric names and values
model.evaluate(val_ds, steps=20, return_dict=True)

In [None]:
# Train the model using the training dataset.
# - train_ds.repeat(): Repeat the training dataset indefinitely for multiple epochs.
# - epochs=100: Train for up to 100 epochs.
# - steps_per_epoch=100: Each epoch consists of 100 batches.
# - validation_data=val_ds: Use the validation dataset for evaluation.
# - validation_steps=20: Evaluate on 20 batches from the validation set each epoch.
# - callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)]: Stop training early if validation loss doesn't improve for 3 epochs.
history = model.fit(
    train_ds.repeat(),
    epochs=100,
    steps_per_epoch=100,
    validation_data=val_ds,
    validation_steps=20,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=3)
    ]
)

In [None]:
# Plot the training and validation loss curves over epochs.
plt.plot(history.history['loss'], label='loss')         # Training loss per epoch
plt.plot(history.history['val_loss'], label='val_loss') # Validation loss per epoch

# Set the y-axis limits to start at 0 and end at the maximum value currently shown.
plt.ylim([0, max(plt.ylim())])

# Label the x-axis as 'Epoch #' and y-axis as 'CE/token' (cross-entropy per token).
plt.xlabel('Epoch #')
plt.ylabel('CE/token')

# Add a legend to distinguish between training and validation loss curves.
plt.legend()

In [None]:
# Plot training accuracy over epochs
plt.plot(history.history['masked_acc'], label='accuracy')
# Plot validation accuracy over epochs
plt.plot(history.history['val_masked_acc'], label='val_accuracy')
# Set y-axis limits to start at 0 and end at the current maximum
plt.ylim([0, max(plt.ylim())])
# Label the x-axis as 'Epoch #'
plt.xlabel('Epoch #')
# Label the y-axis as 'CE/token' (cross-entropy per token)
plt.ylabel('CE/token')
# Add a legend to distinguish between training and validation accuracy
plt.legend()

### Translate 🌍✨

Now that the model is trained, implement a function to execute the full `text => text` translation. This code is basically identical to the [inference example](#inference) in the [decoder section](#the_decoder), but this also captures the attention weights. 🧠🔤

In [None]:
#@title
@Translator.add_method
def translate(self,
              texts, *,
              max_length=50,
              temperature=0.0):
  # Process the input texts
  context = self.encoder.convert_input(texts)
  batch_size = tf.shape(texts)[0]

  # Setup the loop inputs
  tokens = []
  attention_weights = []
  next_token, done, state = self.decoder.get_initial_state(context)

  for _ in range(max_length):
    # Generate the next token
    next_token, done, state = self.decoder.get_next_token(
        context, next_token, done,  state, temperature)

    # Collect the generated tokens
    tokens.append(next_token)
    attention_weights.append(self.decoder.last_attention_weights)

    if tf.executing_eagerly() and tf.reduce_all(done):
      break

  # Stack the lists of tokens and attention weights.
  tokens = tf.concat(tokens, axis=-1)   # t*[(batch 1)] -> (batch, t)
  self.last_attention_weights = tf.concat(attention_weights, axis=1)  # t*[(batch 1 s)] -> (batch, t s)

  result = self.decoder.tokens_to_text(tokens)
  return result

Here are the two helper methods, used above, to convert tokens to text, and to get the next token:

In [None]:
# Translate the given Spanish sentence using the trained model.
result = model.translate(['¿Todavía está en casa?']) # Are you still home

# Convert the first result from a TensorFlow tensor to a Python string and print it.
result[0].numpy().decode()

Use that to generate the attention plot:

In [None]:
#@title
@Translator.add_method
def plot_attention(self, text, **kwargs):
  # Ensure the input is a string
  assert isinstance(text, str)
  # Translate the input text and get the output sentence
  output = self.translate([text], **kwargs)
  output = output[0].numpy().decode()

  # Get the attention weights for the first example
  attention = self.last_attention_weights[0]

  # Tokenize and split the input context text
  context = tf_lower_and_split_punct(text)
  context = context.numpy().decode().split()

  # Tokenize and split the output text, skipping the [START] token
  output = tf_lower_and_split_punct(output)
  output = output.numpy().decode().split()[1:]

  # Create a figure for the attention plot
  fig = plt.figure(figsize=(10, 10))
  ax = fig.add_subplot(1, 1, 1)

  # Display the attention weights as an image
  ax.matshow(attention, cmap='viridis', vmin=0.0)

  # Set font size for axis labels
  fontdict = {'fontsize': 14}

  # Set x-axis labels to context tokens and rotate them
  ax.set_xticklabels([''] + context, fontdict=fontdict, rotation=90)
  # Set y-axis labels to output tokens
  ax.set_yticklabels([''] + output, fontdict=fontdict)

  # Set major tick locations to every token
  ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
  ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

  # Label the axes
  ax.set_xlabel('Input text')
  ax.set_ylabel('Output text')

In [None]:
# Plot the attention weights for the translation of the given Spanish sentence.
# This will visualize which input words the model attends to when generating each output word.
model.plot_attention('¿Todavía está en casa?') # Are you still home

Translate a few more sentences and plot them:

In [None]:
%%time
# Visualize the attention weights for the translation of the Spanish sentence "Esta es mi vida." "This is my life."
# This will show which input words the model attends to when generating each output word.
model.plot_attention('Esta es mi vida.')

In [None]:
%%time
# Visualize the attention weights for the translation of the Spanish sentence "Tratar de descubrir." "Try to find out."
# This will show which input words the model attends to when generating each output word.
model.plot_attention('Tratar de descubrir.')

The short sentences often work well, but if the input is too long the model literally loses focus and stops providing reasonable predictions. There are two main reasons for this:

1. The model was trained with teacher-forcing feeding the correct token at each step, regardless of the model's predictions. The model could be made more robust if it were sometimes fed its own predictions. 🏫🤖
2. The model only has access to its previous output through the RNN state. If the RNN state looses track of where it was in the context sequence there's no way for the model to recover. [Transformers](transformer.ipynb) improve on this by letting the decoder look at what it has output so far. 🔄🧠✨

The raw data is sorted by length, so try translating the longest sequence:

In [None]:
# Get the longest Spanish sentence from the dataset
long_text = context_raw[-1]

# Import the textwrap module for formatting long strings
import textwrap

# Print the expected English translation, wrapped for readability
print('Expected output:\n', '\n'.join(textwrap.wrap(target_raw[-1])))

In [None]:
# Plot the attention weights for the translation of the longest Spanish sentence in the dataset.
# This will visualize which input words the model attends to when generating each output word.
model.plot_attention(long_text)

The `translate` function works on batches, so if you have multiple texts to translate you can pass them all at once, which is much more efficient than translating them one at a time:

In [None]:
# List of Spanish input sentences to translate
inputs = [
    'Hace mucho frio aqui.', # "It's really cold here."
    'Esta es mi vida.',      # "This is my life."
    'Su cuarto es un desastre.' # "His room is a mess"
]

In [None]:
%%time
# Loop through each Spanish input sentence in the 'inputs' list
for t in inputs:
  # Translate the sentence using the trained model and print the result as a decoded string
  print(model.translate([t])[0].numpy().decode())

print()  # Print a blank line for separation

In [None]:
%%time
# Translate a batch of Spanish input sentences using the trained model
result = model.translate(inputs)

# Print the English translation for each input sentence
print(result[0].numpy().decode())
print(result[1].numpy().decode())
print(result[2].numpy().decode())
print()

So overall this text generation function mostly gets the job done, but so you've only used it here in python with eager execution. Let's try to export it next:

### Save Model 💾✨

If you want to export this model you'll need to wrap the `translate` method in a `tf.function`. That implementation will get the job done:


In [None]:
class Export(tf.Module):
  def __init__(self, model):
    # Store the trained translation model
    self.model = model

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.string, shape=[None])])
  def translate(self, inputs):
    # Exported translation function: takes a batch of strings and returns translations
    return self.model.translate(inputs)

In [None]:
# Create an instance of the Export class, wrapping the trained translation model.
# This allows you to export the model with a tf.function for inference.
export = Export(model)

Run the `tf.function` once to compile it:

In [None]:
%%time
# Run the exported translation function on a batch of input sentences.
# This will compile the tf.function and measure execution time.
_ = export.translate(tf.constant(inputs))

In [None]:
%%time
# Translate a batch of Spanish input sentences using the exported model.
result = export.translate(tf.constant(inputs))

# Print the English translation for each input sentence in the batch.
print(result[0].numpy().decode())
print(result[1].numpy().decode())
print(result[2].numpy().decode())
print()

Now that the function has been traced it can be exported using `saved_model.save`:

In [None]:
!export WRAPT_DISABLE_EXTENSIONS=true

In [None]:
%%time
import os
# Set the path where the exported translator model will be saved
TRANSLATOR_PATH = os.path.join("artifacts", "translator")

# Export the model using TensorFlow's SavedModel format
# - export: the Export wrapper containing the translation model
# - TRANSLATOR_PATH: directory to save the model
# - signatures: specify the default serving signature for inference
tf.saved_model.save(export, TRANSLATOR_PATH,
                    signatures={'serving_default': export.translate})

In [None]:
%%time
# Load the exported SavedModel from the specified path
reloaded = tf.saved_model.load(TRANSLATOR_PATH)

# Run the translate function once to warm up the model (compiles the graph for faster inference)
_ = reloaded.translate(tf.constant(inputs)) #warmup

In [None]:
%%time
# Translate a batch of Spanish input sentences using the reloaded exported model.
result = reloaded.translate(tf.constant(inputs))

# Print the English translation for each input sentence in the batch.
print(result[0].numpy().decode())
print(result[1].numpy().decode())
print(result[2].numpy().decode())
print()

#### [Optional] Use a dynamic loop

It's worth noting that this initial implementation is not optimal. It uses a python loop:

```
for _ in range(max_length):
  ...
  if tf.executing_eagerly() and tf.reduce_all(done):
    break
```

The python loop is relatively simple but when `tf.function` converts this to a graph, it **statically unrolls** that loop. Unrolling the loop has two disadvantages:

1. It makes `max_length` copies of the loop body. So the generated graphs take longer to build, save and load.
1. You have to choose a fixed value for the `max_length`.
1. You can't `break` from a statically unrolled loop. The `tf.function`
  version will run the full `max_length` iterations on every call.
  That's why the `break` only works with eager execution. This is
  still marginally faster than eager execution, but not as fast as it could be.


To fix these shortcomings, the `translate_dynamic` method, below, uses a tensorflow loop:

```
for t in tf.range(max_length):
  ...
  if tf.reduce_all(done):
      break
```

It looks like a python loop, but when you use a tensor as the input to a `for` loop (or the condition of a `while` loop) `tf.function` converts it to a dynamic loop using operations like `tf.while_loop`.

There's no need for a `max_length` here it's just in case the model gets stuck generating a loop like: `the united states of the united states of the united states...`.

On the down side, to accumulate tokens from this dynamic loop you can't just append them to a python `list`, you need to use a `tf.TensorArray`:

```
tokens = tf.TensorArray(tf.int64, size=1, dynamic_size=True)
...
for t in tf.range(max_length):
    ...
    tokens = tokens.write(t, next_token) # next_token shape is (batch, 1)
  ...
  tokens = tokens.stack()
  tokens = einops.rearrange(tokens, 't batch 1 -> batch t')
```

This version of the code can be quite a bit more efficient:

In [None]:
#@title
@Translator.add_method
def translate(self,
              texts,
              *,
              max_length=500,
              temperature=tf.constant(0.0)):

  # Create a shape checker utility to validate tensor shapes during debugging
  shape_checker = ShapeChecker()
  # Convert input texts to encoder input format
  context = self.encoder.convert_input(texts)
  # Get the batch size from the context tensor
  batch_size = tf.shape(context)[0]
  # Check the shape of the context tensor: (batch, sequence_length, units)
  shape_checker(context, 'batch s units')

  # Get initial decoder state, start tokens, and done mask for the batch
  next_token, done, state = self.decoder.get_initial_state(context)

  # Initialize a dynamic TensorArray to accumulate generated tokens
  tokens = tf.TensorArray(tf.int64, size=1, dynamic_size=True)

  # Loop to generate up to max_length tokens for each sequence in the batch
  for t in tf.range(max_length):
    # Generate the next token, update done mask and RNN state
    next_token, done, state = self.decoder.get_next_token(
        context, next_token, done, state, temperature)
    # Check the shape of the next token: (batch, 1)
    shape_checker(next_token, 'batch t1')

    # Write the generated token to the TensorArray
    tokens = tokens.write(t, next_token)

    # If all sequences are done, break out of the loop early
    if tf.reduce_all(done):
      break

  # Stack the generated tokens into a tensor: (time, batch, 1)
  tokens = tokens.stack()
  shape_checker(tokens, 't batch t1')
  # Rearrange the tensor to shape: (batch, time)
  tokens = einops.rearrange(tokens, 't batch 1 -> batch t')
  shape_checker(tokens, 'batch t')

  # Convert the token IDs back to text strings
  text = self.decoder.tokens_to_text(tokens)
  shape_checker(text, 'batch')

  # Return the generated text for each input sequence
  return text

With eager execution this implementation performs on par with the original:

In [None]:
%%time
# Translate a batch of Spanish input sentences using the trained model
result = model.translate(inputs)

# Print the English translation for each input sentence in the batch
print(result[0].numpy().decode())
print(result[1].numpy().decode())
print(result[2].numpy().decode())
print()

But when you wrap it in a `tf.function` you'll notice two differences.

In [None]:
class Export(tf.Module):
  def __init__(self, model):
    # Store the trained translation model
    self.model = model

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.string, shape=[None])])
  def translate(self, inputs):
    # Exported translation function: takes a batch of strings and returns translations
    return self.model.translate(inputs)

In [None]:
# Create an instance of the Export class, wrapping the trained translation model.
# This allows you to export the model with a tf.function for inference.
export = Export(model)

First, it's much quicker to trace, since it only creates one copy of the loop body:

In [None]:
%%time
# Run the exported translation function on the batch of input sentences.
# This will compile and execute the tf.function for inference timing.
_ = export.translate(inputs)

The `tf.function` is much faster than running with eager execution, and on small inputs it's often several times faster than the unrolled version, because it can break out of the loop.

In [None]:
%%time
# Translate the batch of Spanish input sentences using the exported model
result = export.translate(inputs)

# Print the English translation for each input sentence in the batch
print(result[0].numpy().decode())
print(result[1].numpy().decode())
print(result[2].numpy().decode())
print()

So save this version as well:

In [None]:
%%time
# Set the path for saving the dynamically exported translator model
DYNAMIC_TRANSLATOR_PATH = os.path.join("artifacts", "dynamic_translator")

# Export the model using TensorFlow's SavedModel format
# - export: the Export wrapper containing the translation model
# - DYNAMIC_TRANSLATOR_PATH: directory to save the model
# - signatures: specify the default serving signature for inference
tf.saved_model.save(export, DYNAMIC_TRANSLATOR_PATH,
                    signatures={'serving_default': export.translate})

In [None]:
%%time
# Load the exported dynamic translator model from the specified path
reloaded = tf.saved_model.load(DYNAMIC_TRANSLATOR_PATH)

# Run the translate function once on the batch of input sentences to warm up the model (compiles the graph for faster inference)
_ = reloaded.translate(tf.constant(inputs)) #warmup

In [None]:
%%time
# Translate the batch of Spanish input sentences using the reloaded exported model.
result = reloaded.translate(tf.constant(inputs))

# Print the English translation for each input sentence in the batch.
print(result[0].numpy().decode())
print(result[1].numpy().decode())
print(result[2].numpy().decode())
print()

## Next steps

* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.
* Experiment with training on a larger dataset, or using more epochs.
* Try the [transformer tutorial](transformer.ipynb) which implements a similar translation task but uses transformer layers instead of RNNs. This version also uses a `text.BertTokenizer` to implement word-piece tokenization.
* Visit the [`tensorflow_addons.seq2seq` tutorial](https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt), which demonstrates a higher-level functionality for implementing this sort of sequence-to-sequence model, such as `seq2seq.BeamSearchDecoder`.