Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion helion/_compiler/matmul_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ def emit_tl_dot_with_padding(
lhs_cast, rhs_cast = cast_ast(lhs, common_dtype), cast_ast(rhs, common_dtype)
m, n, k = (_resolve_dim_size(d) for d in (m, n, k))

fuse_acc = acc is not None and acc_dtype in (common_dtype, torch.float32)
fuse_acc = (
acc is not None
and acc_dtype in (common_dtype, torch.float32)
and (out_dtype is None or out_dtype == acc_dtype)
)
acc_out = acc if not fuse_acc else None
acc_for_dot = acc if fuse_acc else None
acc_cast_dtype = acc_dtype if not fuse_acc else None
Expand Down
39 changes: 31 additions & 8 deletions helion/language/matmul_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def dot(
mat1: torch.Tensor,
mat2: torch.Tensor,
acc: torch.Tensor | None = None,
out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
"""
Performs a matrix multiplication of tensors with support for multiple dtypes.
Expand All @@ -36,6 +37,9 @@ def dot(
acc: The accumulator tensor (2D or 3D tensor of torch.float16, torch.float32, or torch.int32).
If not None, the result is added to this tensor.
If None, a new tensor is created with appropriate dtype based on inputs.
out_dtype: Optional dtype that controls the output type of the multiplication prior
to any accumulation. This maps directly to the Triton ``tl.dot`` ``out_dtype``
argument and overrides the default promotion rules when provided.

Returns:
Result of matrix multiplication. If acc is provided, returns acc + (mat1 @ mat2).
Expand Down Expand Up @@ -67,7 +71,8 @@ def _(
mat1: torch.Tensor,
mat2: torch.Tensor,
acc: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
out_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.dtype | None]:
# Define supported dtypes
supported_dtypes = (
torch.float16,
Expand Down Expand Up @@ -101,6 +106,11 @@ def _(
f"{mat1.shape} @ {mat2.shape}"
)

if out_dtype is not None and not isinstance(out_dtype, torch.dtype):
raise TypeError(
f"hl.dot: out_dtype must be a torch.dtype or None, got {type(out_dtype)}"
)

# Validate accumulator if provided
if acc is not None:
# Allow int32 accumulator for int8 inputs
Expand Down Expand Up @@ -132,7 +142,7 @@ def _(
# Apply min-dot-size constraints so autotuner won't pick invalid block_size
enforce_dot_requirements(mat1, mat2)

return (mat1, mat2, acc)
return (mat1, mat2, acc, out_dtype)


def enforce_dot_requirements(lhs: torch.Tensor, rhs: torch.Tensor) -> None:
Expand All @@ -159,7 +169,10 @@ def enforce_dot_requirements(lhs: torch.Tensor, rhs: torch.Tensor) -> None:

@_decorators.register_fake(dot)
def _(
mat1: torch.Tensor, mat2: torch.Tensor, acc: torch.Tensor | None = None
mat1: torch.Tensor,
mat2: torch.Tensor,
acc: torch.Tensor | None = None,
out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
# Matrix multiplication shape computation
result_shape = list(mat1.shape)
Expand All @@ -169,8 +182,8 @@ def _(
return acc.new_empty(result_shape)

# Determine output dtype using the helper function
out_dtype = _compute_out_dtype(mat1.dtype, mat2.dtype)
return torch.empty(result_shape, dtype=out_dtype, device=mat1.device)
resolved_out_dtype = out_dtype or _compute_out_dtype(mat1.dtype, mat2.dtype)
return torch.empty(result_shape, dtype=resolved_out_dtype, device=mat1.device)


@_decorators.codegen(dot)
Expand All @@ -186,6 +199,7 @@ def _(state: CodegenState) -> object:
rhs_proxy = state.proxy_args[1]
assert isinstance(rhs_proxy, FakeTensor), "rhs_proxy must be a FakeTensor"
acc_proxy = state.proxy_args[2] if len(state.proxy_args) > 2 else None
out_dtype_proxy = state.proxy_args[3] if len(state.proxy_args) > 3 else None

lhs_dtype = lhs_proxy.dtype
rhs_dtype = rhs_proxy.dtype
Expand All @@ -194,6 +208,13 @@ def _(state: CodegenState) -> object:
assert isinstance(acc_proxy, FakeTensor), "acc_proxy must be a FakeTensor"
acc_dtype = acc_proxy.dtype

out_dtype: torch.dtype | None = None
if out_dtype_proxy is not None:
assert isinstance(out_dtype_proxy, torch.dtype), (
"out_dtype must be a torch.dtype"
)
out_dtype = out_dtype_proxy

# Check if accumulator is None
is_acc_none = isinstance(acc_ast, ast.Constant) and acc_ast.value is None

Expand All @@ -216,6 +237,7 @@ def _(state: CodegenState) -> object:
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
acc_shape=acc_shape,
out_dtype=out_dtype,
)


Expand All @@ -224,8 +246,9 @@ def _(
mat1: torch.Tensor,
mat2: torch.Tensor,
acc: torch.Tensor | None = None,
out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
out_dtype = _compute_out_dtype(
resolved_out_dtype = out_dtype or _compute_out_dtype(
mat1.dtype, mat2.dtype, None if acc is None else acc.dtype
)

Expand All @@ -246,11 +269,11 @@ def _(
scale_a,
scale_b,
use_fast_accum=False,
out_dtype=out_dtype,
out_dtype=resolved_out_dtype,
)
else:
# For non-FP8 tensors, use regular matmul
result = torch.mm(mat1, mat2, out_dtype=out_dtype)
result = torch.mm(mat1, mat2, out_dtype=resolved_out_dtype)

if acc is not None:
return acc + result
Expand Down
32 changes: 32 additions & 0 deletions test/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,38 @@ def test_hl_dot_codegen_acc_differs_uses_addition(self):
# Check that we cast the result to acc_dtype
self.assertIn("tl.cast", code3)

@skipIfRefEager("Codegen inspection not applicable in ref eager mode")
def test_hl_dot_out_dtype_argument(self):
@helion.kernel(
config=helion.Config(block_sizes=[32, 32, 32]), dot_precision="tf32"
)
def dot_kernel_out_dtype(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.size()
_, n = y.size()
out = torch.empty([m, n], dtype=torch.float16, device=x.device)

for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float16)
for tile_k in hl.tile(k):
acc = hl.dot(
x[tile_m, tile_k],
y[tile_k, tile_n],
acc=acc,
out_dtype=torch.float16,
)
out[tile_m, tile_n] = acc
return out

x = torch.randn(32, 48, device=DEVICE, dtype=torch.float32)
y = torch.randn(48, 16, device=DEVICE, dtype=torch.float32)

code, result = code_and_output(dot_kernel_out_dtype, (x, y))

self.assertEqual(result.dtype, torch.float16)
expected = (x @ y).to(torch.float16)
torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2)
self.assertIn("out_dtype=tl.float16", code)

def test_torch_matmul_3d(self):
@helion.kernel(static_shapes=True)
def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
Expand Down
Loading