# Neural machine translation with a Transformer and Keras

## Overview 🚀

This notebook offers a hands-on tutorial for building a neural machine translation model using the Transformer architecture with Keras and TensorFlow.

You'll be guided through:
- 📦 Preparing and tokenizing a Portuguese-English translation dataset.
- 🏗️ Constructing essential model components, including positional embeddings and attention layers.
- 🧩 Assembling a multi-layer Transformer encoder-decoder model.

By the end, you will:
- 🏋️‍♂️ Train the model.
- 🌍 Generate translations.
- 🔍 Visualize attention mechanisms.

Complex concepts are broken down into clear, manageable steps, making this notebook accessible for anyone interested in state-of-the-art sequence-to-sequence models for natural language processing tasks.


## Goal 🎯

This tutorial demonstrates how to create and train a [sequence-to-sequence](https://developers.google.com/machine-learning/glossary#sequence-to-sequence-task) [Transformer](https://developers.google.com/machine-learning/glossary#Transformer) model to translate [Portuguese into English](https://www.tensorflow.org/datasets/catalog/ted_hrlr_translate#ted_hrlr_translatept_to_en). The Transformer was originally proposed in ["Attention is all you need"](https://arxiv.org/abs/1706.03762) by Vaswani et al. (2017). 🌍🔄

Transformers are deep neural networks that replace CNNs and RNNs with [self-attention](https://developers.google.com/machine-learning/glossary#self-attention). 🤖✨ Self-attention allows Transformers to easily transmit information across the input sequences.

As explained in the [Google AI Blog post](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html):

> Neural networks for machine translation typically contain an encoder reading the input sentence and generating a representation of it. A decoder then generates the output sentence word by word while consulting the representation generated by the encoder. The Transformer starts by generating initial representations, or embeddings, for each word... Then, using self-attention, it aggregates information from all of the other words, generating a new representation per word informed by the entire context, represented by the filled balls. This step is then repeated multiple times in parallel for all words, successively generating new representations.

<img src="https://www.tensorflow.org/images/tutorials/transformer/apply_the_transformer_to_machine_translation.gif" alt="Applying the Transformer to machine translation">

Figure 1: Applying the Transformer to machine translation. Source: [Google AI Blog](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html).


That's a lot to digest! The goal of this tutorial is to break it down into easy to understand parts. In this tutorial you will:

- 📦 Prepare the data.
- 🏗️ Implement necessary components:
  - 🧩 Positional embeddings.
  - 🎯 Attention layers.
  - 🏛️ The encoder and decoder.
- 🤖 Build & train the Transformer.
- 🌍 Generate translations.
- 📦 Export the model.

A Transformer is a sequence-to-sequence encoder-decoder model similar to the model in week 1. 🤖🔄

A single-layer Transformer takes a little more code to write, but is almost identical to that encoder-decoder RNN model. The only difference is that the RNN layers are replaced with self-attention layers. ✨
This tutorial builds a 4-layer Transformer which is larger and more powerful, but not fundamentally more complex. 🚀

<table>
<tr>
  <th>The <a href=https://www.tensorflow.org/text/tutorials/nmt_with_attention>RNN+Attention model</a></th>
  <th>A 1-layer transformer</th>
</tr>
<tr>
  <td>
   <img width=411 src="https://www.tensorflow.org/images/tutorials/transformer/RNN+attention-words.png"/>
  </td>
  <td>
   <img width=400 src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-words.png"/>
  </td>
</tr>
</table>

After training the model in this notebook, you will be able to input a Portuguese sentence and return the English translation.

<img src="https://www.tensorflow.org/images/tutorials/transformer/attention_map_portuguese.png" alt="Attention heatmap">

Figure 2: Visualized attention weights that you can generate at the end of this tutorial.

## Why Transformers are significant 🤖✨

- Transformers excel at modeling sequential data, such as natural language. 🗣️
- Unlike [recurrent neural networks (RNNs)](./text_generation.ipynb), Transformers are parallelizable. This makes them efficient on hardware like GPUs and TPUs. ⚡ The main reason is that Transformers replaced recurrence with attention, and computations can happen simultaneously. Layer outputs can be computed in parallel, instead of a series like an RNN.
- Unlike [RNNs](https://www.tensorflow.org/guide/keras/rnn) (such as [seq2seq, 2014](https://arxiv.org/abs/1409.3215)) or [convolutional neural networks (CNNs)](https://www.tensorflow.org/tutorials/images/cnn) (for example, [ByteNet](https://arxiv.org/abs/1610.10099)), Transformers are able to capture distant or long-range contexts and dependencies in the data between distant positions in the input or output sequences. Thus, longer connections can be learned. 🔗 Attention allows each location to have access to the entire input at each layer, while in RNNs and CNNs, the information needs to pass through many processing steps to move a long distance, which makes it harder to learn.
- Transformers make no assumptions about the temporal/spatial relationships across the data. This is ideal for processing a set of objects (for example, [StarCraft units](https://www.deepmind.com/blog/alphastar-mastering-the-real-time-strategy-game-starcraft-ii)). 🕹️

<img src="https://www.tensorflow.org/images/tutorials/transformer/encoder_self_attention_distribution.png" width="800" alt="Encoder self-attention distribution for the word it from the 5th to the 6th layer of a Transformer trained on English-to-French translation">

Figure 3: The encoder self-attention distribution for the word “it” from the 5th to the 6th layer of a Transformer trained on English-to-French translation (one of eight attention heads). Source: [Google AI Blog](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html).

## Setup

Begin by installing [TensorFlow Datasets](https://tensorflow.org/datasets) for loading the dataset and [TensorFlow Text](https://www.tensorflow.org/text) for text preprocessing:

In [None]:
# # Install a specific version of CUDA's cuDNN library required for TensorFlow GPU support.
# !apt install --allow-change-held-packages libcudnn8=8.1.0.77-1+cuda11.2

# Uninstall any existing versions of TensorFlow, Keras, TensorFlow Estimator, and TensorFlow Text
# to avoid conflicts and ensure a clean environment.
!pip uninstall -y -q tensorflow keras tensorflow-estimator tensorflow-text

# Install a compatible version of protobuf, which is required by TensorFlow and TensorFlow Datasets.
!pip install protobuf~=3.20.3

# Install TensorFlow Datasets for easy access to pre-built datasets.
!pip install -q tensorflow_datasets

# Install or upgrade TensorFlow and TensorFlow Text to the latest versions.
# TensorFlow Text provides text processing ops compatible with TensorFlow.
!pip install -q -U tensorflow-text tensorflow

!pip install matplotlib
!pip install wrapt==1.15.0
!export WRAPT_DISABLE_EXTENSIONS=true

Import the necessary modules:

In [None]:
import logging
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
import tensorflow as tf

import tensorflow_text
import wrapt
if wrapt.__version__!="1.15.0" or tf.__version__!="2.19.1":
  raise Exception(f"Please restart your session as you are still using warpt version: {wrapt.__version__} and tensorflow: {tf.__version__}")

## Data handling 📦

This section downloads the dataset and the subword tokenizer, from [this tutorial](https://www.tensorflow.org/text/guide/subwords_tokenizer), then wraps it all up in a `tf.data.Dataset` for training.

 <section class="expandable tfo-display-only-on-site">
 <button type="button" class="button-red button expand-control">Toggle section</button>


### Download the dataset

Use TensorFlow Datasets to load the [Portuguese-English translation dataset](https://www.tensorflow.org/datasets/catalog/ted_hrlr_translate#ted_hrlr_translatept_to_en)D Talks Open Translation Project. This dataset contains approximately 52,000 training, 1,200 validation and 1,800 test examples.

In [None]:
# Load the TED Talks Portuguese-English translation dataset using TensorFlow Datasets.
# 'with_info=True' returns the dataset and its metadata (such as feature info, splits, etc.).
# 'as_supervised=True' returns each example as a tuple (input, target) instead of a dictionary.

examples, metadata = tfds.load(
    'ted_hrlr_translate/pt_to_en',  # Dataset name: TED Talks Portuguese to English
    with_info=True,                 # Also return metadata about the dataset
    as_supervised=True              # Return examples as (input, target) pairs
)

# Split the loaded dataset into training and validation sets.
train_examples = examples['train']        # Training set: used to train the model
val_examples = examples['validation']     # Validation set: used to evaluate model performance during training

The `tf.data.Dataset` object returned by TensorFlow Datasets yields pairs of text examples:

In [None]:
# Iterate over a single batch of 3 examples from the training dataset.
for pt_examples, en_examples in train_examples.batch(3).take(1):
  # Print the Portuguese examples in the batch.
  print('> Examples in Portuguese:')
  for pt in pt_examples.numpy():
    # Decode the byte string to a regular string for display.
    print(pt.decode('utf-8'))
  print()

  # Print the corresponding English examples in the batch.
  print('> Examples in English:')
  for en in en_examples.numpy():
    # Decode the byte string to a regular string for display.
    print(en.decode('utf-8'))

### Set up the tokenizer 🧩

Now that you have loaded the dataset, you need to tokenize the text, so that each element is represented as a [token](https://developers.google.com/machine-learning/glossary#token) or token ID (a numeric representation). ✂️🔢

Tokenization is the process of breaking up text, into "tokens". Depending on the tokenizer, these tokens can represent sentence-pieces, words, subwords, or characters. 🧩 To learn more about tokenization, visit [this guide](https://www.tensorflow.org/text/guide/tokenizers).

This tutorial uses the tokenizers built in the [subword tokenizer](https://www.tensorflow.org/text/guide/subwords_tokenizer) tutorial. That tutorial optimizes two `text.BertTokenizer` objects (one for English, one for Portuguese) for **this dataset** and exports them in a TensorFlow `saved_model` format. 🧩🔤

> Note: This is different from the [original paper](https://arxiv.org/pdf/1706.03762.pdf), section 5.1, where they used a single byte-pair tokenizer for both the source and target with a vocabulary-size of 37000.

Download, extract, and import the `saved_model`: 📦⬇️

In [None]:
import os

# Define the directory where artifacts (such as downloaded models) will be stored.
SAVE_DIR = 'artifacts/.'

# If the directory does not exist, create it.
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

# Name of the tokenizer model to download.
model_name = 'ted_hrlr_translate_pt_en_converter'

# Download the tokenizer model zip file from TensorFlow's public storage.
# tf.keras.utils.get_file will:
#   - Download the file if it does not exist in the cache.
#   - Store it in the specified cache_dir and cache_subdir.
#   - Extract the contents of the zip file after downloading.
tf.keras.utils.get_file(
    f'{model_name}.zip',  # Name for the downloaded file.
    f'https://storage.googleapis.com/download.tensorflow.org/models/{model_name}.zip',  # URL to download from.
    cache_dir=SAVE_DIR,   # Directory to cache the file.
    cache_subdir='.',     # Subdirectory within the cache directory.
    extract=True          # Automatically extract the zip file after download.
)

In [None]:
# Construct the path to the tokenizer SavedModel directory.
# The tokenizer was downloaded and extracted in the previous cell.
# SAVE_DIR: Directory where artifacts are stored (created earlier).
# model_name: Name of the tokenizer model (set earlier).
# The "_extracted" subdirectory is created by tf.keras.utils.get_file(extract=True).
TOKENIZER_PATH = os.path.join(SAVE_DIR, f"{model_name}_extracted", model_name)

# Load the tokenizer SavedModel from the constructed path.
# This loads a TensorFlow SavedModel containing two tokenizers:
#   - tokenizers.pt: Portuguese tokenizer
#   - tokenizers.en: English tokenizer
# Each tokenizer provides methods for tokenization, detokenization, and vocabulary lookup.
tokenizers = tf.saved_model.load(TOKENIZER_PATH)

The `tf.saved_model` contains two text tokenizers, one for English and one for Portuguese. Both have the same methods:

In [None]:
# List all public attributes and methods of the English tokenizer object.
# The dir() function returns all attributes (including methods) of an object.
# The list comprehension filters out any attribute names that start with an underscore ('_'),
# which are typically private or internal attributes in Python.
[item for item in dir(tokenizers.en) if not item.startswith('_')]

The `tokenize` method converts a batch of strings to a padded-batch of token IDs. This method splits punctuation, lowercases and unicode-normalizes the input before tokenizing. That standardization is not visible here because the input data is already standardized.

In [None]:
# Print a batch of English strings from the dataset.
print('> This is a batch of strings:')
for en in en_examples.numpy():
  # Each element in en_examples is a byte string, so decode it to a regular string for display.
  print(en.decode('utf-8'))

In [None]:
# Tokenize a batch of English examples using the loaded tokenizer.
# This converts each string in en_examples to a sequence of token IDs.
encoded = tokenizers.en.tokenize(en_examples)

print('> This is a padded-batch of token IDs:')
# Iterate through each row (example) in the tokenized batch.
for row in encoded.to_list():
  # Print the list of token IDs for each example.
  print(row)

The `detokenize` method attempts to convert these token IDs back to human-readable text:

In [None]:
# Detokenize the batch of token IDs back into human-readable text.
# 'encoded' is a batch of token IDs produced by the tokenizer.
# 'tokenizers.en.detokenize' converts these token IDs back to strings.
round_trip = tokenizers.en.detokenize(encoded)

print('> This is human-readable text:')
# Iterate through each detokenized string in the batch.
for line in round_trip.numpy():
  # Each line is a byte string, so decode it to a regular string for display.
  print(line.decode('utf-8'))

The lower level `lookup` method converts from token-IDs to token text:

In [None]:
# Print a header to indicate what is being displayed.
print('> This is the text split into tokens:')

# Use the English tokenizer's 'lookup' method to convert token IDs (from 'encoded')
# back into their corresponding token strings. This returns a RaggedTensor of tokens.
tokens = tokenizers.en.lookup(encoded)

# Display the resulting tokens. Each example in the batch is shown as a list of tokens.
tokens

The output demonstrates the "subword" aspect of the subword tokenization.

For example, the word `'searchability'` is decomposed into `'search'` and `'##ability'`, and the word `'serendipity'` into `'s'`, `'##ere'`, `'##nd'`, `'##ip'` and `'##ity'`.

Note that the tokenized text includes `'[START]'` and `'[END]'` tokens.

The distribution of tokens per example in the dataset is as follows:

In [None]:
# Initialize an empty list to store the lengths of tokenized sequences.
lengths = []

# Iterate over the training examples in batches of 1024.
for pt_examples, en_examples in train_examples.batch(1024):
  # Tokenize the batch of Portuguese examples using the pretrained tokenizer.
  pt_tokens = tokenizers.pt.tokenize(pt_examples)
  # Append the lengths of each tokenized Portuguese sequence in the batch to the 'lengths' list.
  lengths.append(pt_tokens.row_lengths())

  # Tokenize the batch of English examples using the pretrained tokenizer.
  en_tokens = tokenizers.en.tokenize(en_examples)
  # Append the lengths of each tokenized English sequence in the batch to the 'lengths' list.
  lengths.append(en_tokens.row_lengths())

  # Print a dot for each batch processed, to indicate progress.
  print('.', end='', flush=True)

In [None]:
# Concatenate all tokenized sequence lengths from the 'lengths' list into a single numpy array.
all_lengths = np.concatenate(lengths)

# Plot a histogram of the token counts per example.
# The bins are set from 0 to 500, with 101 intervals.
plt.hist(all_lengths, np.linspace(0, 500, 101))

# Set the y-axis limits to the current limits (this line is redundant but preserves the current view).
plt.ylim(plt.ylim())

# Find the maximum token count in the dataset.
max_length = max(all_lengths)

# Draw a vertical line at the maximum token count to highlight it on the histogram.
plt.plot([max_length, max_length], plt.ylim())

# Add a title to the plot showing the maximum number of tokens per example.
plt.title(f'Maximum tokens per example: {max_length}');

### Set up a data pipeline with `tf.data` 📊🔄

The following function takes batches of text as input, and converts them to a format suitable for training. 📝➡️🤖

1. It tokenizes them into ragged batches. ✂️
2. It trims each to be no longer than `MAX_TOKENS`. ✂️🔢
3. It splits the target (English) tokens into inputs and labels. These are shifted by one step so that at each input location the `label` is the id of the next token. 🔄
4. It converts the `RaggedTensor`s to padded dense `Tensor`s. 🧩
5. It returns an `(inputs, labels)` pair. 🎯


In [None]:
# Set the maximum number of tokens for input and output sequences.
MAX_TOKENS = 128

def prepare_batch(pt, en):
    # Tokenize the batch of Portuguese sentences.
    # Output is a RaggedTensor of token IDs.
    pt = tokenizers.pt.tokenize(pt)
    # Trim each Portuguese sequence to a maximum of MAX_TOKENS tokens.
    pt = pt[:, :MAX_TOKENS]
    # Convert the RaggedTensor to a dense Tensor, padding with zeros as needed.
    pt = pt.to_tensor()

    # Tokenize the batch of English sentences.
    en = tokenizers.en.tokenize(en)
    # Trim each English sequence to a maximum of MAX_TOKENS+1 tokens.
    # The extra token is for shifting during input/label creation.
    en = en[:, :(MAX_TOKENS + 1)]
    # Prepare the decoder input by removing the last token ([END]).
    en_inputs = en[:, :-1].to_tensor()
    # Prepare the decoder labels by removing the first token ([START]).
    en_labels = en[:, 1:].to_tensor()

    # Return a tuple: ((Portuguese tokens, English input tokens), English label tokens)
    # This matches the expected input format for training a sequence-to-sequence model.
    return (pt, en_inputs), en_labels

The function below converts a dataset of text examples into data of batches for training. 📝➡️📦

1. It tokenizes the text, and filters out the sequences that are too long. ✂️
   (The `batch`/`unbatch` is included because the tokenizer is much more efficient on large batches).
2. The `cache` method ensures that that work is only executed once. 🗃️
3. Then `shuffle` and, `dense_to_ragged_batch` randomize the order and assemble batches of examples. 🔀
4. Finally `prefetch` runs the dataset in parallel with the model to ensure that data is available when needed. ⚡ See [Better performance with the `tf.data`](https://www.tensorflow.org/guide/data_performance.ipynb) for details.

In [None]:
# Set the buffer size for shuffling the dataset.
# A larger buffer size means better shuffling, but uses more memory.
BUFFER_SIZE = 20000

# Set the batch size for training.
# This determines how many examples are processed together in one training step.
BATCH_SIZE = 64

In [None]:
def make_batches(ds):
  # Shuffle the dataset with a buffer size for randomness.
  # This helps ensure that batches are not correlated and improves training.
  return (
      ds
      .shuffle(BUFFER_SIZE)  # Randomly shuffle the dataset using the specified buffer size.
      .batch(BATCH_SIZE)     # Group the dataset into batches of size BATCH_SIZE.
      # Map the prepare_batch function to each batch.
      # prepare_batch tokenizes, trims, and formats the data for training.
      .map(prepare_batch, tf.data.AUTOTUNE)
      # Prefetch allows data loading and processing to happen asynchronously,
      # so the model always has data ready for training, improving performance.
      .prefetch(buffer_size=tf.data.AUTOTUNE)
  )

## Test the Dataset 🧪✨

In [None]:
# Create training and validation set batches for the Transformer model.
#
# The make_batches function:
#   - Shuffles the dataset for randomness (improves training).
#   - Batches the data into groups of BATCH_SIZE examples.
#   - Tokenizes and formats each batch using prepare_batch (converts text to token IDs, trims, splits, pads).
#   - Prefetches batches for efficient input pipeline (overlaps data preparation and model execution).
#
# train_batches: tf.data.Dataset of (inputs, labels) pairs for training.
# val_batches: tf.data.Dataset of (inputs, labels) pairs for validation.
# Each 'inputs' is a tuple of (Portuguese tokens, English input tokens).
# Each 'labels' is the English target tokens, shifted by one position for teacher forcing.
train_batches = make_batches(train_examples)
val_batches = make_batches(val_examples)

The resulting `tf.data.Dataset` objects are setup for training with Keras.  
Keras `Model.fit` training expects `(inputs, labels)` pairs.  
The `inputs` are pairs of tokenized Portuguese and English sequences, `(pt, en)`.  
The `labels` are the same English sequences shifted by 1.  
This shift is so that at each location input `en` sequence, the `label` in the next token.  

🧑‍💻➡️🤖


<table>
<tr>
  <th>Inputs at the bottom, labels at the top.</th>
</tr>
<tr>
  <td>
   <img width=400 src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-words.png"/>
  </td>
</tr>
</table>

This is the same as the [text generation tutorial](text_generation.ipynb) ✍️,
except here you have additional input "context" (the Portuguese sequence) that the model is "conditioned" on. 🌐➡️🇬🇧

This setup is called "teacher forcing" 👨‍🏫 because regardless of the model's output at each timestep, it gets the true value as input for the next timestep.
This is a simple and efficient way to train a text generation model. ⚡
It's efficient because you don't need to run the model sequentially, the outputs at the different sequence locations can be computed in parallel. 🏃‍♂️🏃‍♀️

You might have expected the `input, output`, pairs to simply be the `Portuguese, English` sequences. 🇵🇹➡️🇬🇧
Given the Portuguese sequence, the model would try to generate the English sequence.

It's possible to train a model that way. You'd need to write out the inference loop and pass the model's output back to the input. 🔄
It's slower (time steps can't run in parallel), and a harder task to learn (the model can't get the end of a sentence right until it gets the beginning right), 🐢
but it can give a more stable model because the model has to learn to correct its own errors during training. 🛠️

In [None]:
# Iterate over one batch from the training dataset.
# train_batches yields ((pt, en), en_labels) for each batch:
#   - pt: Portuguese token IDs, shape (batch_size, seq_len)
#   - en: English input token IDs, shape (batch_size, seq_len)
#   - en_labels: English target token IDs, shape (batch_size, seq_len)
for (pt, en), en_labels in train_batches.take(1):
  # 'break' is used to exit after the first batch is retrieved.
  break

# Print the shapes of the token tensors in the batch.
# These shapes help verify the batch dimensions and sequence lengths.
print(pt.shape)        # Shape of Portuguese token batch (batch_size, seq_len)
print(en.shape)        # Shape of English input token batch (batch_size, seq_len)
print(en_labels.shape) # Shape of English label token batch (batch_size, seq_len)

The `en` and `en_labels` are the same, just shifted by 1:

In [None]:
# Print the first 10 token IDs from the first example in the English input batch.
# 'en' contains the input token IDs for the decoder (English), shape: (batch_size, seq_len).
print(en[0][:10])

# Print the first 10 token IDs from the first example in the English label batch.
# 'en_labels' contains the target token IDs for the decoder (English), shape: (batch_size, seq_len).
# These are shifted by one position compared to 'en', so each label is the next token for each input position.
print(en_labels[0][:10])

## Define the components 🧩

There's a lot going on inside a Transformer. 🤖 The important things to remember are:

1. It follows the same general pattern as a standard sequence-to-sequence model with an encoder and a decoder. 🔄
2. If you work through it step by step it will all make sense. 🪜✨

<table>
<tr>
  <th colspan=1>The original Transformer diagram</th>
  <th colspan=1>A representation of a 4-layer Transformer</th>
</tr>
<tr>
  <td>
   <img width=400 src="https://www.tensorflow.org/images/tutorials/transformer/transformer.png"/>
  </td>
  <td>
   <img width=307 src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-4layer-compact.png"/>
  </td>
</tr>
</table>

Each of the components in these two diagrams will be explained as you progress through the tutorial.

### The embedding and positional encoding layer 🧩✨

The inputs to both the encoder and decoder use the same embedding and positional encoding logic. ✨🔢

<table>
<tr>
  <th colspan=1>The embedding and positional encoding layer</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/PositionalEmbedding.png"/>
  </td>
</tr>
</table>

Given a sequence of tokens, both the input tokens (Portuguese) and target tokens (English) have to be converted to vectors using a `tf.keras.layers.Embedding` layer. 🧩🔢

The attention layers used throughout the model see their input as a set of vectors, with no order. Since the model doesn't contain any recurrent or convolutional layers, it needs some way to identify word order, otherwise it would see the input sequence as a [bag of words](https://developers.google.com/machine-learning/glossary#bag-of-words) instance — `how are you`, `how you are`, `you how are`, and so on, are indistinguishable. 👜

A Transformer adds a "Positional Encoding" to the embedding vectors. It uses a set of sines and cosines at different frequencies (across the sequence). By definition, nearby elements will have similar position encodings. 🌊🔢

The original paper uses the following formula for calculating the positional encoding: ✨

$$\Large{PE_{(pos, 2i)} = \sin(pos / 10000^{2i / d_{model}})} $$
$$\Large{PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i / d_{model}})} $$

> 📝 Note: The code below implements it, but instead of interleaving the sines and cosines, the vectors of sines and cosines are simply concatenated. Permuting the channels like this is functionally equivalent, and just a little easier to implement and show in the plots below.

In [None]:
def positional_encoding(length, depth):
  # The positional encoding uses sine and cosine functions at different frequencies.
  # The depth is split in half for sine and cosine components.
  depth = depth / 2

  # Create a column vector of positions (sequence indices), shape: (length, 1)
  positions = np.arange(length)[:, np.newaxis]

  # Create a row vector of normalized depths, shape: (1, depth)
  # This determines the frequency for each dimension.
  depths = np.arange(depth)[np.newaxis, :] / depth

  # Calculate the angle rates for each depth dimension.
  # This controls the frequency of the sine/cosine for each channel.
  angle_rates = 1 / (10000 ** depths)  # shape: (1, depth)

  # Compute the angle radians for each position and depth.
  # This is the core of the positional encoding formula.
  angle_rads = positions * angle_rates  # shape: (length, depth)

  # Concatenate the sine and cosine encodings along the last axis.
  # This doubles the depth, matching the original embedding dimension.
  pos_encoding = np.concatenate(
      [np.sin(angle_rads), np.cos(angle_rads)],
      axis=-1
  )

  # Convert the numpy array to a TensorFlow tensor of type float32.
  return tf.cast(pos_encoding, dtype=tf.float32)

The position encoding function is a stack of sines and cosines that vibrate at different frequencies depending on their location along the depth of the embedding vector. They vibrate across the position axis.

In [None]:
#@title
# Generate positional encoding for a sequence of length 2048 and embedding depth 512.
# This encoding will be used to inject positional information into token embeddings.
pos_encoding = positional_encoding(length=2048, depth=512)

# Print the shape of the positional encoding tensor.
# Expected shape: (2048, 512), where 2048 is the sequence length and 512 is the embedding dimension.
print(pos_encoding.shape)

# Visualize the positional encoding matrix.
# Transpose the matrix so that each row corresponds to a depth dimension and each column to a position.
# Use a color mesh plot to show how the encoding values vary across positions and embedding dimensions.
plt.pcolormesh(pos_encoding.numpy().T, cmap='RdBu')
plt.ylabel('Depth')      # Y-axis: embedding dimension (depth)
plt.xlabel('Position')   # X-axis: sequence position
plt.colorbar()           # Add a color bar to indicate value scale
plt.show()               # Display the plot

By definition these vectors align well with nearby vectors along the position axis.  
Below, the position encoding vectors are normalized and the vector from position `1000` is compared, by dot-product, to all the others: ✨🔢

In [None]:
#@title
# Normalize each positional encoding vector to unit length along the depth axis.
# This ensures that the dot products below measure only the direction similarity, not magnitude.
pos_encoding /= tf.norm(pos_encoding, axis=1, keepdims=True)

# Select the positional encoding vector at position 1000.
p = pos_encoding[1000]

# Compute the dot product between the position-1000 vector and every other position vector.
# tf.einsum('pd,d -> p', pos_encoding, p) computes the dot product for each position.
dots = tf.einsum('pd,d -> p', pos_encoding, p)

# Plot the dot products for all positions to visualize similarity with position 1000.
plt.subplot(2,1,1)
plt.plot(dots)                # Plot similarity across all positions.
plt.ylim([0,1])               # Limit y-axis to [0, 1] for clarity.

# Draw vertical lines to highlight the zoom region (positions 950 to 1050).
plt.plot([950, 950, float('nan'), 1050, 1050],
         [0,1,float('nan'),0,1], color='k', label='Zoom')
plt.legend()

# Plot a zoomed-in view of the dot products around position 1000.
plt.subplot(2,1,2)
plt.plot(dots)
plt.xlim([950, 1050])         # Focus x-axis on positions near 1000.
plt.ylim([0,1])               # Keep y-axis limits consistent.


So use this to create a `PositionEmbedding` layer that looks-up a token's embedding vector and adds the position vector:

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):
  def __init__(self, vocab_size, d_model):
    super().__init__()
    self.d_model = d_model
    # Embedding layer: maps token IDs to dense vectors of size d_model.
    # mask_zero=True ensures that padding tokens (ID=0) are masked out.
    self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
    # Precompute positional encodings for sequences up to length 2048.
    # These encodings inject position information into the token embeddings.
    self.pos_encoding = positional_encoding(length=2048, depth=d_model)

  def compute_mask(self, *args, **kwargs):
    # Propagate the mask from the embedding layer.
    # This allows downstream layers to ignore padding tokens.
    return self.embedding.compute_mask(*args, **kwargs)

  def call(self, x):
    # x: input tensor of token IDs, shape (batch_size, seq_len)
    length = tf.shape(x)[1]  # Get the sequence length for this batch.
    x = self.embedding(x)    # Convert token IDs to embedding vectors.
    # Scale embeddings by sqrt(d_model) as in the original Transformer paper.
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    # Add positional encoding to each token embedding.
    # pos_encoding is sliced to match the current sequence length.
    x = x + self.pos_encoding[tf.newaxis, :length, :]
    # Return the combined embeddings (with position information).
    return x


> Note: The [original paper](https://arxiv.org/pdf/1706.03762.pdf), section 3.4 and 5.1, uses a single tokenizer and weight matrix for both the source and target languages. This tutorial uses two separate tokenizers and weight matrices.

In [None]:
# Create positional embedding layers for Portuguese and English tokens.
# - vocab_size: Number of unique tokens in the tokenizer's vocabulary.
# - d_model: Dimensionality of the embedding vectors (512, as used in the Transformer model).
embed_pt = PositionalEmbedding(
    vocab_size=tokenizers.pt.get_vocab_size().numpy(),  # Portuguese vocab size
    d_model=512                                         # Embedding dimension
)
embed_en = PositionalEmbedding(
    vocab_size=tokenizers.en.get_vocab_size().numpy(),  # English vocab size
    d_model=512                                         # Embedding dimension
)

# Apply the positional embedding layers to the tokenized input sequences.
# - pt: Tensor of Portuguese token IDs (shape: [batch_size, seq_len])
# - en: Tensor of English token IDs (shape: [batch_size, seq_len])
# The output is a tensor of shape [batch_size, seq_len, d_model] for each language,
# where each token is represented by a vector that combines its learned embedding and positional encoding.
pt_emb = embed_pt(pt)  # Portuguese token embeddings with position encoding
en_emb = embed_en(en)  # English token embeddings with position encoding

In [None]:
# The _keras_mask attribute indicates which positions in the input are padding (masked out).
# This mask is automatically generated by the Embedding layer when mask_zero=True.
# It is a boolean tensor of shape (batch_size, sequence_length), where True means the token is not padding.
# This mask is used by subsequent layers (like attention) to ignore padding tokens during computation.
en_emb._keras_mask

### Add and normalize

<table>
<tr>
  <th colspan=2>Add and normalize</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/Add+Norm.png"/>
  </td>
</tr>
</table>

These "Add & Norm" blocks are scattered throughout the model. Each one joins a residual connection and runs the result through a `LayerNormalization` layer. ✨➕📏

The easiest way to organize the code is around these residual blocks. The following sections will define custom layer classes for each. 🧩

The residual "Add & Norm" blocks are included so that training is efficient. The residual connection provides a direct path for the gradient (and ensures that vectors are **updated** by the attention layers instead of **replaced**), while the normalization maintains a reasonable scale for the outputs. 🚀

Note: The implementations, below, use the `Add` layer to ensure that Keras masks are propagated (the `+` operator does not). 🛡️


### The base attention layer 🤖✨

Attention layers are used throughout the model. These are all identical except for how the attention is configured. Each one contains a `layers.MultiHeadAttention`, a `layers.LayerNormalization` and a `layers.Add`. 🤖✨🧠

<table>
<tr>
  <th colspan=2>The base attention layer</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/BaseAttention.png"/>
  </td>
</tr>
</table>

To implement these attention layers, start with a simple base class that just contains the component layers. Each use-case will be implemented as a subclass. It's a little more code to write this way, but it keeps the intention clear. 🤖✨

In [None]:
class BaseAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    # MultiHeadAttention layer: computes attention over input sequences.
    # kwargs can include num_heads, key_dim, etc.
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    # LayerNormalization: normalizes the output for stable training.
    self.layernorm = tf.keras.layers.LayerNormalization()
    # Add layer: adds residual connections (input + attention output).
    self.add = tf.keras.layers.Add()

#### Attention refresher 🤖✨

Before you get into the specifics of each usage, here is a quick refresher on how attention works:

<table>
<tr>
  <th colspan=1>The base attention layer</th>
</tr>
<tr>
  <td>
   <img width=430 src="https://www.tensorflow.org/images/tutorials/transformer/BaseAttention-new.png"/>
  </td>
</tr>
</table>

There are two inputs: 🤔🔍

1. The query sequence; the sequence being processed; the sequence doing the attending (bottom).
2. The context sequence; the sequence being attended to (left).

The output has the same shape as the query-sequence.

The common comparison is that this operation is like a dictionary lookup.
A **fuzzy**, **differentiable**, **vectorized** dictionary lookup.

Here's a regular python dictionary, with 3 keys and 3 values being passed a single query.

```
d = {'color': 'blue', 'age': 22, 'type': 'pickup'}
result = d['color']
```

- The `query`s is what you're trying to find.
- The `key`s what sort of information the dictionary has.
- The `value` is that information.

When you look up a `query` in a regular dictionary, the dictionary finds the matching `key`, and returns its associated `value`.
The `query` either has a matching `key` or it doesn't.
You can imagine a **fuzzy** dictionary where the keys don't have to match perfectly.
If you looked up `d["species"]` in the dictionary above, maybe you'd want it to return `"pickup"` since that's the best match for the query.

An attention layer does a fuzzy lookup like this, but it's not just looking for the best key.
It combines the `values` based on how well the `query` matches each `key`.

How does that work? In an attention layer the `query`, `key`, and `value` are each vectors.
Instead of doing a hash lookup the attention layer combines the `query` and `key` vectors to determine how well they match, the "attention score".
The layer returns the average across all the `values`, weighted by the "attention scores".

Each location the query-sequence provides a `query` vector.
The context sequence acts as the dictionary. At each location in the context sequence provides a `key` and `value` vector.
The input vectors are not used directly, the `layers.MultiHeadAttention` layer includes `layers.Dense` layers to project the input vectors before using them.


### The cross attention layer 🤝✨

At the literal center of the Transformer is the cross-attention layer. 🤝✨ This layer connects the encoder and decoder. This layer is the most straight-forward use of attention in the model, it performs the same task as the attention block in the [NMT with attention tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention).

<table>
<tr>
  <th colspan=1>The cross attention layer</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/CrossAttention.png"/>
  </td>
</tr>
</table>

To implement this you pass the target sequence `x` as the `query` and the `context` sequence as the `key/value` when calling the `mha` layer: 🤝✨

In [None]:
class CrossAttention(BaseAttention):
  def call(self, x, context):
    # Compute multi-head attention:
    # - query: the target sequence (x)
    # - key/value: the context sequence (context, typically encoder output)
    # - return_attention_scores=True: also returns the attention weights for visualization
    attn_output, attn_scores = self.mha(
        query=x,
        key=context,
        value=context,
        return_attention_scores=True)

    # Store the attention scores for later analysis or visualization
    self.last_attn_scores = attn_scores

    # Add the attention output to the original input (residual connection)
    x = self.add([x, attn_output])
    # Normalize the result for stable training
    x = self.layernorm(x)

    # Return the processed output (same shape as input x)
    return x

The caricature below shows how information flows through this layer. The columns represent the weighted sum over the context sequence. 🎨🔄

For simplicity the residual connections are not shown.

<table>
<tr>
  <th>The cross attention layer</th>
</tr>
<tr>
  <td>
   <img width=430 src="https://www.tensorflow.org/images/tutorials/transformer/CrossAttention-new-full.png"/>
  </td>
</tr>
</table>

The output length is the length of the `query` sequence, and not the length of the context `key/value` sequence. 🕵️‍♂️🔑

The diagram is further simplified, below. There's no need to draw the entire "Attention weights" matrix. 🎨
The point is that each `query` location can see all the `key/value` pairs in the context, but no information is exchanged between the queries. 👀🔄

<table>
<tr>
  <th>Each query sees the whole context.</th>
</tr>
<tr>
  <td>
   <img width=430 src="https://www.tensorflow.org/images/tutorials/transformer/CrossAttention-new.png"/>
  </td>
</tr>
</table>

Test run it on sample inputs:

In [None]:
# Create a CrossAttention layer instance with 2 attention heads and key dimension of 512.
# This layer will allow the decoder (English embeddings) to attend to the encoder (Portuguese embeddings).
sample_ca = CrossAttention(num_heads=2, key_dim=512)

# Print the shape of the Portuguese token embeddings with positional encoding.
# pt_emb: Tensor of shape (batch_size, pt_seq_len, d_model)
print(pt_emb.shape)

# Print the shape of the English token embeddings with positional encoding.
# en_emb: Tensor of shape (batch_size, en_seq_len, d_model)
print(en_emb.shape)

# Pass the English embeddings (as queries) and Portuguese embeddings (as context) to the CrossAttention layer.
# This computes attention from each English token to all Portuguese tokens.
# The output shape matches the input query shape: (batch_size, en_seq_len, d_model)
print(sample_ca(en_emb, pt_emb).shape)

### The global self-attention layer 🤖✨

This layer is responsible for processing the context sequence, and propagating information along its length: 🔄✨

<table>
<tr>
  <th colspan=1>The global self-attention layer</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/SelfAttention.png"/>
  </td>
</tr>
</table>

Since the context sequence is fixed while the translation is being generated, information is allowed to flow in both directions. 🔄

Before Transformers and self-attention, models commonly used RNNs or CNNs to do this task: 🤖🧠

<table>
<tr>
  <th colspan=1>Bidirectional RNNs and CNNs</th>
</tr>
<tr>
  <td>
   <img width=500 src="https://www.tensorflow.org/images/tutorials/transformer/RNN-bidirectional.png"/>
  </td>
</tr>
<tr>
  <td>
   <img width=500 src="https://www.tensorflow.org/images/tutorials/transformer/CNN.png"/>
  </td>
</tr>
</table>

RNNs and CNNs have their limitations. 🤔

- The RNN allows information to flow all the way across the sequence, but it passes through many processing steps to get there (limiting gradient flow). These RNN steps have to be run sequentially and so the RNN is less able to take advantage of modern parallel devices. 🐢
- In the CNN each location can be processed in parallel, but it only provides a limited receptive field. The receptive field only grows linearly with the number of CNN layers,  You need to stack a number of Convolution layers to transmit information across the sequence ([Wavenet](https://arxiv.org/abs/1609.03499) reduces this problem by using dilated convolutions). 🏗️

The global self-attention layer on the other hand lets every sequence element directly access every other sequence element, with only a few operations, and all the outputs can be computed in parallel. ⚡🤖

To implement this layer you just need to pass the target sequence, `x`, as both the `query`, and `value` arguments to the `mha` layer:

In [None]:
class GlobalSelfAttention(BaseAttention):
  def call(self, x):
    # Compute multi-head self-attention:
    # - query: the input sequence (x)
    # - key: the input sequence (x)
    # - value: the input sequence (x)
    # This allows every position in the sequence to attend to every other position.
    attn_output = self.mha(
        query=x,
        value=x,
        key=x)
    # Add the attention output to the original input (residual connection).
    x = self.add([x, attn_output])
    # Normalize the result for stable training.
    x = self.layernorm(x)
    # Return the processed output (same shape as input x).
    return x

In [None]:
# Create an instance of the GlobalSelfAttention layer.
# - num_heads=2: The attention mechanism will use 2 separate attention heads.
# - key_dim=512: Each attention head projects the input to a 512-dimensional space.
sample_gsa = GlobalSelfAttention(num_heads=2, key_dim=512)

# Print the shape of the Portuguese token embeddings with positional encoding.
# pt_emb: Tensor of shape (batch_size, sequence_length, embedding_dim)
print(pt_emb.shape)

# Pass the Portuguese embeddings through the GlobalSelfAttention layer.
# This allows each token in the sequence to attend to every other token in the same sequence.
# The output shape matches the input: (batch_size, sequence_length, embedding_dim)
print(sample_gsa(pt_emb).shape)

Sticking with the same style as before you could draw it like this: 🎨✨

<table>
<tr>
  <th colspan=1>The global self-attention layer</th>
<tr>
<tr>
  <td>
   <img width=330 src="https://www.tensorflow.org/images/tutorials/transformer/SelfAttention-new-full.png"/>
  </td>
</tr>
</table>

Again, the residual connections are omitted for clarity. 😊

It's more compact, and just as accurate to draw it like this: 🎨

<table>
<tr>
  <th colspan=1>The global self-attention layer</th>
<tr>
<tr>
  <td>
   <img width=500 src="https://www.tensorflow.org/images/tutorials/transformer/SelfAttention-new.png"/>
  </td>
</tr>
</table>

### The causal self-attention layer ⏩🤖✨

This layer does a similar job as the global self-attention layer, for the output sequence: ⏩🤖✨

<table>
<tr>
  <th colspan=1>The causal self-attention layer</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/CausalSelfAttention.png"/>
  </td>
</tr>
</table>

This needs to be handled differently from the encoder's global self-attention layer.  

Like the [text generation tutorial](https://www.tensorflow.org/text/tutorials/text_generation) and the [NMT with attention](https://www.tensorflow.org/text/tutorials/nmt_with_attention) tutorial, Transformers are an "autoregressive" model: They generate the text one token at a time and feed that output back to the input. To make this _efficient_, these models ensure that the output for each sequence element only depends on the previous sequence elements; the models are "causal". ⏩🤖✨

A single-direction RNN is causal by definition. 🕒 To make a causal convolution you just need to pad the input and shift the output so that it aligns correctly (use `layers.Conv1D(padding='causal')`) . 🧩➡️

<table>
<tr>
  <th colspan=1>Causal RNNs and CNNs</th>
</tr>
<tr>
  <td>
   <img width=500 src="https://www.tensorflow.org/images/tutorials/transformer/RNN.png"/>
  </td>
</tr>
<tr>
  <td>
   <img width=500 src="https://www.tensorflow.org/images/tutorials/transformer/CNN-causal.png"/>
  </td>
</tr>
</table>

A causal model is efficient in two ways: ⚡⏩

1. In training, it lets you compute loss for every location in the output sequence while executing the model just once. 🏋️‍♂️
2. During inference, for each new token generated you only need to calculate its outputs, the outputs for the previous sequence elements can be reused. 🔄
  - For an RNN you just need the RNN-state to account for previous computations (pass `return_state=True` to the RNN layer's constructor). 🧠
  - For a CNN you would need to follow the approach of [Fast Wavenet](https://arxiv.org/abs/1611.09482) 🚀

To build a causal self-attention layer, you need to use an appropriate mask when computing the attention scores and summing the attention `value`s. 🛡️✨

This is taken care of automatically if you pass `use_causal_mask = True` to the `MultiHeadAttention` layer when you call it: 🤖⏩

In [None]:
class CausalSelfAttention(BaseAttention):
  def call(self, x):
    # Compute multi-head self-attention with a causal mask:
    # - query: the input sequence (x)
    # - key: the input sequence (x)
    # - value: the input sequence (x)
    # - use_causal_mask=True: ensures that each position can only attend to previous positions (not future ones)
    attn_output = self.mha(
        query=x,
        value=x,
        key=x,
        use_causal_mask=True)
    # Add the attention output to the original input (residual connection)
    x = self.add([x, attn_output])
    # Normalize the result for stable training
    x = self.layernorm(x)
    # Return the processed output (same shape as input x)
    return x

The causal mask ensures that each location only has access to the locations that come before it:

<table>
<tr>
  <th colspan=1>The causal self-attention layer</th>
<tr>
<tr>
  <td>
   <img width=330 src="https://www.tensorflow.org/images/tutorials/transformer/CausalSelfAttention-new-full.png"/>
  </td>
</tr>
</table>

Again, the residual connections are omitted for simplicity.

The more compact representation of this layer would be:

<table>
</tr>
  <th colspan=1>The causal self-attention layer</th>
<tr>
<tr>
  <td>
   <img width=430 src="https://www.tensorflow.org/images/tutorials/transformer/CausalSelfAttention-new.png"/>
  </td>
</tr>
</table>

Test out the layer:

In [None]:
# Create an instance of the CausalSelfAttention layer.
# - num_heads=2: The attention mechanism will use 2 separate attention heads.
# - key_dim=512: Each attention head projects the input to a 512-dimensional space.
sample_csa = CausalSelfAttention(num_heads=2, key_dim=512)

# Print the shape of the English token embeddings with positional encoding.
# en_emb: Tensor of shape (batch_size, sequence_length, embedding_dim)
print(en_emb.shape)

# Pass the English embeddings through the CausalSelfAttention layer.
# This allows each token in the sequence to attend only to previous tokens (not future ones),
# enforcing causality for autoregressive decoding.
# The output shape matches the input: (batch_size, sequence_length, embedding_dim)
print(sample_csa(en_emb).shape)

The output for early sequence elements doesn't depend on later elements, so it shouldn't matter if you trim elements before or after applying the layer:

In [None]:
# Compute the output of the CausalSelfAttention layer on the first 3 tokens of each sequence.
# - en[:, :3]: Selects the first 3 tokens from each sequence in the English input batch.
# - embed_en(en[:, :3]): Applies the positional embedding to these tokens.
# - sample_csa(...): Passes the embedded tokens through the causal self-attention layer.
out1 = sample_csa(embed_en(en[:, :3]))

# Compute the output of the CausalSelfAttention layer on the full sequence,
# then select only the first 3 tokens from the output.
# - embed_en(en): Applies positional embedding to the full English input batch.
# - sample_csa(...): Passes the full embedded sequence through the causal self-attention layer.
# - ...[:, :3]: Selects the first 3 tokens from the output sequence.
out2 = sample_csa(embed_en(en))[:, :3]

# Calculate the maximum absolute difference between the two outputs.
# - abs(out1 - out2): Element-wise absolute difference between the two tensors.
# - tf.reduce_max(...): Finds the maximum value in the difference tensor.
# - .numpy(): Converts the result to a NumPy scalar for display.
tf.reduce_max(abs(out1 - out2)).numpy()

Note: When using Keras masks, the output values at invalid locations are not well defined. So the above may not hold for masked regions. ⚠️🤖

### The feed forward network ⚡🧠

The transformer also includes this point-wise feed-forward network in both the encoder and decoder: ⚡🧠

<table>
<tr>
  <th colspan=1>The feed forward network</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/FeedForward.png"/>
  </td>
</tr>
</table>

The network consists of two linear layers (`tf.keras.layers.Dense`) with a ReLU activation in-between, and a dropout layer. As with the attention layers the code here also includes the residual connection and normalization: ⚡🧠➕📏

In [None]:
class FeedForward(tf.keras.layers.Layer):
  def __init__(self, d_model, dff, dropout_rate=0.1):
    super().__init__()
    # Sequential block: two dense layers with ReLU and dropout in between.
    # - First Dense: expands to dff units (hidden size), ReLU activation.
    # - Second Dense: projects back to d_model (original embedding size).
    # - Dropout: regularizes the output to prevent overfitting.
    self.seq = tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),   # Hidden layer with ReLU
      tf.keras.layers.Dense(d_model),                  # Output layer, restores original size
      tf.keras.layers.Dropout(dropout_rate)            # Dropout for regularization
    ])
    # Add layer for residual connection (input + output of FFN)
    self.add = tf.keras.layers.Add()
    # LayerNormalization to stabilize training and scale outputs
    self.layer_norm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    # Apply the feed-forward network to the input
    # Add the input (residual connection) to the output of the FFN
    x = self.add([x, self.seq(x)])
    # Normalize the result for stable training
    x = self.layer_norm(x)
    # Return the processed tensor
    return x


Test the layer, the output is the same shape as the input:

In [None]:
# Create an instance of the FeedForward layer.
# - d_model=512: The input and output dimensionality of the layer (matches embedding size).
# - dff=2048: The hidden layer size inside the feed-forward network.
sample_ffn = FeedForward(512, 2048)

# Print the shape of the English token embeddings with positional encoding.
# en_emb: Tensor of shape (batch_size, sequence_length, embedding_dim)
print(en_emb.shape)

# Pass the English embeddings through the FeedForward layer.
# The output shape matches the input: (batch_size, sequence_length, embedding_dim)
# This layer applies two dense layers with ReLU and dropout, then adds a residual connection and normalizes the result.
print(sample_ffn(en_emb).shape)

### The encoder layer 🧩✨

The encoder contains a stack of `N` encoder layers. Where each `EncoderLayer` contains a `GlobalSelfAttention` and `FeedForward` layer: 🧩✨

<table>
<tr>
  <th colspan=1>The encoder layer</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/EncoderLayer.png"/>
  </td>
</tr>
</table>

Here is the definition of the `EncoderLayer`:

In [None]:
class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):
    super().__init__()
    # GlobalSelfAttention: allows each token in the input sequence to attend to every other token.
    # - num_heads: Number of attention heads.
    # - key_dim: Dimensionality of each attention head (usually matches d_model).
    # - dropout: Dropout rate for regularization.
    self.self_attention = GlobalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    # FeedForward: point-wise feed-forward network applied to each position.
    # - d_model: Input/output dimensionality (matches embedding size).
    # - dff: Hidden layer size inside the feed-forward network.
    self.ffn = FeedForward(d_model, dff)

  def call(self, x):
    # Apply global self-attention to the input sequence.
    # This propagates information along the sequence.
    x = self.self_attention(x)
    # Apply the feed-forward network to each position in the sequence.
    x = self.ffn(x)
    # Return the processed sequence (same shape as input).
    return x

And a quick test, the output will have the same shape as the input:

In [None]:
# Create an instance of the EncoderLayer.
# - d_model=512: The dimensionality of the input and output vectors (embedding size).
# - num_heads=8: Number of attention heads in the multi-head attention mechanism.
# - dff=2048: The hidden layer size inside the feed-forward network.
sample_encoder_layer = EncoderLayer(d_model=512, num_heads=8, dff=2048)

# Print the shape of the Portuguese token embeddings with positional encoding.
# pt_emb: Tensor of shape (batch_size, sequence_length, embedding_dim)
print(pt_emb.shape)

# Pass the Portuguese embeddings through the EncoderLayer.
# The EncoderLayer applies:
#   1. Global self-attention: Each token attends to every other token in the sequence.
#   2. Feed-forward network: Point-wise transformation of each token embedding.
# The output shape matches the input: (batch_size, sequence_length, embedding_dim)
print(sample_encoder_layer(pt_emb).shape)

### The encoder 🧩✨

Next build the encoder. 🏗️✨

<table>
<tr>
  <th colspan=1>The encoder</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/Encoder.png"/>
  </td>
</tr>
</table>

The encoder consists of: 🏗️✨

- A `PositionalEmbedding` layer at the input. 🧩🔢
- A stack of `EncoderLayer` layers. 🧱🧱🧱

In [None]:
class Encoder(tf.keras.layers.Layer):
  def __init__(self, *, num_layers, d_model, num_heads,
               dff, vocab_size, dropout_rate=0.1):
    super().__init__()

    # Store the model dimension and number of layers for reference
    self.d_model = d_model
    self.num_layers = num_layers

    # PositionalEmbedding: Converts token IDs to embeddings and adds positional encoding
    # - vocab_size: Number of unique tokens in the vocabulary
    # - d_model: Dimensionality of the embedding vectors
    self.pos_embedding = PositionalEmbedding(
        vocab_size=vocab_size, d_model=d_model)

    # Stack of EncoderLayer instances
    # Each EncoderLayer contains:
    #   - GlobalSelfAttention: lets each token attend to every other token in the sequence
    #   - FeedForward: point-wise feed-forward network
    # - num_layers: Number of layers to stack
    # - d_model: Embedding dimension
    # - num_heads: Number of attention heads
    # - dff: Hidden layer size in the feed-forward network
    # - dropout_rate: Dropout rate for regularization
    self.enc_layers = [
        EncoderLayer(d_model=d_model,
                     num_heads=num_heads,
                     dff=dff,
                     dropout_rate=dropout_rate)
        for _ in range(num_layers)]

    # Dropout layer applied after positional embedding for regularization
    self.dropout = tf.keras.layers.Dropout(dropout_rate)

  def call(self, x):
    # Input x: token IDs of shape (batch_size, seq_len)
    # Convert token IDs to embeddings and add positional encoding
    x = self.pos_embedding(x)  # Shape: (batch_size, seq_len, d_model)

    # Apply dropout to the embeddings
    x = self.dropout(x)

    # Pass the embeddings through each EncoderLayer in the stack
    for i in range(self.num_layers):
      x = self.enc_layers[i](x)

    # Output: processed sequence of shape (batch_size, seq_len, d_model)
    return x

Test the encoder:

In [None]:
# Instantiate the encoder with the specified configuration:
# - num_layers=4: The encoder will have 4 stacked EncoderLayer blocks.
# - d_model=512: Each token embedding and hidden state will have 512 dimensions.
# - num_heads=8: Multi-head attention will use 8 parallel attention heads.
# - dff=2048: The feed-forward network inside each EncoderLayer will expand to 2048 units before projecting back to d_model.
# - vocab_size=8500: The vocabulary size for the input tokens (number of unique tokens).
sample_encoder = Encoder(
    num_layers=4,      # Number of encoder layers to stack
    d_model=512,       # Dimensionality of embeddings and hidden states
    num_heads=8,       # Number of attention heads in MultiHeadAttention
    dff=2048,          # Hidden layer size in the feed-forward network
    vocab_size=8500    # Size of the input vocabulary
)

# Pass the Portuguese token IDs (pt) through the encoder.
# - pt: Tensor of shape (batch_size, input_seq_len), containing token IDs for each input sentence.
# - training=False: Indicates that the encoder is running in inference mode (no dropout applied).
# The encoder will:
#   1. Embed the input tokens and add positional encoding.
#   2. Apply dropout (skipped if training=False).
#   3. Pass the embeddings through 4 stacked EncoderLayer blocks.
#   4. Each EncoderLayer applies global self-attention and a feed-forward network.
# The output is a tensor of shape (batch_size, input_seq_len, d_model), representing the encoded input sequence.
sample_encoder_output = sample_encoder(pt, training=False)

# Print the shape of the input token IDs tensor.
print(pt.shape)  # Shape: (batch_size, input_seq_len)

# Print the shape of the encoder output tensor.
# The output shape should be (batch_size, input_seq_len, d_model), where:
#   - batch_size: Number of input sequences in the batch
#   - input_seq_len: Length of each input sequence (number of tokens)
#   - d_model: Dimensionality of the encoder output vectors (512)
print(sample_encoder_output.shape)  # Shape `(batch_size, input_seq_len, d_model)`.

### The decoder layer 🧩✨

The decoder's stack is slightly more complex, with each `DecoderLayer` containing a `CausalSelfAttention` ⏩🤖, a `CrossAttention` 🤝✨, and a `FeedForward` ⚡🧠 layer:

<table>
<tr>
  <th colspan=1>The decoder layer</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/DecoderLayer.png"/>
  </td>
</tr>
</table>

In [None]:
class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self,
               *,
               d_model,
               num_heads,
               dff,
               dropout_rate=0.1):
    super(DecoderLayer, self).__init__()

    # CausalSelfAttention: allows each token in the output sequence to attend only to previous tokens (not future ones).
    # - num_heads: Number of attention heads.
    # - key_dim: Dimensionality of each attention head (usually matches d_model).
    # - dropout: Dropout rate for regularization.
    self.causal_self_attention = CausalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    # CrossAttention: allows each token in the output sequence to attend to all tokens in the encoder output (context).
    # - num_heads: Number of attention heads.
    # - key_dim: Dimensionality of each attention head (usually matches d_model).
    # - dropout: Dropout rate for regularization.
    self.cross_attention = CrossAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    # FeedForward: point-wise feed-forward network applied to each position in the sequence.
    # - d_model: Input/output dimensionality (matches embedding size).
    # - dff: Hidden layer size inside the feed-forward network.
    self.ffn = FeedForward(d_model, dff)

  def call(self, x, context):
    # Apply causal self-attention to the input sequence (decoder input).
    # This enforces autoregressive decoding by masking future tokens.
    x = self.causal_self_attention(x=x)

    # Apply cross-attention, allowing the decoder to attend to the encoder output (context).
    x = self.cross_attention(x=x, context=context)

    # Cache the last attention scores from the cross-attention layer for visualization or analysis.
    self.last_attn_scores = self.cross_attention.last_attn_scores

    # Apply the feed-forward network to each position in the sequence.
    # The output shape matches the input: (batch_size, seq_len, d_model).
    x = self.ffn(x)
    return x

Test the decoder layer:

In [None]:
# Instantiate a DecoderLayer with the specified configuration:
# - d_model=512: Dimensionality of the input and output vectors (embedding size).
# - num_heads=8: Number of attention heads in the multi-head attention mechanism.
# - dff=2048: Hidden layer size inside the feed-forward network.
sample_decoder_layer = DecoderLayer(d_model=512, num_heads=8, dff=2048)

# Pass the English token embeddings (en_emb) and Portuguese token embeddings (pt_emb) to the decoder layer.
# - x=en_emb: The input to the decoder (target sequence embeddings with positional encoding).
# - context=pt_emb: The encoder output (source sequence embeddings with positional encoding).
# The DecoderLayer applies:
#   1. Causal self-attention: Each token in the target sequence attends only to previous tokens (autoregressive).
#   2. Cross-attention: Each token in the target sequence attends to all tokens in the source sequence (encoder output).
#   3. Feed-forward network: Point-wise transformation of each token embedding.
# The output shape matches the input: (batch_size, seq_len, d_model)
sample_decoder_layer_output = sample_decoder_layer(
    x=en_emb, context=pt_emb)

# Print the shape of the English token embeddings (input to the decoder).
print(en_emb.shape)  # Shape: (batch_size, target_seq_len, d_model)

# Print the shape of the Portuguese token embeddings (encoder output/context).
print(pt_emb.shape)  # Shape: (batch_size, source_seq_len, d_model)

# Print the shape of the decoder layer output.
# The output shape should be (batch_size, target_seq_len, d_model), where:
#   - batch_size: Number of sequences in the batch
#   - target_seq_len: Length of each target sequence (number of tokens)
#   - d_model: Dimensionality of the decoder output vectors (512)
print(sample_decoder_layer_output.shape)  # `(batch_size, seq_len, d_model)`

### The decoder 🧩✨

Similar to the `Encoder`, the `Decoder` consists of a `PositionalEmbedding`, and a stack of `DecoderLayer`s:

<table>
<tr>
  <th colspan=1>The embedding and positional encoding layer</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/Decoder.png"/>
  </td>
</tr>
</table>


Define the decoder by extending `tf.keras.layers.Layer`:

In [None]:
class Decoder(tf.keras.layers.Layer):
  def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
               dropout_rate=0.1):
    super(Decoder, self).__init__()

    # Store model hyperparameters for reference
    self.d_model = d_model
    self.num_layers = num_layers

    # PositionalEmbedding: Converts token IDs to embeddings and adds positional encoding
    # - vocab_size: Number of unique tokens in the vocabulary
    # - d_model: Dimensionality of the embedding vectors
    self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                             d_model=d_model)
    # Dropout layer for regularization after embedding
    self.dropout = tf.keras.layers.Dropout(dropout_rate)
    # Stack of DecoderLayer instances
    # Each DecoderLayer contains:
    #   - CausalSelfAttention: lets each token attend only to previous tokens in the sequence
    #   - CrossAttention: lets each token attend to all tokens in the encoder output
    #   - FeedForward: point-wise feed-forward network
    self.dec_layers = [
        DecoderLayer(d_model=d_model, num_heads=num_heads,
                     dff=dff, dropout_rate=dropout_rate)
        for _ in range(num_layers)]

    # Will store the attention scores from the last DecoderLayer for visualization/analysis
    self.last_attn_scores = None

  def call(self, x, context):
    # x: token IDs of shape (batch_size, target_seq_len)
    # context: encoder output, shape (batch_size, input_seq_len, d_model)

    # Convert token IDs to embeddings and add positional encoding
    x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

    # Apply dropout to the embeddings
    x = self.dropout(x)

    # Pass the embeddings through each DecoderLayer in the stack
    # Each layer applies causal self-attention, cross-attention, and feed-forward network
    for i in range(self.num_layers):
      x  = self.dec_layers[i](x, context)

    # Cache the last attention scores from the final DecoderLayer for visualization/analysis
    self.last_attn_scores = self.dec_layers[-1].last_attn_scores

    # Output: processed sequence of shape (batch_size, target_seq_len, d_model)
    return x

Test the decoder:

In [None]:
# Instantiate the decoder with the specified configuration:
# - num_layers=4: The decoder will have 4 stacked DecoderLayer blocks.
# - d_model=512: Each token embedding and hidden state will have 512 dimensions.
# - num_heads=8: Multi-head attention will use 8 parallel attention heads.
# - dff=2048: The feed-forward network inside each DecoderLayer will expand to 2048 units before projecting back to d_model.
# - vocab_size=8000: The vocabulary size for the target language tokens.
sample_decoder = Decoder(
    num_layers=4,      # Number of decoder layers to stack
    d_model=512,       # Dimensionality of embeddings and hidden states
    num_heads=8,       # Number of attention heads in MultiHeadAttention
    dff=2048,          # Hidden layer size in the feed-forward network
    vocab_size=8000    # Size of the target vocabulary
)

# Pass the English token IDs (en) and Portuguese embeddings (pt_emb) to the decoder.
# - x=en: Tensor of shape (batch_size, target_seq_len), containing token IDs for each target sentence.
# - context=pt_emb: Tensor of shape (batch_size, source_seq_len, d_model), containing encoder output embeddings.
# The decoder will:
#   1. Embed the target tokens and add positional encoding.
#   2. Apply dropout (if training).
#   3. Pass the embeddings through 4 stacked DecoderLayer blocks.
#   4. Each DecoderLayer applies causal self-attention, cross-attention, and a feed-forward network.
# The output is a tensor of shape (batch_size, target_seq_len, d_model), representing the decoded sequence.
output = sample_decoder(
    x=en,         # Target token IDs
    context=pt_emb  # Encoder output embeddings
)

# Print the shapes of the input and output tensors for verification.
print(en.shape)        # Shape: (batch_size, target_seq_len) - input token IDs
print(pt_emb.shape)    # Shape: (batch_size, source_seq_len, d_model) - encoder output
print(output.shape)    # Shape: (batch_size, target_seq_len, d_model) - decoder output

In [None]:
# The attention scores from the last DecoderLayer in the sample_decoder.
# This tensor contains the cross-attention weights, showing how much each target token (in the output sequence)
# attends to each input token (in the encoder output) for each attention head and batch.
# The shape of the tensor is (batch_size, num_heads, target_seq_len, input_seq_len):
#   - batch_size: Number of sequences in the batch.
#   - num_heads: Number of attention heads in the decoder.
#   - target_seq_len: Length of the output (target) sequence.
#   - input_seq_len: Length of the input (source) sequence.
sample_decoder.last_attn_scores.shape  # (batch, heads, target_seq, input_seq)

Having created the Transformer encoder and decoder, it's time to build the Transformer model and train it. 🚀🤖✨

## The Transformer 🚀🤖✨

You now have `Encoder` 🧩 and `Decoder` 🧩. To complete the `Transformer` model 🚀🤖✨, you need to put them together and add a final linear (`Dense`) layer 🧠 which converts the resulting vector at each location into output token probabilities 🔢.

The output of the decoder is the input to this final linear layer.

<table>
<tr>
  <th colspan=1>The transformer</th>
<tr>
<tr>
  <td>
   <img src="https://www.tensorflow.org/images/tutorials/transformer/transformer.png"/>
  </td>
</tr>
</table>

A `Transformer` with one layer in both the `Encoder` and `Decoder` looks almost exactly like the model from the [RNN+attention tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention). 🤖✨ A multi-layer Transformer has more layers, but is fundamentally doing the same thing. 🧩🔄

<table>
<tr>
  <th colspan=1>A 1-layer transformer</th>
  <th colspan=1>A 4-layer transformer</th>
</tr>
<tr>
  <td>
   <img width=400 src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-compact.png"/>
  </td>
  <td rowspan=3>
   <img width=330 src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-4layer-compact.png"/>
  </td>
</tr>
<tr>
  <th colspan=1>The RNN+Attention model</th>
</tr>
<tr>
  <td>
   <img width=400 src="https://www.tensorflow.org/images/tutorials/transformer/RNN+attention-compact.png"/>
  </td>
</tr>
</table>

Create the `Transformer` by extending `tf.keras.Model` 🤖✨:

> Note: The [original paper](https://arxiv.org/pdf/1706.03762.pdf), section 3.4, shares the weight matrix between the embedding layer and the final linear layer. To keep things simple, this tutorial uses two separate weight matrices.

In [None]:
class Transformer(tf.keras.Model):
  def __init__(self, *, num_layers, d_model, num_heads, dff,
               input_vocab_size, target_vocab_size, dropout_rate=0.1):
    super().__init__()
    # Encoder: Processes the input sequence and produces a context representation.
    # - num_layers: Number of stacked EncoderLayer blocks.
    # - d_model: Dimensionality of embeddings and hidden states.
    # - num_heads: Number of attention heads in MultiHeadAttention.
    # - dff: Hidden layer size in the feed-forward network.
    # - vocab_size: Size of the input vocabulary.
    # - dropout_rate: Dropout rate for regularization.
    self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           vocab_size=input_vocab_size,
                           dropout_rate=dropout_rate)

    # Decoder: Processes the target sequence and attends to the encoder output.
    # - num_layers: Number of stacked DecoderLayer blocks.
    # - d_model: Dimensionality of embeddings and hidden states.
    # - num_heads: Number of attention heads in MultiHeadAttention.
    # - dff: Hidden layer size in the feed-forward network.
    # - vocab_size: Size of the target vocabulary.
    # - dropout_rate: Dropout rate for regularization.
    self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           vocab_size=target_vocab_size,
                           dropout_rate=dropout_rate)

    # Final linear layer: Projects decoder outputs to logits for each target token.
    # - target_vocab_size: Number of possible output tokens.
    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, inputs):
    # The call method defines the forward pass of the model.
    # Inputs:
    #   - inputs: Tuple (context, x)
    #     - context: Input token IDs (source sequence), shape (batch_size, context_len)
    #     - x: Target token IDs (target sequence), shape (batch_size, target_len)
    context, x  = inputs

    # Pass the input tokens through the encoder to get the context representation.
    # Output shape: (batch_size, context_len, d_model)
    context = self.encoder(context)

    # Pass the target tokens and context through the decoder.
    # Output shape: (batch_size, target_len, d_model)
    x = self.decoder(x, context)

    # Project the decoder output to logits for each token in the target vocabulary.
    # Output shape: (batch_size, target_len, target_vocab_size)
    logits = self.final_layer(x)

    try:
      # Remove the Keras mask from logits to avoid affecting loss/metrics calculations.
      # This is a workaround for a known issue (b/250038731).
      del logits._keras_mask
    except AttributeError:
      pass

    # Return the logits (predicted token probabilities for each position).
    return logits

### Hyperparameters ⚙️✨

To keep this example small and relatively fast, the number of layers (`num_layers`), the dimensionality of the embeddings (`d_model`), and the internal dimensionality of the `FeedForward` layer (`dff`) have been reduced. ⚡✨

The base model described in the original Transformer paper used `num_layers=6`, `d_model=512`, and `dff=2048`. 📄🤖

The number of self-attention heads remains the same (`num_heads=8`). 🧠🔄


In [None]:
# Set the hyperparameters for the Transformer model.
# These control the size and complexity of the model.

num_layers = 4        # Number of encoder/decoder layers to stack in the Transformer.
d_model = 128         # Dimensionality of the embedding vectors and hidden states.
dff = 512             # Dimensionality of the feed-forward network's hidden layer.
num_heads = 8         # Number of attention heads in the MultiHeadAttention layers.
dropout_rate = 0.1    # Dropout rate for regularization to prevent overfitting.

### Try it out 🚀🤖✨

Instantiate the `Transformer` model:

In [None]:
# Instantiate the Transformer model with the specified hyperparameters and vocabulary sizes.
# - num_layers: Number of encoder and decoder layers to stack in the Transformer.
# - d_model: Dimensionality of the embedding vectors and hidden states.
# - num_heads: Number of attention heads in the MultiHeadAttention layers.
# - dff: Dimensionality of the feed-forward network's hidden layer.
# - input_vocab_size: Size of the input vocabulary (Portuguese), obtained from the tokenizer.
# - target_vocab_size: Size of the target vocabulary (English), obtained from the tokenizer.
# - dropout_rate: Dropout rate for regularization to prevent overfitting.
transformer = Transformer(
    num_layers=num_layers,                          # Number of layers in encoder/decoder
    d_model=d_model,                                # Embedding and hidden state size
    num_heads=num_heads,                            # Number of attention heads
    dff=dff,                                        # Feed-forward network hidden size
    input_vocab_size=tokenizers.pt.get_vocab_size().numpy(),   # Portuguese vocab size
    target_vocab_size=tokenizers.en.get_vocab_size().numpy(),  # English vocab size
    dropout_rate=dropout_rate                       # Dropout rate for regularization
)

Test it:

In [None]:
# Pass the Portuguese and English token IDs through the Transformer model.
# - pt: Tensor of shape (batch_size, input_seq_len), containing token IDs for the source (Portuguese) sentences.
# - en: Tensor of shape (batch_size, target_seq_len), containing token IDs for the target (English) sentences.
# The transformer model will:
#   1. Encode the Portuguese input sequence using the encoder stack.
#   2. Decode the English target sequence using the decoder stack, attending to the encoder output.
#   3. Project the decoder output to logits for each token in the target vocabulary.
output = transformer((pt, en))

# Print the shape of the English token IDs tensor (target sequence).
# This shows the batch size and target sequence length.
print(en.shape)

# Print the shape of the Portuguese token IDs tensor (source sequence).
# This shows the batch size and input sequence length.
print(pt.shape)

# Print the shape of the output tensor from the Transformer model.
# The output shape is (batch_size, target_seq_len, target_vocab_size), representing the predicted logits for each token position in the target sequence.
print(output.shape)

In [None]:
# Retrieve the cross-attention scores from the last DecoderLayer in the Transformer.
# - transformer.decoder.dec_layers[-1]: Accesses the last DecoderLayer in the decoder stack.
# - last_attn_scores: Stores the cross-attention weights from the most recent forward pass.
#   These scores indicate how much each target token (in the output sequence) attends to each input token (in the encoder output),
#   for every attention head and batch in the input.
attn_scores = transformer.decoder.dec_layers[-1].last_attn_scores

# Print the shape of the attention scores tensor.
# The shape is (batch_size, num_heads, target_seq_len, input_seq_len), where:
#   - batch_size: Number of sequences in the batch.
#   - num_heads: Number of attention heads in the decoder.
#   - target_seq_len: Length of the output (target) sequence.
#   - input_seq_len: Length of the input (source) sequence.
print(attn_scores.shape)  # (batch, heads, target_seq, input_seq)

Print the summary of the model:

In [None]:
# Install the pydot package, which is used for generating and visualizing graphs in Python.
# pydot is often used together with Graphviz for plotting model architectures, especially in TensorFlow and Keras.
# The commented-out line shows how to install Graphviz via apt (for Linux systems), which is required for rendering graphs.
# For most use cases in Jupyter, installing pydot via pip is sufficient if Graphviz is already available on your system.
!pip install pydot

In [None]:
# Import the plot_model utility from TensorFlow Keras.
# This function generates a visual diagram of the model architecture.
from tensorflow.keras.utils import plot_model

# Visualize the architecture of the 'transformer' model.
# - transformer: The model instance to visualize.
# - show_shapes=True: Display the shape of the input and output tensors for each layer.
# - dpi=64: Set the resolution of the generated image (dots per inch).
# - rankdir='LR': Arrange the diagram from left to right (horizontal layout).
plot_model(transformer, show_shapes=True, dpi=64, rankdir='LR')

## Training 🏋️‍♂️✨

It's time to prepare the model and start training it. 🚀

### Set up the optimizer ⚡🧠✨

Use the Adam optimizer with a custom learning rate scheduler according to the formula in the original Transformer [paper](https://arxiv.org/abs/1706.03762).

$$\Large{lrate = d_{model}^{-0.5} * \min(step{\_}num^{-0.5}, step{\_}num \cdot warmup{\_}steps^{-1.5})}$$

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super().__init__()

    # Store the model dimensionality (d_model) as a float tensor.
    # This is used to scale the learning rate according to the formula.
    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    # Store the number of warmup steps.
    # During warmup, the learning rate increases linearly.
    self.warmup_steps = warmup_steps

  def __call__(self, step):
    # Convert the current training step to float for computation.
    step = tf.cast(step, dtype=tf.float32)

    # Compute the first term: inverse square root of the step number.
    # This causes the learning rate to decrease as training progresses.
    arg1 = tf.math.rsqrt(step)

    # Compute the second term: step number multiplied by the inverse 1.5 power of warmup_steps.
    # This causes the learning rate to increase linearly during the warmup period.
    arg2 = step * (self.warmup_steps ** -1.5)

    # The learning rate is scaled by the inverse square root of d_model,
    # and is the minimum of the two terms above.
    # This implements the learning rate schedule from the original Transformer paper:
    # lrate = d_model^{-0.5} * min(step_num^{-0.5}, step_num * warmup_steps^{-1.5})
    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

Instantiate the optimizer (in this example it's `tf.keras.optimizers.Adam`):

In [None]:
# Create a custom learning rate schedule using the formula from the original Transformer paper.
# - d_model: Dimensionality of the model's embeddings and hidden states.
# The learning rate will increase linearly for the first 'warmup_steps' training steps,
# then decay proportionally to the inverse square root of the step number.
learning_rate = CustomSchedule(d_model)

# Instantiate the Adam optimizer for training the Transformer model.
# - learning_rate: Uses the custom schedule defined above.
# - beta_1=0.9: Exponential decay rate for the first moment estimates (default for Adam).
# - beta_2=0.98: Exponential decay rate for the second moment estimates (slightly higher than default).
# - epsilon=1e-9: Small constant to prevent division by zero in the optimizer's update step.
optimizer = tf.keras.optimizers.Adam(
    learning_rate,      # Custom learning rate schedule
    beta_1=0.9,         # Decay rate for first moment estimates
    beta_2=0.98,        # Decay rate for second moment estimates
    epsilon=1e-9        # Numerical stability constant
)

Test the custom learning rate scheduler:

In [None]:
# Plot the learning rate schedule over training steps.
# - learning_rate: Custom learning rate scheduler (instance of CustomSchedule).
# - tf.range(40000, dtype=tf.float32): Generates a tensor of training steps from 0 to 39,999.
# - learning_rate(...): Computes the learning rate for each training step.
# - plt.plot(...): Plots the computed learning rates against training steps.
plt.plot(learning_rate(tf.range(40000, dtype=tf.float32)))

# Set the label for the y-axis to indicate it represents the learning rate.
plt.ylabel('Learning Rate')

# Set the label for the x-axis to indicate it represents the training step number.
plt.xlabel('Train Step')

### Set up the loss and metrics ⚡📏✨

Since the target sequences are padded, it is important to apply a padding mask when calculating the loss. Use the cross-entropy loss function (`tf.keras.losses.SparseCategoricalCrossentropy`):

In [None]:
def masked_loss(label, pred):
  # Create a mask to ignore padding tokens (assumed to be 0)
  mask = label != 0

  # Define the loss function: SparseCategoricalCrossentropy
  # - from_logits=True: pred contains raw logits, not probabilities
  # - reduction='none': compute loss for each element, don't reduce yet
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

  # Compute the per-token loss
  loss = loss_object(label, pred)

  # Cast the mask to the same dtype as loss (float32)
  mask = tf.cast(mask, dtype=loss.dtype)

  # Apply the mask: zero out loss for padding tokens
  loss *= mask

  # Compute the mean loss over non-padding tokens
  loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
  return loss


def masked_accuracy(label, pred):
  # Get the predicted token IDs by taking argmax over the last axis (vocab)
  pred = tf.argmax(pred, axis=2)

  # Cast label to the same dtype as pred for comparison
  label = tf.cast(label, pred.dtype)

  # Compare predicted tokens to true labels
  match = label == pred

  # Create a mask to ignore padding tokens (assumed to be 0)
  mask = label != 0

  # Only count matches for non-padding tokens
  match = match & mask

  # Cast match and mask to float for averaging
  match = tf.cast(match, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)

  # Compute accuracy as the mean over non-padding tokens
  return tf.reduce_sum(match) / tf.reduce_sum(mask)

### Train the model 🏋️‍♂️🤖✨

With all the components ready, configure the training procedure using `model.compile`, and then run it with `model.fit`: 🏋️‍♂️🤖✨

Note: This takes about an hour to train in Colab. ⏳💻

In [None]:
# Compile the Transformer model for training.
# - loss: masked_loss function, which computes cross-entropy loss while ignoring padding tokens.
# - optimizer: Adam optimizer with a custom learning rate schedule (optimizer variable).
# - metrics: masked_accuracy function, which computes accuracy while ignoring padding tokens.
# This prepares the model for training with model.fit, ensuring that loss and metrics are calculated correctly for padded sequences.
transformer.compile(
    loss=masked_loss,          # Custom loss function that ignores padding tokens
    optimizer=optimizer,       # Adam optimizer with custom learning rate schedule
    metrics=[masked_accuracy]  # Custom accuracy metric that ignores padding tokens
)

In [None]:
# Train the Transformer model using the training and validation datasets.
# - transformer.fit: Starts the training process for the model.
# - train_batches: The training dataset, batched and prefetched for efficiency.
# - epochs=20: Train the model for 20 complete passes through the training data.
# - validation_data=val_batches: Use the validation dataset to evaluate the model after each epoch.
# During training:
#   1. The model receives batches of input and target sequences from train_batches.
#   2. For each batch, it computes predictions, calculates the masked loss and accuracy, and updates weights using the optimizer.
#   3. After each epoch, the model evaluates its performance on val_batches to monitor generalization and prevent overfitting.
# This process continues for the specified number of epochs, saving training history and metrics.
transformer.fit(
    train_batches,           # Training data: batched (source, target) pairs
    steps_per_epoch=20,      # Number of steps to 20 as it's heavy computation and will take a lot of time.
    epochs=1,               # Number of training epochs, 1 for elaboration purposes, otherwise make it 20.
    validation_data=val_batches  # Validation data for evaluation after each epoch
)

## Run inference 🚀🤖✨

You can now test the model by performing a translation. The following steps are used for inference: 🌍🤖✨

* Encode the input sentence using the Portuguese tokenizer (`tokenizers.pt`). This is the encoder input. 🧩🔢
* The decoder input is initialized to the `[START]` token. 🚦
* Calculate the padding masks and the look ahead masks. 🛡️👀
* The `decoder` then outputs the predictions by looking at the `encoder output` and its own output (self-attention). 🔄✨
* Concatenate the predicted token to the decoder input and pass it to the decoder. ➕🔁
* In this approach, the decoder predicts the next token based on the previous tokens it predicted. ⏩🧠

Note: The model is optimized for _efficient training_ and makes a next-token prediction for each token in the output simultaneously. This is redundant during inference, and only the last prediction is used.  This model can be made more efficient for inference if you only calculate the last prediction when running in inference mode (`training=False`). ⚡

Define the `Translator` class by subclassing `tf.Module`: 🏗️🤖

In [None]:
class Translator(tf.Module):
  def __init__(self, tokenizers, transformer):
    # Store the tokenizers and transformer model for use in translation
    self.tokenizers = tokenizers
    self.transformer = transformer

  def __call__(self, sentence, max_length=MAX_TOKENS):
    # The input sentence should be a tf.Tensor (Portuguese string)
    assert isinstance(sentence, tf.Tensor)
    # If the input is a scalar tensor, add a batch dimension
    if len(sentence.shape) == 0:
      sentence = sentence[tf.newaxis]

    # Tokenize the Portuguese sentence and pad to tensor
    sentence = self.tokenizers.pt.tokenize(sentence).to_tensor()
    encoder_input = sentence  # Encoder input for the transformer

    # Prepare the initial decoder input: English [START] token
    start_end = self.tokenizers.en.tokenize([''])[0]
    start = start_end[0][tf.newaxis]  # [START] token as tensor
    end = start_end[1][tf.newaxis]    # [END] token as tensor

    # Use tf.TensorArray for dynamic sequence generation in the loop
    output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
    output_array = output_array.write(0, start)  # Write [START] token at position 0

    # Loop to generate each token in the output sequence
    for i in tf.range(max_length):
      # Stack and transpose output_array to shape (batch, seq_len)
      output = tf.transpose(output_array.stack())
      # Run the transformer to get predictions for the next token
      predictions = self.transformer([encoder_input, output], training=False)
      # Select the logits for the last generated token position
      predictions = predictions[:, -1:, :]  # Shape: (batch_size, 1, vocab_size)
      # Get the predicted token ID (highest probability)
      predicted_id = tf.argmax(predictions, axis=-1)
      # Append the predicted token to the output sequence
      output_array = output_array.write(i+1, predicted_id[0])
      # If the predicted token is [END], stop generation
      if predicted_id == end:
        break

    # Final output sequence: transpose to shape (batch, tokens)
    output = tf.transpose(output_array.stack())
    # Detokenize to get the translated English text
    text = tokenizers.en.detokenize(output)[0]  # Shape: ()
    # Lookup token strings for the output token IDs
    tokens = tokenizers.en.lookup(output)[0]

    # Attention weights: recalculate for the final output sequence
    # (tf.function tracing prevents direct access inside the loop)
    self.transformer([encoder_input, output[:,:-1]], training=False)
    attention_weights = self.transformer.decoder.last_attn_scores

    # Return the translated text, token strings, and attention weights
    return text, tokens, attention_weights

Note: This function uses an unrolled loop, not a dynamic loop. It generates `MAX_TOKENS` on every call. Refer to the [NMT with attention](nmt_with_attention.ipynb) tutorial for an example implementation with a dynamic loop, which can be much more efficient. ⚡🤖✨

Create an instance of this `Translator` class, and try it out a few times: 🌍🤖✨

In [None]:
# Create an instance of the Translator class for inference.
# - tokenizers: Contains the Portuguese and English tokenizers for encoding/decoding text.
# - transformer: The trained Transformer model for translation.
# The Translator class wraps the model and tokenizers, providing a convenient interface
# for translating Portuguese sentences to English using the trained Transformer.
translator = Translator(tokenizers, transformer)

In [None]:
def print_translation(sentence, tokens, ground_truth):
  # Print the input sentence (Portuguese or source language)
  print(f'{"Input:":15s}: {sentence}')

  # Print the predicted translation (English or target language)
  # - tokens: RaggedTensor containing token strings for the predicted output
  # - tokens.numpy(): Converts the RaggedTensor to a NumPy array of bytes
  # - .decode("utf-8"): Decodes the byte string to a regular Python string
  print(f'{"Prediction":15s}: {tokens.numpy().decode("utf-8")}')

  # Print the ground truth translation (reference target sentence)
  print(f'{"Ground truth":15s}: {ground_truth}')

Example 1:

In [None]:
# Define the input sentence in Portuguese and its ground truth English translation.
sentence = 'este é um problema que temos que resolver.'
ground_truth = 'this is a problem we have to solve .'

# Use the Translator instance to translate the Portuguese sentence to English.
# - translator: An instance of the Translator class, which wraps the tokenizers and trained Transformer model.
# - tf.constant(sentence): Converts the input sentence to a TensorFlow tensor, as required by the Translator.
# - translated_text: The translated English sentence as a string.
# - translated_tokens: The tokenized output sequence (tokens) for the translated sentence.
# - attention_weights: The cross-attention weights from the Transformer, showing how each output token attends to the input tokens.
translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))

# Print the input sentence, the predicted translation, and the ground truth translation.
# - print_translation: Utility function to display the input, prediction, and reference translation.
print_translation(sentence, translated_text, ground_truth)

Example 2:

In [None]:
# Define the input sentence in Portuguese and its ground truth English translation.
sentence = 'os meus vizinhos ouviram sobre esta ideia.'
ground_truth = 'and my neighboring homes heard about this idea .'

# Use the Translator instance to translate the Portuguese sentence to English.
# - translator: An instance of the Translator class, which wraps the tokenizers and trained Transformer model.
# - tf.constant(sentence): Converts the input sentence to a TensorFlow tensor, as required by the Translator.
# - translated_text: The translated English sentence as a string.
# - translated_tokens: The tokenized output sequence (tokens) for the translated sentence.
# - attention_weights: The cross-attention weights from the Transformer, showing how each output token attends to the input tokens.
translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))

# Print the input sentence, the predicted translation, and the ground truth translation.
# - print_translation: Utility function to display the input, prediction, and reference translation.
print_translation(sentence, translated_text, ground_truth)

Example 3:

In [None]:
# Define the input sentence in Portuguese and its ground truth English translation.
sentence = 'vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.'
ground_truth = "so i'll just share with you some stories very quickly of some magical things that have happened."

# Use the Translator instance to translate the Portuguese sentence to English.
# - translator: An instance of the Translator class, which wraps the tokenizers and trained Transformer model.
# - tf.constant(sentence): Converts the input sentence to a TensorFlow tensor, as required by the Translator.
# - translated_text: The translated English sentence as a string.
# - translated_tokens: The tokenized output sequence (tokens) for the translated sentence.
# - attention_weights: The cross-attention weights from the Transformer, showing how each output token attends to the input tokens.
translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))

# Print the input sentence, the predicted translation, and the ground truth translation.
# - print_translation: Utility function to display the input, prediction, and reference translation.
print_translation(sentence, translated_text, ground_truth)

## Create attention plots 🎯✨

The `Translator` class you created in the previous section returns a dictionary of attention heatmaps you can use to visualize the internal working of the model. 🎯✨

For example:

In [None]:
# Define the input sentence in Portuguese and its ground truth English translation.
sentence = 'este é o primeiro livro que eu fiz.'
ground_truth = "this is the first book i've ever done."

# Use the Translator instance to translate the Portuguese sentence to English.
# - translator: An instance of the Translator class, which wraps the tokenizers and trained Transformer model.
# - tf.constant(sentence): Converts the input sentence to a TensorFlow tensor, as required by the Translator.
# - translated_text: The translated English sentence as a string.
# - translated_tokens: The tokenized output sequence (tokens) for the translated sentence.
# - attention_weights: The cross-attention weights from the Transformer, showing how each output token attends to the input tokens.
translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))

# Print the input sentence, the predicted translation, and the ground truth translation.
# - print_translation: Utility function to display the input, prediction, and reference translation.
print_translation(sentence, translated_text, ground_truth)

Create a function that plots the attention when a token is generated:

In [None]:
def plot_attention_head(in_tokens, translated_tokens, attention):
  # Remove the <START> token from the translated output tokens for plotting.
  translated_tokens = translated_tokens[1:]

  # Get the current matplotlib axis for plotting.
  ax = plt.gca()
  # Display the attention matrix as a heatmap.
  ax.matshow(attention)
  # Set the x-axis ticks to match the number of input tokens.
  ax.set_xticks(range(len(in_tokens)))
  # Set the y-axis ticks to match the number of translated tokens (excluding <START>).
  ax.set_yticks(range(len(translated_tokens)))

  # Decode the input tokens from bytes to strings for labeling the x-axis.
  labels = [label.decode('utf-8') for label in in_tokens.numpy()]
  ax.set_xticklabels(labels, rotation=90)  # Rotate labels for readability.

  # Decode the translated tokens from bytes to strings for labeling the y-axis.
  labels = [label.decode('utf-8') for label in translated_tokens.numpy()]
  ax.set_yticklabels(labels)

In [None]:
head = 0  # Select which attention head to visualize (0-based index)

# The attention_weights tensor has shape (batch_size, num_heads, seq_len_q, seq_len_k)
# - batch_size: Number of sequences in the batch (usually 1 for inference/visualization)
# - num_heads: Number of attention heads in the decoder
# - seq_len_q: Length of the query sequence (target/output tokens)
# - seq_len_k: Length of the key sequence (source/input tokens)

# Remove the batch dimension since we're visualizing a single example
attention_heads = tf.squeeze(attention_weights, 0)  # Shape: (num_heads, seq_len_q, seq_len_k)

# Select the attention matrix for the desired head
attention = attention_heads[head]  # Shape: (seq_len_q, seq_len_k)

# Print the shape of the selected attention matrix for verification
attention.shape  # Should be (seq_len_q, seq_len_k)

These are the input (Portuguese) tokens:

In [None]:
# Convert the input sentence (Portuguese) to a tensor with batch dimension.
in_tokens = tf.convert_to_tensor([sentence])

# Tokenize the input sentence using the Portuguese tokenizer.
# This converts the string to a sequence of token IDs.
in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()

# Lookup the string representation of each token ID.
# This converts the token IDs back to their corresponding token strings.
in_tokens = tokenizers.pt.lookup(in_tokens)[0]

# Display the list of token strings for the input sentence.
in_tokens

And these are the output (English translation) tokens:

In [None]:
translated_tokens

In [None]:
# Plot the attention heatmap for a single attention head.
# This function visualizes how each output token attends to each input token during translation.
# Arguments:
#   in_tokens: Tensor of input (source language) tokens as strings.
#   translated_tokens: Tensor of output (target language) tokens as strings.
#   attention: 2D attention matrix (shape: [output_tokens, input_tokens]) for a single attention head.

plot_attention_head(in_tokens, translated_tokens, attention)

In [None]:
def plot_attention_weights(sentence, translated_tokens, attention_heads):
  # Convert the input sentence (string) to a tensor with batch dimension
  in_tokens = tf.convert_to_tensor([sentence])
  # Tokenize the input sentence using the Portuguese tokenizer and pad to tensor
  in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()
  # Lookup the string representation of each token ID for display
  in_tokens = tokenizers.pt.lookup(in_tokens)[0]

  # Create a matplotlib figure with a specified size to hold all attention head plots
  fig = plt.figure(figsize=(16, 8))

  # Iterate over each attention head and its index
  for h, head in enumerate(attention_heads):
    # Add a subplot for each attention head (2 rows, 4 columns, index h+1)
    ax = fig.add_subplot(2, 4, h+1)

    # Plot the attention heatmap for this head using the helper function
    plot_attention_head(in_tokens, translated_tokens, head)

    # Label the x-axis with the head number for clarity
    ax.set_xlabel(f'Head {h+1}')

  # Adjust subplot layout to prevent overlap and display the figure
  plt.tight_layout()
  plt.show()

In [None]:
# Plot the attention weights for all attention heads for a given input sentence and its translation.
# Arguments:
#   sentence: The input sentence (string) in the source language (Portuguese).
#   translated_tokens: The output tokens (tensor) from the translation (target language, English).
#   attention_weights[0]: The attention weights for the first example in the batch.
#     - Shape: (num_heads, output_tokens, input_tokens)
#     - Each head's attention matrix shows how each output token attends to each input token.

plot_attention_weights(
    sentence,            # The input sentence to be translated and visualized.
    translated_tokens,   # The tokens of the translated output sentence.
    attention_weights[0] # The attention weights for the first example (all heads).
)

The model can handle unfamiliar words. Neither `'triceratops'` nor `'encyclopédia'` are in the input dataset, and the model attempts to transliterate them even without a shared vocabulary. For example:

In [None]:
# Define a Portuguese input sentence and its English ground truth translation.
sentence = 'Eu li sobre triceratops na enciclopédia.'
ground_truth = 'I read about triceratops in the encyclopedia.'

# Use the Translator instance to translate the Portuguese sentence to English.
# - translator: An instance of the Translator class, which wraps the tokenizers and trained Transformer model.
# - tf.constant(sentence): Converts the input sentence to a TensorFlow tensor, as required by the Translator.
# - translated_text: The translated English sentence as a string.
# - translated_tokens: The tokenized output sequence (tokens) for the translated sentence.
# - attention_weights: The cross-attention weights from the Transformer, showing how each output token attends to the input tokens.
translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))

# Print the input sentence, the predicted translation, and the ground truth translation.
# - print_translation: Utility function to display the input, prediction, and reference translation.
print_translation(sentence, translated_text, ground_truth)

# Visualize the attention weights for all heads for this translation.
# - plot_attention_weights: Plots the attention heatmaps for each attention head.
# - sentence: The input sentence (Portuguese).
# - translated_tokens: The output tokens (English translation).
# - attention_weights[0]: The attention weights for the first example in the batch (all heads).
plot_attention_weights(sentence, translated_tokens, attention_weights[0])

## Export the model 📦✨

You have tested the model and the inference is working. Next, you can export it as a `tf.saved_model`. To learn about saving and loading a model in the SavedModel format, use [this guide](https://www.tensorflow.org/guide/saved_model). 📦✨

Create a class called `ExportTranslator` by subclassing the `tf.Module` subclass with a `tf.function` on the `__call__` method: 🤖📝

In [None]:
class ExportTranslator(tf.Module):
  def __init__(self, translator):
    # Store the Translator instance, which wraps the tokenizers and trained Transformer model.
    self.translator = translator

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def __call__(self, sentence):
    # Run the translation for the input sentence using the Translator.
    # - sentence: A scalar tf.Tensor of dtype string (the input sentence to translate).
    # - max_length: Maximum number of tokens to generate in the output sequence.
    # The Translator returns:
    #   - result: The translated sentence as a string.
    #   - tokens: The tokenized output sequence (tokens) for the translated sentence.
    #   - attention_weights: The cross-attention weights from the Transformer.
    (result,
     tokens,
     attention_weights) = self.translator(sentence, max_length=MAX_TOKENS)

    # Return only the translated sentence (result).
    # This is the output of the exported model for inference.
    return result

In the above `tf.function` only the output sentence is returned. Thanks to the [non-strict execution](https://tensorflow.org/guide/intro_to_graphs) in `tf.function` any unnecessary values are never computed.

Wrap `translator` in the newly created `ExportTranslator`:

In [None]:
# Wrap the existing Translator instance in the ExportTranslator class for exporting as a SavedModel.
# - ExportTranslator: A tf.Module subclass that provides a tf.function for inference.
#   This allows you to export the translation model for serving or deployment.
# - translator: The Translator instance, which contains the tokenizers and trained Transformer model.
# After this assignment, 'translator' refers to the ExportTranslator instance,
# which can be used for inference and exporting the model.
translator = ExportTranslator(translator)

Since the model is decoding the predictions using `tf.argmax` the predictions are deterministic. The original model and one reloaded from its `SavedModel` should give identical predictions:

In [None]:
# Translate a Portuguese sentence to English using the exported Translator model.
# - translator: An instance of ExportTranslator, which wraps the trained Translator and Transformer model.
# - The input is a Portuguese sentence as a string.
# - The output is a TensorFlow string tensor containing the translated English sentence.
# - .numpy(): Converts the output tensor to a NumPy array (or Python string) for display.
translator('este é o primeiro livro que eu fiz.').numpy()

In [None]:
import os

# Directory where the model will be saved
SAVE_DIR = "artifacts"

# Name for the exported model
model_name = "translator"

# Full path to save the model (artifacts/translator)
MODEL_PATH = os.path.join(SAVE_DIR, model_name)

# Export the translator model as a TensorFlow SavedModel.
# - translator: The ExportTranslator instance, which wraps the trained Translator and Transformer model.
# - export_dir: The directory path where the SavedModel will be stored.
# This will create the directory and save all necessary files for loading and serving the model later.
tf.saved_model.save(translator, export_dir=MODEL_PATH)

In [None]:
# Load the exported TensorFlow SavedModel from disk.
# - MODEL_PATH: The directory path where the SavedModel was saved (e.g., "artifacts/translator").
# - tf.saved_model.load: Loads the model and returns a callable object for inference.
#   This object can be used to perform translation using the exported model.
#   The loaded model will have the same behavior as the original ExportTranslator instance.
reloaded = tf.saved_model.load(MODEL_PATH)

In [None]:
# Use the reloaded SavedModel to translate a Portuguese sentence to English.
# - reloaded: This is the loaded TensorFlow SavedModel, which wraps the ExportTranslator module.
# - The input is a Portuguese sentence as a string.
# - The output is a TensorFlow string tensor containing the translated English sentence.
# - .numpy(): Converts the output tensor to a NumPy array (or Python string) for display.
# This demonstrates that the exported and reloaded model produces the same deterministic output as the original model.
reloaded('este é o primeiro livro que eu fiz.').numpy()

## Conclusion 🎉🤖✨

In this tutorial you learned about:

* The Transformers and their significance in machine learning 🚀
* Attention, self-attention and multi-head attention 🎯
* Positional encoding with embeddings 🧩
* The encoder-decoder architecture of the original Transformer 🏗️
* Masking in self-attention 🛡️
* How to put it all together to translate text 🌍

The downsides of this architecture are:

- For a time-series, the output for a time-step is calculated from the *entire history* instead of only the inputs and current hidden-state. This _may_ be less efficient. ⏳
- If the input has a temporal/spatial relationship, like text or images, some positional encoding must be added or the model will effectively see a bag of words. 🗂️

If you want to practice, there are many things you could try with it. For example:

* Use a different dataset to train the Transformer. 📚
* Create the "Base Transformer" or "Transformer XL" configurations from the original paper by changing the hyperparameters. 🛠️
* Use the layers defined here to create an implementation of [BERT](https://arxiv.org/abs/1810.04805) 🐻
* Use Beam search to get better predictions. 🔎

There are a wide variety of Transformer-based models, many of which improve upon the 2017 version of the original Transformer with encoder-decoder, encoder-only and decoder-only architectures.

Some of these models are covered in the following research publications:

* ["Efficient Transformers: a survey"](https://arxiv.org/abs/2009.06732) (Tay et al., 2022) ⚡
* ["Formal algorithms for Transformers"](https://arxiv.org/abs/2207.09238) (Phuong and Hutter, 2022). 📐
* [T5 ("Exploring the limits of transfer learning with a unified text-to-text Transformer")](https://arxiv.org/abs/1910.10683) (Raffel et al., 2019) 🔤

You can learn more about other models in the following Google blog posts:

* [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html). 🌴
* [LaMDA](https://ai.googleblog.com/2022/01/lamda-towards-safe-grounded-and-high.html) 🗣️
* [MUM](https://blog.google/products/search/introducing-mum/) 🤹
* [Reformer](https://ai.googleblog.com/2020/01/reformer-efficient-transformer.html) 🔁
* [BERT](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html) 🐻

If you're interested in studying how attention-based models have been applied in tasks outside of natural language processing, check out the following resources:

- Vision Transformer (ViT): [Transformers for image recognition at scale](https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html) 🖼️
- [Multi-task multitrack music transcription (MT3)](https://magenta.tensorflow.org/transcription-with-transformers) with a Transformer 🎵
- [Code generation with AlphaCode](https://www.deepmind.com/blog/competitive-programming-with-alphacode) 💻
- [Reinforcement learning with multi-game decision Transformers](https://ai.googleblog.com/2022/07/training-generalist-agents-with-multi.html) 🎮
- [Protein structure prediction with AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) 🧬
- [OptFormer: Towards universal hyperparameter optimization with Transformers](http://ai.googleblog.com/2022/08/optformer-towards-universal.html) ⚙️

