diff --git a/docs/api/exceptions.md b/docs/api/exceptions.md index 6533c1065..fffc6c27f 100644 --- a/docs/api/exceptions.md +++ b/docs/api/exceptions.md @@ -350,6 +350,10 @@ Warnings can be suppressed by including them in the `ignore_warnings` setting: Warns when operations return tensors on wrong device. +.. autoclass:: TiledKMatmulAccumulationWarning + + Warns when ``acc += lhs @ rhs`` pattern is used inside tiled device loops. + ``` ### Warning Suppression diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 69ca13338..626b98fc0 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -1646,6 +1646,32 @@ def generic_visit(self, node: ast.AST) -> TypeInfo: super().generic_visit(node) raise exc.UnsupportedPythonType(f"ast.{node.__class__.__name__}") + @staticmethod + def _contains_matmul(node: ast.AST | None) -> bool: + if node is None: + return False + + matmul_functions = ["torch.matmul", "torch.mm", "torch.bmm", "hl.dot"] + + for sub_node in ast.walk(node): + # Check for @ operator + if isinstance(sub_node, ast.BinOp) and isinstance(sub_node.op, ast.MatMult): + return True + + # Check for function calls + if not isinstance(sub_node, ast.Call): + continue + + func = sub_node.func + + # Check for matmul function calls + if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): + qualified_name = f"{func.value.id}.{func.attr}" + if qualified_name in matmul_functions: + return True + + return False + def _bool_op(self, op: ast.boolop, left: TypeInfo, right: TypeInfo) -> TypeInfo: try: val = left.truth_value() @@ -2094,6 +2120,12 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> TypeInfo: def visit_AugAssign(self, node: ast.AugAssign) -> TypeInfo: assert isinstance(node.target, ExtendedAST) + if ( + self.device_loop_depth > 0 + and isinstance(node.op, ast.Add) + and self._contains_matmul(node.value) + ): + warning(exc.TiledKMatmulAccumulationWarning) try: type_info = self.visit( create( diff --git a/helion/exc.py b/helion/exc.py index fbc20b234..3f4aea565 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -414,6 +414,22 @@ class BlockSizeIgnoredInInterpretMode(BaseWarning): message = "block_size is specified to be {0}, but in interpret mode, the full dimension size is always used." +class TiledKMatmulAccumulationWarning(BaseWarning): + message = ( + "Detected one of the following usage patterns inside a Helion device loop:\n" + "- `acc += lhs @ rhs`\n" + "- `acc += torch.matmul(lhs, rhs)`\n" + "- `acc += torch.mm(lhs, rhs)`\n" + "- `acc += torch.bmm(lhs, rhs)`\n" + "- `acc += hl.dot(lhs, rhs)`\n" + "For accurate numerics, please use one of:\n" + "- `torch.addmm(acc, ...)`\n" + "- `torch.baddbmm(acc, ...)`\n" + "- `hl.dot(acc=...)`\n" + "to accumulate across tiled-K iterations of a matmul operation." + ) + + class AutotuningDisallowedInEnvironment(BaseError): message = "Autotuning is disabled {0}, please provide a config to @helion.kernel via the config= argument." diff --git a/test/test_dot.py b/test/test_dot.py index 5d56807bb..107e3a341 100644 --- a/test/test_dot.py +++ b/test/test_dot.py @@ -1,5 +1,7 @@ from __future__ import annotations +import contextlib +import io import itertools from typing import Callable import unittest @@ -288,6 +290,141 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: expected = torch.bmm(A, B).to(result.dtype) * 2 torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2) + def _assert_warning_in_stderr( + self, kernel, args, expected_result, warning_str, *, atol=1e-2, rtol=1e-2 + ): + stderr_buffer = io.StringIO() + with contextlib.redirect_stderr(stderr_buffer): + _, out = code_and_output(kernel, args) + + torch.testing.assert_close(out, expected_result, atol=atol, rtol=rtol) + + warning_text = stderr_buffer.getvalue() + self.assertIn(warning_str, warning_text) + + @skipIfRefEager("Warning emitted in compile mode only") + def test_augassign_at_operator_warning(self): + @helion.kernel(static_shapes=True) + def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.shape + k2, n = y.shape + assert k == k2 + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + lhs = x[tile_m, tile_k] + rhs = y[tile_k, tile_n] + acc += lhs @ rhs + out[tile_m, tile_n] = acc + return out + + x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32) + y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32) + + self._assert_warning_in_stderr( + warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]" + ) + + @skipIfRefEager("Warning emitted in compile mode only") + def test_augassign_torch_matmul_warning(self): + @helion.kernel(static_shapes=True) + def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.shape + k2, n = y.shape + assert k == k2 + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + lhs = x[tile_m, tile_k] + rhs = y[tile_k, tile_n] + acc += torch.matmul(lhs, rhs) + out[tile_m, tile_n] = acc + return out + + x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32) + y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32) + + self._assert_warning_in_stderr( + warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]" + ) + + @skipIfRefEager("Warning emitted in compile mode only") + def test_augassign_torch_mm_warning(self): + @helion.kernel(static_shapes=True) + def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.shape + k2, n = y.shape + assert k == k2 + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + lhs = x[tile_m, tile_k] + rhs = y[tile_k, tile_n] + acc += torch.mm(lhs, rhs) + out[tile_m, tile_n] = acc + return out + + x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32) + y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32) + + self._assert_warning_in_stderr( + warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]" + ) + + @skipIfRefEager("Warning emitted in compile mode only") + def test_augassign_torch_bmm_warning(self): + @helion.kernel(static_shapes=True) + def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + b, m, k = x.shape + b2, k2, n = y.shape + assert b == b2 and k == k2 + out = torch.empty([b, m, n], dtype=x.dtype, device=x.device) + for tile_b, tile_m, tile_n in hl.tile([b, m, n]): + acc = hl.zeros([tile_b, tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + lhs = x[tile_b, tile_m, tile_k] + rhs = y[tile_b, tile_k, tile_n] + acc += torch.bmm(lhs, rhs) + out[tile_b, tile_m, tile_n] = acc + return out + + x = torch.randn(4, 32, 16, device=DEVICE, dtype=torch.float32) + y = torch.randn(4, 16, 32, device=DEVICE, dtype=torch.float32) + + self._assert_warning_in_stderr( + warn_kernel, + (x, y), + torch.bmm(x, y), + "WARNING[TiledKMatmulAccumulationWarning]", + ) + + @skipIfRefEager("Warning emitted in compile mode only") + def test_augassign_hl_dot_warning(self): + @helion.kernel(static_shapes=True) + def no_warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.shape + k2, n = y.shape + assert k == k2 + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + lhs = x[tile_m, tile_k] + rhs = y[tile_k, tile_n] + acc += hl.dot(lhs, rhs) + out[tile_m, tile_n] = acc + return out + + x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32) + y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32) + + self._assert_warning_in_stderr( + no_warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]" + ) + # Note: numerical behavior for differing acc dtype is covered by existing dot tests; here we focus on codegen shape # torch.baddbmm codegen shape is covered indirectly by broader matmul tests; skipping a brittle code-inspection here