Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion helion/_compiler/host_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .output_header import SOURCE_MODULE
from .source_location import SourceLocation
from .source_location import UnknownLocation
from .tensor_utils import patch_tensor_factories
from .type_printer import print_ast
from .variable_origin import AttributeOrigin
from .variable_origin import GlobalOrigin
Expand Down Expand Up @@ -112,7 +113,8 @@ def __init__(
unroll_static_loops(self)
propagate_types(self)
env.finalize_config_spec()
self.device_ir = lower_to_device_ir(self)
with patch_tensor_factories():
self.device_ir = lower_to_device_ir(self)

@staticmethod
def validate_ast(root: ast.FunctionDef) -> None:
Expand Down
58 changes: 58 additions & 0 deletions helion/_compiler/tensor_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

from typing import Callable
from typing import ClassVar

import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from triton import next_power_of_2


class _PadTensorFactoryMode(TorchDispatchMode):
Copy link
Contributor Author

@yf225 yf225 Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use TorchDispatchMode to intercept tensor factory ops at the aten level and avoid monkey-patching.

"""Dispatch mode that pads tensor factory size arguments."""

_SIZE_ARG_INDEX: ClassVar[dict[Callable[..., torch.Tensor], int]] = {
torch.ops.aten.zeros.default: 0, # pyright: ignore[reportAttributeAccessIssue]
torch.ops.aten.ones.default: 0, # pyright: ignore[reportAttributeAccessIssue]
torch.ops.aten.empty.memory_format: 0, # pyright: ignore[reportAttributeAccessIssue]
torch.ops.aten.full.default: 0, # pyright: ignore[reportAttributeAccessIssue]
torch.ops.aten.new_empty.default: 1, # pyright: ignore[reportAttributeAccessIssue]
torch.ops.aten.new_full.default: 1, # pyright: ignore[reportAttributeAccessIssue]
torch.ops.aten.new_zeros.default: 1, # pyright: ignore[reportAttributeAccessIssue]
torch.ops.aten.new_ones.default: 1, # pyright: ignore[reportAttributeAccessIssue]
}

def __torch_dispatch__(
self,
func: Callable[..., torch.Tensor],
types: tuple[type, ...],
args: tuple[object, ...] = (),
kwargs: dict[str, object] | None = None,
) -> torch.Tensor:
def _pad_shape(shape: object) -> object:
"""Pad positive integer dimension sizes to the next power of 2."""

def _pad_dim(dim_size: object) -> object:
if isinstance(dim_size, int) and dim_size > 0:
return next_power_of_2(dim_size)
return dim_size

return tree_map(_pad_dim, shape)

kwargs = dict(kwargs or {})
size_index = self._SIZE_ARG_INDEX.get(func)
if size_index is not None:
if "size" in kwargs:
kwargs["size"] = _pad_shape(kwargs["size"])
elif size_index < len(args):
args_list = list(args)
args_list[size_index] = _pad_shape(args_list[size_index])
args = tuple(args_list)
return func(*args, **kwargs)


patch_tensor_factories = _PadTensorFactoryMode


__all__ = ["patch_tensor_factories"]
22 changes: 15 additions & 7 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .host_function import SymbolOrigin
from .output_header import library_imports
from .source_location import current_location
from .tensor_utils import patch_tensor_factories
from .utils import compute_slice_size
from .variable_origin import ArgumentOrigin
from .variable_origin import AttributeOrigin
Expand Down Expand Up @@ -1042,7 +1043,8 @@ def proxy(self) -> object:
torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue]
)
try:
return Tile(self.block_id)
with torch._C._DisableTorchDispatch(): # pyright: ignore[reportAttributeAccessIssue]
return Tile(self.block_id)
finally:
assert fake_mode is not None
torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue]
Expand Down Expand Up @@ -2191,12 +2193,18 @@ def visit_For(self, node: ast.For) -> TypeInfo:
raise exc.NestedGridLoop

self.device_loop_depth += device_loop
body = self._loop_body(node.body)
with self.swap_scope(body):
# second pass for fixed point
body.merge(self._loop_body(node.body))
orelse = self._body(node.orelse)
self.scope.merge_if_else(body, orelse)
_maybe_patch_tensor_factories = (
patch_tensor_factories
if self.device_loop_depth > 0
else contextlib.nullcontext
)
with _maybe_patch_tensor_factories():
body = self._loop_body(node.body)
with self.swap_scope(body):
# second pass for fixed point
body.merge(self._loop_body(node.body))
orelse = self._body(node.orelse)
self.scope.merge_if_else(body, orelse)
self.device_loop_depth -= device_loop
return NoType(origin=self.origin())

Expand Down
Loading
Loading