In [None]:
import pathlib
import re
import functools

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
import tensorflow as tf

import tensorflow_text as text
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab

Data Handling

In [None]:
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en',
                               with_info=True,
                               as_supervised=True)

train_ex, valid_ex = examples['train'], examples['validation']

Data Insights

In [None]:
for pt_examples, en_examples in train_ex.batch(3).take(1):
  print('Portuguese:')
  for pt in pt_examples.numpy():
    print(pt.decode('utf-8'))
  print()
  print('English:')
  for en in en_examples.numpy():
    print(en.decode('utf-8'))

Tokenizer

In [None]:
train_en = train_ex.map(lambda pt, en: en)
train_pt = train_ex.map(lambda pt, en: pt)

In [None]:
bert_tokenizer_params=dict(lower_case=True)
reserved_tokens=["[PAD]", "[UNK]", "[START]", "[END]"]

bert_vocab_args = dict(
    vocab_size = 8000,
    reserved_tokens=reserved_tokens,
    bert_tokenizer_params=bert_tokenizer_params,
    learn_params={},
)

In [None]:
pt_vocab = bert_vocab.bert_vocab_from_dataset(
    train_pt.batch(1000).prefetch(2),
    **bert_vocab_args
)

en_vocab = bert_vocab.bert_vocab_from_dataset(
    train_en.batch(1000).prefetch(2),
    **bert_vocab_args
)

In [None]:
def write_vocab_file(filepath, vocab):
  with open(filepath, 'w') as f:
    for token in vocab:
      print(token, file=f)


write_vocab_file('pt_vocab.txt', pt_vocab)
write_vocab_file('en_vocab.txt', en_vocab)

In [None]:
START = tf.argmax(tf.constant(reserved_tokens) == "[START]")
END = tf.argmax(tf.constant(reserved_tokens) == "[END]")

def adding_start_end(ragged):
  count = ragged.bounding_shape()[0]
  starts = tf.fill([count,1], START)
  ends = tf.fill([count,1], END)
  return tf.concat([starts, ragged, ends], axis=1)

def cleanup_text(reserved_tokens, token_txt):
  bad_tokens = [re.escape(tok) for tok in reserved_tokens if tok != "[UNK]"]
  bad_token_re = "|".join(bad_tokens)

  bad_cells = tf.strings.regex_full_match(token_txt, bad_token_re)
  result = tf.ragged.boolean_mask(token_txt, ~bad_cells)
  result = tf.strings.reduce_join(result, separator=' ', axis=-1)

  return result


In [None]:
class CustomTokenizer(tf.Module):
  def __init__(self, reserved_tokens, vocab_path):
    self.tokenizer = text.BertTokenizer(vocab_path, lower_case=True)
    self._reserved_tokens = reserved_tokens
    self._vocab_path = tf.saved_model.Asset(vocab_path)

    vocab = pathlib.Path(vocab_path).read_text().splitlines()
    self.vocab = tf.Variable(vocab)

  @tf.function
  def tokenize(self, strings):
    enc = self.tokenizer.tokenize(strings)
    enc = enc.merge_dims(-2, -1)
    enc = adding_start_end(enc)
    return enc

  @tf.function
  def detokenize(self, tokenized):
    words = self.tokenizer.detokenize(tokenized)
    return cleanup_text(self._reserved_tokens, words)

  @tf.function
  def lookup(self, token_ids):
    return tf.gather(self.vocab, token_ids)

  @tf.function
  def get_vocab_size(self):
    return tf.shape(self.vocab)[0]

  @tf.function
  def get_reserved_tokens(self):
    return tf.constant(self._reserved_tokens)


In [None]:
pt_tokenizer = CustomTokenizer(reserved_tokens, 'pt_vocab.txt')
en_tokenizer = CustomTokenizer(reserved_tokens, 'en_vocab.txt')

In [None]:
for pt_examples, en_examples in train_ex.batch(3).take(1):
  print('Tokenized using en_tokenizer:')
  tokenized = en_tokenizer.tokenize(en_examples)
  print(en_tokenizer.tokenize(en_examples))
  print('Detokenized using en_tokenizer:')
  print(en_tokenizer.detokenize(tokenized))
  print('Text split into tokens:')
  print(en_tokenizer.lookup(tokenized))


Data Pipeline

In [None]:
MAX_TOKENS = 128

def prepare_batch(pt, en):

  pt = pt_tokenizer.tokenize(pt)
  pt = pt[:, :MAX_TOKENS]
  pt = pt.to_tensor()

  en = en_tokenizer.tokenize(en)
  en = en[:, :(MAX_TOKENS+1)]
  en_inputs = en[:, :-1].to_tensor()
  en_labels = en[:, 1:].to_tensor()

  return (pt, en_inputs), en_labels

In [None]:
BUFFER_SIZE = 20000
BATCH_SIZE = 64

In [None]:
def make_batches(ds):
  return (ds.shuffle(BUFFER_SIZE)
            .batch(BATCH_SIZE)
            .map(prepare_batch, tf.data.AUTOTUNE)
            .prefetch(buffer_size=tf.data.AUTOTUNE))

In [None]:
train_batches = make_batches(train_ex)
val_batches = make_batches(valid_ex)

Transformer

Positional Embedding

In [None]:
def positional_encoding(length, depth):
  depth = depth/2

  positions = np.arange(length)[:, np.newaxis]
  depths = np.arange(depth)[np.newaxis, :]/depth

  angle_rates = 1 / (10000**depths)
  angle_rads = positions * angle_rates

  pos_encoding = np.concatenate([np.sin(angle_rads), np.cos(angle_rads)],axis=-1)

  return tf.cast(pos_encoding, dtype=tf.float32)

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):
  def __init__(self, vocab_size, d_model):
    super().__init__()
    self.d_model = d_model
    self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
    self.positional_encoding = positional_encoding(length=2048, depth=d_model)

    def compute_mask(self, *args, **kwargs):
      return self.embedding.compute_mask(*args, **kwargs)

  def call(self, x):
      length = tf.shape(x)[1]
      x = self.embedding(x)
      x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
      x = x + self.positional_encoding[tf.newaxis, :length, :]
      return x

Attention Blocks

- Base Attention Layer

In [None]:
class BaseAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.layernorm = tf.keras.layers.LayerNormalization()
    self.add = tf.keras.layers.Add()



- Cross Attention

In [None]:
class CrossAttention(BaseAttention):
  def call(self, x, context):
    output, scores = self.mha(
        query=x,
        key=context,
        value=context,
        return_attention_scores=True
    )
    self.scores = scores
    x = self.add([x, output])
    x = self.layernorm(x)

    return x

- Global Self Attention

In [None]:
class GlobalSelfAttention(BaseAttention):
  def call(self, x):
    output = self.mha(
        query=x,
        value=x,
        key=x
    )
    x = self.add([x, output])
    x = self.layernorm(x)

    return x

- Causal Self Attention (Masked Multi-head Attention)

In [None]:
class CausalSelfAttention(BaseAttention):
  def call(self, x):
    output = self.mha(
        query=x,
        value=x,
        key=x,
        use_causal_mask=True
    )
    x = self.add([x, output])
    x = self.layernorm(x)

    return x

Feed Forward Network

In [None]:
class FeedForward(tf.keras.layers.Layer):
  def __init__(self, d_model, dff, dropout_rate=0.1):
    super().__init__()
    self.seq = tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),
      tf.keras.layers.Dense(d_model),
      tf.keras.layers.Dropout(dropout_rate)
    ])
    self.add = tf.keras.layers.Add()
    self.layer_norm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    x = self.add([x, self.seq(x)])
    x = self.layer_norm(x)
    return x

Encoder

In [None]:
class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):
    super().__init__()

    self.self_attention = GlobalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate
    )

    self.ffn = FeedForward(d_model, dff)

  def call(self, x):
      x = self.self_attention(x)
      x = self.ffn(x)
      return x

In [None]:
class Encoder(tf.keras.layers.Layer):
  def __init__(self,*, num_layers, d_model, num_heads,
               vocab_size, dff, dropout_rate=0.1):
    super().__init__()
    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_embedding = PositionalEmbedding(vocab_size, d_model)
    self.enc_layers = [
        EncoderLayer(d_model=d_model,
                     num_heads=num_heads,
                     dff=dff,
                     dropout_rate=dropout_rate) for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(dropout_rate)

  def call(self, x):
      x = self.pos_embedding(x)
      x = self.dropout(x)

      for i in range(self.num_layers):
        x = self.enc_layers[i](x)

      return x

Decoder

In [None]:
class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):
    super().__init__()

    self.causal_self_attention = CausalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.cross_attention = CrossAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.ffn = FeedForward(d_model, dff)

  def call(self, x, context):
    x = self.causal_self_attention(x=x)
    x = self.cross_attention(x=x, context=context)

    self.attn_scores = self.cross_attention.scores

    x = self.ffn(x)
    return x


In [None]:
class Decoder(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_layers, num_heads, dff, vocab_size,
               dropout_rate=0.1):
    super().__init__()

    self.d_model = d_model
    self.num_layers = num_layers
    self.pos_embedding = PositionalEmbedding(vocab_size, d_model)
    self.dropout = tf.keras.layers.Dropout(dropout_rate)
    self.dec_layers = [
        DecoderLayer(d_model=d_model,
                     num_heads=num_heads,
                     dff=dff,
                     dropout_rate=dropout_rate)
        for _ in range(num_layers)]
    self.attn_scores = None

  def call(self, x, context):
    x = self.pos_embedding(x)
    x = self.dropout(x)

    for i in range(self.num_layers):
      x = self.dec_layers[i](x, context)
    self.attn_scores = self.dec_layers[-1].attn_scores

    return x

Transformer

In [None]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.1

In [None]:
class Transformer(tf.keras.Model):
  def __init__(self, *, num_layers, d_model, num_heads, dff,
               input_vocab_size, target_vocab_size, dropout_rate=0.1):
    super().__init__()
    self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           vocab_size=input_vocab_size,
                           dropout_rate=dropout_rate)

    self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           vocab_size=target_vocab_size,
                           dropout_rate=dropout_rate)

    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, inputs):
    context, x  = inputs
    context = self.encoder(context)
    x = self.decoder(x, context)
    logits = self.final_layer(x)

    try:
      del logits._keras_mask
    except AttributeError:
      pass

    return logits


Compiling the Transformer

- Custom Learning Rate Scheduler

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=3000):
    super().__init__()
    self.d_model = tf.cast(d_model, tf.float32)
    self.warmup_steps = warmup_steps

  def __call__(self, step):
    step = tf.cast(step, tf.float32)
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [None]:
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)

- Loss and Metrics

In [None]:
def masked_loss(label, pred):
  mask = label != 0
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
  loss = loss_object(label, pred)


  mask = tf.cast(mask, loss.dtype)
  loss *= mask

  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss

In [None]:
class MaskedAccuracy(tf.keras.metrics.Metric):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.total_match = self.add_weight(name='total_match', initializer='zeros')
    self.total_mask = self.add_weight(name='total_mask', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    pred = tf.argmax(y_pred, axis=2)
    label = tf.cast(y_true, pred.dtype)
    match = label == pred
    mask = label != 0
    match = match & mask
    match = tf.cast(match, tf.float32)
    mask = tf.cast(mask, tf.float32)

    self.total_match.assign_add(tf.reduce_sum(match))
    self.total_mask.assign_add(tf.reduce_sum(mask))

  def result(self):
    return self.total_match / self.total_mask

  def reset_state(self):
    self.total_match.assign(0.0)
    self.total_mask.assign(0.0)

In [None]:
transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=pt_tokenizer.get_vocab_size().numpy(),
    target_vocab_size=en_tokenizer.get_vocab_size().numpy(),
    dropout_rate=dropout_rate)

In [None]:
transformer.compile(
    loss=masked_loss,
    optimizer=optimizer,
    metrics=[MaskedAccuracy()])


In [None]:
transformer.fit(train_batches,
                epochs=20,
                validation_data=val_batches)


Translator

In [None]:
class Translator(tf.Module):
  def __init__(self, tokenizers, transformer):
    self.pt_tokenizer = tokenizers[0]
    self.en_tokenizer = tokenizers[1]

  def __call__(self, sentence, max_length=MAX_TOKENS):
    assert isinstance(sentence, tf.Tensor)
    if len(sentence.shape) == 0:
      sentence = sentence[tf.newaxis]

    sentence = self.pt_tokenizer(sentence).to_tensor()
    encoder_input = sentence

    start_end = self.en_tokenizer.tokenize([''])[0]
    start = start_end[0][tf.newaxis]
    end = start_end[1][tf.newaxis]

    output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
    output_array = output_array.write(0, start)

    for i in range(max_length):
      output = tf.transpose(output_array.stack())
      predictions = self.transformer([encoder_input, output], training=False)
      predictions = predictions[:, -1, :]

      predicted_id = tf.argmax(predictions, axis=-1)
      output_array = output_array.write(i+1, predicted_id[0])

      if predicted_id == end:
        break

    output = tf.transpose(output_array.stack())
    text = en_tokenizer.detokenize(output)[0]
    tokens = en_tokenizer.lookup(output)[0]

    self.transformer([encoder_input, output[:,:-1]], training=False)
    attention_weights = self.transformer.decoder.attn_scores

    return text, tokens, attention_weights

In [None]:
translator = Translator([pt_tokenizer, en_tokenizer], transformer)

Translation

In [None]:
sentence = 'a tradução é facilitada usando máquina'

translated_text, translated_tokens, attention_weights = translator(
    tf.constant(sentence))