Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm
from torch_xla import runtime as xr
from torch_xla._internal import tpu

Expand All @@ -24,7 +24,7 @@
class MegabloxTest(unittest.TestCase):

def _reference_gmm(self, lhs: torch.Tensor, rhs: torch.Tensor,
group_sizes: torch.Tensor) -> np.array:
group_sizes: torch.Tensor) -> torch.Tensor:
start = 0
out = []
for i, size in enumerate(group_sizes):
Expand All @@ -33,6 +33,16 @@ def _reference_gmm(self, lhs: torch.Tensor, rhs: torch.Tensor,
start += group_sizes[i]
return torch.cat(out)

def _reference_tgmm(self, lhs: torch.Tensor, rhs: torch.Tensor,
group_sizes: torch.Tensor) -> torch.Tensor:
start = 0
out = []
for i, size in enumerate(group_sizes):
result = lhs[:, start:start + size] @ rhs[start:start + size, :]
out.append(result)
start += group_sizes[i]
return torch.stack(out)

def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor:
# Randomly sample the ends of the groups in the m-dimension. Let the fuzzer
# sample with replacement so that it's possible to get zero-sized groups. Get
Expand Down Expand Up @@ -280,6 +290,54 @@ def test_sorting_input(self):
self.assertTrue(
torch.all(group_sizes == torch.tensor([1, 2, 3, 2], device="xla")))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tgmm(self):
met.clear_all()
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)

self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']

lhs = torch.rand(k, m, dtype=lhs_dtype)
rhs = torch.rand(m, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)

out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure tgmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tgmm_bf16(self):
met.clear_all()

self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

lhs = torch.rand(k, m, dtype=lhs_dtype)
rhs = torch.rand(m, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_tgmm(lhs, rhs, group_sizes)

out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure tgmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
74 changes: 63 additions & 11 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _make_group_metadata(
"""Create the metadata needed for grouped matmul computation.

Args:
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
m: The number of rows in lhs.
tm: The m-dimension tile size being used.
visit_empty_groups: If True, do not squeeze tiles for empty groups out of
Expand All @@ -537,14 +537,14 @@ def _make_group_metadata(

Returns:
tuple of:
group_offsets: A 1d, jnp.ndarray with shape [num_groups + 1] and jnp.int32
group_offsets: A 1d, torch.Tensor with shape [num_groups + 1] and torch.int32
dtype. group_offsets[i] indicates the row at which group [i] starts in
the lhs matrix and group_offsets[i-1] = m.
group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups - 1] and
jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will
group_ids: A 1d, torch.Tensor with shape [m_tiles + num_groups - 1] and
torch.int32 dtype. group_ids[i] indicates which group grid index 'i' will
work on.
m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups - 1] and
jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i'
m_tile_ids: A 1d, torch.Tensor with shape [m_tiles + num_groups - 1] and
torch.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i'
will work on.
num_tiles: The number of m-dimension tiles to execute including overlapping
executions. And don't confuse this with m_tiles which is m // tm.
Expand Down Expand Up @@ -723,14 +723,13 @@ def gmm(
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.

Args:
lhs: A 2d, jnp.ndarray with shape [m, k].
rhs: A 3d, jnp.ndarray with shape [num_groups, k, n].
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
preferred_element_type: jnp.dtype, the element type for the output matrix.
lhs: A 2d, torch.Tensor with shape [m, k].
rhs: A 3d, torch.Tensor with shape [num_groups, k, n].
group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.

Returns:
A 2d, jnp.ndarray with shape [m, n].
A 2d, torch.Tensor with shape [m, n].
"""
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
Expand Down Expand Up @@ -766,6 +765,59 @@ def gmm(
], payload, [torch.Size([m, n])], [preferred_element_type])[0]


def tgmm(
lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
tiling: tuple[int, int, int] = (128, 128, 128)
) -> torch.Tensor:
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].

Args:
lhs: A 2d, torch.Tensor with shape [k, m].
rhs: A 2d, torch.Tensor with shape [m, n].
group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.

Returns:
A 3d, torch.Tensor with shape [num_groups, k, n].
"""
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
from jax.experimental.pallas.ops.tpu.megablox.gmm import tgmm

k, m, n, num_groups = lhs.shape[0], lhs.shape[1], rhs.shape[
1], group_sizes.shape[0]
tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n)
preferred_element_type = lhs.dtype

payload, _ = trace_pallas(
tgmm,
lhs,
rhs,
group_sizes,
static_argnames=["tiling", "preferred_element_type"],
preferred_element_type=convert_torch_dtype_to_jax(preferred_element_type),
tiling=(tm, tk, tn))

# Create the metadata we need for computation, and that's why need to separate
# the tracing and execution part.
group_offsets, group_ids, m_tile_ids, num_tiles = _make_group_metadata(
group_sizes=group_sizes,
m=m,
tm=tm,
visit_empty_groups=True,
)
group_offset_torch = torch.tensor([0], dtype=torch.int32).to(lhs.device)

lhs = lhs.swapaxes(0, 1)
return torch_xla._XLAC._xla_tpu_custom_call([
num_tiles, group_offsets, group_ids, m_tile_ids, group_offset_torch, lhs,
rhs
], payload, [torch.Size([num_groups, k, n])], [preferred_element_type])[0]


def non_xla_attetion(q, k, v, attention_type):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
Expand Down