In [59]:
import jax
import jax.numpy as jnp
from flax import nnx
from attentions import MultiHeadAttentionRoPE
from flax.nnx import MultiHeadAttention

batch_size, seq_len, in_features = 2, 4, 8
key1, key2, key3 = jax.random.split(jax.random.key(1), 3)
q = jax.random.normal(key1, (batch_size, seq_len, in_features))
k = jax.random.normal(key2, (batch_size, seq_len, in_features))
v = jax.random.normal(key3, (batch_size, seq_len, in_features))

In [60]:
rngs = nnx.Rngs(0)
mha_rope = MultiHeadAttentionRoPE(
    num_heads=4,
    in_features=in_features,
    qkv_features=8,
    out_features=8,
    rngs=rngs,
    use_rope=False,
    decode=False,
)

rngs = nnx.Rngs(0)
mha = MultiHeadAttention(
    num_heads=4,
    in_features=in_features,
    qkv_features=8,
    out_features=8,
    rngs=rngs,
    decode=False,
)


print("=============================MHA ROPE=============================")
out = mha_rope(q, k, v)
print("Output shape:", out.shape)
print("Output sample:", out[0, 0, :5])

print("================================MHA===============================")
out = mha(q, k, v)
print("Output shape:", out.shape)
print("Output sample:", out[0, 0, :5])

Output shape: (2, 4, 8)
Output sample: [ 0.06869456 -0.03120548 -0.16521749 -0.2619766   0.08072615]
Output shape: (2, 4, 8)
Output sample: [ 0.06869456 -0.03120548 -0.16521749 -0.2619766   0.08072615]


In [61]:
mha = MultiHeadAttention(
    num_heads=4,
    in_features=in_features,
    qkv_features=8,
    out_features=8,
    rngs=nnx.Rngs(1),
    decode=False,
)
mha_plain = MultiHeadAttentionRoPE(
    num_heads=4,
    in_features=in_features,
    qkv_features=8,
    out_features=8,
    rngs=nnx.Rngs(1),
    use_rope=False,
    decode=False,
)

out_plain = mha(q, k, v)

print("Same shape:", out.shape == out_plain.shape)
print("Allclose?", jnp.allclose(out, out_plain, atol=1e-5))

out_plain = mha_plain(q, k, v)

print("Same shape:", out.shape == out_plain.shape)
print("Allclose?", jnp.allclose(out, out_plain, atol=1e-5))



Same shape: True
Allclose? False
Same shape: True
Allclose? False


In [62]:
rngs = nnx.Rngs(0)
batch_size, seq_len, in_features = 2, 4, 8
num_heads = 4

q = jax.random.normal(jax.random.key(0), (batch_size, seq_len, in_features))

mha_rope = MultiHeadAttentionRoPE(
    num_heads=num_heads,
    in_features=in_features,
    qkv_features=8,
    out_features=8,
    rngs=nnx.Rngs(1),
    decode=False,
    dropout_rate=0.0,
    use_rope=True,
)

tri = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))
causal_mask = jnp.broadcast_to(tri, (batch_size, num_heads, seq_len, seq_len))

out_full = mha_rope(q, mask=causal_mask, decode=False)

mha_decode = MultiHeadAttentionRoPE(
    num_heads=num_heads,
    in_features=in_features,
    qkv_features=8,
    out_features=8,
    rngs=nnx.Rngs(1),
    decode=True,
    dropout_rate=0.0,
    use_rope=True,
)

mha_decode.init_cache((batch_size, 1, in_features))

outs = []
for t in range(seq_len):
    out_t = mha_decode(q[:, t:t+1, :])
    outs.append(out_t)

out_step = jnp.concatenate(outs, axis=1)

print("Full output shape:", out_full.shape)
print("Step output shape:", out_step.shape)
print("Allclose:", jnp.allclose(out_full, out_step, atol=1e-4))

print("\nFirst token diff:")
print(out_full[0, 0, :5])
print(out_step[0, 0, :5])

Full output shape: (2, 4, 8)
Step output shape: (2, 4, 8)
Allclose: False

First token diff:
[ 0.1624817  1.6478726  0.7238274 -1.3535289  0.8455794]
[ 0.16221178  1.6474111   0.72350943 -1.3529046   0.84493995]
