Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@ Warnings can be suppressed by including them in the `ignore_warnings` setting:

Warns when operations return tensors on wrong device.

.. autoclass:: TiledKMatmulAccumulationWarning

Warns when ``acc += lhs @ rhs`` pattern is used inside tiled device loops.

```

### Warning Suppression
Expand Down
32 changes: 32 additions & 0 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,32 @@ def generic_visit(self, node: ast.AST) -> TypeInfo:
super().generic_visit(node)
raise exc.UnsupportedPythonType(f"ast.{node.__class__.__name__}")

@staticmethod
def _contains_matmul(node: ast.AST | None) -> bool:
if node is None:
return False

matmul_functions = ["torch.matmul", "torch.mm", "torch.bmm", "hl.dot"]

for sub_node in ast.walk(node):
# Check for @ operator
if isinstance(sub_node, ast.BinOp) and isinstance(sub_node.op, ast.MatMult):
return True

# Check for function calls
if not isinstance(sub_node, ast.Call):
continue

func = sub_node.func

# Check for matmul function calls
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
qualified_name = f"{func.value.id}.{func.attr}"
if qualified_name in matmul_functions:
return True

return False

def _bool_op(self, op: ast.boolop, left: TypeInfo, right: TypeInfo) -> TypeInfo:
try:
val = left.truth_value()
Expand Down Expand Up @@ -2094,6 +2120,12 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> TypeInfo:

def visit_AugAssign(self, node: ast.AugAssign) -> TypeInfo:
assert isinstance(node.target, ExtendedAST)
if (
self.device_loop_depth > 0
and isinstance(node.op, ast.Add)
and self._contains_matmul(node.value)
):
warning(exc.TiledKMatmulAccumulationWarning)
try:
type_info = self.visit(
create(
Expand Down
16 changes: 16 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,22 @@ class BlockSizeIgnoredInInterpretMode(BaseWarning):
message = "block_size is specified to be {0}, but in interpret mode, the full dimension size is always used."


class TiledKMatmulAccumulationWarning(BaseWarning):
message = (
"Detected one of the following usage patterns inside a Helion device loop:\n"
"- `acc += lhs @ rhs`\n"
"- `acc += torch.matmul(lhs, rhs)`\n"
"- `acc += torch.mm(lhs, rhs)`\n"
"- `acc += torch.bmm(lhs, rhs)`\n"
"- `acc += hl.dot(lhs, rhs)`\n"
"For accurate numerics, please use one of:\n"
"- `torch.addmm(acc, ...)`\n"
"- `torch.baddbmm(acc, ...)`\n"
"- `hl.dot(acc=...)`\n"
"to accumulate across tiled-K iterations of a matmul operation."
)


class AutotuningDisallowedInEnvironment(BaseError):
message = "Autotuning is disabled {0}, please provide a config to @helion.kernel via the config= argument."

Expand Down
137 changes: 137 additions & 0 deletions test/test_dot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import contextlib
import io
import itertools
from typing import Callable
import unittest
Expand Down Expand Up @@ -288,6 +290,141 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
expected = torch.bmm(A, B).to(result.dtype) * 2
torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2)

def _assert_warning_in_stderr(
self, kernel, args, expected_result, warning_str, *, atol=1e-2, rtol=1e-2
):
stderr_buffer = io.StringIO()
with contextlib.redirect_stderr(stderr_buffer):
_, out = code_and_output(kernel, args)

torch.testing.assert_close(out, expected_result, atol=atol, rtol=rtol)

warning_text = stderr_buffer.getvalue()
self.assertIn(warning_str, warning_text)

@skipIfRefEager("Warning emitted in compile mode only")
def test_augassign_at_operator_warning(self):
@helion.kernel(static_shapes=True)
def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.shape
k2, n = y.shape
assert k == k2
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
lhs = x[tile_m, tile_k]
rhs = y[tile_k, tile_n]
acc += lhs @ rhs
out[tile_m, tile_n] = acc
return out

x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32)
y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32)

self._assert_warning_in_stderr(
warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]"
)

@skipIfRefEager("Warning emitted in compile mode only")
def test_augassign_torch_matmul_warning(self):
@helion.kernel(static_shapes=True)
def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.shape
k2, n = y.shape
assert k == k2
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
lhs = x[tile_m, tile_k]
rhs = y[tile_k, tile_n]
acc += torch.matmul(lhs, rhs)
out[tile_m, tile_n] = acc
return out

x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32)
y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32)

self._assert_warning_in_stderr(
warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]"
)

@skipIfRefEager("Warning emitted in compile mode only")
def test_augassign_torch_mm_warning(self):
@helion.kernel(static_shapes=True)
def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.shape
k2, n = y.shape
assert k == k2
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
lhs = x[tile_m, tile_k]
rhs = y[tile_k, tile_n]
acc += torch.mm(lhs, rhs)
out[tile_m, tile_n] = acc
return out

x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32)
y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32)

self._assert_warning_in_stderr(
warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]"
)

@skipIfRefEager("Warning emitted in compile mode only")
def test_augassign_torch_bmm_warning(self):
@helion.kernel(static_shapes=True)
def warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
b, m, k = x.shape
b2, k2, n = y.shape
assert b == b2 and k == k2
out = torch.empty([b, m, n], dtype=x.dtype, device=x.device)
for tile_b, tile_m, tile_n in hl.tile([b, m, n]):
acc = hl.zeros([tile_b, tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
lhs = x[tile_b, tile_m, tile_k]
rhs = y[tile_b, tile_k, tile_n]
acc += torch.bmm(lhs, rhs)
out[tile_b, tile_m, tile_n] = acc
return out

x = torch.randn(4, 32, 16, device=DEVICE, dtype=torch.float32)
y = torch.randn(4, 16, 32, device=DEVICE, dtype=torch.float32)

self._assert_warning_in_stderr(
warn_kernel,
(x, y),
torch.bmm(x, y),
"WARNING[TiledKMatmulAccumulationWarning]",
)

@skipIfRefEager("Warning emitted in compile mode only")
def test_augassign_hl_dot_warning(self):
@helion.kernel(static_shapes=True)
def no_warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.shape
k2, n = y.shape
assert k == k2
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
lhs = x[tile_m, tile_k]
rhs = y[tile_k, tile_n]
acc += hl.dot(lhs, rhs)
out[tile_m, tile_n] = acc
return out

x = torch.randn(32, 16, device=DEVICE, dtype=torch.float32)
y = torch.randn(16, 32, device=DEVICE, dtype=torch.float32)

self._assert_warning_in_stderr(
no_warn_kernel, (x, y), x @ y, "WARNING[TiledKMatmulAccumulationWarning]"
)

# Note: numerical behavior for differing acc dtype is covered by existing dot tests; here we focus on codegen shape

# torch.baddbmm codegen shape is covered indirectly by broader matmul tests; skipping a brittle code-inspection here
Expand Down
Loading