From 39ff22d7e39dfdab25c75ac951f8b8ddac310ee7 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 29 May 2024 08:38:56 +0000 Subject: [PATCH 1/2] initial commit --- test/test_gmm.py | 62 ++++++++++++++++++++- torch_xla/experimental/custom_kernel.py | 73 +++++++++++++++++++++---- 2 files changed, 122 insertions(+), 13 deletions(-) diff --git a/test/test_gmm.py b/test/test_gmm.py index 141c66ca3422..8b73891d7361 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -7,7 +7,7 @@ import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met -from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram +from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm from torch_xla import runtime as xr from torch_xla._internal import tpu @@ -24,7 +24,7 @@ class MegabloxTest(unittest.TestCase): def _reference_gmm(self, lhs: torch.Tensor, rhs: torch.Tensor, - group_sizes: torch.Tensor) -> np.array: + group_sizes: torch.Tensor) -> torch.Tensor: start = 0 out = [] for i, size in enumerate(group_sizes): @@ -33,6 +33,16 @@ def _reference_gmm(self, lhs: torch.Tensor, rhs: torch.Tensor, start += group_sizes[i] return torch.cat(out) + def _reference_tgmm(self, lhs: torch.Tensor, rhs: torch.Tensor, + group_sizes: torch.Tensor) -> torch.Tensor: + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = lhs[:, start:start + size] @ rhs[start:start + size, :] + out.append(result) + start += group_sizes[i] + return torch.stack(out) + def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor: # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer # sample with replacement so that it's possible to get zero-sized groups. Get @@ -280,6 +290,54 @@ def test_sorting_input(self): self.assertTrue( torch.all(group_sizes == torch.tensor([1, 2, 3, 2], device="xla"))) + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_tgmm(self): + met.clear_all() + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + + self._init_test_cases() + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = test_case['dtype'] + + lhs = torch.rand(k, m, dtype=lhs_dtype) + rhs = torch.rand(m, n, dtype=rhs_dtype) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + ref_out = self._reference_tgmm(lhs, rhs, group_sizes) + + out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + self.assertTrue(torch.allclose(ref_out, out.cpu())) + + # Make sure tgmm doesn't fallback. + self.assertNotIn("aten::", met.short_metrics_report()) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_tgmm_bf16(self): + met.clear_all() + + self._init_test_cases() + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = torch.bfloat16 + + lhs = torch.rand(k, m, dtype=lhs_dtype) + rhs = torch.rand(m, n, dtype=rhs_dtype) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + ref_out = self._reference_tgmm(lhs, rhs, group_sizes) + + out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + self.assertTrue(torch.allclose(ref_out, out.cpu())) + + # Make sure tgmm doesn't fallback. + self.assertNotIn("aten::", met.short_metrics_report()) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index c28870e18123..d49f2ad308b8 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -528,7 +528,7 @@ def _make_group_metadata( """Create the metadata needed for grouped matmul computation. Args: - group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype. m: The number of rows in lhs. tm: The m-dimension tile size being used. visit_empty_groups: If True, do not squeeze tiles for empty groups out of @@ -537,14 +537,14 @@ def _make_group_metadata( Returns: tuple of: - group_offsets: A 1d, jnp.ndarray with shape [num_groups + 1] and jnp.int32 + group_offsets: A 1d, torch.Tensor with shape [num_groups + 1] and torch.int32 dtype. group_offsets[i] indicates the row at which group [i] starts in the lhs matrix and group_offsets[i-1] = m. - group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups - 1] and - jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will + group_ids: A 1d, torch.Tensor with shape [m_tiles + num_groups - 1] and + torch.int32 dtype. group_ids[i] indicates which group grid index 'i' will work on. - m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups - 1] and - jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' + m_tile_ids: A 1d, torch.Tensor with shape [m_tiles + num_groups - 1] and + torch.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' will work on. num_tiles: The number of m-dimension tiles to execute including overlapping executions. And don't confuse this with m_tiles which is m // tm. @@ -723,14 +723,13 @@ def gmm( """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'. Args: - lhs: A 2d, jnp.ndarray with shape [m, k]. - rhs: A 3d, jnp.ndarray with shape [num_groups, k, n]. - group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. - preferred_element_type: jnp.dtype, the element type for the output matrix. + lhs: A 2d, torch.Tensor with shape [m, k]. + rhs: A 3d, torch.Tensor with shape [num_groups, k, n]. + group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype. tiling: 3-tuple of ints. The m, k and n-dimension tile sizes. Returns: - A 2d, jnp.ndarray with shape [m, n]. + A 2d, torch.Tensor with shape [m, n]. """ # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. @@ -766,6 +765,58 @@ def gmm( ], payload, [torch.Size([m, n])], [preferred_element_type])[0] +def tgmm( + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + tiling: tuple[int, int, int] = (128, 128, 128) +) -> torch.Tensor: + """Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :]. + + Args: + lhs: A 2d, torch.Tensor with shape [k, m]. + rhs: A 2d, torch.Tensor with shape [m, n]. + group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype. + tiling: 3-tuple of ints. The m, k and n-dimension tile sizes. + + Returns: + A 3d, torch.Tensor with shape [num_groups, k, n]. + """ + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + from jax.experimental.pallas.ops.tpu.megablox.gmm import tgmm + + k, m, n, num_groups = lhs.shape[0], lhs.shape[1], rhs.shape[1], group_sizes.shape[0] + tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n) + preferred_element_type = lhs.dtype + + payload, _ = trace_pallas( + tgmm, + lhs, + rhs, + group_sizes, + static_argnames=["tiling", "preferred_element_type"], + preferred_element_type=convert_torch_dtype_to_jax(preferred_element_type), + tiling=(tm, tk, tn)) + + # Create the metadata we need for computation, and that's why need to separate + # the tracing and execution part. + group_offsets, group_ids, m_tile_ids, num_tiles = _make_group_metadata( + group_sizes=group_sizes, + m=m, + tm=tm, + visit_empty_groups=True, + ) + group_offset_torch = torch.tensor([0], dtype=torch.int32).to(lhs.device) + + lhs = lhs.swapaxes(0, 1) + return torch_xla._XLAC._xla_tpu_custom_call([ + num_tiles, group_offsets, group_ids, m_tile_ids, group_offset_torch, lhs, + rhs + ], payload, [torch.Size([num_groups, k, n])], [preferred_element_type])[0] + + def non_xla_attetion(q, k, v, attention_type): # This will be called when dynamo use fake tensor to construct the fake output. # We need to make sure output tensor's shape is correct. From aa125f995b251bfcd6df806579183502067e074b Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 29 May 2024 08:40:58 +0000 Subject: [PATCH 2/2] Fix linters --- test/test_gmm.py | 2 +- torch_xla/experimental/custom_kernel.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_gmm.py b/test/test_gmm.py index 8b73891d7361..7f9626405af3 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -34,7 +34,7 @@ def _reference_gmm(self, lhs: torch.Tensor, rhs: torch.Tensor, return torch.cat(out) def _reference_tgmm(self, lhs: torch.Tensor, rhs: torch.Tensor, - group_sizes: torch.Tensor) -> torch.Tensor: + group_sizes: torch.Tensor) -> torch.Tensor: start = 0 out = [] for i, size in enumerate(group_sizes): diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index d49f2ad308b8..cdb7ff481bb1 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -787,7 +787,8 @@ def tgmm( jax_import_guard() from jax.experimental.pallas.ops.tpu.megablox.gmm import tgmm - k, m, n, num_groups = lhs.shape[0], lhs.shape[1], rhs.shape[1], group_sizes.shape[0] + k, m, n, num_groups = lhs.shape[0], lhs.shape[1], rhs.shape[ + 1], group_sizes.shape[0] tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n) preferred_element_type = lhs.dtype