<a href="https://colab.research.google.com/github/vedanttalnikar-bits/transformer_study_public/blob/main/multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:

import tensorflow as tf

import time
import numpy as np
import matplotlib.pyplot as plt

## Multi-head attention

<img src="https://www.tensorflow.org/images/tutorials/transformer/multi_head_attention.png" width="500" alt="multi-head attention">


Multi-head attention consists of four parts:
*    Linear layers and split into heads.
*    Scaled dot-product attention.
*    Concatenation of heads.
*    Final linear layer.

Each multi-head attention block gets three inputs; Q (query), K (key), V (value). These are put through linear (Dense) layers and split up into multiple heads. 

The `scaled_dot_product_attention` defined above is applied to each head (broadcasted for efficiency). An appropriate mask must be used in the attention step.  The attention output for each head is then concatenated (using `tf.transpose`, and `tf.reshape`) and put through a final `Dense` layer.

Instead of one single attention head, Q, K, and V are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.

In [69]:
def calc_mul_add(q, k):
  # mul = depth * (seq_len_q * seq_len_k)
  # add = (depth - 1) * (seq_len_q * seq_len_k))

  #depth
  dk = tf.cast(tf.shape(k)[3], tf.float32)

  seq_len_k = tf.cast(tf.shape(k)[2], tf.float32)
  seq_len_q = tf.cast(tf.shape(q)[2], tf.float32)
  batch_size = tf.cast(tf.shape(q)[0],tf.float32)
  heads = tf.cast(tf.shape(q)[1],tf.float32)

  mul = dk * (seq_len_q * seq_len_k)
  add = (dk - 1 ) * (seq_len_q * seq_len_k)
  #tf.print(heads, batch_size, seq_len_k, seq_len_q,dk)
  total_mul = batch_size * (mul * heads) 
  total_add = batch_size * (add * heads) 
  tf.print("Activation vs Activation per head:")
  tf.print("Multiplications:", mul)
  tf.print("Additions:", add)
  tf.print("Total Mul = ", total_mul  )
  tf.print("Total add = ", total_add  )


In [70]:
def calc_mul_add2(w, v):
  # mul = depth * (seq_len_q * seq_len_k)
  # add = (depth - 1) * (seq_len_q * seq_len_k))

  #depth
  dk = tf.cast(tf.shape(w)[3], tf.float32)
  #tf.print(v.shape)
  #tf.print(w.shape)
  #tf.print(dk)
  seq_len_v = tf.cast(tf.shape(v)[2], tf.float32)
  seq_len_w = tf.cast(tf.shape(w)[2], tf.float32)
  batch_size = tf.cast(tf.shape(w)[0],tf.float32)
  heads = tf.cast(tf.shape(w)[1],tf.float32)

  mul = dk * (seq_len_w * seq_len_v)
  add = (dk - 1 ) * (seq_len_w * seq_len_v)

  tf.print("Weight vs Activation:")
  tf.print("Multiplications:", mul )
  tf.print("Additions:", add)



In [71]:
def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.
  
  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.
    
  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
  calc_mul_add(q,k)
  #tf.print(q.shape, k.shape)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += m  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
  calc_mul_add2(attention_weights,v)

  return output, attention_weights

In [76]:
class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model
    
    assert d_model % self.num_heads == 0
    
    self.depth = d_model // self.num_heads
    
    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)
    
    self.dense = tf.keras.layers.Dense(d_model)
        
  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])
    
  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]
    
    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)
    
    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)
    #tf.print("output before concat:", scaled_attention.shape)
    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    #tf.print("post concatation:", concat_attention.shape)
    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
    tf.print(output.shape)
    return output, attention_weights

Create a `MultiHeadAttention` layer to try out. At each location in the sequence, `y`, the `MultiHeadAttention` runs all 8 attention heads across all other locations in the sequence, returning a new vector of the same length at each location.

In [77]:
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape

Activation vs Activation per head:
Multiplications: 230400
Additions: 226800
Total Mul =  1.8432e+06
Total add =  1.8144e+06
Weight vs Activation:
Multiplications: 216000
Additions: 212400
shape_dense: TensorShape([1, 8, 60, 64])
shape_before_dense: TensorShape([1, 60, 512])
TensorShape([1, 60, 512])


(TensorShape([1, 60, 512]), TensorShape([1, 8, 60, 60]))