Let's import all the necessary libraries

In [None]:
import os
import pathlib
import random
import string
import re
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization
import tensorflow.strings as tf_strings
import tensorflow.data as tf_data
from sklearn.model_selection import train_test_split


In [None]:
os.environ["KERAS_BACKEND"] = "tensorflow"

Now let's download an English-to-Spanish Translation dataset.

In [None]:
text_file = keras.utils.get_file(
    fname="spa-eng.zip",
    origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
    extract=True,
)
text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"

Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip


When parsing the data, we treat each line as a pair of sentences: an English sentence and its corresponding Spanish sentence. The English sentence serves as the source sequence, while the Spanish sentence is the target sequence. To prepare the Spanish sentence for training, we add a special token "[start]" at the beginning and "[end]" at the end.


In [None]:
with open(text_file, "r") as f:
    lines = f.readlines()

# Process each line to create text pairs
text_pairs = [(line.split("\t")[0], "[start] " + line.split("\t")[1].strip() + " [end]") for line in lines]
for _ in range(5):
    print(random.choice(text_pairs))

("I'm not bossy.", '[start] No soy mandón. [end]')
('Many people use ATMs to withdraw money.', '[start] Mucha gente usa cajeros automáticos para retirar dinero. [end]')
('Tom became a successful photographer.', '[start] Tom se convirtió en un exitoso fotógrafo. [end]')
('Cows provide us with milk.', '[start] Las vacas nos proveen de leche. [end]')
("It's time to go.", '[start] Es hora de irse. [end]')


This is what our sentence looks like

Let's split the dataset into training and testing part

In [None]:
random.shuffle(text_pairs)

# Define the sizes of the validation and test sets
val_size = int(0.15 * len(text_pairs))
test_size = val_size

# Split the text_pairs into training, validation, and test sets
train_pairs, remaining_pairs = train_test_split(text_pairs, test_size=val_size * 2, random_state=42)
val_pairs, test_pairs = train_test_split(remaining_pairs, test_size=val_size, random_state=42)

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

118964 total pairs
83276 training pairs
17844 validation pairs
17844 test pairs


To vectorize the text data, we use two TextVectorization layers: one for English and one for Spanish. These layers convert the original strings into integer sequences, with each integer representing the index of a word in a vocabulary.

The English layer standardizes the text by stripping punctuation characters and splitting on whitespace. For the Spanish layer, we add the character "¿" to the set of punctuation characters to be stripped.

Note: In a production-grade machine translation model, it's recommended to avoid stripping punctuation characters. Instead, each punctuation character could be turned into its own token by providing a custom split function to the TextVectorization layer.


In [21]:
strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")


In [22]:
vocab_size = 15000
sequence_length = 20
batch_size = 64

In [24]:
def custom_standardization(input_string):
    lowercase = tf.strings.lower(input_string)
    strip_chars = "¿"  # Add the character "¿" to the set of punctuation characters to be stripped
    return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")

eng_vectorization = TextVectorization(max_tokens=vocab_size, output_mode="int", output_sequence_length=sequence_length)
spa_vectorization = TextVectorization(max_tokens=vocab_size, output_mode="int", output_sequence_length=sequence_length + 1, standardize=custom_standardization)
train_eng_texts = [pair[0] for pair in train_pairs]
train_spa_texts = [pair[1] for pair in train_pairs]
eng_vectorization.adapt(train_eng_texts)
spa_vectorization.adapt(train_spa_texts)

For dataset formatting, each training step predicts target words N+1 using the source sentence and target words 0 to N. The training dataset yields tuples of encoder_inputs (vectorized source) and decoder_inputs (target words 0 to N) along with targets (next words in the target sentence).

In [27]:
def format_dataset(eng, spa):
    eng = eng_vectorization(eng)
    spa = spa_vectorization(spa)
    return (
        {
            "encoder_inputs": eng,
            "decoder_inputs": spa[:, :-1],
        },
        spa[:, 1:],
    )


def make_dataset(pairs):
    eng_texts, spa_texts = zip(*pairs)
    eng_texts = list(eng_texts)
    spa_texts = list(spa_texts)
    dataset = tf_data.Dataset.from_tensor_slices((eng_texts, spa_texts))
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(format_dataset)
    return dataset.cache().shuffle(2048).prefetch(16)

train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)

This is the sequence shape (we have batches of 64 pairs, and all sequences are 20 steps long)

In [28]:
for inputs, targets in train_ds.take(1):
    print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
    print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
    print(f"targets.shape: {targets.shape}")

inputs["encoder_inputs"].shape: (64, 20)
inputs["decoder_inputs"].shape: (64, 20)
targets.shape: (64, 20)


In [29]:
print(inputs["decoder_inputs"][0])

tf.Tensor([ 2 16 26  1  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0], shape=(20,), dtype=int64)


In [30]:
print(inputs["encoder_inputs"][0])

tf.Tensor(
[  24   23    5 1314   20    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0], shape=(20,), dtype=int64)


Our sequence-to-sequence Transformer comprises a TransformerEncoder and TransformerDecoder. The TransformerEncoder processes the source sequence to generate a new representation, which is then passed to the TransformerDecoder along with the target sequence up to the current point. The TransformerDecoder predicts the next words in the target sequence, ensuring it only uses information from past tokens via causal masking to prevent future information leakage.

In [31]:
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, mask=None):
        if mask is not None:
            padding_mask = tf.cast(mask[:, None, :], dtype="int32")
        else:
            padding_mask = None

        attention_output = self.attention(
            query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
        )
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "dense_dim": self.dense_dim,
                "num_heads": self.num_heads,
            }
        )
        return config


In [32]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        if mask is None:
            return None
        else:
            return tf.not_equal(inputs, 0)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config


In [33]:
class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(latent_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs)
        if mask is not None:
            padding_mask = tf.cast(mask[:, None, :], dtype="int32")
            padding_mask = tf.minimum(padding_mask, causal_mask)
        else:
            padding_mask = None

        attention_output_1 = self.attention_1(
            query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(out_2 + proj_output)

    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = tf.range(sequence_length)[:, None]
        j = tf.range(sequence_length)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, sequence_length, sequence_length))
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.convert_to_tensor([1, 1])],
            axis=0,
        )
        return tf.tile(mask, mult)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "latent_dim": self.latent_dim,
                "num_heads": self.num_heads,
            }
        )
        return config

In [35]:
embed_dim = 256
latent_dim = 2048
num_heads = 8

transformer_encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
transformer_decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(transformer_encoder_inputs)
transformer_encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(transformer_decoder_inputs)
decoder_state_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, decoder_state_inputs)
x = layers.Dropout(0.5)(x)
transformer_decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)

transformer_encoder = keras.Model(transformer_encoder_inputs, transformer_encoder_outputs)
transformer_decoder = keras.Model([transformer_decoder_inputs, decoder_state_inputs], transformer_decoder_outputs)

decoder_outputs = transformer_decoder([transformer_decoder_inputs, transformer_encoder_outputs])
transformer = keras.Model(
    [transformer_encoder_inputs, transformer_decoder_inputs], decoder_outputs, name="transformer"
)


In [36]:
from tensorflow.python.ops import math_ops as ops


In [37]:
epochs = 3

transformer.summary()
transformer.compile(
    "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)

Model: "transformer"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 encoder_inputs (InputLayer  [(None, None)]               0         []                            
 )                                                                                                
                                                                                                  
 positional_embedding_2 (Po  (None, None, 256)            3845120   ['encoder_inputs[0][0]']      
 sitionalEmbedding)                                                                               
                                                                                                  
 decoder_inputs (InputLayer  [(None, None)]               0         []                            
 )                                                                                      

<keras.src.callbacks.History at 0x7db95bb19090>