diff --git a/test/test_pallas.py b/test/test_pallas.py index dcc915a6ba8f..17527b8b7dcf 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -20,6 +20,12 @@ class PallasTest(unittest.TestCase): + def _attention(self, q, k, v): + attn_weight = q @ k.transpose(-2, -1) + attn_weight = nn.functional.softmax(attn_weight, dim=-1) + attn_output = attn_weight @ v + return attn_output + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_tpu_custom_call_pallas_add(self): # This payload is generated by the following Pallas code: @@ -80,13 +86,7 @@ def test_tpu_custom_call_pallas_flash_attention(self): v = torch.ones(3, 2, 128, 4).to("xla") o = torch.zeros(3, 2, 128, 4).to("xla") - def attention(q, k, v): - attn_weight = q @ k.transpose(-2, -1) - attn_weight = nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output - - expected_o = attention(q, k, v) + expected_o = self._attention(q, k, v) torch_xla._XLAC._xla_tpu_custom_call_([o], [q, k, v], payload) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) @@ -182,12 +182,6 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self): flash_attention_kernel = make_kernel_from_pallas( flash_attention, lambda q, k, v: [(q.shape, q.dtype)]) - def attention(q, k, v): - attn_weight = q @ k.transpose(-2, -1) - attn_weight = nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output - q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13 k_mini = torch.arange( 1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13 @@ -196,7 +190,7 @@ def attention(q, k, v): v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla") o = flash_attention_kernel(q, k, v) - expected_o = attention(q, k, v) + expected_o = self._attention(q, k, v) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, @@ -205,18 +199,12 @@ def test_flash_attention_wrapper(self): jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) from torch_xla.experimental.custom_kernel import flash_attention - def attention(q, k, v): - attn_weight = q @ k.transpose(-2, -1) - attn_weight = nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output - q = torch.randn(3, 2, 128, 4).to("xla") k = torch.randn(3, 2, 128, 4).to("xla") v = torch.randn(3, 2, 128, 4).to("xla") o = flash_attention(q, k, v) - expected_o = attention(q, k, v) + expected_o = self._attention(q, k, v) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) @@ -226,12 +214,6 @@ def test_flash_attention_wrapper_with_dynamo(self): jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) from torch_xla.experimental.custom_kernel import flash_attention - def attention(q, k, v): - attn_weight = q @ k.transpose(-2, -1) - attn_weight = nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output - def flash_attention_wrapper(q, k, v, causal=False): return torch.ops.xla.flash_attention(q, k, v, causal) @@ -243,7 +225,7 @@ def flash_attention_wrapper(q, k, v, causal=False): flash_attention_wrapper, backend="openxla") o_no_causal = compiled_flash_attention(q, k, v) o_with_causal = compiled_flash_attention(q, k, v, causal=True) - expected_o = attention(q, k, v) + expected_o = self._attention(q, k, v) self.assertTrue(torch.allclose(o_no_causal.cpu(), expected_o.cpu())) # The causal mask is turned on by default in the wrapper. # It masks out the top right triangle of the attention matrix, @@ -257,12 +239,6 @@ def test_flash_attention_wrapper_causal(self): jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) from torch_xla.experimental.custom_kernel import flash_attention - def attention(q, k, v): - attn_weight = q @ k.transpose(-2, -1) - attn_weight = nn.functional.softmax(attn_weight, dim=-1) - attn_output = attn_weight @ v - return attn_output - q = torch.randn(3, 2, 128, 4).to("xla") k = torch.randn(3, 2, 128, 4).to("xla") v = torch.randn(3, 2, 128, 4).to("xla") @@ -270,7 +246,7 @@ def attention(q, k, v): # The causal mask is turned on by default in the wrapper. # It masks out the top right triangle of the attention matrix, therefore it speeds up the compute but also changes the output. o = flash_attention(q, k, v, causal=True) - expected_o = attention(q, k, v) + expected_o = self._attention(q, k, v) self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu())) jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) @@ -467,6 +443,45 @@ def test__flash_attention_bwd_dkv(self): # TODO: I don't really know how to test the value. Let's do the shape check for now. self.assertEqual(grad_q.shape, (3, 2, 128, 4)) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_backward(self): + from torch_xla.experimental.custom_kernel import flash_attention + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention(q, k, v) + loss = o.sum() + loss.backward() + xm.mark_step() + + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention(q, k, v) + loss = o.sum() + loss.backward() + xm.mark_step() + + mse = torch.nn.MSELoss() + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(mse(i[0].grad.cpu(), i[1].cpu()) < 1e-4) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index ef59e7ee8e0b..d91369d11708 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -160,47 +160,181 @@ def wrapped_kernel(kernel: Callable, return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn) -# This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139 -# where we only takes q, k, v, segment_ids and causal as input and set block_sizes for the users. +class FlashAttention(torch.autograd.Function): + """ + This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139 + where we only takes q, k, v and causal as input and set block_sizes for the users. + """ + + MIN_BLOCK_SIZE = 128 + DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max) + # The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240 + # It yields much better performance than the default block_sizes. + DEFAULT_BLOCK_SIZES = { + "block_q": 512, + "block_k_major": 512, + "block_k": 512, + "block_b": 2, + "block_q_major_dkv": 512, + "block_k_major_dkv": 512, + "block_q_dkv": 512, + "block_k_dkv": 512, + "block_q_dq": 1024, + "block_k_dq": 256, + "block_k_major_dq": 512, + } + + @staticmethod + def forward(ctx, q, k, v, causal=False): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl + + ctx.causal = causal + save_residuals = q.requires_grad or k.requires_grad or v.requires_grad + + # It returns the shape and type of o, l, m. + def shape_dtype(q, *arg): + if not save_residuals: + return [(q.shape, q.dtype)] + res_shape = list(q.shape) + res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE + return [(q.shape, q.dtype), (res_shape, torch.float32), + (res_shape, torch.float32)] + + # We can't directly use flash_attention as we need to override the save_residuals flag which returns + # l and m that is needed for the backward. Then we lose all the shape checks. + # TODO: replicate the shape checks on flash_attention. + _flash_attention_impl = make_kernel_from_pallas(_flash_attention_impl, + shape_dtype) + with torch.no_grad(): + o = _flash_attention_impl( + q, + k, + v, + None, + None, + save_residuals, + causal, + 1.0, + min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]), + min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]), + min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]), + min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]), + False, + static_argnums=range(5, 13)) + if not save_residuals: + return o + o, *aux = o + l, m = (v[..., 0] for v in aux[-2:]) + + ctx.save_for_backward(q, k, v, o, l, m) + return o + + @staticmethod + def backward(ctx, grad_output): + from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv + + q, k, v, o, l, m = ctx.saved_tensors + causal = ctx.causal + grad_q = grad_k = grad_v = None + + grad_i = torch.sum( + o.to(torch.float32) * grad_output.to(torch.float32), + axis=-1) # [batch_size, num_heads, q_seq_len] + + expanded_l = l.unsqueeze(-1).expand([-1 for _ in l.shape] + + [FlashAttention.MIN_BLOCK_SIZE]) + expanded_m = m.unsqueeze(-1).expand([-1 for _ in m.shape] + + [FlashAttention.MIN_BLOCK_SIZE]) + expanded_grad_i = grad_i.unsqueeze(-1).expand( + [-1 for _ in grad_i.shape] + [FlashAttention.MIN_BLOCK_SIZE]) + + if ctx.needs_input_grad[0]: + payload, _ = trace_pallas( + _flash_attention_bwd_dq, + q, + k, + v, + None, + None, + l, + m, + grad_output, + grad_i, + block_q_major=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dq"], + q.shape[2]), + block_k_major=min( + FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dq"], + k.shape[2]), + block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"], + k.shape[2]), + sm_scale=1.0, + causal=causal, + mask_value=FlashAttention.DEFAULT_MASK_VALUE, + debug=False, + static_argnames=[ + "block_q_major", "block_k_major", "block_k", "sm_scale", "causal", + "mask_value", "debug" + ]) + grad_q = torch.empty(q.shape, dtype=q.dtype).to(q.device) + torch_xla._XLAC._xla_tpu_custom_call_( + [grad_q], + [q, k, v, expanded_l, expanded_m, grad_output, expanded_grad_i], + payload) + + if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: + payload, _ = trace_pallas( + _flash_attention_bwd_dkv, + q, + k, + v, + None, + None, + l, + m, + grad_output, + grad_i, + block_q_major=min( + FlashAttention.DEFAULT_BLOCK_SIZES["block_q_major_dkv"], + q.shape[2]), + block_k_major=min( + FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dkv"], + k.shape[2]), + block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dkv"], + k.shape[2]), + block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"], + q.shape[2]), + sm_scale=1.0, + causal=causal, + mask_value=FlashAttention.DEFAULT_MASK_VALUE, + debug=False, + static_argnames=[ + "block_q_major", "block_k_major", "block_k", "block_q", + "sm_scale", "causal", "mask_value", "debug" + ]) + grad_k = torch.empty(k.shape, dtype=k.dtype).to(k.device) + grad_v = torch.empty(v.shape, dtype=v.dtype).to(v.device) + torch_xla._XLAC._xla_tpu_custom_call_( + [grad_k, grad_v], + [q, k, v, expanded_l, expanded_m, grad_output, expanded_grad_i], + payload) + if not ctx.needs_input_grad[1]: + grad_k = None + if not ctx.needs_input_grad[2]: + grad_v = None + + return grad_q, grad_k, grad_v, None + + def flash_attention( q, # [batch_size, num_heads, q_seq_len, d_model] k, # [batch_size, num_heads, kv_seq_len, d_model] v, # [batch_size, num_heads, kv_seq_len, d_model] - segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len] causal=False, ): - # Import JAX within the function such that we don't need to call the jax_import_guard() - # in the global scope which could cause problems for xmp.spawn. - jax_import_guard() - import jax - import jax.numpy as jnp - import jax.experimental.pallas.ops.tpu.flash_attention as tpu_flash_attention - - # TODO: Support segment_ids. - flash_attention_kernel = make_kernel_from_pallas( - tpu_flash_attention.flash_attention, lambda q, k, v: [(q.shape, q.dtype)]) - - # The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240 - # It yields much better performance than the default block_sizes. - return flash_attention_kernel( - q, - k, - v, - static_argnames=["block_sizes", "causal"], - block_sizes=tpu_flash_attention.BlockSizes( - block_q=min(512, q.shape[2]), - block_k_major=min(512, k.shape[2]), - block_k=min(512, k.shape[2]), - block_b=min(2, q.shape[0]), - block_q_major_dkv=min(512, q.shape[2]), - block_k_major_dkv=min(512, k.shape[2]), - block_q_dkv=min(512, q.shape[2]), - block_k_dkv=min(512, k.shape[2]), - block_q_dq=min(1024, q.shape[2]), - block_k_dq=min(256, k.shape[2]), - block_k_major_dq=min(512, k.shape[2]), - ), - causal=causal) + return FlashAttention.apply(q, k, v, causal) XLA_LIB.define(