In [2]:
import tensorflow as tf

In [12]:
import tensorflow as tf

class ScaledAttention(tf.keras.layers.Layer):
    def __init__(self):
        super(ScaledAttention, self).__init__()

    def call(self, query, key, value, mask=None):
        # Calculate scaled dot-product attention
        matmul_qk = tf.matmul(query, key, transpose_b=True)
        scale = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(scale)

        # Apply mask if provided
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)

        # Softmax to get attention weights
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        context_vector = tf.matmul(attention_weights, value)

        return context_vector, attention_weights

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, num_heads, d_model):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.attention = ScaledAttention()
        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, query, key, value, mask=None):
        batch_size = tf.shape(query)[0]

        query = self.wq(query)
        key = self.wk(key)
        value = self.wv(value)

        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)

        scaled_attention, attention_weights = self.attention(query, key, value, mask)
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
        output = self.dense(concat_attention)

        return output, attention_weights


In [13]:
num_heads = 8
d_model = 64 
batch_size = 2
seq_len = 5

query = tf.random.normal((batch_size , seq_len , d_model))
key = tf.random.normal((batch_size , seq_len , d_model))
value = tf.random.normal((batch_size , seq_len , d_model))

mha = MultiHeadAttention(num_heads= num_heads , d_model=d_model)

context_vector , attention_weights = mha(query , key , value)

In [14]:
context_vector.shape , attention_weights.shape

(TensorShape([2, 5, 64]), TensorShape([2, 8, 5, 5]))

In [15]:
print("context vector: ",context_vector) , print("attention weights: ",attention_weights)

context vector:  tf.Tensor(
[[[ 1.17092955e+00  1.74959034e-01  8.14614236e-01  2.72841901e-01
    1.34353900e+00 -1.30290017e-01  5.54659963e-01 -1.54419971e+00
   -3.62821728e-01 -5.39235994e-02  2.17677712e-01 -3.49965960e-01
   -7.07851112e-01  2.61607729e-02 -6.27807736e-01 -4.51061606e-01
   -6.79228753e-02  2.77737081e-01 -8.42350185e-01  8.42121243e-01
    7.87424743e-01  7.60869265e-01 -7.86327124e-01  5.84989846e-01
   -1.35983229e+00 -4.33838546e-01 -1.49074733e-01 -5.51107466e-01
   -6.89930081e-01  2.37578437e-01 -1.07542656e-01 -1.31057358e+00
   -6.41678870e-01 -2.73314446e-01 -5.26086152e-01  1.76131446e-02
   -2.85808176e-01  6.76846355e-02 -1.13950111e-01  1.12478115e-01
    1.91762790e-01 -3.83961618e-01 -1.15867168e-01 -4.52213645e-01
   -2.61492521e-01 -3.65843594e-01  4.58277971e-01 -6.46082610e-02
    9.02185380e-01  2.70749420e-01 -7.25422811e-04 -1.07466352e+00
   -4.62511659e-01  3.71136725e-01 -3.28190923e-02 -8.58771086e-01
    1.54520556e-01 -4.25858855e-01

(None, None)