## Multi Head Attention

In [82]:
import numpy as np
import tensorflow as tf

In [83]:
sequence_length = 4
batch_size = 1
input_dim = 512
d_model = 512
x = np.random.randn(batch_size, sequence_length, input_dim)

In [84]:
x.shape

(1, 4, 512)

In [85]:
qkv_layer = tf.keras.layers.Dense(3 * d_model)

In [86]:
qkv = qkv_layer(x)



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



In [87]:
qkv.shape

TensorShape([1, 4, 1536])

In [88]:
num_heads = 8
head_dim = d_model // num_heads
qkv = tf.reshape(qkv, [batch_size, sequence_length, num_heads, 3 * head_dim])

In [89]:
qkv.shape

TensorShape([1, 4, 8, 192])

In [90]:
qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
qkv.shape

TensorShape([1, 8, 4, 192])

In [91]:
q, k, v = tf.split(qkv, num_or_size_splits=3, axis = -1)
q.shape, k.shape, v.shape

(TensorShape([1, 8, 4, 64]),
 TensorShape([1, 8, 4, 64]),
 TensorShape([1, 8, 4, 64]))

## Self Attention for multiple heads

For a single head:
$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [92]:
d_k = k.shape[-1]
scaled = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) / np.sqrt(d_k)
scaled.shape

TensorShape([1, 8, 4, 4])

In [93]:
mask = tf.linalg.band_part((tf.ones(scaled.shape)), -1,0,)
mask = tf.where(mask == 1, 0, tf.constant(float('-inf')))
mask[0][1]

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]], dtype=float32)>

In [94]:
(scaled + mask)[0][0]

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[-0.0669643 ,        -inf,        -inf,        -inf],
       [ 0.8429377 , -0.22580042,        -inf,        -inf],
       [ 0.6326221 ,  0.31469107, -0.54598874,        -inf],
       [ 0.3652924 , -0.2694186 ,  0.5640688 , -0.15838446]],
      dtype=float32)>

In [95]:
scaled += mask

In [96]:
scaled.ndim

4

In [97]:
attention = tf.keras.activations.softmax(scaled, axis=-1)

In [98]:
values = tf.matmul(attention,v)
values.shape

TensorShape([1, 8, 4, 64])

## Function

In [99]:
def scaled_dot_product(q, k, v, mask = None):
    d_k = k.shape[-1]
    scaled = tf.matmul(q, tf.transpose(v, perm=[0, 1, 3, 2])) / np.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = tf.keras.activations.softmax(scaled, axis = -1)
    values = tf.matmul(attention, v)
    return values, attention

In [100]:
values, attention = scaled_dot_product(q, k, v)

In [101]:
values.shape, attention.shape

(TensorShape([1, 8, 4, 64]), TensorShape([1, 8, 4, 4]))

In [102]:
values = tf.reshape(values, (batch_size, sequence_length, num_heads * head_dim))
values.shape

TensorShape([1, 4, 512])

In [104]:
linear_layer = tf.keras.layers.Dense(d_model)

In [105]:
out = linear_layer(values)

In [106]:
out.shape

TensorShape([1, 4, 512])

In [107]:
out

<tf.Tensor: shape=(1, 4, 512), dtype=float32, numpy=
array([[[-0.53420955,  0.10022534,  0.07923827, ...,  0.3830155 ,
         -0.37225428,  0.11568592],
        [ 0.20668475, -0.10040505,  0.30879456, ..., -0.03135356,
         -0.3400464 ,  0.28938416],
        [ 0.5138191 , -0.4163006 , -0.19618417, ..., -0.1770674 ,
          0.08143231, -0.37141818],
        [ 1.1689396 ,  0.40013433,  0.08186674, ..., -0.37103337,
          0.87498176, -0.30928454]]], dtype=float32)>

## Class

In [118]:
class MultiheadAttention():
    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = tf.keras.layers.Dense(3 * d_model)
        self.linear_layer = tf.keras.layers.Dense(d_model)

    def forward(self, x, mask = None):
        batch_size, sequence_length, input_dim = x.shape
        print(f"x.shape: {x.shape}")
        qkv = self.qkv_layer(x)
        print(f"qkv.shape: {qkv.shape}")
        qkv = tf.reshape(qkv, [batch_size, sequence_length, self.num_heads, self.head_dim * 3])
        print(f"qkv.shape: {qkv.shape}")
        qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
        print(f"qkv.shape: {qkv.shape}")
        q, k, v = tf.split(qkv, num_or_size_splits=3, axis = -1)
        print(f"q shape: {q.shape}, k shape: {k.shape}, v shape: {v.shape}")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.shape: {values.shape}, attention.shape: {attention.shape}")
        values = tf.reshape(values, [batch_size, sequence_length, self.head_dim * self.num_heads])
        print(f"values.shape: {values.shape}")
        out = self.linear_layer(values)
        print(f"out.shape: {out.shape}")
        return out

In [119]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 10
sequence_length = 5
x = tf.random.normal((batch_size, sequence_length, input_dim))

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.shape: (10, 5, 1024)
qkv.shape: (10, 5, 1536)
qkv.shape: (10, 5, 8, 192)
qkv.shape: (10, 8, 5, 192)
q shape: (10, 8, 5, 64), k shape: (10, 8, 5, 64), v shape: (10, 8, 5, 64)
values.shape: (10, 8, 5, 64), attention.shape: (10, 8, 5, 5)
values.shape: (10, 5, 512)
out.shape: (10, 5, 512)
