### Imports

In [1]:
from flax import nnx
import jax
import jax.numpy as jnp

### Self attention block

![Attention Block](https://i.sstatic.net/t6qJz.png)

In [2]:
class SelfAttention(nnx.Module):
    def __init__(self, d_model: int = 2, 
                 row_dim: int = 0, 
                 col_dim: int = 1, 
                 *, 
                 rngs: nnx.Rngs):
        """
        d_model: the number of embedding values per token (default=2 for manual math)
        row_dim, col_dim: indices to access rows or columns
        rngs: random number generator state for initialization
        """
        # In Flax NNX, we initialize parameters directly in the Module
        # Using nnx.Param for trainable parameters
        key = rngs.params()
        
        # Initialize weight matrices (equivalent to nn.Linear without bias)
        # Shape: (d_model, d_model) for each W_q, W_k, W_v
        self.W_q = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_k = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_v = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        
        # Store static attributes directly
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model

    def __call__(self, token_encodings: jax.Array):
        """Forward pass of self-attention"""
        # Create query, key, and values by matrix multiplication
        # In JAX, we use @ for matrix multiplication
        q = token_encodings @ self.W_q
        k = token_encodings @ self.W_k
        v = token_encodings @ self.W_v

        # Compute similarity scores: (q * k^T)
        # Transpose k using swapaxes for the specified dimensions
        k_t = k.swapaxes(self.row_dim, self.col_dim)
        sims = q @ k_t

        # Scale similarities by sqrt(d_model)
        scaled_sims = sims / jnp.sqrt(self.d_model)

        # Apply softmax along col_dim
        attention_percents = jax.nn.softmax(scaled_sims, axis=self.col_dim)

        # Compute final attention scores
        attention_scores = attention_percents @ v

        return attention_scores

### Manual dry run

In [3]:
# Create input tensor
encodings_matrix = jnp.array([[1.16, 0.23],
                             [0.57, 1.36],
                             [4.41, -2.16]])

# Set random seed and create RNG key
key = jax.random.PRNGKey(42)

# Initialize the self-attention module
self_attention = SelfAttention(d_model=2,
                             row_dim=0,
                             col_dim=1,
                             rngs=nnx.Rngs(params=key))

# Calculate attention scores
attention_output = self_attention(encodings_matrix)

# Display results
print("Attention output:")
print(attention_output)
nnx.display(self_attention)

Attention output:
[[1.1276929  0.73620886]
 [1.2950552  0.9069855 ]
 [0.95412695 0.5512294 ]]
