### Imports

In [3]:
from flax import nnx
import jax
import jax.numpy as jnp

### Masked Self Attention Block

In [8]:
class MaskedSelfAttention(nnx.Module):
    def __init__(self, d_model: int = 2, 
                 row_dim: int = 0, 
                 col_dim: int = 1, 
                 *, 
                 rngs: nnx.Rngs):
        key = rngs.params()
        self.W_q = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_k = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.W_v = nnx.Param(jax.random.uniform(key, (d_model, d_model)))
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model

    def __call__(self, token_encodings: jax.Array, mask: jax.Array = None):
        q = token_encodings @ self.W_q
        k = token_encodings @ self.W_k
        v = token_encodings @ self.W_v

        sims = q @ k.swapaxes(self.row_dim, self.col_dim)
        scaled_sims = sims / jnp.sqrt(self.d_model)

        if mask is not None:
            # Mask out values with a large negative number
            scaled_sims = jnp.where(mask, -1e9, scaled_sims)

        attention_percents = jax.nn.softmax(scaled_sims, axis=self.col_dim)
        attention_scores = attention_percents @ v

        return attention_scores

### Manual Dry Run

In [9]:
# Create input and model
encodings_matrix = jnp.array([[1.16, 0.23],
                             [0.57, 1.36],
                             [4.41, -2.16]])

# Set random seed
key = jax.random.PRNGKey(42)

# Create masked self-attention object
masked_self_attention = MaskedSelfAttention(d_model=2,
                                          row_dim=0,
                                          col_dim=1,
                                          rngs=nnx.Rngs(params=key))

# Create causal mask (lower triangular)
mask = jnp.tril(jnp.ones((3, 3))) == 0
print("Mask:")
print(mask)

# Calculate masked self-attention
output = masked_self_attention(encodings_matrix, mask)
print("\nMasked Self-Attention Output:")
print(output)

Mask:
[[False  True  True]
 [False False  True]
 [False False False]]

Masked Self-Attention Output:
[[0.8224545  0.52411664]
 [1.3705109  0.99325544]
 [0.95412695 0.5512294 ]]


### Manual verification

In [10]:
# Print weights and verify calculations
print("\nQuery weights (W_q) transposed:")
print(masked_self_attention.W_q.value.T)

print("\nKey weights (W_k) transposed:")
print(masked_self_attention.W_k.value.T)

print("\nValue weights (W_v) transposed:")
print(masked_self_attention.W_v.value.T)

# Calculate intermediate values
q = encodings_matrix @ masked_self_attention.W_q.value
print("\nQueries (q):")
print(q)

k = encodings_matrix @ masked_self_attention.W_k.value
print("\nKeys (k):")
print(k)

v = encodings_matrix @ masked_self_attention.W_v.value
print("\nValues (v):")
print(v)

sims = q @ k.swapaxes(0, 1)
print("\nSimilarity scores (sims):")
print(sims)

scaled_sims = sims / jnp.sqrt(2)
print("\nScaled similarities:")
print(scaled_sims)

masked_scaled_sims = jnp.where(mask, -1e9, scaled_sims)
print("\nMasked scaled similarities:")
print(masked_scaled_sims)

attention_percents = jax.nn.softmax(masked_scaled_sims, axis=1)
print("\nAttention percentages:")
print(attention_percents)

attention_output = attention_percents @ v
print("\nFinal attention output:")
print(attention_output)


Query weights (W_q) transposed:
[[0.5302608  0.90153027]
 [0.31336212 0.6983329 ]]

Key weights (W_k) transposed:
[[0.5302608  0.90153027]
 [0.31336212 0.6983329 ]]

Value weights (W_v) transposed:
[[0.5302608  0.90153027]
 [0.31336212 0.6983329 ]]

Queries (q):
[[ 0.8224545   0.52411664]
 [ 1.5283298   1.1283492 ]
 [ 0.3911445  -0.12647225]]

Keys (k):
[[ 0.8224545   0.52411664]
 [ 1.5283298   1.1283492 ]
 [ 0.3911445  -0.12647225]]

Values (v):
[[ 0.8224545   0.52411664]
 [ 1.5283298   1.1283492 ]
 [ 0.3911445  -0.12647225]]

Similarity scores (sims):
[[0.9511297  1.8483683  0.25541237]
 [1.8483683  3.608964   0.45509294]
 [0.25541237 0.45509294 0.16898927]]

Scaled similarities:
[[0.67255026 1.3069937  0.18060382]
 [1.3069937  2.551923   0.3217993 ]
 [0.18060382 0.3217993  0.11949346]]

Masked scaled similarities:
[[ 6.7255026e-01 -1.0000000e+09 -1.0000000e+09]
 [ 1.3069937e+00  2.5519230e+00 -1.0000000e+09]
 [ 1.8060382e-01  3.2179931e-01  1.1949346e-01]]

Attention percentages:
[