In [2]:
import jax
from jax.experimental.pallas.ops.tpu import flash_attention as pallas_attention

BATCH = 1
HEADS = 1
SEQUENCE = 2048
HEAD_DIM = 128

Q = jax.random.normal( jax.random.key(0), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
K = jax.random.normal( jax.random.key(1), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
V = jax.random.normal( jax.random.key(2), (BATCH, SEQUENCE, HEADS, HEAD_DIM))

def attention_ourselves(_Q, _K, _V):
    _weights_unnormalized = jax.numpy.einsum("BSHD,BTHD->BHST", _Q, _K)
    _weights_unnormalized_to_zero_out = jax.numpy.triu( jax.numpy.ones((SEQUENCE,SEQUENCE), jax.numpy.bfloat16), 1)
    _weights = jax.nn.softmax(_weights_unnormalized - 1e6 * _weights_unnormalized_to_zero_out)  ### Creating something of size (B,HEADS, SEQUENCE, SEQUENCE)
    #print(f"{_weights.size=}")
    output = jax.numpy.einsum("BHST,BTHD->BSHD", _weights, _V)

    return output

attn_ourselves_value = attention_ourselves(Q,K,V)
dropout_rate = 0.1  # Replace with an appropriate value
attn_value = pallas_attention.mha_reference(Q, K, V, ab=dropout_rate, segment_ids=None, causal=True)
# attn_value = pallas_attention.mha_reference(Q, K, V,  segment_ids=None, causal=True)

assert jax.numpy.allclose(attn_ourselves_value, attn_value, atol=1e-1, rtol=1e-1)

AssertionError: 

In [5]:
import jax
from jax.experimental.pallas.ops.tpu import flash_attention as pallas_attention

BATCH = 1
HEADS = 1
SEQUENCE = 2048
HEAD_DIM = 128

Q = jax.random.normal(jax.random.key(0), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
K = jax.random.normal(jax.random.key(1), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
V = jax.random.normal(jax.random.key(2), (BATCH, SEQUENCE, HEADS, HEAD_DIM))

def attention_ourselves(_Q, _K, _V):
    _weights_unnormalized = jax.numpy.einsum("BSHD,BTHD->BHST", _Q, _K)
    _weights_unnormalized_to_zero_out = jax.numpy.triu(jax.numpy.ones((SEQUENCE, SEQUENCE), jax.numpy.bfloat16), 1)
    _weights = jax.nn.softmax(_weights_unnormalized - 1e6 * _weights_unnormalized_to_zero_out)  ### Creating something of size (B,HEADS, SEQUENCE, SEQUENCE)
    # print(f"{_weights.size=}")
    output = jax.numpy.einsum("BHST,BTHD->BSHD", _weights, _V)
    return output

attn_ourselves_value = attention_ourselves(Q, K, V)
attn_value = pallas_attention.mha_reference(Q, K, V, ab=None, segment_ids=None, causal=True)
# attn_value = pallas_attention.mha_reference(Q, K, V, segment_ids=None, causal=True)
print(f"{attention_ourselves=}")
print(f"{attn_value=}")

print(1e-1)
# Relax the tolerances to allow for potential numerical differences
# or implementation variations between the two attention functions.
assert jax.numpy.allclose(attn_ourselves_value, attn_value, atol=0.1, rtol=0.1), f"Arrays are not close: \n{jax.numpy.max(jax.numpy.abs(attn_ourselves_value - attn_value))}"

attention_ourselves=<function attention_ourselves at 0x78aafc6334c0>
attn_value=Array([[[[ 0.36057308,  1.2849717 , -0.7387236 , ...,  0.9781729 ,
           2.2189548 ,  0.38818341]],

        [[-1.5729874 , -0.76189977,  0.6108661 , ...,  0.42959312,
          -0.21007903,  2.0072043 ]],

        [[-0.7875979 , -0.11790697,  0.38122553, ..., -0.15842435,
           0.4373327 , -0.92844707]],

        ...,

        [[ 0.39558995,  0.74515975,  0.20654577, ...,  2.295514  ,
           0.30290228, -0.48013103]],

        [[-0.91627604, -0.6682399 , -1.6973673 , ...,  0.8714954 ,
          -1.0584216 , -0.708561  ]],

        [[ 0.60500056,  0.08252969,  1.6498895 , ..., -0.04127837,
          -0.15703496, -0.72441936]]]], dtype=float32)
0.1


AssertionError: Arrays are not close: 
6.370874404907227

## Fix the non causal by using correct einsum

In [7]:
import jax
from jax.experimental.pallas.ops.tpu import flash_attention as pallas_attention

BATCH = 1
HEADS = 1
SEQUENCE = 2048
HEAD_DIM = 128

Q = jax.random.normal(jax.random.key(0), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
K = jax.random.normal(jax.random.key(1), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
V = jax.random.normal(jax.random.key(2), (BATCH, SEQUENCE, HEADS, HEAD_DIM))

def attention_ourselves(_Q, _K, _V):
    # Batch Sequence Heads HeadsDimension turns into
    # Batch Heads Sequence T=OutputSequence
    _weights_unnormalized = jax.numpy.einsum("BSHD,BTHD->BHST", _Q, _K)
    _weights = jax.nn.softmax(_weights_unnormalized)  ### Creating something of size (B,HEADS, SEQUENCE, SEQUENCE)
    print(f"{_weights.size=}")
    ## This was wrong in original code, code in github was "BHST,BTHD->BSHD"
    # Question: Why was that wrong?
    output = jax.numpy.einsum("BHST,BSHD->BSHD", _weights, _V)
    return output

attn_ourselves_value = attention_ourselves(Q, K, V)
attn_value = pallas_attention.mha_reference(Q, K, V, ab=None, segment_ids=None, causal=False)
# attn_value = pallas_attention.mha_reference(Q, K, V, segment_ids=None, causal=True)
print(f"{attention_ourselves=}")
print(f"{attn_value=}")

print(1e-1)
# Relax the tolerances to allow for potential numerical differences
# or implementation variations between the two attention functions.
assert jax.numpy.allclose(attn_ourselves_value, attn_value, atol=0.1, rtol=0.1), f"Arrays are not close: \n{jax.numpy.max(jax.numpy.abs(attn_ourselves_value - attn_value))}"

_weights.size=4194304
attention_ourselves=<function attention_ourselves at 0x78a8c6bece00>
attn_value=Array([[[[ 0.36057308,  1.2849717 , -0.7387236 , ...,  0.9781729 ,
           2.2189548 ,  0.38818341]],

        [[-1.5729874 , -0.76189977,  0.6108661 , ...,  0.42959312,
          -0.21007903,  2.0072043 ]],

        [[-0.7875979 , -0.11790697,  0.38122553, ..., -0.15842435,
           0.4373327 , -0.92844707]],

        ...,

        [[ 0.39558995,  0.74515975,  0.20654577, ...,  2.295514  ,
           0.30290228, -0.48013103]],

        [[-0.91627604, -0.6682399 , -1.6973673 , ...,  0.8714954 ,
          -1.0584216 , -0.708561  ]],

        [[ 0.60500056,  0.08252969,  1.6498895 , ..., -0.04127837,
          -0.15703496, -0.72441936]]]], dtype=float32)
0.1


## Fix causal

In [2]:
import jax
from jax.experimental.pallas.ops.tpu import flash_attention as pallas_attention

BATCH = 1
HEADS = 1
SEQUENCE = 2048
HEAD_DIM = 128

Q = jax.random.normal(jax.random.key(0), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
K = jax.random.normal(jax.random.key(1), (BATCH, SEQUENCE, HEADS, HEAD_DIM))
V = jax.random.normal(jax.random.key(2), (BATCH, SEQUENCE, HEADS, HEAD_DIM))

def attention_ourselves(_Q, _K, _V):
    _weights_unnormalized = jax.numpy.einsum("BSHD,BTHD->BHST", _Q, _K)
    _weights_unnormalized_to_zero_out = jax.numpy.triu( jax.numpy.ones((SEQUENCE,SEQUENCE), jax.numpy.bfloat16), 1)
    _weights = jax.nn.softmax(_weights_unnormalized - 1e6 * _weights_unnormalized_to_zero_out)  ### Creating something of size (B,HEADS, SEQUENCE, SEQUENCE)
    # weight is Seq x Seq
    print(f"{_weights.size=}")
    ## This was wrong in original code, code in github was "BHST,BTHD->BSHD"
    # Question: Why was that wrong?
    output = jax.numpy.einsum("BHST,BSHD->BSHD", _weights, _V)
    return output

attn_ourselves_value = attention_ourselves(Q, K, V)
attn_value = pallas_attention.mha_reference(Q, K, V, ab=None, segment_ids=None, causal=True)
# attn_value = pallas_attention.mha_reference(Q, K, V, segment_ids=None, causal=True)
print(f"{attention_ourselves=}")
print(f"{attn_value=}")

print(1e-1)
# Relax the tolerances to allow for potential numerical differences
# or implementation variations between the two attention functions.
assert jax.numpy.allclose(attn_ourselves_value, attn_value, atol=0.1, rtol=0.1), f"Arrays are not close: \n{jax.numpy.max(jax.numpy.abs(attn_ourselves_value - attn_value))}"

_weights.size=4194304
attention_ourselves=<function attention_ourselves at 0x7c3028082a20>
attn_value=Array([[[[ 0.36057308,  1.2849717 , -0.7387236 , ...,  0.9781729 ,
           2.2189548 ,  0.38818341]],

        [[-1.5729874 , -0.76189977,  0.6108661 , ...,  0.42959312,
          -0.21007903,  2.0072043 ]],

        [[-0.7875979 , -0.11790697,  0.38122553, ..., -0.15842435,
           0.4373327 , -0.92844707]],

        ...,

        [[ 0.39558995,  0.74515975,  0.20654577, ...,  2.295514  ,
           0.30290228, -0.48013103]],

        [[-0.91627604, -0.6682399 , -1.6973673 , ...,  0.8714954 ,
          -1.0584216 , -0.708561  ]],

        [[ 0.60500056,  0.08252969,  1.6498895 , ..., -0.04127837,
          -0.15703496, -0.72441936]]]], dtype=float32)
0.1
