# Transformers
In this notebook, we will introduce and implement the individual building blocks of the transformer. Later notebooks will re-use these building blocks for several applications. 

This notebook is rather technical and therefore optional.

In [None]:
import logging
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
import tensorflow as tf

## Positional Encoding

In [None]:
def positional_encoding(length, depth):
    depth = depth/2

    positions = np.arange(length)[:, np.newaxis] 
    depths = np.arange(depth)[np.newaxis, :]/depth

    angle_rates = 1 / (10000**depths)
    angle_rads = positions * angle_rates

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1)

    return tf.cast(pos_encoding, dtype=tf.float32)

## Positional Embedding
Positional Embedding is the combination of a regular embedding layer as we've already seen before in the context of bag-of-words, and the positional encoding, which introduces the concept of the *location* of the word in its context.

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.embedding = tf.keras.layers.Embedding(vocab_size, embed_dim, mask_zero=True)
        self.pos_encoding = positional_encoding(length=2048, depth=embed_dim)

    def compute_mask(self, *args, **kwargs):
        return self.embedding.compute_mask(*args, **kwargs)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positonal_encoding.
        x *= tf.math.sqrt(tf.cast(self.embed_dim, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x

## Attention Blocks
Attention blocks combine a `MultiHeadAttention` layer and a add&norm block that combines the output ofthe attention layer with the pass-through (skip) connection, and norms it. We define a `BaseAttention` class which has these components as attributes. Afterwards, we will derive several sub-types of the `BaseAttention` block which implement the different variants of attention.

In [None]:
class BaseAttention(tf.keras.layers.Layer):

    def __init__(self, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

### Global Self Attention

In [None]:
class GlobalSelfAttention(BaseAttention):

    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x)

        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x

In [None]:
class CausalSelfAttention(BaseAttention):

    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x,
            use_causal_mask = True)

        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x

In [None]:
class CrossAttention(BaseAttention):
    def call(self, x, context):
        attn_output = self.mha(
            query=x,
            key=context,
            value=context)

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x

## Feed Forward Layer

In [None]:
class FeedForward(tf.keras.layers.Layer):
    def __init__(self, embed_dim, ff_dim, dropout_rate=0.1):
        super().__init__()
        self.seq = tf.keras.Sequential([
            tf.keras.layers.Dense(ff_dim, activation='relu'),
            tf.keras.layers.Dense(embed_dim),
            tf.keras.layers.Dropout(dropout_rate)
        ])
        self.add = tf.keras.layers.Add()
        self.layer_norm = tf.keras.layers.LayerNormalization()

    def call(self, x):
        x = self.add([x, self.seq(x)])
        x = self.layer_norm(x)
        return x

## Encoder Layer
The encoder layer consists of a global self-attention block, and a feed forward block.

In [None]:
class EncoderLayer(tf.keras.layers.Layer):

    def __init__(self,*, embed_dim, num_heads, ff_dim, dropout_rate=0.1):
        super().__init__()
        self.self_attention = GlobalSelfAttention(
            num_heads=num_heads,
            key_dim=embed_dim,
            dropout=dropout_rate)

        self.ffn = FeedForward(embed_dim, ff_dim)

    def call(self, x):
        x = self.self_attention(x)
        x = self.ffn(x)
        return x

## Encoder
The encoder (the left side of the transformer diagram in the original paper) consists of a positional embedding and several encoder layers

In [None]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, *, num_layers, embed_dim, num_heads,
                 ff_dim, vocab_size, dropout_rate=0.1):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(
            vocab_size=vocab_size, embed_dim=embed_dim)

        self.enc_layers = [
            EncoderLayer(embed_dim=embed_dim,
                         num_heads=num_heads,
                         ff_dim=ff_dim,
                         dropout_rate=dropout_rate)
            for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x):
        # `x` is token-IDs shape: (batch, seq_len)
        x = self.pos_embedding(x)  # Shape `(batch_size, seq_len, embed_dim)`.

        # Add dropout.
        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x)

        return x  # Shape `(batch_size, seq_len, embed_dim)`.

## Decoder Layer
Now we move to the right side of the diagram. The decoder layer is a bit more complex than the encoder layer. Here, we have two attention blocks followed by a feed-forward layer. The first attention is a masked (or causal) self-attention block, the second one is a cross-attention block, combining the input and the output.

In [None]:
class DecoderLayer(tf.keras.layers.Layer):

    def __init__(self,
                 *,
                 embed_dim,
                 num_heads,
                 ff_dim,
                 dropout_rate=0.1):
        super(DecoderLayer, self).__init__()

        self.causal_self_attention = CausalSelfAttention(
            num_heads=num_heads,
            key_dim=embed_dim,
            dropout=dropout_rate)

        self.cross_attention = CrossAttention(
            num_heads=num_heads,
            key_dim=embed_dim,
            dropout=dropout_rate)

        self.ffn = FeedForward(embed_dim, ff_dim)

    def call(self, x, context):
        x = self.causal_self_attention(x=x)
        x = self.cross_attention(x=x, context=context)

        x = self.ffn(x)  # Shape `(batch_size, seq_len, embed_dim)`.
        return x

## Decoder
Similar to the encoder, a decoder consists of the positional embedding of the output, several decoder layers, and final linear and softmax layers.

In [None]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, *, num_layers, embed_dim, num_heads, ff_dim, vocab_size,
                 dropout_rate=0.1):
        super(Decoder, self).__init__()

        self.embed_dim = embed_dim
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                                 embed_dim=embed_dim)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.dec_layers = [
            DecoderLayer(embed_dim=embed_dim, num_heads=num_heads,
                         ff_dim=ff_dim, dropout_rate=dropout_rate)
            for _ in range(num_layers)]

    def call(self, x, context):
        # `x` is token-IDs shape (batch, target_seq_len)
        x = self.pos_embedding(x)  # (batch_size, target_seq_len, embed_dim)

        x = self.dropout(x)

        for i in range(self.num_layers):
            x  = self.dec_layers[i](x, context)

        # The shape of x is (batch_size, target_seq_len, embed_dim).
        return x

## Transfomer
Now we have all the building blocks to define the full-fledged transformer:

In [None]:
class Transformer(tf.keras.Model):
    def __init__(self, *, num_layers, embed_dim, num_heads, ff_dim,
                 input_vocab_size, target_vocab_size, dropout_rate=0.1):
        super().__init__()
        self.encoder = Encoder(num_layers=num_layers, embed_dim=embed_dim,
                               num_heads=num_heads, ff_dim=ff_dim,
                               vocab_size=input_vocab_size,
                               dropout_rate=dropout_rate)

        self.decoder = Decoder(num_layers=num_layers, embed_dim=embed_dim,
                               num_heads=num_heads, ff_dim=ff_dim,
                               vocab_size=target_vocab_size,
                               dropout_rate=dropout_rate)

        self.final_layer = tf.keras.layers.Dense(target_vocab_size)

    def call(self, inputs):
        # To use a Keras model with `.fit` you must pass all your inputs in the
        # first argument.
        context, x  = inputs

        context = self.encoder(context)  # (batch_size, context_len, embed_dim)

        x = self.decoder(x, context)  # (batch_size, target_len, embed_dim)

        # Final linear layer output.
        logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

        try:
            # Drop the keras mask, so it doesn't scale the losses/metrics.
            # b/250038731
            del logits._keras_mask
        except AttributeError:
            pass

        # Return the final output and the attention weights.
        return logits