From 4baa7361f2021915b4975f93b62c6bf267c3c5dd Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 30 May 2024 04:12:53 +0000 Subject: [PATCH 1/3] Introduce GMM --- test/test_gmm.py | 38 ++++++++++++++++++++++++- torch_xla/experimental/custom_kernel.py | 14 +++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/test/test_gmm.py b/test/test_gmm.py index bf8fdebe24ca..e979731a46b4 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -7,7 +7,7 @@ import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met -from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward +from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward, GMM from torch_xla import runtime as xr from torch_xla._internal import tpu @@ -374,6 +374,42 @@ def test_gmm_backward(self): # Make sure gmm doesn't fallback. self.assertNotIn("aten::", met.short_metrics_report()) + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_gmm_backward_2(self): + self._init_test_cases() + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = torch.bfloat16 + + torch.manual_seed(42) + lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True) + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + lhs.retain_grad() + rhs.retain_grad() + + ref_out = self._reference_gmm(lhs, rhs, group_sizes) + ref_out.sum().backward() + + torch.manual_seed(42) + lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla") + rhs_xla = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") + lhs_xla.retain_grad() + rhs_xla.retain_grad() + + out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla")) + out.sum().backward() + + self.assertTrue(torch.allclose(ref_out, out.cpu())) + self.assertTrue(torch.allclose(lhs.grad, lhs_xla.grad.cpu())) + self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu())) + + # Make sure gmm doesn't fallback. + self.assertNotIn("aten::", met.short_metrics_report()) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 9bc32a3fc3fa..619496bcf1ac 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -825,6 +825,20 @@ def gmm_backward(grad, lhs, rhs, group_sizes, tiling=(512, 512, 512)): return grad_lhs, grad_rhs +class GMM(torch.autograd.Function): + @staticmethod + def forward(ctx, lhs, rhs, group_sizes, tiling=(512, 512, 512)): + ctx.save_for_backward(lhs, rhs, group_sizes) + ctx.tiling = tiling + return gmm(lhs, rhs, group_sizes, tiling) + + @staticmethod + def backward(ctx, grad_output): + lhs, rhs, group_sizes = ctx.saved_tensors + grad_lhs, grad_rhs = gmm_backward(grad_output, lhs, rhs, group_sizes, ctx.tiling) + return grad_lhs, grad_rhs, None, None + + def non_xla_attetion(q, k, v, attention_type): # This will be called when dynamo use fake tensor to construct the fake output. # We need to make sure output tensor's shape is correct. From 4dd8826f50b811c9633068accc90a052e44932e6 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 30 May 2024 04:26:07 +0000 Subject: [PATCH 2/3] Use torch.autograd.backward --- test/test_gmm.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/test_gmm.py b/test/test_gmm.py index e979731a46b4..56ce9572ce60 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -410,6 +410,43 @@ def test_gmm_backward_2(self): # Make sure gmm doesn't fallback. self.assertNotIn("aten::", met.short_metrics_report()) + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_gmm_backward_3(self): + self._init_test_cases() + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = torch.bfloat16 + + torch.manual_seed(42) + lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True) + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + lhs.retain_grad() + rhs.retain_grad() + + ref_out = self._reference_gmm(lhs, rhs, group_sizes) + ref_out.sum().backward() + + torch.manual_seed(42) + lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla") + rhs_xla = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") + lhs_xla.retain_grad() + rhs_xla.retain_grad() + + out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla")) + grad_out = torch.ones_like(out) + torch.autograd.backward([out], [grad_out, lhs_xla, rhs_xla]) + + self.assertTrue(torch.allclose(ref_out, out.cpu())) + self.assertTrue(torch.allclose(lhs.grad, lhs_xla.grad.cpu())) + self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu())) + + # Make sure gmm doesn't fallback. + self.assertNotIn("aten::", met.short_metrics_report()) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From f7131d98ddfa0868e3a7cf19b3c647a363ad9838 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 30 May 2024 04:27:27 +0000 Subject: [PATCH 3/3] Fix linters --- test/test_gmm.py | 6 ++++-- torch_xla/experimental/custom_kernel.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/test/test_gmm.py b/test/test_gmm.py index 56ce9572ce60..b594a85c065c 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -396,7 +396,8 @@ def test_gmm_backward_2(self): torch.manual_seed(42) lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla") - rhs_xla = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") + rhs_xla = torch.rand( + num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") lhs_xla.retain_grad() rhs_xla.retain_grad() @@ -432,7 +433,8 @@ def test_gmm_backward_3(self): torch.manual_seed(42) lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla") - rhs_xla = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") + rhs_xla = torch.rand( + num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") lhs_xla.retain_grad() rhs_xla.retain_grad() diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 619496bcf1ac..1a8a8cd3852d 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -826,6 +826,7 @@ def gmm_backward(grad, lhs, rhs, group_sizes, tiling=(512, 512, 512)): class GMM(torch.autograd.Function): + @staticmethod def forward(ctx, lhs, rhs, group_sizes, tiling=(512, 512, 512)): ctx.save_for_backward(lhs, rhs, group_sizes) @@ -835,7 +836,8 @@ def forward(ctx, lhs, rhs, group_sizes, tiling=(512, 512, 512)): @staticmethod def backward(ctx, grad_output): lhs, rhs, group_sizes = ctx.saved_tensors - grad_lhs, grad_rhs = gmm_backward(grad_output, lhs, rhs, group_sizes, ctx.tiling) + grad_lhs, grad_rhs = gmm_backward(grad_output, lhs, rhs, group_sizes, + ctx.tiling) return grad_lhs, grad_rhs, None, None