Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn as nn

import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
from torch_xla._internal import tpu

Expand Down Expand Up @@ -283,6 +284,158 @@ def add_minus_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
self.assertTrue(torch.allclose(o[0].cpu(), expected_o0.cpu()))
self.assertTrue(torch.allclose(o[1].cpu(), expected_o1.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test__flash_attention_impl(self):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
MIN_BLOCK_SIZE = 128

def shape_dtype(q, *arg):
res_shape = list(q.shape)
res_shape[-1] = MIN_BLOCK_SIZE
return [(q.shape, q.dtype), (res_shape, torch.float32),
(res_shape, torch.float32)]

flash_attention_kernel = make_kernel_from_pallas(_flash_attention_impl,
shape_dtype)

q = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")
k = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")
v = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")

o, l, m = flash_attention_kernel(
q,
k,
v,
None,
None,
True,
False,
1.0,
2,
128,
128,
128,
False,
static_argnums=range(5, 13))
xm.mark_step()

# TODO: I don't really know how to test the value. Let's do the shape check for now.
self.assertEqual(l.shape, (3, 2, 128, MIN_BLOCK_SIZE))
self.assertEqual(l.dtype, torch.float32)
self.assertEqual(m.shape, (3, 2, 128, MIN_BLOCK_SIZE))
self.assertEqual(m.dtype, torch.float32)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test__flash_attention_bwd_dkv(self):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dkv
from torch_xla.experimental.custom_kernel import trace_pallas
MIN_BLOCK_SIZE = 128
DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
l = torch.randn(3, 2, 128).to("xla")
m = torch.randn(3, 2, 128).to("xla")
grad_i = torch.randn(3, 2, 128, dtype=torch.float32).to("xla")
grad_o = torch.randn(3, 2, 128, 4).to("xla")

payload, _ = trace_pallas(
_flash_attention_bwd_dkv,
q,
k,
v,
None,
None,
l,
m,
grad_o,
grad_i,
block_q_major=128,
block_k_major=128,
block_k=128,
block_q=128,
sm_scale=1.0,
causal=False,
mask_value=DEFAULT_MASK_VALUE,
debug=False,
static_argnames=[
"block_q_major", "block_k_major", "block_k", "block_q", "sm_scale",
"causal", "mask_value", "debug"
])

# TODO: Because of the following reshapes, we can't use make_kernel_from_pallas directly.
l = l.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
m = m.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
grad_i = grad_i.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
grad_k = torch.randn(3, 2, 128, 4).to("xla")
grad_v = torch.randn(3, 2, 128, 4).to("xla")
torch_xla._XLAC._xla_tpu_custom_call_([grad_k, grad_v],
[q, k, v, l, m, grad_o, grad_i],
payload)

xm.mark_step()

# TODO: I don't really know how to test the value. Let's do the shape check for now.
self.assertEqual(grad_k.shape, (3, 2, 128, 4))
self.assertEqual(grad_v.shape, (3, 2, 128, 4))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test__flash_attention_bwd_dkv(self):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq
from torch_xla.experimental.custom_kernel import trace_pallas
MIN_BLOCK_SIZE = 128
DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
l = torch.randn(3, 2, 128).to("xla")
m = torch.randn(3, 2, 128).to("xla")
grad_i = torch.randn(3, 2, 128, dtype=torch.float32).to("xla")
grad_o = torch.randn(3, 2, 128, 4).to("xla")

payload, _ = trace_pallas(
_flash_attention_bwd_dq,
q,
k,
v,
None,
None,
l,
m,
grad_o,
grad_i,
block_q_major=128,
block_k_major=128,
block_k=128,
sm_scale=1.0,
causal=False,
mask_value=DEFAULT_MASK_VALUE,
debug=False,
static_argnames=[
"block_q_major", "block_k_major", "block_k", "sm_scale", "causal",
"mask_value", "debug"
])

# TODO: Because of the following reshapes, we can't use make_kernel_from_pallas directly.
l = l.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
m = m.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
grad_i = grad_i.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
grad_q = torch.randn(3, 2, 128, 4).to("xla")
torch_xla._XLAC._xla_tpu_custom_call_([grad_q],
[q, k, v, l, m, grad_o, grad_i],
payload)

xm.mark_step()

# TODO: I don't really know how to test the value. Let's do the shape check for now.
self.assertEqual(grad_q.shape, (3, 2, 128, 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we do the fwd and do res.backward then check the grad on q they should match?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The softmax is done differently. I don't think there is any guarantees.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result should still be somewhat close right? we can tune down the precision. If the result return by this is dramatically different than the one that was computed using dot attention that seems wrong..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html

Softmax requires all the elements to produce the results, but flash attention chunks the data into blocks and use a technique called tiling to make sure the softmax still serve the functionality to stable the data. Since there are no aggregation, I don't know how the tiling softmax could produce the same results as the regular one.

In JAX, I have to use atol=1e-01, rtol=1e-01 to do the comparisons...



if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
56 changes: 36 additions & 20 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def jax_import_guard():
torch_xla._XLAC._init_computation_client()


def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
def trace_pallas(kernel: Callable,
*args,
static_argnums=None,
static_argnames=None,
**kwargs):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
Expand Down Expand Up @@ -102,37 +106,49 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
else:
raise ValueError(f"Unsupported dtype: {dtype}")

jax_args = [] # for tracing
tensor_args = [] # for execution
for i, arg in enumerate(args):
# TODO: Could the args be a tuple of tensors or a list of tensors? Flattern them?
if torch.is_tensor(arg):
# ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload.
jax_meta_tensor = jax.ShapeDtypeStruct(
arg.shape, convert_torch_dtype_to_jax(arg.dtype))
jax_args.append(jax_meta_tensor)
tensor_args.append(arg)
else:
jax_args.append(arg)

# Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
ir = jax.jit(
kernel, static_argnums=static_argnums,
static_argnames=static_argnames).lower(*jax_args, **kwargs).compiler_ir()
payload = _extract_backend_config(ir)
return payload, tensor_args


def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
# TODO: Maybe we can cache the payload for the same input.
def wrapped_kernel(kernel: Callable,
output_shape_dtype_fn: Callable,
*args,
static_argnames=[],
static_argnums=None,
static_argnames=None,
**kwargs) -> Callable:
jax_args = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
# ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload.
jax_meta_tensor = jax.ShapeDtypeStruct(
arg.shape, convert_torch_dtype_to_jax(arg.dtype))
jax_args.append(jax_meta_tensor)
else:
# TODO: We can support more types here.
assert False, f"Unsupported argument type: {type(arg)}"

# Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
ir = jax.jit(
kernel, static_argnames=static_argnames).lower(*jax_args,
**kwargs).compiler_ir()
payload = _extract_backend_config(ir)
# TODO: We can consider supporting un-array output.
payload, tensor_args = trace_pallas(
kernel,
*args,
static_argnums=static_argnums,
static_argnames=static_argnames,
**kwargs)
outputs = []
output_shape_dtype = output_shape_dtype_fn(*args)
assert isinstance(output_shape_dtype,
list), "The output_shape_dtype_fn should return a list."
for output_shape, output_dtype in output_shape_dtype:
outputs.append(
torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device()))
torch_xla._XLAC._xla_tpu_custom_call_(outputs, args, payload)
torch_xla._XLAC._xla_tpu_custom_call_(outputs, tensor_args, payload)

# Make the output easier to use.
if len(outputs) == 1:
Expand Down