Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@

class PallasTest(unittest.TestCase):

def _attention(self, q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_add(self):
# This payload is generated by the following Pallas code:
Expand Down Expand Up @@ -80,13 +86,7 @@ def test_tpu_custom_call_pallas_flash_attention(self):
v = torch.ones(3, 2, 128, 4).to("xla")
o = torch.zeros(3, 2, 128, 4).to("xla")

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

expected_o = attention(q, k, v)
expected_o = self._attention(q, k, v)

torch_xla._XLAC._xla_tpu_custom_call_([o], [q, k, v], payload)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
Expand Down Expand Up @@ -182,12 +182,6 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self):
flash_attention_kernel = make_kernel_from_pallas(
flash_attention, lambda q, k, v: [(q.shape, q.dtype)])

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
k_mini = torch.arange(
1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
Expand All @@ -196,7 +190,7 @@ def attention(q, k, v):
v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")

o = flash_attention_kernel(q, k, v)
expected_o = attention(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
Expand All @@ -205,18 +199,12 @@ def test_flash_attention_wrapper(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

o = flash_attention(q, k, v)
expected_o = attention(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

Expand All @@ -226,12 +214,6 @@ def test_flash_attention_wrapper_with_dynamo(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

def flash_attention_wrapper(q, k, v, causal=False):
return torch.ops.xla.flash_attention(q, k, v, causal)

Expand All @@ -243,7 +225,7 @@ def flash_attention_wrapper(q, k, v, causal=False):
flash_attention_wrapper, backend="openxla")
o_no_causal = compiled_flash_attention(q, k, v)
o_with_causal = compiled_flash_attention(q, k, v, causal=True)
expected_o = attention(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o_no_causal.cpu(), expected_o.cpu()))
# The causal mask is turned on by default in the wrapper.
# It masks out the top right triangle of the attention matrix,
Expand All @@ -257,20 +239,14 @@ def test_flash_attention_wrapper_causal(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

# The causal mask is turned on by default in the wrapper.
# It masks out the top right triangle of the attention matrix, therefore it speeds up the compute but also changes the output.
o = flash_attention(q, k, v, causal=True)
expected_o = attention(q, k, v)
expected_o = self._attention(q, k, v)
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

Expand Down Expand Up @@ -467,6 +443,45 @@ def test__flash_attention_bwd_dkv(self):
# TODO: I don't really know how to test the value. Let's do the shape check for now.
self.assertEqual(grad_q.shape, (3, 2, 128, 4))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_backward(self):
from torch_xla.experimental.custom_kernel import flash_attention

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = flash_attention(q, k, v)
loss = o.sum()
loss.backward()
xm.mark_step()

q_grad = q.grad
k_grad = k.grad
v_grad = v.grad

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = self._attention(q, k, v)
loss = o.sum()
loss.backward()
xm.mark_step()

mse = torch.nn.MSELoss()
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(mse(i[0].grad.cpu(), i[1].cpu()) < 1e-4)
Comment on lines +481 to +483
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what this part is checking, do you mind explaining a bit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the gradients are a little bit off, it's hard to use torch.allclose. I'm just trying to use MSE to calculate the difference to see if it's close to zero.



if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
204 changes: 169 additions & 35 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,47 +160,181 @@ def wrapped_kernel(kernel: Callable,
return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)


# This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
# where we only takes q, k, v, segment_ids and causal as input and set block_sizes for the users.
class FlashAttention(torch.autograd.Function):
"""
This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
where we only takes q, k, v and causal as input and set block_sizes for the users.
"""

MIN_BLOCK_SIZE = 128
DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)
# The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240
# It yields much better performance than the default block_sizes.
DEFAULT_BLOCK_SIZES = {
"block_q": 512,
"block_k_major": 512,
"block_k": 512,
"block_b": 2,
"block_q_major_dkv": 512,
"block_k_major_dkv": 512,
"block_q_dkv": 512,
"block_k_dkv": 512,
"block_q_dq": 1024,
"block_k_dq": 256,
"block_k_major_dq": 512,
}

@staticmethod
def forward(ctx, q, k, v, causal=False):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl

ctx.causal = causal
save_residuals = q.requires_grad or k.requires_grad or v.requires_grad

# It returns the shape and type of o, l, m.
def shape_dtype(q, *arg):
if not save_residuals:
return [(q.shape, q.dtype)]
res_shape = list(q.shape)
res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE
return [(q.shape, q.dtype), (res_shape, torch.float32),
(res_shape, torch.float32)]

# We can't directly use flash_attention as we need to override the save_residuals flag which returns
# l and m that is needed for the backward. Then we lose all the shape checks.
# TODO: replicate the shape checks on flash_attention.
_flash_attention_impl = make_kernel_from_pallas(_flash_attention_impl,
shape_dtype)
with torch.no_grad():
o = _flash_attention_impl(
q,
k,
v,
None,
None,
save_residuals,
causal,
1.0,
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]),
False,
static_argnums=range(5, 13))
if not save_residuals:
return o
o, *aux = o
l, m = (v[..., 0] for v in aux[-2:])

ctx.save_for_backward(q, k, v, o, l, m)
return o

@staticmethod
def backward(ctx, grad_output):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv

q, k, v, o, l, m = ctx.saved_tensors
causal = ctx.causal
grad_q = grad_k = grad_v = None

grad_i = torch.sum(
o.to(torch.float32) * grad_output.to(torch.float32),
axis=-1) # [batch_size, num_heads, q_seq_len]

expanded_l = l.unsqueeze(-1).expand([-1 for _ in l.shape] +
[FlashAttention.MIN_BLOCK_SIZE])
expanded_m = m.unsqueeze(-1).expand([-1 for _ in m.shape] +
[FlashAttention.MIN_BLOCK_SIZE])
expanded_grad_i = grad_i.unsqueeze(-1).expand(
[-1 for _ in grad_i.shape] + [FlashAttention.MIN_BLOCK_SIZE])

if ctx.needs_input_grad[0]:
payload, _ = trace_pallas(
_flash_attention_bwd_dq,
q,
k,
v,
None,
None,
l,
m,
grad_output,
grad_i,
block_q_major=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dq"],
q.shape[2]),
block_k_major=min(
FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dq"],
k.shape[2]),
block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"],
k.shape[2]),
sm_scale=1.0,
causal=causal,
mask_value=FlashAttention.DEFAULT_MASK_VALUE,
debug=False,
static_argnames=[
"block_q_major", "block_k_major", "block_k", "sm_scale", "causal",
"mask_value", "debug"
])
grad_q = torch.empty(q.shape, dtype=q.dtype).to(q.device)
torch_xla._XLAC._xla_tpu_custom_call_(
[grad_q],
[q, k, v, expanded_l, expanded_m, grad_output, expanded_grad_i],
payload)

if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
payload, _ = trace_pallas(
_flash_attention_bwd_dkv,
q,
k,
v,
None,
None,
l,
m,
grad_output,
grad_i,
block_q_major=min(
FlashAttention.DEFAULT_BLOCK_SIZES["block_q_major_dkv"],
q.shape[2]),
block_k_major=min(
FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dkv"],
k.shape[2]),
block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dkv"],
k.shape[2]),
block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"],
q.shape[2]),
sm_scale=1.0,
causal=causal,
mask_value=FlashAttention.DEFAULT_MASK_VALUE,
debug=False,
static_argnames=[
"block_q_major", "block_k_major", "block_k", "block_q",
"sm_scale", "causal", "mask_value", "debug"
])
grad_k = torch.empty(k.shape, dtype=k.dtype).to(k.device)
grad_v = torch.empty(v.shape, dtype=v.dtype).to(v.device)
torch_xla._XLAC._xla_tpu_custom_call_(
[grad_k, grad_v],
[q, k, v, expanded_l, expanded_m, grad_output, expanded_grad_i],
payload)
if not ctx.needs_input_grad[1]:
grad_k = None
if not ctx.needs_input_grad[2]:
grad_v = None

return grad_q, grad_k, grad_v, None


def flash_attention(
q, # [batch_size, num_heads, q_seq_len, d_model]
k, # [batch_size, num_heads, kv_seq_len, d_model]
v, # [batch_size, num_heads, kv_seq_len, d_model]
segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len]
causal=False,
):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax
import jax.numpy as jnp
import jax.experimental.pallas.ops.tpu.flash_attention as tpu_flash_attention

# TODO: Support segment_ids.
flash_attention_kernel = make_kernel_from_pallas(
tpu_flash_attention.flash_attention, lambda q, k, v: [(q.shape, q.dtype)])

# The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240
# It yields much better performance than the default block_sizes.
return flash_attention_kernel(
q,
k,
v,
static_argnames=["block_sizes", "causal"],
block_sizes=tpu_flash_attention.BlockSizes(
block_q=min(512, q.shape[2]),
block_k_major=min(512, k.shape[2]),
block_k=min(512, k.shape[2]),
block_b=min(2, q.shape[0]),
block_q_major_dkv=min(512, q.shape[2]),
block_k_major_dkv=min(512, k.shape[2]),
block_q_dkv=min(512, q.shape[2]),
block_k_dkv=min(512, k.shape[2]),
block_q_dq=min(1024, q.shape[2]),
block_k_dq=min(256, k.shape[2]),
block_k_major_dq=min(512, k.shape[2]),
),
causal=causal)
return FlashAttention.apply(q, k, v, causal)


XLA_LIB.define(
Expand Down