From 869d2c6c3e8c34ade1bf791bd61841c61ed25ac4 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 4 Oct 2025 12:05:42 -0700 Subject: [PATCH] Add out_dtype arg to hl.dot Fixes #747 stack-info: PR: https://github.com/pytorch/helion/pull/813, branch: jansel/stack/163 --- helion/_compiler/matmul_utils.py | 6 ++++- helion/language/matmul_ops.py | 39 +++++++++++++++++++++++++------- test/test_dot.py | 32 ++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 9 deletions(-) diff --git a/helion/_compiler/matmul_utils.py b/helion/_compiler/matmul_utils.py index 42defefed..1390e351e 100644 --- a/helion/_compiler/matmul_utils.py +++ b/helion/_compiler/matmul_utils.py @@ -188,7 +188,11 @@ def emit_tl_dot_with_padding( lhs_cast, rhs_cast = cast_ast(lhs, common_dtype), cast_ast(rhs, common_dtype) m, n, k = (_resolve_dim_size(d) for d in (m, n, k)) - fuse_acc = acc is not None and acc_dtype in (common_dtype, torch.float32) + fuse_acc = ( + acc is not None + and acc_dtype in (common_dtype, torch.float32) + and (out_dtype is None or out_dtype == acc_dtype) + ) acc_out = acc if not fuse_acc else None acc_for_dot = acc if fuse_acc else None acc_cast_dtype = acc_dtype if not fuse_acc else None diff --git a/helion/language/matmul_ops.py b/helion/language/matmul_ops.py index 8e475d46e..dbf5ef3af 100644 --- a/helion/language/matmul_ops.py +++ b/helion/language/matmul_ops.py @@ -22,6 +22,7 @@ def dot( mat1: torch.Tensor, mat2: torch.Tensor, acc: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, ) -> torch.Tensor: """ Performs a matrix multiplication of tensors with support for multiple dtypes. @@ -36,6 +37,9 @@ def dot( acc: The accumulator tensor (2D or 3D tensor of torch.float16, torch.float32, or torch.int32). If not None, the result is added to this tensor. If None, a new tensor is created with appropriate dtype based on inputs. + out_dtype: Optional dtype that controls the output type of the multiplication prior + to any accumulation. This maps directly to the Triton ``tl.dot`` ``out_dtype`` + argument and overrides the default promotion rules when provided. Returns: Result of matrix multiplication. If acc is provided, returns acc + (mat1 @ mat2). @@ -67,7 +71,8 @@ def _( mat1: torch.Tensor, mat2: torch.Tensor, acc: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + out_dtype: torch.dtype | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.dtype | None]: # Define supported dtypes supported_dtypes = ( torch.float16, @@ -101,6 +106,11 @@ def _( f"{mat1.shape} @ {mat2.shape}" ) + if out_dtype is not None and not isinstance(out_dtype, torch.dtype): + raise TypeError( + f"hl.dot: out_dtype must be a torch.dtype or None, got {type(out_dtype)}" + ) + # Validate accumulator if provided if acc is not None: # Allow int32 accumulator for int8 inputs @@ -132,7 +142,7 @@ def _( # Apply min-dot-size constraints so autotuner won't pick invalid block_size enforce_dot_requirements(mat1, mat2) - return (mat1, mat2, acc) + return (mat1, mat2, acc, out_dtype) def enforce_dot_requirements(lhs: torch.Tensor, rhs: torch.Tensor) -> None: @@ -159,7 +169,10 @@ def enforce_dot_requirements(lhs: torch.Tensor, rhs: torch.Tensor) -> None: @_decorators.register_fake(dot) def _( - mat1: torch.Tensor, mat2: torch.Tensor, acc: torch.Tensor | None = None + mat1: torch.Tensor, + mat2: torch.Tensor, + acc: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, ) -> torch.Tensor: # Matrix multiplication shape computation result_shape = list(mat1.shape) @@ -169,8 +182,8 @@ def _( return acc.new_empty(result_shape) # Determine output dtype using the helper function - out_dtype = _compute_out_dtype(mat1.dtype, mat2.dtype) - return torch.empty(result_shape, dtype=out_dtype, device=mat1.device) + resolved_out_dtype = out_dtype or _compute_out_dtype(mat1.dtype, mat2.dtype) + return torch.empty(result_shape, dtype=resolved_out_dtype, device=mat1.device) @_decorators.codegen(dot) @@ -186,6 +199,7 @@ def _(state: CodegenState) -> object: rhs_proxy = state.proxy_args[1] assert isinstance(rhs_proxy, FakeTensor), "rhs_proxy must be a FakeTensor" acc_proxy = state.proxy_args[2] if len(state.proxy_args) > 2 else None + out_dtype_proxy = state.proxy_args[3] if len(state.proxy_args) > 3 else None lhs_dtype = lhs_proxy.dtype rhs_dtype = rhs_proxy.dtype @@ -194,6 +208,13 @@ def _(state: CodegenState) -> object: assert isinstance(acc_proxy, FakeTensor), "acc_proxy must be a FakeTensor" acc_dtype = acc_proxy.dtype + out_dtype: torch.dtype | None = None + if out_dtype_proxy is not None: + assert isinstance(out_dtype_proxy, torch.dtype), ( + "out_dtype must be a torch.dtype" + ) + out_dtype = out_dtype_proxy + # Check if accumulator is None is_acc_none = isinstance(acc_ast, ast.Constant) and acc_ast.value is None @@ -216,6 +237,7 @@ def _(state: CodegenState) -> object: lhs_shape=lhs_shape, rhs_shape=rhs_shape, acc_shape=acc_shape, + out_dtype=out_dtype, ) @@ -224,8 +246,9 @@ def _( mat1: torch.Tensor, mat2: torch.Tensor, acc: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, ) -> torch.Tensor: - out_dtype = _compute_out_dtype( + resolved_out_dtype = out_dtype or _compute_out_dtype( mat1.dtype, mat2.dtype, None if acc is None else acc.dtype ) @@ -246,11 +269,11 @@ def _( scale_a, scale_b, use_fast_accum=False, - out_dtype=out_dtype, + out_dtype=resolved_out_dtype, ) else: # For non-FP8 tensors, use regular matmul - result = torch.mm(mat1, mat2, out_dtype=out_dtype) + result = torch.mm(mat1, mat2, out_dtype=resolved_out_dtype) if acc is not None: return acc + result diff --git a/test/test_dot.py b/test/test_dot.py index 1642bd899..3899738d4 100644 --- a/test/test_dot.py +++ b/test/test_dot.py @@ -224,6 +224,38 @@ def test_hl_dot_codegen_acc_differs_uses_addition(self): # Check that we cast the result to acc_dtype self.assertIn("tl.cast", code3) + @skipIfRefEager("Codegen inspection not applicable in ref eager mode") + def test_hl_dot_out_dtype_argument(self): + @helion.kernel( + config=helion.Config(block_sizes=[32, 32, 32]), dot_precision="tf32" + ) + def dot_kernel_out_dtype(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=torch.float16, device=x.device) + + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float16) + for tile_k in hl.tile(k): + acc = hl.dot( + x[tile_m, tile_k], + y[tile_k, tile_n], + acc=acc, + out_dtype=torch.float16, + ) + out[tile_m, tile_n] = acc + return out + + x = torch.randn(32, 48, device=DEVICE, dtype=torch.float32) + y = torch.randn(48, 16, device=DEVICE, dtype=torch.float32) + + code, result = code_and_output(dot_kernel_out_dtype, (x, y)) + + self.assertEqual(result.dtype, torch.float16) + expected = (x @ y).to(torch.float16) + torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2) + self.assertIn("out_dtype=tl.float16", code) + def test_torch_matmul_3d(self): @helion.kernel(static_shapes=True) def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: