From c90168060616d4f32f04025211ce325d90e56971 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 8 May 2024 00:14:06 +0000 Subject: [PATCH 1/2] initial commit --- test/test_pallas.py | 64 +++++++++++++++++++++++++ torch_xla/experimental/custom_kernel.py | 24 ++++++---- 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 8901a84c80ad..58b759d49213 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -736,6 +736,70 @@ def test_flash_attention_backward_segment_ids(self): self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_wrapper_sm_scale(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + q = torch.randn(3, 2, 128, 4).to("xla") + k = torch.randn(3, 2, 128, 4).to("xla") + v = torch.randn(3, 2, 128, 4).to("xla") + sm_scale = 0.7 + o = flash_attention(q, k, v, False, None, None, sm_scale) + + expected_o = self._attention( + q * sm_scale, + k, + v) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_sm_scale_backward(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + sm_scale = 0.7 + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention(q, k, v, False, None, None, sm_scale) + loss = o.sum() + loss.backward() + xm.mark_step() + + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention( + q * sm_scale, + k, + v) + loss = o.sum() + loss.backward() + xm.mark_step() + + # Hmm, the gradients are the same even the autograd graph seems different. + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 484002124d5a..4c7406621213 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -202,11 +202,12 @@ def forward(ctx, q, k, v, - causal=False, - q_segment_ids=None, - kv_segment_ids=None, - partition_spec=None, - mesh=None): + causal, + q_segment_ids, + kv_segment_ids, + sm_scale, + partition_spec, + mesh): # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() @@ -214,6 +215,7 @@ def forward(ctx, from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl ctx.causal = causal + ctx.sm_scale = sm_scale ctx.partition_spec = partition_spec ctx.mesh = mesh ctx.full_shape = None @@ -258,7 +260,7 @@ def forward(ctx, segment_ids, save_residuals, causal, - 1.0, + sm_scale, min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]), min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]), min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]), @@ -300,6 +302,7 @@ def backward(ctx, grad_output): q, k, v, o, l, m, q_segment_ids, kv_segment_ids = ctx.saved_tensors causal = ctx.causal + sm_scale = ctx.sm_scale partition_spec = ctx.partition_spec mesh = ctx.mesh full_shape = ctx.full_shape @@ -350,7 +353,7 @@ def backward(ctx, grad_output): k.shape[2]), block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"], k.shape[2]), - sm_scale=1.0, + sm_scale=sm_scale, causal=causal, mask_value=FlashAttention.DEFAULT_MASK_VALUE, debug=False, @@ -388,7 +391,7 @@ def backward(ctx, grad_output): k.shape[2]), block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"], q.shape[2]), - sm_scale=1.0, + sm_scale=sm_scale, causal=causal, mask_value=FlashAttention.DEFAULT_MASK_VALUE, debug=False, @@ -418,7 +421,7 @@ def backward(ctx, grad_output): grad_v = xs.disable_manual_sharding( grad_v, partition_spec, full_shape, mesh=mesh).global_tensor - return grad_q, grad_k, grad_v, None, None, None, None, None + return grad_q, grad_k, grad_v, None, None, None, None, None, None def flash_attention( @@ -428,11 +431,12 @@ def flash_attention( causal=False, q_segment_ids=None, kv_segment_ids=None, + sm_scale=1.0, *, partition_spec=None, mesh=None): # TODO: support SPMD and Dynamo with segment_ids. - return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids, + return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, partition_spec, mesh) From 1279d7436f2b848c91096a4c8284869f37f3d8ba Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 8 May 2024 00:15:49 +0000 Subject: [PATCH 2/2] Fix linters --- test/test_pallas.py | 32 ++++++++++--------------- torch_xla/experimental/custom_kernel.py | 16 ++++--------- 2 files changed, 17 insertions(+), 31 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 58b759d49213..37d5932c4f9d 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -737,26 +737,23 @@ def test_flash_attention_backward_segment_ids(self): jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, - "This test only works on TPUv3+.") + "This test only works on TPUv3+.") def test_flash_attention_wrapper_sm_scale(self): - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) - from torch_xla.experimental.custom_kernel import flash_attention + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention - q = torch.randn(3, 2, 128, 4).to("xla") - k = torch.randn(3, 2, 128, 4).to("xla") - v = torch.randn(3, 2, 128, 4).to("xla") - sm_scale = 0.7 - o = flash_attention(q, k, v, False, None, None, sm_scale) + q = torch.randn(3, 2, 128, 4).to("xla") + k = torch.randn(3, 2, 128, 4).to("xla") + v = torch.randn(3, 2, 128, 4).to("xla") + sm_scale = 0.7 + o = flash_attention(q, k, v, False, None, None, sm_scale) - expected_o = self._attention( - q * sm_scale, - k, - v) - self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + expected_o = self._attention(q * sm_scale, k, v) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, - "This test only works on TPUv3+.") + "This test only works on TPUv3+.") def test_flash_attention_sm_scale_backward(self): jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) from torch_xla.experimental.custom_kernel import flash_attention @@ -787,10 +784,7 @@ def test_flash_attention_sm_scale_backward(self): k.retain_grad() v.retain_grad() - o = self._attention( - q * sm_scale, - k, - v) + o = self._attention(q * sm_scale, k, v) loss = o.sum() loss.backward() xm.mark_step() diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 4c7406621213..43e1ab1c128f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -198,16 +198,8 @@ def prepare_segment_ids(q_segment_ids, kv_segment_ids): return segment_ids, q_segment_ids, kv_segment_ids @staticmethod - def forward(ctx, - q, - k, - v, - causal, - q_segment_ids, - kv_segment_ids, - sm_scale, - partition_spec, - mesh): + def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, + partition_spec, mesh): # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() @@ -436,8 +428,8 @@ def flash_attention( partition_spec=None, mesh=None): # TODO: support SPMD and Dynamo with segment_ids. - return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, - partition_spec, mesh) + return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids, + sm_scale, partition_spec, mesh) def paged_attention(q, k_pages, v_pages, lengths, page_indices,