### Imports

In [3]:
from flax import nnx
import jax
import jax.numpy as jnp

### Masked Self Attention Block

In [8]:
class MaskedSelfAttention(nnx.Module):
    def __init__(self, d_model: int = 2, 
                 row_dim: int = 0, 
                 col_dim: int = 1, 
                 *, 
                 rngs: nnx.Rngs):
        key = rngs.params()
        self.W_q = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_k = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_v = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model

    def __call__(self, token_encodings: jax.Array, mask: jax.Array = None):
        q = token_encodings @ self.W_q
        k = token_encodings @ self.W_k
        v = token_encodings @ self.W_v

        sims = q @ k.swapaxes(self.row_dim, self.col_dim)
        scaled_sims = sims / jnp.sqrt(self.d_model)

        if mask is not None:
            # Mask out values with a large negative number
            scaled_sims = jnp.where(mask, -1e9, scaled_sims)

        attention_percents = jax.nn.softmax(scaled_sims, axis=self.col_dim)
        attention_scores = attention_percents @ v

        return attention_scores