diff --git a/test/test_pallas.py b/test/test_pallas.py index 8901a84c80ad..37d5932c4f9d 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -736,6 +736,64 @@ def test_flash_attention_backward_segment_ids(self): self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_wrapper_sm_scale(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + q = torch.randn(3, 2, 128, 4).to("xla") + k = torch.randn(3, 2, 128, 4).to("xla") + v = torch.randn(3, 2, 128, 4).to("xla") + sm_scale = 0.7 + o = flash_attention(q, k, v, False, None, None, sm_scale) + + expected_o = self._attention(q * sm_scale, k, v) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_sm_scale_backward(self): + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + from torch_xla.experimental.custom_kernel import flash_attention + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + sm_scale = 0.7 + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention(q, k, v, False, None, None, sm_scale) + loss = o.sum() + loss.backward() + xm.mark_step() + + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention(q * sm_scale, k, v) + loss = o.sum() + loss.backward() + xm.mark_step() + + # Hmm, the gradients are the same even the autograd graph seems different. + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 484002124d5a..43e1ab1c128f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -198,15 +198,8 @@ def prepare_segment_ids(q_segment_ids, kv_segment_ids): return segment_ids, q_segment_ids, kv_segment_ids @staticmethod - def forward(ctx, - q, - k, - v, - causal=False, - q_segment_ids=None, - kv_segment_ids=None, - partition_spec=None, - mesh=None): + def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, + partition_spec, mesh): # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() @@ -214,6 +207,7 @@ def forward(ctx, from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl ctx.causal = causal + ctx.sm_scale = sm_scale ctx.partition_spec = partition_spec ctx.mesh = mesh ctx.full_shape = None @@ -258,7 +252,7 @@ def forward(ctx, segment_ids, save_residuals, causal, - 1.0, + sm_scale, min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]), min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]), min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]), @@ -300,6 +294,7 @@ def backward(ctx, grad_output): q, k, v, o, l, m, q_segment_ids, kv_segment_ids = ctx.saved_tensors causal = ctx.causal + sm_scale = ctx.sm_scale partition_spec = ctx.partition_spec mesh = ctx.mesh full_shape = ctx.full_shape @@ -350,7 +345,7 @@ def backward(ctx, grad_output): k.shape[2]), block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"], k.shape[2]), - sm_scale=1.0, + sm_scale=sm_scale, causal=causal, mask_value=FlashAttention.DEFAULT_MASK_VALUE, debug=False, @@ -388,7 +383,7 @@ def backward(ctx, grad_output): k.shape[2]), block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"], q.shape[2]), - sm_scale=1.0, + sm_scale=sm_scale, causal=causal, mask_value=FlashAttention.DEFAULT_MASK_VALUE, debug=False, @@ -418,7 +413,7 @@ def backward(ctx, grad_output): grad_v = xs.disable_manual_sharding( grad_v, partition_spec, full_shape, mesh=mesh).global_tensor - return grad_q, grad_k, grad_v, None, None, None, None, None + return grad_q, grad_k, grad_v, None, None, None, None, None, None def flash_attention( @@ -428,12 +423,13 @@ def flash_attention( causal=False, q_segment_ids=None, kv_segment_ids=None, + sm_scale=1.0, *, partition_spec=None, mesh=None): # TODO: support SPMD and Dynamo with segment_ids. return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids, - partition_spec, mesh) + sm_scale, partition_spec, mesh) def paged_attention(q, k_pages, v_pages, lengths, page_indices,