### Multihead Attention

### Imports

In [1]:
from flax import nnx
import jax
import jax.numpy as jnp

### Attention Block

In [2]:
class Attention(nnx.Module):
    def __init__(self, d_model: int = 2, 
                 row_dim: int = 0, 
                 col_dim: int = 1, 
                 *, 
                 rngs: nnx.Rngs):
        key = rngs.params()
        self.W_q = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_k = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_v = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model

    def __call__(self, encodings_for_q: jax.Array, 
                 encodings_for_k: jax.Array, 
                 encodings_for_v: jax.Array, 
                 mask: jax.Array = None):
        q = encodings_for_q @ self.W_q
        k = encodings_for_k @ self.W_k
        v = encodings_for_v @ self.W_v

        sims = q @ k.swapaxes(self.row_dim, self.col_dim)
        scaled_sims = sims / jnp.sqrt(self.d_model)

        if mask is not None:
            scaled_sims = jnp.where(mask, -1e9, scaled_sims)

        attention_percents = jax.nn.softmax(scaled_sims, axis=self.col_dim)
        attention_scores = attention_percents @ v

        return attention_scores

### Multihead attention block

In [3]:
class MultiHeadAttention(nnx.Module):
    def __init__(self, 
                 d_model: int = 2, 
                 row_dim: int = 0, 
                 col_dim: int = 1, 
                 num_heads: int = 1, 
                 *, 
                 rngs: nnx.Rngs):
        self.heads = [Attention(d_model, row_dim, col_dim, rngs=rngs) 
                     for _ in range(num_heads)]
        self.col_dim = col_dim

    def __call__(self, 
                 encodings_for_q: jax.Array, 
                 encodings_for_k: jax.Array, 
                 encodings_for_v: jax.Array):
        # Run data through all attention heads and concatenate along col_dim
        outputs = [head(encodings_for_q, encodings_for_k, encodings_for_v) 
                  for head in self.heads]
        return jnp.concatenate(outputs, axis=self.col_dim)

### Verifying Attention block

In [4]:
# Create input matrices
encodings_for_q = jnp.array([[1.16, 0.23],
                           [0.57, 1.36],
                           [4.41, -2.16]])
encodings_for_k = jnp.array([[1.16, 0.23],
                           [0.57, 1.36],
                           [4.41, -2.16]])
encodings_for_v = jnp.array([[1.16, 0.23],
                           [0.57, 1.36],
                           [4.41, -2.16]])

# Set random seed
key = jax.random.PRNGKey(42)

# Calculate Encoder-Decoder Attention
attention = Attention(d_model=2, row_dim=0, col_dim=1, rngs=nnx.Rngs(params=key))
attention_output = attention(encodings_for_q, encodings_for_k, encodings_for_v)
print("Encoder-Decoder Attention Output:")
print(attention_output)

Encoder-Decoder Attention Output:
[[1.1276929  0.73620886]
 [1.2950552  0.9069855 ]
 [0.95412695 0.5512294 ]]


### Verifying Multihead attention block

In [5]:
# Calculate Multi-Head Attention with single head
multi_head_attention_single = MultiHeadAttention(d_model=2, 
                                               row_dim=0, 
                                               col_dim=1, 
                                               num_heads=1, 
                                               rngs=nnx.Rngs(params=key))
single_head_output = multi_head_attention_single(encodings_for_q, 
                                               encodings_for_k, 
                                               encodings_for_v)
print("\nMulti-Head Attention (1 head) Output:")
print(single_head_output)

# Calculate Multi-Head Attention with two heads
multi_head_attention_double = MultiHeadAttention(d_model=2, 
                                               row_dim=0, 
                                               col_dim=1, 
                                               num_heads=2, 
                                               rngs=nnx.Rngs(params=key))
double_head_output = multi_head_attention_double(encodings_for_q, 
                                               encodings_for_k, 
                                               encodings_for_v)
print("\nMulti-Head Attention (2 heads) Output:")
print(double_head_output)


Multi-Head Attention (1 head) Output:
[[1.1276929  0.73620886]
 [1.2950552  0.9069855 ]
 [0.95412695 0.5512294 ]]

Multi-Head Attention (2 heads) Output:
[[1.1276929  0.73620886 2.5678048  2.6611998 ]
 [1.2950552  0.9069855  2.4171586  2.512345  ]
 [0.95412695 0.5512294  2.815371   2.9060547 ]]
