<a href="https://colab.research.google.com/github/ybw9000/jax_playground/blob/main/jax_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
import jax.nn as nn

In [2]:
devices = jax.devices()

In [3]:
devices

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [44]:
def attention(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray):
    mm0 = jnp.matmul(query, jnp.transpose(key, [0, 1, 3, 2]))
    sm = nn.softmax(mm0 / jnp.sqrt(query.shape[-1]), axis=-1)
    return jnp.matmul(sm, value)

In [8]:
key = jax.random.PRNGKey(42)

In [9]:
x = jax.random.normal(key, [2, 4, 128, 64], jnp.bfloat16)
x = x.to_device(devices[0])

In [10]:
def custom_softmax(x: jnp.ndarray, axis: int):
    max_x = jnp.max(x, axis=axis, keepdims=True)
    exp_x = jnp.exp(x - max_x)
    sum_x = jnp.sum(exp_x, axis=axis, keepdims=True)
    return exp_x / sum_x

In [11]:
with jax.default_device(devices[0]):
    test_input = jax.random.normal(key, [2, 128], jnp.bfloat16)
    test_custom = custom_softmax(x, axis=-1)
    test_jax = nn.softmax(x, axis=-1)
    print(jnp.allclose(test_custom, test_jax))


True


In [106]:
def flashattention(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray, blocksize_q: int, blocksize_kv: int):
    query = query.reshape(query.shape[0], query.shape[1], query.shape[2] // blocksize_q, 1, blocksize_q, query.shape[3])
    key = key.reshape(key.shape[0], key.shape[1], 1, key.shape[2] // blocksize_kv, blocksize_kv, key.shape[3])
    key = key.transpose(0, 1, 2, 3, 5, 4)
    mm0 = jnp.matmul(query, key) / jnp.sqrt(query.shape[-1])  # b, h, num_blocks_q, num_blocks_kv, blocksize_q, block_size_kv
    softmax_max = jnp.max(mm0, axis=-1, keepdims=True)  # b, h, num_blocks_q, num_blocks_kv, blocksize_q, 1
    softmax_exp = jnp.exp(mm0 - softmax_max) # b, h, num_blocks_q, num_blocks_kv, blocksize_q, blocksize_kv
    value = value.reshape(value.shape[0], value.shape[1], 1, value.shape[2] // blocksize_kv, blocksize_kv, value.shape[3])
    mm1_pristine = jnp.matmul(softmax_exp, value)  # b, h, num_blocks_q, num_blocks_kv, blocksize_q, d
    softmax_block_max = jnp.max(softmax_max, axis=-3, keepdims=True)  # b, h, num_blocks_q, 1, blocksize_q, 1
    softmax_max_offset = softmax_max - softmax_block_max # b, h, num_blocks_q, num_blocks_kv, blocksize_q, 1
    softmax_exp_offset = jnp.exp(softmax_max_offset)  # b, h, num_blocks_q, num_blocks_kv, blocksize_q, 1
    mm1_exp_scaled = jnp.multiply(mm1_pristine, softmax_exp_offset)  # b, h, num_blocks_q, num_blocks_kv, blocksize_q, d
    mm1_scaled_sum = jnp.sum(mm1_exp_scaled, axis=-3, keepdims=True)  # b, h, num_blocks_q, 1, blocksize_q, d
    softmax_exp_scaled = jnp.multiply(softmax_exp, softmax_exp_offset)  # b, h, num_blocks_q, num_blocks_kv, blocksize_q, blocksize_kv
    softmax_sum = jnp.sum(softmax_exp_scaled, axis=[-3, -1], keepdims=True)  # b, h, num_blocks_q, 1, blocksize_q, 1
    res = jnp.divide(mm1_scaled_sum, softmax_sum)  # b, h, num_blocks_q, 1, blocksize_q, d
    return res.reshape(res.shape[0], res.shape[1], -1, res.shape[-1])  # b, h, q, d

In [63]:
y = attention(x, x, x)

In [4]:
def dot_product(x, y):
  return jnp.dot(x, y)

In [6]:
gemmv = jax.vmap(dot_product, in_axes=(0, None), out_axes=0)
gemm = jax.vmap(gemmv, in_axes=(None, 1), out_axes=1)

In [10]:
with jax.default_device(devices[0]):
    test_x = jax.random.normal(key, [4, 2], jnp.bfloat16)
    test_y = jax.random.normal(key, [2, 8], jnp.bfloat16)
    test_custom = gemm(test_x, test_y)
    test_jax = jnp.matmul(test_x, test_y)
    print(jnp.allclose(test_custom, test_jax))

True


In [58]:
with jax.default_device(devices[0]):
    test_x = jax.random.normal(key, [4, 2], jnp.bfloat16)
    test_y = jax.random.normal(key, [2, 8], jnp.bfloat16)
    test_custom = gemm(test_x, test_y)
    jaxpr = jax.make_jaxpr(gemm)(test_x, test_y)
    print(jaxpr)

{ lambda ; a:bf16[4,2] b:bf16[2,8]. let
    c:bf16[4,8] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=bfloat16
    ] a b
  in (c,) }


In [99]:
from functools import partial

def attention_2d(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray):
    # query: d
    # key, value: sk x d
    key_t = key.transpose()  # sk x d
    mm0 = jnp.matmul(query, key_t)  # sk
    mm0_scale = mm0 / jnp.sqrt(query.shape[-1])  # sk
    mm0_max = jnp.max(mm0_scale, axis=-1, keepdims=True)  # 1
    mm0_exp = jnp.exp(mm0_scale - mm0_max)  # sk
    mm1 = jnp.matmul(mm0_exp, value)  # d
    return mm1, mm0_max, mm0_exp

def attention_kv_tiling(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray):
    # query: d
    # key, value: num_sk x sk_size x d
    attention_out, attention_max, attention_exp = jax.vmap(attention_2d, in_axes=(None, 0, 0))(query, key, value)  # num_sk, d; num_sk, 1; num_sk, sk_size
    attention_max_global = attention_max.max(axis=0, keepdims=True)  # 1, 1
    attention_max_offset = attention_max - attention_max_global  # num_sk, 1
    attention_exp_offset = jnp.exp(attention_max_offset)  # num_sk, 1
    attention_out_scaled = jnp.multiply(attention_out, attention_exp_offset)  # num_sk, d
    attention_out_sum = jnp.sum(attention_out_scaled, axis=0, keepdims=False)  # d
    attention_exp_scaled = jnp.multiply(attention_exp, attention_exp_offset)  # num_sk, sk_size
    attention_sm_sum = jnp.sum(attention_exp_scaled, keepdims=False)  # 1
    return attention_out_sum / attention_sm_sum

def attention_kv_tiling(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray):
    # query: d
    # key, value: num_sk x sk_size x d
    attention_out, attention_max, attention_exp = jax.vmap(attention_2d, in_axes=(None, 0, 0))(query, key, value)  # num_sk, d; num_sk, 1; num_sk, sk_size
    attention_max_global = attention_max.max(axis=0, keepdims=True)  # 1, 1
    attention_max_offset = attention_max - attention_max_global  # num_sk, 1
    attention_exp_offset = jnp.exp(attention_max_offset)  # num_sk, 1
    attention_out_scaled = jnp.multiply(attention_out, attention_exp_offset)  # num_sk, d
    attention_out_sum = jnp.sum(attention_out_scaled, axis=0, keepdims=False)  # d
    attention_exp_scaled = jnp.multiply(attention_exp, attention_exp_offset)  # num_sk, sk_size
    attention_sm_sum = jnp.sum(attention_exp_scaled, keepdims=False)  # 1
    return attention_out_sum / attention_sm_sum


def attention_kv_looping(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray):
    # query: d
    # key, value: num_sk x sk_size x d
    attention_accum, attention_max, attention_exp = attention_2d(query, key[0], value[0])  # d; 1; sk_size
    for i in range(1, len(key)):
        attention_out, attention_max_i, attention_exp_i = attention_2d(query, key[i], value[i])  # d; 1; sk_size
        attention_max_offset = attention_max - attention_max_i # 1
        attenion_exp_offset = jnp.exp(attention_max_offset) # 1
        attention_accum *= attenion_exp_offset # d
        attention_exp *= attenion_exp_offset # sk_size
        attention_max = jnp.maximum(attention_max, attention_max_i) # 1
        attention_exp_i_offset = jnp.exp(attention_max_i - attention_max) # 1
        attention_accum += attention_out * attention_exp_i_offset # d
        attention_exp += attention_exp_i * attention_exp_i_offset # sk_size
    attention_accum /= attention_exp.sum()  # d
    return attention_accum


def attention_kv_looping_stable(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray):
    # query: d
    # key, value: num_sk x sk_size x d
    attention_accum, attention_max, attention_exp = attention_2d(query, key[0], value[0])  # d; 1; sk_size
    attention_exp_sum = attention_exp.sum()  # 1
    attention_accum /= attention_exp_sum  # d
    for i in range(1, len(key)):
        attention_out, attention_max_i, attention_exp_i = attention_2d(query, key[i], value[i])  # d; 1; sk_size
        attention_max_offset = attention_max - attention_max_i # 1
        attenion_exp_offset = jnp.exp(attention_max_offset) # 1
        attention_max = jnp.maximum(attention_max, attention_max_i) # 1
        attention_exp_i_offset = jnp.exp(attention_max_i - attention_max) # 1
        attention_accum = attention_accum * attention_exp_sum * attenion_exp_offset + attention_out * attention_exp_i_offset # d
        attention_exp_sum = attention_exp_sum * attenion_exp_offset + attention_exp_i.sum() * attention_exp_i_offset # 1
        attention_accum /= attention_exp_sum # d
    return attention_accum


def attention_q_tiling(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray):
    # query: sq_size x d
    # key, value: num_sk x sk_size x d
    return jax.vmap(attention_kv_looping_stable, in_axes=(0, None, None))(query, key, value)  # sq_size, d

def attention_q_blocking(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray):
    # query: num_sq x sq_size x d
    # key, value: num_sk x sk_size x d
    return jax.vmap(attention_q_tiling, in_axes=(0, None, None))(query, key, value).reshape(-1, query.shape[-1])  # num_sq x sq_size, d

def fa_(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray, blocksize_q: int, blocksize_kv: int):
    # query: sq x d
    # key, value: sk x d
    query = query.reshape(query.shape[0]// blocksize_q, blocksize_q, query.shape[1])
    key = key.reshape(key.shape[0] // blocksize_kv, blocksize_kv, key.shape[1])
    value = value.reshape(value.shape[0] // blocksize_kv, blocksize_kv, value.shape[1])
    return attention_q_blocking(query, key, value)

@partial(jax.jit, static_argnames=("blocksize_q", "blocksize_kv"))
def fa_full(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray, blocksize_q: int, blocksize_kv: int):
    # query: b x h x sq x d
    # key, value: b x h x sk x d
    fa_partial = partial(fa_, blocksize_q=blocksize_q, blocksize_kv=blocksize_kv)
    return jax.vmap(jax.vmap(fa_partial))(query, key, value)

In [15]:
def attention_h(query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray):
    mm0 = jnp.matmul(query, jnp.transpose(key, [0, 2, 1]))
    sm = nn.softmax(mm0 / jnp.sqrt(query.shape[-1]), axis=-1)
    return jnp.matmul(sm, value)

In [42]:
x = jax.random.normal(key, [2, 2, 8, 4], jnp.bfloat16)
x = x.to_device(devices[0])

In [46]:
at_full = attention(x, x, x)

In [100]:
at_fa = fa_full(x, x, x, 4, 4)
at_fa

Array([[[[0.664062, -1.65625, 0.490234, -2.03125],
         [0.269531, 0.00402832, 0.667969, -0.0100708],
         [1.15625, -2.23438, 0.699219, -2.09375],
         [-0.878906, 1.54688, -2.89062, 0.953125],
         [-0.143555, 0.59375, -0.898438, 0.209961],
         [0.163086, -0.189453, -0.546875, -0.0385742],
         [-1.23438, -1.50781, 0.121094, 1.5625],
         [-0.355469, -0.267578, 1, 1.11719]],

        [[-0.699219, 0.18457, 1.21094, 1.5],
         [-0.554688, -0.137695, 1.42969, 0.933594],
         [-0.972656, 0.279297, 2.40625, 1.84375],
         [0.480469, 1.34375, -1.04688, -1.25],
         [1.05469, -2.04688, -0.00469971, -0.259766],
         [-0.0732422, -0.648438, 0.0844727, -0.15918],
         [-0.289062, 0.210938, 0.902344, 1.42188],
         [-1.5625, 0.240234, 0.316406, -0.554688]]],


       [[[-0.324219, -0.347656, 0.566406, 0.710938],
         [-0.800781, -0.178711, 0.664062, 0.582031],
         [-0.19043, -1.08594, 0.746094, 0.570312],
         [-1.38281, -0.0

In [95]:
at_fa

Array([[[[0.664062, -1.65625, 0.492188, -2.03125],
         [0.267578, 0.00390625, 0.664062, -0.0117798],
         [1.16406, -2.23438, 0.699219, -2.10938],
         [-0.878906, 1.54688, -2.89062, 0.953125],
         [-0.144531, 0.59375, -0.894531, 0.209961],
         [0.163086, -0.189453, -0.546875, -0.0378418],
         [-1.23438, -1.50781, 0.121094, 1.5625],
         [-0.355469, -0.267578, 1, 1.11719]],

        [[-0.699219, 0.18457, 1.20312, 1.5],
         [-0.554688, -0.137695, 1.42188, 0.933594],
         [-0.976562, 0.279297, 2.40625, 1.85156],
         [0.480469, 1.34375, -1.04688, -1.25],
         [1.05469, -2.04688, -0.0045166, -0.259766],
         [-0.0732422, -0.648438, 0.0844727, -0.15918],
         [-0.289062, 0.210938, 0.90625, 1.42188],
         [-1.5625, 0.240234, 0.316406, -0.554688]]],


       [[[-0.322266, -0.349609, 0.566406, 0.710938],
         [-0.800781, -0.178711, 0.664062, 0.582031],
         [-0.19043, -1.07812, 0.75, 0.570312],
         [-1.38281, -0.0825195

In [83]:
at_full

Array([[[[0.632812, -1.57812, 0.451172, -1.89062],
         [0.236328, 0.012207, 0.660156, 0.0751953],
         [1.16406, -2.23438, 0.695312, -2.09375],
         [-0.863281, 1.52344, -2.84375, 0.941406],
         [-0.0869141, 0.527344, -0.753906, 0.24707],
         [0.162109, -0.189453, -0.546875, -0.0383301],
         [-1.23438, -1.50781, 0.120605, 1.5625],
         [-0.353516, -0.267578, 1, 1.10938]],

        [[-0.589844, 0.175781, 1.04688, 1.39844],
         [-0.435547, -0.296875, 1.10156, 0.695312],
         [-0.949219, 0.28125, 2.34375, 1.82031],
         [0.308594, 1.11719, -0.902344, -1.11719],
         [1.05469, -2.04688, -0.00430298, -0.259766],
         [-0.0722656, -0.648438, 0.0844727, -0.160156],
         [-0.289062, 0.210938, 0.902344, 1.42188],
         [-1.5625, 0.240234, 0.316406, -0.554688]]],


       [[[-0.304688, -0.34375, 0.566406, 0.71875],
         [-0.628906, -0.0839844, 0.679688, 0.636719],
         [-0.0576172, -1.04688, 0.664062, 0.671875],
         [-1.109

In [84]:
jax_pr = jax.make_jaxpr(fa_full, static_argnums=[3, 4])(x, x, x, 4, 4)

In [85]:
jax_pr

{ lambda ; a:bf16[2,2,8,4] b:bf16[2,2,8,4] c:bf16[2,2,8,4]. let
    d:bf16[2,2,2,4,4] = reshape[dimensions=None new_sizes=(2, 2, 2, 4, 4)] a
    e:bf16[2,2,2,4,4] = reshape[dimensions=None new_sizes=(2, 2, 2, 4, 4)] b
    f:bf16[2,2,2,4,4] = reshape[dimensions=None new_sizes=(2, 2, 2, 4, 4)] c
    g:bf16[2,2,1,4,4] = slice[
      limit_indices=(2, 2, 1, 4, 4)
      start_indices=(0, 0, 0, 0, 0)
      strides=None
    ] e
    h:bf16[2,2,4,4] = squeeze[dimensions=(2,)] g
    i:bf16[2,2,1,4,4] = slice[
      limit_indices=(2, 2, 1, 4, 4)
      start_indices=(0, 0, 0, 0, 0)
      strides=None
    ] f
    j:bf16[2,2,4,4] = squeeze[dimensions=(2,)] i
    k:bf16[2,2,4,4] = transpose[permutation=(0, 1, 3, 2)] h
    l:bf16[2,2,2,4,4] = dot_general[
      dimension_numbers=(([4], [2]), ([0, 1], [0, 1]))
      preferred_element_type=bfloat16
    ] d k
    m:f32[] = sqrt 4.0
    n:bf16[] = convert_element_type[new_dtype=bfloat16 weak_type=False] m
    o:bf16[2,2,2,4,4] = div l n
    p:bf16[2,2,2,4