# Attention: overview and exploration

Brief: implement and explore attention layers to gain understanding and intuition

Author: Stephen Menary, sbmenary@gmail.com

Date: 13/12/2022

Last edit: 21/12/2022

---

## Intro

An attention layer takes $N$ feature vectors as input and creates $N$ feature vectors as output. For each token, we look at the features of a reference sequence. Using a simple linear transformation of the token and reference token features, we calculate a set of _attention weights_ that determine how much information the token should be pulling from the reference. The attention weights are normalised to unity using a `softmax` operation to maintain a consistent scale.

Self-attention occurs when the query and reference are the same sequence. Our model effectively allows the tokens to pass information to one another with strengths determined by the attention matrix. This makes the self-attention layer a type of graph network with a continuous adjacency matrix defined by the attention weights. The strength of information diffusion is small when the attention mechanism deems the tokens to be independent. We could view other attention layers as graph networks as well, where the graph contains both the query and reference sequences and the adjacency matrix contains many $0$s so that information may only propagate in certain directions.


---

## Attention

Consider the transformation $x^{i,l} \rightarrow x^{i,l+1}$, where $x^{i,l}$ is the feature vector for token $i$ in layer $l$. We "attend" to a _set of_ vectors $\{y^j \forall j\}$ (the feature vectors of the reference sequence). We do this by performing a linear transformation $Q^{i}_\alpha = \sum_\beta A_{\alpha\beta} x^{i,l}_\beta$, and a separate linear transformation $K^{j}_\alpha = \sum_\beta B_{\alpha\beta} y^{j}_\beta$ for every $j$. The strength of the connection is calculated as the dot product $s_{i,j} = \sum_\alpha Q^{i}_\alpha K^{j}_\alpha= \sum_{\alpha,\beta,\gamma} y^{j}_\gamma B^T_{\gamma\alpha} A_{\alpha\beta} x^{i,l}_\beta$.

The attention weights are obtained by enforcing positive definiteness and normalisation to unity using a `Softmax` function, $w_{ij} ~=~ \mathrm{Softmax}_j\left(s_{ij}\right)$ which allows us to interpret $w_{ij}$ as a probability distribution over $j$. At this stage, all we have done is calculate a matrix connecting element $i$ of the query sequence with element $j$ of the reference matrix. This is turned into an actual propagation of node features by (i) transforming the reference sequence into a new sequence of feature vectors called _values_, then (ii) aggregating these value vectors according to the attention weights. 

Using the same method as for keys, the value vectors are calculatd as $V^{j}_\alpha = \sum_\beta C_{\alpha\beta} y^{j}_\beta$. The updated features are then calculated as $x_\mu^{i,l+1} = \sum_{j} w_{ij} V^{j}_\mu$. Compacting the indices over feature vectors, we may summarise the entire layer as
\begin{equation}
    \mathrm{Attention}\left(Q,K,V\right) ~=~ \mathrm{Softmax}\left(QK^T\right)~V
\end{equation}
Why do we want to write the attention layer in this way? The answer is that we may construct $Q$, $K$ and $V$ in whatever way we want. Here we used simple `matmul` transformations using the matrices $A$, $B$ and $C$. But we didn't have to! For example, we can use dense layers that are the same but with additive bias terms, or CNN layers, and we may share layers when constructing $Q$, $K$ and $V$.

Actually, we have one more detail to include. We have implemented "dot-product attention", because the attention weights $w_{ij}$ are calculated as the dot-product of the transformed feature vectors of length $d_k$. Other forms of attention mechanism are possible, e.g. additive attention. Now, if $d_k$ is very large, then the function $w_{ij}\sim\exp\left[s_{ij}\right]=\exp\left[\sum_\alpha Q^{i}_\alpha K^{j}_\alpha\right]$. If we model the intial query-key pairs as normally distributed with mean $0$ and variance $1$, then the exponent is a sum of $d_k$ normal distributions, which itself is a normal distribution with variance $d_k$. We now have $w_{ij} \sim \exp\left[\sim \mathcal{N}(0;d_k)\right]$. As $d_k$ gets larger, the variance gets larger, and so we begin to see large values in the exponent. However, the _gradient_ of the `Softmax` function becomes very small for large exponents. This makes learning very slow. To counteract this effect, the AIAYN paper scales the dot-product by $1/\sqrt{d_k}$ so the scale of arguments is roughly independent of $d_k$:
\begin{equation}
    \mathrm{Attention}\left(Q,K,V\right) ~=~ \mathrm{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)~V
\end{equation}
Of course, the $Q$ and $K$ matrices should be initialised such that they have reasonable values themselves.

The other configurable length is the size of the output feature vectors, $d_v$. In principle this can be any length, but we often use the same as the input so they may be linearly combined. This means that all feature vectors will be the same length throughout the model, and we can call this $d_\mathrm{model}$.

Let us now try to implement a custom attention layer. Of course, Keras has it's own attention layers, but for education purposes we will implement our own.


In [1]:
#============================================
#===  Import packages and print versions  ===
#============================================

##  Import Python packages
import sys

##  Import additional packages
import matplotlib as mpl
import numpy      as np
import tensorflow as tf
#import tensorflow_datasets as tfds
import tensorflow_text

##  Import objects from packages
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Add, Concatenate, Dense, Dropout, Embedding, Input, Layer, LayerNormalization, Softmax
from tensorflow.keras.models import Model

##  Print versions for reproducibility
print(f"Python version is {sys.version}".replace("\n", " | "))
print(f"Matplotlib version is {mpl.__version__}")
print(f"Numpy version is {np.__version__}")
print(f"Tensorflow version is {tf.__version__}")
#print(f"Tensorflow-datasets version is {tfds.__version__}")


Python version is 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ]
Matplotlib version is 3.6.2
Numpy version is 1.23.2
Tensorflow version is 2.11.0


In [2]:
#======================================================
#===  Create keras layer for dot-product attention  ===
#======================================================

class LinearDotProductAttention(Layer) :
    def __init__(self, dk, dv, same_transform_QK=False, same_transform_KV=False, use_causal_mask=False):
        """
        Init docstring
        """
        ##  Initialise base class
        super().__init__()
        
        ##  Store dimension information
        self.dk      = dk
        self.dv      = dv
        self.sqrt_dk = np.sqrt(dk)
        self.use_causal_mask = use_causal_mask
        
        ##  Create Softmax layer, which can handle masking whereas tf.nn.softmax cannot
        self.softmax = Softmax()
        
        ##  Create Dense transform of query feature vectors
        self.query_transform = Dense(dk)
        
        ##  Create Dense transform of key feature vectors
        if same_transform_QK : self.key_transform = self.query_transform
        else                 : self.key_transform = Dense(dk)
            
        ##  Create Dense transform of value feature vectors
        if same_transform_KV : 
            if dk != dk :
                raise ValueError(f"Can only use same transform for keys and values if dk (={dk}) \
                                   is equal to dv (={dv})")
            self.value_transform = self.key_transform
        else : 
            self.value_transform = Dense(dv)
                    
    def call(self, inputs, training=False):
        """
        Call docstring
        """
        ##  Resolve inputs
        if len(inputs) != 2 : raise ValueError(f"inputs of length {len(inputs)}, expected 2")
        query_features = inputs[0]  # shape [batch_size, query_sequence_length, query_feature_length]
        ref_features   = inputs[1]  # shape [batch_size, ref_sequence_length  , ref_feature_length  ]
                
        ##  Calculate Q, K, V matrices
        Q = self.query_transform(query_features, training=training)  # shape [batch_size, query_sequence_length, dk]
        K = self.key_transform  (ref_features  , training=training)  # shape [batch_size, ref_sequence_length  , dk]
        V = self.value_transform(ref_features  , training=training)  # shape [batch_size, ref_sequence_length  , dv]
                
        ##  Calculate transpose of K, without modifying the first axis which indexes batch samples
        K_T = tf.transpose(K, perm=[0,2,1])  # shape [batch_size, dk, ref_sequence_length]
                
        ##  Calculate dot-product attention scores
        x = tf.matmul(Q, K_T)             # shape [batch_size, query_sequence_length, ref_sequence_length]
        x = x / self.sqrt_dk              # shape [batch_size, query_sequence_length, ref_sequence_length]
        
        ##  Create a causal mask on-the-fly if needed
        mask = None
        if self.use_causal_mask :
            mask_shape = tf.shape(x)
            mask = self._create_causal_mask(mask_shape)
                
        ##  Calculate attention weights
        x = self.softmax(x, mask=mask, training=training)   # shape [batch_size, query_sequence_length, ref_sequence_length]
                        
        ##  Attend to reference sequence and return updated feature vector of length dv
        x = tf.matmul(x, V)              # shape [batch_size, query_sequence_length, dv] 
                
        return x
    
    def _create_causal_mask(self, mask_shape) :
        """
        Method docstring
        - using trick for creating causal mask from from keras base_dense_attention class method
        - https://github.com/keras-team/keras/blob/e6784e4302c7b8cd116b74a784f4b78d60e83c26/keras/layers/attention/base_dense_attention.py
        Mask is 1 for elements that we want to include, 0 for elements we want to exclude
        Axis -1 is the "reference sequence" index
        Axis -2 is the "query sequence" index
        We want mask to be 1 only when Axis -1 <= Axis -2, the same as Axis -2 >= Axis -1
        With indices [row, col], we have 1 when row >= col, giving a lower triangular matrix
        """
        ones_like_x = tf.ones(shape=mask_shape, dtype=tf.int32)
        query_index = tf.cumsum(ones_like_x, axis=-2)
        ref_index   = tf.cumsum(ones_like_x, axis=-1)
        return tf.greater_equal(query_index, ref_index)
    

In [3]:
#=======================================================================================
#===  Create method for building simple model that performs one attention operation  ===
#=======================================================================================

def build_attention_model(query_length, ref_length=None, dk=10, dv=None, name=None) :
    """
    Method docstring
    """
    ##  If ref_length or dv not provided then set equal to query_length
    if not ref_length : ref_length = query_length
    if not dv         : dv         = query_length
        
    ##  Create model inputs
    query_sequence = Input((None, query_length))
    ref_sequence   = Input((None, ref_length  ))
    
    ##  Perform attention step
    x = LinearDotProductAttention(dk, dv)([query_sequence, ref_sequence])
    
    ##  Create model, but don't compile as we don't yet have a well-defined loss function
    model = Model([query_sequence, ref_sequence], x, name=name)
    
    ##  Return
    return model
    

In [4]:
#==============================================================
#===  Build and print model to check for structural errors  ===
#==============================================================

d_in    = 512
d_model = 128
dk      = d_model / 8
dv      = d_model / 8

model = build_attention_model(d_in, d_in, dk=dk, dv=dv, name="attention_model")

model.summary(expand_nested=True)


Model: "attention_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, None, 512)]  0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, None, 512)]  0           []                               
                                                                                                  
 linear_dot_product_attention (  (None, None, 16)    24624       ['input_1[0][0]',                
 LinearDotProductAttention)                                       'input_2[0][0]']                
                                                                                                  
Total params: 24,624
Trainable params: 24,624
Non-trainable params: 0
______________

---

## Multi-Head Attention

Multi-head attention essentially means that we run several attention operations in parallel. Each head calculates a simple representation for every token, and combines them into a pair-wise attention weight. This weight only knows the context of those two tokens and nothing else. It then uses the attention weights to average the values assigned to each token. This averaging step limits how many different tokens may be attended to, and the simnple representation limits how much information may be contained.

By performing multiple attention steps in parallel, each token may attend to many different locations, using many different representations. Furthermore, by combining the results using a dense layer (rather than simple averaging), we may attend to multi-token combinations.

In summary: multi-head attention allows us to attend to many different locations using different feature transformations, allowing for a more expressive propagation of information.

The recipe is:
- perform $h$ attention steps in parallel
- concatenate the results into a $h \cdot d_v$-dimensional feature vector
- use a Dense layer to project down to a $d_v$-dimensional feature vector


In [5]:
#=====================================================
#===  Create keras layer for multi-head attention  ===
#=====================================================

class MultiHeadAttention(Layer) :
    """
    Class docstring
    """
    def __init__(self, num_heads, d_out, dk_per_head, dv_per_head, use_causal_mask=True, dropout=0.1):
        """
        Init docstring
        """
        ##  Initialise base class
        super().__init__()
        
        ##  Store dimension information
        self.num_heads       = num_heads
        self.d_out           = d_out
        self.dk_per_head     = dk_per_head   
        self.dv_per_head     = dv_per_head
        self.use_causal_mask = use_causal_mask
        
        ##  Create heads
        self.heads = [LinearDotProductAttention(self.dk_per_head, self.dv_per_head, use_causal_mask=use_causal_mask) 
                      for hi in range(num_heads)]
                    
        ##  Create other keras layers
        self.concat  = Concatenate()
        self.linear  = Dense(d_out)
        self.dropout = Dropout(dropout)
        
    def call(self, inputs, training=False):
        """
        Call docstring
        """    
        ##  Calculate attention heads
        #        shape [batch_size, query_sequence_length, dv] for each list element
        x = [head(inputs, training=training) for head in self.heads]
        
        ##  Concatenate heads and project onto single output
        x = self.concat (x)
        x = self.dropout(x, training=training)  
        x = self.linear (x, training=training)     
        return x


In [6]:
#==================================================================================================
#===  Create method for building simple model that performs one multi-head attention operation  ===
#==================================================================================================

def build_multi_head_attention_model(num_heads, d_out, query_length, ref_length=None, dk_per_head=10, 
                                     dv_per_head=None, name=None) :
    """
    Method docstring
    """
    ##  If ref_length or dv not provided then set equal to query_length
    if not ref_length  : ref_length  = query_length
    if not dv_per_head : dv_per_head = int(query_length/num_heads)
        
    ##  Create model inputs
    query_sequence = Input((None, query_length))
    ref_sequence   = Input((None, ref_length  ))
    
    ##  Perform attention step
    x = MultiHeadAttention(num_heads, d_out, dk_per_head, dv_per_head)([query_sequence, ref_sequence])
    
    ##  Create model, but don't compile as we don't yet have a well-defined loss function
    model = Model([query_sequence, ref_sequence], x, name=name)
    
    ##  Return
    return model
    

In [7]:
#==============================================================
#===  Build and print model to check for structural errors  ===
#==============================================================

d_in        = 128
d_model     = 128
num_heads   = 8
dk_per_head = int(d_model / num_heads)
dv_per_head = int(d_model / num_heads)

model = build_multi_head_attention_model(num_heads, d_model, d_in, dk_per_head=dk_per_head, 
                                         dv_per_head=dv_per_head, name="multi_head_attention_model")

model.summary(expand_nested=True)


Model: "multi_head_attention_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, None, 128)]  0           []                               
                                                                                                  
 input_4 (InputLayer)           [(None, None, 128)]  0           []                               
                                                                                                  
 multi_head_attention (MultiHea  (None, None, 128)   66048       ['input_3[0][0]',                
 dAttention)                                                      'input_4[0][0]']                
                                                                                                  
Total params: 66,048
Trainable params: 66,048
Non-trainable params: 0
___

In [8]:

from tensorflow.keras.layers import MultiHeadAttention as KMHA

def build_keras_multi_head_attention_model(num_heads, d_model, dk_per_head, dv_per_head, name=None) :
    """
    Method docstring
    """
        
    ##  Create model inputs
    query_sequence = Input((None, d_model))
    
    ##  Perform attention step
    x = KMHA(num_heads, key_dim=dk_per_head, value_dim=dv_per_head)(query_sequence, query_sequence)
    
    ##  Create model, but don't compile as we don't yet have a well-defined loss function
    model = Model(query_sequence, x, name=name)
    
    ##  Return
    return model
    
test_model = build_keras_multi_head_attention_model(num_heads, d_model, dk_per_head, dv_per_head)

test_model.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_5 (InputLayer)           [(None, None, 128)]  0           []                               
                                                                                                  
 multi_head_attention_1 (MultiH  (None, None, 128)   66048       ['input_5[0][0]',                
 eadAttention)                                                    'input_5[0][0]']                
                                                                                                  
Total params: 66,048
Trainable params: 66,048
Non-trainable params: 0
__________________________________________________________________________________________________


---

##   Encoder blocks

We will build a sequence encoder by stacking blocks build on the attention mechanism. Each block will perform **self-attention**, meaning that we use the same sequence for both the query and reference. This means that a sequence will attend to itself to create an abstract representation, just like how we stack layers in a CNN or MLP. 

To construct each block, the first ingredient we will add is a **skip connection**. This means that we use addition or concatenation to combine the output of the attention layer with its own input to create a joint representation. In our case we will use addition to retain the same dimensionality. We could include a variable scalar to control how much information we want to include from the processed vs unprocessed representations, although the original AIAYN does not. Perhaps this is because this scale will be learned naturally as the scale of $V$. Skip connections can be _short_ skipping only one or two layers, or _long_ skipping many layers and maybe connecting directly to the final processing layers. We will use short skipp connections that bypass only the multi-head attention layer. The advantages of skip connections are:
- we may re-use low-level features in our output / later layers
- if e.g. positional information is lost in the attention layer, it may still propagate further down the network
- allow us to learn deeper networks (unclear about explanation of why)

We will use **layer normalisation** to keep our activations on a consistent scale. This is like batch normalisation, but it applies to individual datapoints rather than batches. For this reason, it only really works if the dimension $d_v$ is large enough, so the activations form a distribution with some notable shape. Note that vectors $[-1, 9]$, $[400, 500]$, $[-0.1, 0.1]$ etc will all be represented as $[-1, 1]$ after layer normalisation.

Instead of feeding the resulting feature vectors directly into a new layer, we will introduce two dense layers to process the vector. This will turn the information into something abstract that will work together with the next to propagate information efficiently. The first layer has dimension $d_{ff}$ and a `relu` activation, and the second layer has dimension $d_v$ and a linear activation. Once again we will combine the output with a skip connection so the encoder block may be considered to _modify_ and not _replace_ the input.

Let us combine all of these encoder block elements into a single layer.


In [9]:
#==============================================
#===  Create keras layer for encoder block  ===
#==============================================

class EncoderBlock(Layer) :
    """
    Class docstring
    """
    def __init__(self, d_model, num_heads, dk_per_head, dv_per_head, dff, dropout=0.1):
        """
        Init docstring
        """
        ##  Initialise base class
        super().__init__()
        
        ##  Store dimension information
        self.d_model     = d_model
        self.num_heads   = num_heads
        self.dk_per_head = dk_per_head
        self.dv_per_head = dv_per_head
        self.dff         = dff
        
        ##  Create attention layer
        self.mha = MultiHeadAttention(num_heads, d_model, dk_per_head, dv_per_head)
                    
        ##  Create other keras layers
        self.add1    = Add()
        self.norm1   = LayerNormalization()
        self.add2    = Add()
        self.norm2   = LayerNormalization()
        self.dense1  = Dense(dff, activation="relu")
        self.dense2  = Dense(d_model)
        self.dropout = Dropout(dropout)
        
    def call(self, query_features, training=False):
        """
        Call docstring
        """
        ##  Calculate multi-head attention
        x_skip = query_features
        x = self.mha([query_features, query_features], training=training)   # shape [batch_size, query_sequence_length, d_model] 
                
        ##  Combine attention output with skip-connection
        x = self.add1([x, x_skip], training=training)    # shape [batch_size, query_sequence_length, d_model] 
        x = self.norm1(x, training=training)             # shape [batch_size, query_sequence_length, d_model] 
        x_skip = x
                
        ##  Feed-forward processing of linearly-combined feature vectors from each head
        x = self.dense1 (x, training=training)            # shape [batch_size, query_sequence_length, dff] 
        x = self.dense2 (x, training=training)            # shape [batch_size, query_sequence_length, d_model] 
        x = self.dropout(x, training=training)            # shape [batch_size, query_sequence_length, d_model] 
                
        ##  Skip-connect and return
        x = self.add2([x, x_skip], training=training)    # shape [batch_size, query_sequence_length, d_model] 
        x = self.norm2(x, training=training)             # shape [batch_size, query_sequence_length, d_model]     
        return x


In [10]:
#==================================================================================================
#===  Create method for building simple model that performs one multi-head attention operation  ===
#==================================================================================================

def build_encoder_block_model(d_model, num_heads, dk_per_head, dv_per_head, dff, name=None) :
    """
    Method docstring
    """ 
    ##  Create model inputs
    query_features = Input((None, d_model))
    
    ##  Perform attention step
    x = EncoderBlock(d_model, num_heads, dk_per_head, dv_per_head, dff)(query_features)
    
    ##  Create model, but don't compile as we don't yet have a well-defined loss function
    model = Model(query_features, x, name=name)
    
    ##  Return
    return model
    

In [11]:
#==============================================================
#===  Build and print model to check for structural errors  ===
#==============================================================

d_model     = 128
num_heads   = 8
dk_per_head = int(d_model / num_heads)
dv_per_head = int(d_model / num_heads)
dff         = 512

model = build_encoder_block_model(d_model, num_heads, dk_per_head=dk_per_head, dv_per_head=dv_per_head, 
                                  dff=dff, name="encoder")

model.summary(expand_nested=True)


Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_6 (InputLayer)        [(None, None, 128)]       0         
                                                                 
 encoder_block (EncoderBlock  (None, None, 128)        198272    
 )                                                               
                                                                 
Total params: 198,272
Trainable params: 198,272
Non-trainable params: 0
_________________________________________________________________


---

## Input embedding

Our total encoder model is made of:
- An embedding layer to create an efficient representation of the inputs. This uses a keras Embedding layer which possibly requires that the inputs are integer type. The output is of size $d_\mathrm{model}$.
- A positional encoding layer which adds information about sequence position. We add this to the embedding to create tokens that represent the bare token and its position
- Several encoder blocks which perform attention and feed-forward processing, with output size $d_model$

In [12]:

class PositionalEncoding(tf.keras.layers.Layer) :
    
    def __init__(self, d_in, d_model) :
        """
        Init docstring
        """
        ##  Initialise base class
        super().__init__()
        
        ##  Store dimension information
        self.d_in      = d_in
        self.d_model   = d_model
        
        ##  Store Tensor object with pre-computed positional encodings
        self.encoded_positions = self.create_encoded_positions_tensor(d_in, d_model)

    def call(self, x) :
        """
        Call docstring
        """
        ##  Return slice of stored encoded_positions Tensor with correct shape
        length = tf.shape(x)[1]
        return self.encoded_positions[tf.newaxis, :length, :]
    
    def create_encoded_positions_tensor(self, d_in, d_model) :
        """
        Method docstring
        """
        ##  Create numpy array with positions
        positions = np.arange(d_in)   # shape (d_in)
        
        ##  Combine with indices to create 2D array of angles
        half_indices = np.arange(d_model/2)   # shape (d_model/2)
        angles = (10000**(-half_indices))     # shape (d_model/2)
        angles = np.outer(positions, angles)  # shape (d_in, d_model/2)

        ##  Interleave sing and cos of angles into single 2D array of positional encodings
        pos_encoding = np.concatenate([np.sin(angles), np.cos(angles)], axis=-1)   # shape (d_in, d_model)

        ##  Return Tensor of positional encodings
        return tf.cast(pos_encoding, dtype=tf.float32)
        

In [13]:
#========================================
#===  Create keras layer for encoder  ===
#========================================

class Encoder(Layer) :
    """
    Class docstring
    """
    def __init__(self, num_blocks, d_in, d_model, num_heads, dk_per_head, dv_per_head, dff):
        """
        Init docstring
        """
        ##  Initialise base class
        super().__init__()
        
        ##  Store dimension information
        self.num_blocks  = num_blocks
        self.d_model     = d_model
        self.num_heads   = num_heads
        self.dk_per_head = dk_per_head
        self.dv_per_head = dv_per_head
        self.dff         = dff
        self.emb_scalar  = tf.math.sqrt(tf.cast(d_model, tf.float32))
                    
        ##  Create embedding layers
        self.add               = Add()
        self.token_embedding   = Embedding(d_in, d_model, mask_zero=True)
        self.position_encoding = PositionalEncoding(d_in, d_model)
        
        ##  Create encoder block layers
        self.encoder_blocks = [EncoderBlock(d_model, num_heads, dk_per_head, dv_per_head, dff) 
                               for ei in range(num_blocks)]
        
    def call(self, query_features, training=False):
        """
        Call docstring
        """
        ##  Calculate embeddings
        token_embedding   = self.token_embedding(query_features, training=training) # shape [batch_size, query_sequence_length, d_model]
        token_embedding  *= self.emb_scalar                                         # shape [batch_size, query_sequence_length, d_model]
        position_encoding = self.position_encoding(query_features)                  # shape [batch_size, query_sequence_length, d_model]
        x = self.add([token_embedding, position_encoding], training=training)       # shape [batch_size, query_sequence_length, d_model] 
                
        ##  Pass through encoder blocks
        for encoder_block in self.encoder_blocks :
            x = encoder_block(x, training=training)
            
        ##  Return encoded sequence
        return x
    
    '''def compute_mask(self, *args, **kwargs):
        """
        Method docstring
        """
        return self.token_embedding.compute_mask(*args, **kwargs)'''


In [14]:
#============================================================================
#===  Create method for building simple model that performs the encoding  ===
#============================================================================

def build_encoder_model(num_blocks, d_in, d_model, num_heads, dk_per_head, dv_per_head, dff, name=None) :
    """
    Method docstring
    """ 
    ##  Create model inputs
    query_features = Input((None,))
    
    ##  Perform attention step
    x = Encoder(num_blocks, d_in, d_model, num_heads, dk_per_head, dv_per_head, dff)(query_features)
    
    ##  Create model, but don't compile as we don't yet have a well-defined loss function
    model = Model(query_features, x, name=name)
    
    ##  Return
    return model
    

In [15]:
#==============================================================
#===  Build and print model to check for structural errors  ===
#==============================================================

d_in        = 7765   # portuguese vocab size
d_model     = 128
num_blocks  = 4
num_heads   = 8
dk_per_head = d_model  # int(d_model / num_heads)
dv_per_head = d_model  # int(d_model / num_heads)
dff         = 512

model = build_encoder_model(num_blocks, d_in, d_model, num_heads, dk_per_head, dv_per_head, dff, name="encoder")

model.summary(expand_nested=True)


Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_7 (InputLayer)        [(None, None)]            0         
                                                                 
 encoder (Encoder)           (None, None, 128)         3632768   
                                                                 
Total params: 3,632,768
Trainable params: 3,632,768
Non-trainable params: 0
_________________________________________________________________


## Decoder block

In [16]:
#==============================================
#===  Create keras layer for encoder block  ===
#==============================================

class DecoderBlock(Layer) :
    """
    Class docstring
    """
    def __init__(self, d_model, num_heads, dk_per_head, dv_per_head, dff, dropout=0.1):
        """
        Init docstring
        """
        ##  Initialise base class
        super().__init__()
        
        ##  Store dimension information
        self.d_model     = d_model
        self.num_heads   = num_heads
        self.dk_per_head = dk_per_head
        self.dv_per_head = dv_per_head
        self.dff         = dff
        
        ##  Create attention layers
        self.masked_attention = MultiHeadAttention(num_heads, d_model, dk_per_head, dv_per_head, use_causal_mask=True )
        self.cross_attention  = MultiHeadAttention(num_heads, d_model, dk_per_head, dv_per_head, use_causal_mask=False)
                    
        ##  Create other keras layers
        self.add1    = Add()
        self.norm1   = LayerNormalization()
        self.add2    = Add()
        self.norm2   = LayerNormalization()
        self.add3    = Add()
        self.norm3   = LayerNormalization()
        self.dense1  = Dense(dff, activation="relu")
        self.dense2  = Dense(d_model)
        self.dropout = Dropout(dropout)
        
    def call(self, inputs, training=False):
        """
        Call docstring
        """
        ##  Resolve inputs
        if len(inputs) != 2 : raise ValueError(f"inputs of length {len(inputs)}, expected 2")
        decoder_input  = inputs[0]
        encoder_output = inputs[1]
        
        ##  Calculate masked self-attention
        x_skip = decoder_input
        x = self.masked_attention([decoder_input, decoder_input], training=training)   # shape [batch_size, query_sequence_length, d_model] 
                
        ##  Combine attention output with skip-connection
        x = self.add1([x, x_skip], training=training)          # shape [batch_size, query_sequence_length, d_model] 
        x = self.norm1(x, training=training)                   # shape [batch_size, query_sequence_length, d_model] 
        x_skip = x
        
        ##  Calculate cross-attention
        x = self.cross_attention([x, encoder_output], training=training)   # shape [batch_size, query_sequence_length, d_model] 
                
        ##  Combine attention output with skip-connection
        x = self.add2([x, x_skip], training=training)          # shape [batch_size, query_sequence_length, d_model] 
        x = self.norm2(x, training=training)                   # shape [batch_size, query_sequence_length, d_model] 
        x_skip = x
        
        ##  Feed-forward processing of linearly-combined feature vectors from each head
        x = self.dense1 (x, training=training)            # shape [batch_size, query_sequence_length, dff] 
        x = self.dense2 (x, training=training)            # shape [batch_size, query_sequence_length, d_model] 
        x = self.dropout(x, training=training)            # shape [batch_size, query_sequence_length, d_model] 
                
        ##  Skip-connect and return
        x = self.add3([x, x_skip], training=training)          # shape [batch_size, query_sequence_length, d_model] 
        x = self.norm3(x, training=training)                   # shape [batch_size, query_sequence_length, d_model] 
                
        return x


In [17]:
#==================================================================================================
#===  Create method for building simple model that performs one multi-head attention operation  ===
#==================================================================================================

def build_decoder_block_model(d_model, num_heads, dk_per_head, dv_per_head, dff, name=None) :
    """
    Method docstring
    """ 
    ##  Create model inputs
    query_features = Input((None, d_model))
    ref_features   = Input((None, d_model))
    
    ##  Perform attention step
    x = DecoderBlock(d_model, num_heads, dk_per_head, dv_per_head, dff)([query_features, ref_features])
    
    ##  Create model, but don't compile as we don't yet have a well-defined loss function
    model = Model([query_features, ref_features], x, name=name)
    
    ##  Return
    return model
    

In [18]:
#==============================================================
#===  Build and print model to check for structural errors  ===
#==============================================================

d_model     = 128
num_heads   = 8
dk_per_head = int(d_model / num_heads)
dv_per_head = int(d_model / num_heads)
dff         = 512

model = build_decoder_block_model(d_model, num_heads, dk_per_head=dk_per_head, dv_per_head=dv_per_head, 
                                  dff=dff, name="decoder_block_model")

model.summary(expand_nested=True)


Model: "decoder_block_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_8 (InputLayer)           [(None, None, 128)]  0           []                               
                                                                                                  
 input_9 (InputLayer)           [(None, None, 128)]  0           []                               
                                                                                                  
 decoder_block (DecoderBlock)   (None, None, 128)    264576      ['input_8[0][0]',                
                                                                  'input_9[0][0]']                
                                                                                                  
Total params: 264,576
Trainable params: 264,576
Non-trainable params: 0
________

## Decoder

In [19]:
#========================================
#===  Create keras layer for decoder  ===
#========================================

class Decoder(Layer) :
    """
    Class docstring
    """
    def __init__(self, num_blocks, d_in, d_model, num_heads, dk_per_head, dv_per_head, dff):
        """
        Init docstring
        """
        ##  Initialise base class
        super().__init__()
        
        ##  Store dimension information
        self.num_blocks  = num_blocks
        self.d_in        = d_in
        self.d_model     = d_model
        self.num_heads   = num_heads
        self.dk_per_head = dk_per_head
        self.dv_per_head = dv_per_head
        self.dff         = dff
        self.emb_scalar  = tf.math.sqrt(tf.cast(d_model, tf.float32))
                    
        ##  Create embedding layers
        self.add               = Add()
        self.token_embedding   = Embedding(d_in, d_model, mask_zero=True)
        self.position_encoding = PositionalEncoding(d_in, d_model)
        
        ##  Create decoder block layers
        self.decoder_blocks = [DecoderBlock(d_model, num_heads, dk_per_head, dv_per_head, dff) 
                               for ei in range(num_blocks)]
        
    def call(self, inputs, training=False):
        """
        Call docstring
        """
        ##  Resolve inputs
        if len(inputs) != 2 : raise ValueError(f"inputs of length {len(inputs)}, expected 2")
        decoder_input  = inputs[0]
        encoder_output = inputs[1]
        
        ##  Calculate embeddings
        token_embedding   = self.token_embedding(decoder_input, training=training) # shape [batch_size, query_sequence_length, d_model]
        token_embedding  *= self.emb_scalar                                        # shape [batch_size, query_sequence_length, d_model]
        position_encoding = self.position_encoding(decoder_input)                  # shape [batch_size, query_sequence_length, d_model]
        x = self.add([token_embedding, position_encoding], training=training)      # shape [batch_size, query_sequence_length, d_model] 
                
        ##  Pass through decoder blocks
        for decoder_block in self.decoder_blocks :
            x = decoder_block([x, encoder_output], training=training)
            
        ##  Return token probabilities
        return x


In [20]:
#============================================================================
#===  Create method for building simple model that performs the decoding  ===
#============================================================================

def build_decoder_model(num_blocks, d_in, d_model, num_heads, dk_per_head, dv_per_head, dff, name=None) :
    """
    Method docstring
    """ 
    ##  Create model inputs
    query_features = Input((None,))
    ref_features   = Input((None, d_model))
    
    ##  Perform attention step
    x = Decoder(num_blocks, d_in, d_model, num_heads, dk_per_head, dv_per_head, dff)([query_features, ref_features])
    
    ##  Create model, but don't compile as we don't yet have a well-defined loss function
    model = Model([query_features, ref_features], x, name=name)
    
    ##  Return
    return model
    

In [21]:
#==============================================================
#===  Build and print model to check for structural errors  ===
#==============================================================

d_in        = 7010
d_model     = 128
num_blocks  = 4
num_heads   = 8
dk_per_head = d_model  # int(d_model / num_heads)
dv_per_head = d_model  # int(d_model / num_heads)
dff         = 512

model = build_decoder_model(num_blocks, d_in, d_model, num_heads, dk_per_head, dv_per_head, dff, name="decoder")

model.summary(expand_nested=True)


Model: "decoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_10 (InputLayer)          [(None, None)]       0           []                               
                                                                                                  
 input_11 (InputLayer)          [(None, None, 128)]  0           []                               
                                                                                                  
 decoder (Decoder)              (None, None, 128)    5647104     ['input_10[0][0]',               
                                                                  'input_11[0][0]']               
                                                                                                  
Total params: 5,647,104
Trainable params: 5,647,104
Non-trainable params: 0
________________

## Transformer

In [22]:
#========================================
#===  Create keras layer for decoder  ===
#========================================

class Transformer(Layer) :
    """
    Class docstring
    """
    def __init__(self, num_blocks, d_in, d_out, d_model, num_heads, dk_per_head, dv_per_head, dff):
        """
        Init docstring
        """
        ##  Initialise base class
        super().__init__()
        
        ##  Store dimension information
        self.num_blocks  = num_blocks
        self.d_in        = d_in
        self.d_out       = d_out
        self.d_model     = d_model
        self.num_heads   = num_heads
        self.dk_per_head = dk_per_head
        self.dv_per_head = dv_per_head
        self.dff         = dff
        self.emb_scalar  = tf.math.sqrt(tf.cast(d_model, tf.float32))
                    
        ##  Create encoder and decoder layers
        self.encoder = Encoder(num_blocks, d_in , d_model, num_heads, dk_per_head, dv_per_head, dff)
        self.decoder = Decoder(num_blocks, d_out, d_model, num_heads, dk_per_head, dv_per_head, dff)
                
        ##  Create layers to convert decoder output to token probabilities
        self.linear  = Dense(d_out)
        self.softmax = Softmax()
        
    def call(self, inputs, training=False):
        """
        Call docstring
        """
        ##  Resolve inputs
        if len(inputs) != 2 : raise ValueError(f"inputs of length {len(inputs)}, expected 2")
        encoder_input = inputs[0]
        decoder_input = inputs[1]
        
        ##  Calculate encoding
        x = self.encoder(encoder_input, training=training)
                
        ##  Calculate decoding
        x = self.decoder([decoder_input, x], training=training)
                        
        ##  Turn decoder outputs into token probabilities
        x = self.linear (x, training=training)
        x = self.softmax(x, training=training)
            
        ##  Return token probabilities
        return x


In [37]:
#============================================================================
#===  Create method for building simple model that performs the decoding  ===
#============================================================================

def build_transformer_model(num_blocks, d_in, d_out, d_model, num_heads, dk_per_head, dv_per_head, dff, name=None) :
    """
    Method docstring
    """ 
    ##  Create model inputs
    original_message   = Input((None,))
    translated_message = Input((None,))
    
    ##  Perform attention step
    x = Transformer(num_blocks, d_in, d_out, d_model, num_heads, dk_per_head, dv_per_head, dff)([original_message, translated_message])
    
    ##  Create model
    model = Model([original_message, translated_message], x, name=name)
    
    ##  Return
    return model
    

In [38]:
#==============================================================
#===  Build and print model to check for structural errors  ===
#==============================================================

d_in        = 7765    # portuguese vocab size
d_out       = 7010    # english vocab size
d_model     = 128
num_blocks  = 4
num_heads   = 8
dk_per_head = d_model  # int(d_model / num_heads)
dv_per_head = d_model  # int(d_model / num_heads)
dff         = 512

transformer = build_transformer_model(num_blocks, d_in, d_out, d_model, num_heads, dk_per_head, dv_per_head, 
                                      dff, name="transformer")

transformer.summary(expand_nested=True)


Model: "transformer"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_14 (InputLayer)          [(None, None)]       0           []                               
                                                                                                  
 input_15 (InputLayer)          [(None, None)]       0           []                               
                                                                                                  
 transformer_1 (Transformer)    (None, None, 7010)   10184162    ['input_14[0][0]',               
                                                                  'input_15[0][0]']               
                                                                                                  
Total params: 10,184,162
Trainable params: 10,184,162
Non-trainable params: 0
__________

In [39]:

input_pt = np.random.randint(d_in , size=(1,23))
seq_en   = np.random.randint(d_out, size=(1,17))
input_en = seq_en[:,:-1]
label_en = seq_en[:,1: ]

input_pt = tf.cast(input_pt, dtype=tf.int32)
input_en = tf.cast(input_en, dtype=tf.int32)
label_en = tf.cast(label_en, dtype=tf.int32)

print(input_pt)
print(input_en)
print(label_en)


tf.Tensor(
[[ 343 5351 5122 7580 4693 5521  965 2847 1894  267  421 2502  405 5599
    62 4444 4641 3127 1506 6238 2282 7389  680]], shape=(1, 23), dtype=int32)
tf.Tensor(
[[2349 6745 2682 4439 6642 3803 1927 6798  301 5927 2179 1697 2458 5773
  2286 4358]], shape=(1, 16), dtype=int32)
tf.Tensor(
[[6745 2682 4439 6642 3803 1927 6798  301 5927 2179 1697 2458 5773 2286
  4358 1331]], shape=(1, 16), dtype=int32)


In [40]:
model_output = transformer([input_pt, input_en])

In [41]:
model_output.shape

TensorShape([1, 16, 7010])

## Training

In [42]:
'''
def masked_loss(label, pred):
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
    loss = loss_object(label, pred)

    mask = tf.where(label==0, 0, 1)
    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask

    loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
    return loss

def masked_accuracy(label, pred):
    pred  = tf.argmax(pred)
    label = tf.cast (label, pred.dtype)
    mask  = tf.where(label==0, 0, 1)
    match = tf.where(label==pred, 1, 0)
    match = mask * match
    match = tf.cast(match, dtype=tf.float32)
    mask  = tf.cast(mask , dtype=tf.float32)
    return tf.reduce_sum(match) / tf.reduce_sum(mask)
'''

'\ndef masked_loss(label, pred):\n    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()\n    loss = loss_object(label, pred)\n\n    mask = tf.where(label==0, 0, 1)\n    mask = tf.cast(mask, dtype=loss.dtype)\n    loss *= mask\n\n    loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)\n    return loss\n\ndef masked_accuracy(label, pred):\n    pred  = tf.argmax(pred)\n    label = tf.cast (label, pred.dtype)\n    mask  = tf.where(label==0, 0, 1)\n    match = tf.where(label==pred, 1, 0)\n    match = mask * match\n    match = tf.cast(match, dtype=tf.float32)\n    mask  = tf.cast(mask , dtype=tf.float32)\n    return tf.reduce_sum(match) / tf.reduce_sum(mask)\n'

In [43]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super().__init__()
        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

def masked_loss(label, pred):
    mask = label != 0
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction='none')
    loss = loss_object(label, pred)
    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask
    loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
    return loss

def masked_accuracy(label, pred):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred
    mask = label != 0
    match = match & mask
    match = tf.cast(match, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(match)/tf.reduce_sum(mask)


In [44]:

optimizer = tf.keras.optimizers.Adam(CustomSchedule(d_model), beta_1=0.9, beta_2=0.98, epsilon=1e-9)

transformer.compile(
    loss      = masked_loss,
    optimizer = optimizer,
    metrics   = [masked_accuracy])


In [45]:
train_batches = tf.data.Dataset.load("tokenised_train_batches")
val_batches   = tf.data.Dataset.load("tokenised_val_batches"  )


In [54]:

transformer.fit(
    train_batches, 
    epochs=20, 
    validation_data=val_batches)


Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x286f0c7f0>

In [55]:

# model.save('saved_model/transformer.h5')
# model.save_weights('saved_model/transformer_weights')


In [56]:

def translate_sentence(sentence, transformer, tokenizers, max_tokens=30) :
    with tf.device("CPU") :
        tokenised_sentence = tokenizers.pt.tokenize([sentence]).to_tensor()
        
        start_end   = tokenizers.en.tokenize([''])[0]
        begin_token = start_end[0 ]
        end_token   = start_end[-1]
        
        tokenised_translated_sentence = tf.cast([[begin_token]], dtype=tf.int64)
                
        best_token, num_tokens = begin_token, 1
        while best_token != end_token and num_tokens < max_tokens :
            token_probs = transformer([tokenised_sentence, tokenised_translated_sentence])
            best_token  = tf.argmax(token_probs[0,-1,:], axis=-1)
            tokenised_translated_sentence = tf.concat([tokenised_translated_sentence, 
                                                       tf.reshape(best_token, (1,1))],
                                                      axis=-1)
            num_tokens += 1
            
        translated_sentence = tokenizers.en.detokenize(tokenised_translated_sentence)[0]
        return translated_sentence.numpy().decode('utf-8')
    

In [57]:

tokenizer_model_fname = 'ted_hrlr_translate_pt_en_converter'
tokenizers = tf.saved_model.load(tokenizer_model_fname)


In [58]:
en_sentence = "There is no apple in this room, there are only three bananas and a grapefruit"
pt_sentence = "Não há maçã nesta sala, há apenas três bananas e uma toranja"

pred_sentence = translate_sentence(pt_sentence, transformer, tokenizers)

print(f"Portuguese sentence  : {pt_sentence}")
print(f"Ground truth         : {en_sentence}")
print(f"Predicted translation: {pred_sentence}")


Portuguese sentence  : Não há maçã nesta sala, há apenas três bananas e uma toranja
Ground truth         : There is no apple in this room, there are only three bananas and a grapefruit
Predicted translation: there ' s no apple in this room , there ' s only three braft and a srarararararararara
