<a href="https://colab.research.google.com/github/pavansai26/Attention-mechanisms/blob/main/Multihead_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Multihead Attention



#A mechanism that allows a model to attend to different parts of an input sequence multiple times, each time with a different focus.

#Steps:


#Split and Project: The input is split into multiple heads, and each head is projected into separate query, key, and value matrices.

#Scaled Dot-Product Attention: For each head, attention scores are calculated using scaled dot-product attention between queries and keys.

#Softmax and Weighted Sum: The scores are normalized using softmax, and a weighted sum of values is computed based on the attention scores.

#Concatenate and Project: The outputs from all heads are concatenated and projected into the final output.

#Advantages:



#Captures Diverse Relationships: Allows the model to learn multiple, diverse relationships between elements in the input sequence.

#Improves Representation Learning: Leads to richer, more informative representations of the input.

#Enhances Parallelism: The heads can be calculated in parallel, improving computational efficiency.

#Stabilizes Training: Helps prevent overfitting and improves model generalization.

#Disadvantages:

#Increased Complexity: Adds more parameters and complexity to the model.

#Potential Overfitting: If not regularized appropriately, it can lead to overfitting.

#Interpretability: Can be harder to interpret the model's decision-making process.

In [1]:
import tensorflow as tf

"""
MultiheadAttention class for implementing multihead attention in TensorFlow.
"""
class MultiheadAttention(tf.keras.layers.Layer):
   """
   Initialize a MultiheadAttention layer.

   Args:
       d_model: The dimensionality of the input and output representations.
       num_heads: The number of attention heads.
   """
   def __init__(self, d_model, num_heads):
       super().__init__()
       self.d_model = d_model
       self.num_heads = num_heads
       self.head_dim = d_model // num_heads  # Dimensionality of each attention head

       # Linear layers for projecting queries, keys, and values
       self.wq = tf.keras.layers.Dense(d_model)  # Query projection
       self.wk = tf.keras.layers.Dense(d_model)  # Key projection
       self.wv = tf.keras.layers.Dense(d_model)  # Value projection
       self.wo = tf.keras.layers.Dense(d_model)  # Output projection

   """
   Perform multihead attention on the given inputs.

   Args:
       query: The query tensor.
       key: The key tensor.
       value: The value tensor.
       mask: (optional) A mask to prevent attention to certain positions.

   Returns:
       The output of the multihead attention layer and the attention weights.
   """
   def call(self, query, key, value, mask=None):
       batch_size = tf.shape(query)[0]  # Get batch size

       # Project queries, keys, and values using linear layers
       q = self.wq(query)
       k = self.wk(key)
       v = self.wv(value)

       # Reshape into multiple heads (batch_size, sequence_length, num_heads, head_dim)
       q = tf.reshape(q, (batch_size, -1, self.num_heads, self.head_dim))
       k = tf.reshape(k, (batch_size, -1, self.num_heads, self.head_dim))
       v = tf.reshape(v, (batch_size, -1, self.num_heads, self.head_dim))

       # Transpose to (batch_size, num_heads, sequence_length, head_dim)
       q = tf.transpose(q, perm=[0, 2, 1, 3])
       k = tf.transpose(k, perm=[0, 2, 1, 3])
       v = tf.transpose(v, perm=[0, 2, 1, 3])

       # Calculate scaled dot-product attention scores
       scores = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(tf.cast(self.head_dim, tf.float32))

       # Apply masking if provided
       if mask is not None:
           scores = tf.where(mask == 0, tf.fill(tf.shape(scores), -1e9), scores)

       # Normalize scores with softmax
       attention = tf.nn.softmax(scores, axis=-1)

       # Compute weighted values
       output = tf.matmul(attention, v)

       # Transpose back to original shape (batch_size, sequence_length, num_heads, head_dim)
       output = tf.transpose(output, perm=[0, 2, 1, 3])

       # Reshape to (batch_size, sequence_length, d_model)
       output = tf.reshape(output, (batch_size, -1, self.d_model))

       # Final output projection
       output = self.wo(output)

       return output, attention
