### Multihead Attention

### Imports

In [1]:
from flax import nnx
import jax
import jax.numpy as jnp

### Attention Block

In [2]:
class Attention(nnx.Module):
    def __init__(self, d_model: int = 2, 
                 row_dim: int = 0, 
                 col_dim: int = 1, 
                 *, 
                 rngs: nnx.Rngs):
        key = rngs.params()
        self.W_q = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_k = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_v = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model

    def __call__(self, encodings_for_q: jax.Array, 
                 encodings_for_k: jax.Array, 
                 encodings_for_v: jax.Array, 
                 mask: jax.Array = None):
        q = encodings_for_q @ self.W_q
        k = encodings_for_k @ self.W_k
        v = encodings_for_v @ self.W_v

        sims = q @ k.swapaxes(self.row_dim, self.col_dim)
        scaled_sims = sims / jnp.sqrt(self.d_model)

        if mask is not None:
            scaled_sims = jnp.where(mask, -1e9, scaled_sims)

        attention_percents = jax.nn.softmax(scaled_sims, axis=self.col_dim)
        attention_scores = attention_percents @ v

        return attention_scores

### Multihead attention block

In [3]:
class MultiHeadAttention(nnx.Module):
    def __init__(self, 
                 d_model: int = 2, 
                 row_dim: int = 0, 
                 col_dim: int = 1, 
                 num_heads: int = 1, 
                 *, 
                 rngs: nnx.Rngs):
        self.heads = [Attention(d_model, row_dim, col_dim, rngs=rngs) 
                     for _ in range(num_heads)]
        self.col_dim = col_dim

    def __call__(self, 
                 encodings_for_q: jax.Array, 
                 encodings_for_k: jax.Array, 
                 encodings_for_v: jax.Array):
        # Run data through all attention heads and concatenate along col_dim
        outputs = [head(encodings_for_q, encodings_for_k, encodings_for_v) 
                  for head in self.heads]
        return jnp.concatenate(outputs, axis=self.col_dim)