diff --git a/helion/_compiler/host_function.py b/helion/_compiler/host_function.py index 6b9ff44d8..4238bc3ad 100644 --- a/helion/_compiler/host_function.py +++ b/helion/_compiler/host_function.py @@ -24,6 +24,7 @@ from .output_header import SOURCE_MODULE from .source_location import SourceLocation from .source_location import UnknownLocation +from .tensor_utils import patch_tensor_factories from .type_printer import print_ast from .variable_origin import AttributeOrigin from .variable_origin import GlobalOrigin @@ -112,7 +113,8 @@ def __init__( unroll_static_loops(self) propagate_types(self) env.finalize_config_spec() - self.device_ir = lower_to_device_ir(self) + with patch_tensor_factories(): + self.device_ir = lower_to_device_ir(self) @staticmethod def validate_ast(root: ast.FunctionDef) -> None: diff --git a/helion/_compiler/tensor_utils.py b/helion/_compiler/tensor_utils.py new file mode 100644 index 000000000..0c0cc5035 --- /dev/null +++ b/helion/_compiler/tensor_utils.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import Callable +from typing import ClassVar + +import torch +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map +from triton import next_power_of_2 + + +class _PadTensorFactoryMode(TorchDispatchMode): + """Dispatch mode that pads tensor factory size arguments.""" + + _SIZE_ARG_INDEX: ClassVar[dict[Callable[..., torch.Tensor], int]] = { + torch.ops.aten.zeros.default: 0, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.ones.default: 0, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.empty.memory_format: 0, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.full.default: 0, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.new_empty.default: 1, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.new_full.default: 1, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.new_zeros.default: 1, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.new_ones.default: 1, # pyright: ignore[reportAttributeAccessIssue] + } + + def __torch_dispatch__( + self, + func: Callable[..., torch.Tensor], + types: tuple[type, ...], + args: tuple[object, ...] = (), + kwargs: dict[str, object] | None = None, + ) -> torch.Tensor: + def _pad_shape(shape: object) -> object: + """Pad positive integer dimension sizes to the next power of 2.""" + + def _pad_dim(dim_size: object) -> object: + if isinstance(dim_size, int) and dim_size > 0: + return next_power_of_2(dim_size) + return dim_size + + return tree_map(_pad_dim, shape) + + kwargs = dict(kwargs or {}) + size_index = self._SIZE_ARG_INDEX.get(func) + if size_index is not None: + if "size" in kwargs: + kwargs["size"] = _pad_shape(kwargs["size"]) + elif size_index < len(args): + args_list = list(args) + args_list[size_index] = _pad_shape(args_list[size_index]) + args = tuple(args_list) + return func(*args, **kwargs) + + +patch_tensor_factories = _PadTensorFactoryMode + + +__all__ = ["patch_tensor_factories"] diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 3a408f438..635731886 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -42,6 +42,7 @@ from .host_function import SymbolOrigin from .output_header import library_imports from .source_location import current_location +from .tensor_utils import patch_tensor_factories from .utils import compute_slice_size from .variable_origin import ArgumentOrigin from .variable_origin import AttributeOrigin @@ -1042,7 +1043,8 @@ def proxy(self) -> object: torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue] ) try: - return Tile(self.block_id) + with torch._C._DisableTorchDispatch(): # pyright: ignore[reportAttributeAccessIssue] + return Tile(self.block_id) finally: assert fake_mode is not None torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue] @@ -2191,12 +2193,18 @@ def visit_For(self, node: ast.For) -> TypeInfo: raise exc.NestedGridLoop self.device_loop_depth += device_loop - body = self._loop_body(node.body) - with self.swap_scope(body): - # second pass for fixed point - body.merge(self._loop_body(node.body)) - orelse = self._body(node.orelse) - self.scope.merge_if_else(body, orelse) + _maybe_patch_tensor_factories = ( + patch_tensor_factories + if self.device_loop_depth > 0 + else contextlib.nullcontext + ) + with _maybe_patch_tensor_factories(): + body = self._loop_body(node.body) + with self.swap_scope(body): + # second pass for fixed point + body.merge(self._loop_body(node.body)) + orelse = self._body(node.orelse) + self.scope.merge_if_else(body, orelse) self.device_loop_depth -= device_loop return NoType(origin=self.origin()) diff --git a/test/test_specialize.expected b/test/test_specialize.expected index 5fd75107f..53052fc0a 100644 --- a/test/test_specialize.expected +++ b/test/test_specialize.expected @@ -185,6 +185,42 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestSpecialize.test_hl_zeros_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 0.0, tl.float32) + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + grad_w_m_copy = grad_w_m + grad_w_m_copy_0 = grad_w_m_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + grad_w_m = grad_w_m_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), grad_w_m, mask_2) + +def reduce_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + --- assertExpectedJournal(TestSpecialize.test_specialize_host) from __future__ import annotations @@ -270,3 +306,639 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 32 _launcher(_helion_fn, (triton.cdiv(x.size(0) * x.size(1), _BLOCK_SIZE_0_1), 1, 1), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), scale, _BLOCK_SIZE_0_1, num_warps=4, num_stages=3) return out + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_specialize as _source_module + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 0, tl.float32) + full_1 = tl.full([64], 0, tl.float32) + v_0 = grad_w_m * full_1 + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_0_copy = v_0 + v_0_copy_0 = v_0_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_0 = v_0_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_0, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_specialize as _source_module + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 1, tl.float32) + full_1 = tl.full([64], 0, tl.float32) + v_0 = grad_w_m * full_1 + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_0_copy = v_0 + v_0_copy_0 = v_0_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_0 = v_0_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_0, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_specialize as _source_module + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + v_0 = 0.0 + full = tl.full([64], 0, tl.float32) + v_1 = v_0 * full + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_1_copy = v_1 + v_1_copy_0 = v_1_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_1 = v_1_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_1, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_specialize as _source_module + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 1.0, tl.float32) + full_1 = tl.full([64], 0, tl.float32) + v_0 = grad_w_m * full_1 + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_0_copy = v_0 + v_0_copy_0 = v_0_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_0 = v_0_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_0, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 0, tl.float32) + full_1 = tl.full([64], 0, tl.float32) + v_0 = grad_w_m * full_1 + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_0_copy = v_0 + v_0_copy_0 = v_0_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_0 = v_0_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_0, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 1, tl.float32) + full_1 = tl.full([64], 0, tl.float32) + v_0 = grad_w_m * full_1 + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_0_copy = v_0 + v_0_copy_0 = v_0_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_0 = v_0_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_0, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + v_0 = 0.0 + full = tl.full([64], 0, tl.float32) + v_1 = v_0 * full + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_1_copy = v_1 + v_1_copy_0 = v_1_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_1 = v_1_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_1, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 1.0, tl.float32) + full_1 = tl.full([64], 0, tl.float32) + v_0 = grad_w_m * full_1 + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_0_copy = v_0 + v_0_copy_0 = v_0_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_0 = v_0_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_0, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_specialize as _source_module + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 0.0, tl.float32) + full_1 = tl.full([64], 0, tl.float32) + v_0 = grad_w_m * full_1 + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_0_copy = v_0 + v_0_copy_0 = v_0_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_0 = v_0_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_0, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_specialize as _source_module + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 1.0, tl.float32) + full_1 = tl.full([64], 0, tl.float32) + v_0 = grad_w_m * full_1 + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_0_copy = v_0 + v_0_copy_0 = v_0_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_0 = v_0_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_0, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_factory_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_specialize as _source_module + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 1.0, tl.float32) + full_1 = tl.full([64], 0, tl.float32) + v_0 = grad_w_m * full_1 + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + v_0_copy = v_0 + v_0_copy_0 = v_0_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + v_0 = v_0_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), v_0, mask_2) + +def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_tensor_new_zeros_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 0, tl.float32) + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + grad_w_m_copy = grad_w_m + grad_w_m_copy_0 = grad_w_m_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + grad_w_m = grad_w_m_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), grad_w_m, mask_2) + +def reduce_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + host_buffer = x.new_zeros(weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_torch_zeros_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 0, tl.float32) + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + grad_w_m_copy = grad_w_m + grad_w_m_copy_0 = grad_w_m_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + grad_w_m = grad_w_m_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), grad_w_m, mask_2) + +def reduce_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + host_buffer = x.new_zeros(weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_zeros_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 0, tl.float32) + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + grad_w_m_copy = grad_w_m + grad_w_m_copy_0 = grad_w_m_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + grad_w_m = grad_w_m_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), grad_w_m, mask_2) + +def reduce_kernel(x: torch.Tensor, zeros_factory, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = zeros_factory(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) + +--- assertExpectedJournal(TestSpecialize.test_zeros_specialize_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_specialize as _source_module + +@triton.jit +def _helion_reduce_kernel(x, grad_weight, x_size_0, grad_weight_stride_0, grad_weight_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < 56 + grad_w_m = tl.full([64], 0.0, tl.float32) + tile_end = tl.minimum(offset_0 + _BLOCK_SIZE_0, x_size_0) + for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < tile_end + grad_w_m_copy = grad_w_m + grad_w_m_copy_0 = grad_w_m_copy + load = tl.load(x + (indices_2[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_1[:, None] & mask_2[None, :], other=0) + sum_1 = tl.cast(tl.sum(load, 0), tl.float32) + grad_w_m = grad_w_m_copy_0 + sum_1 + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), grad_w_m, mask_2) + +def reduce_kernel(x: torch.Tensor, zeros_factory, test_host, *, _launcher=_default_launcher): + m_block = 32 + grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32) + weight_shape = 56 + if test_host: + host_buffer = zeros_factory(x, weight_shape, dtype=torch.float32) + assert host_buffer.size(0) == 56 + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_reduce_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_weight, x.size(0), grad_weight.stride(0), grad_weight.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return grad_weight.sum(0).to(x.dtype) diff --git a/test/test_specialize.py b/test/test_specialize.py index 6dc9e7e2e..7a8487cfb 100644 --- a/test/test_specialize.py +++ b/test/test_specialize.py @@ -225,6 +225,65 @@ def fn( ) self.assertExpectedJournal(code) + def test_tensor_factory_specialize_non_power_of_2(self): + def _test_with_factory(factory_fn, test_host=True): + @helion.kernel() + def reduce_kernel( + x: torch.Tensor, tensor_factory_fn, test_host + ) -> torch.Tensor: + m_block = hl.register_block_size(x.size(0)) + grad_weight = x.new_empty( + [(x.size(0) + m_block - 1) // m_block, x.size(1)], + dtype=torch.float32, + ) + weight_shape = hl.specialize(x.size(1)) + if test_host: + # Host-side tensor creation should NOT be padded + host_buffer = tensor_factory_fn( + x, weight_shape, dtype=torch.float32 + ) + # Verify host-side tensor has correct non-padded size + assert host_buffer.size(0) == 56 + for mb_cta in hl.tile(x.size(0), block_size=m_block): + # Device-side tensor creation should be padded to 64 + grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) + # Set to 0 to normalize different factory functions + grad_w_m = grad_w_m * grad_w_m.new_zeros(weight_shape) + for mb in hl.tile(mb_cta.begin, mb_cta.end): + grad_w_m += x[mb, :].to(torch.float32).sum(0) + grad_weight[mb_cta.id, :] = grad_w_m + return grad_weight.sum(0).to(x.dtype) + + x = torch.randn([128, 56], device=DEVICE, dtype=torch.float32) + code, result = code_and_output(reduce_kernel, (x, factory_fn, test_host)) + reference = x.sum(0) + torch.testing.assert_close(result, reference, rtol=1e-3, atol=1e-3) + self.assertExpectedJournal(code) + + for name in ["zeros", "ones", "empty"]: + _test_with_factory( + lambda x, s, factory_name=name, **kw: getattr(torch, factory_name)( + s, device=x.device, **kw + ) + ) + _test_with_factory( + lambda x, s, **kw: torch.full([s], 1.0, device=x.device, **kw) + ) + + for name in ["zeros", "ones", "empty"]: + _test_with_factory( + lambda x, s, method_name=name, **kw: getattr(x, f"new_{method_name}")( + s, **kw + ), + test_host=True, + ) + _test_with_factory( + lambda x, s, **kw: x.new_full([s], 1.0, **kw), test_host=True + ) + + _test_with_factory(lambda x, s, **kw: hl.zeros([s], **kw), test_host=False) + _test_with_factory(lambda x, s, **kw: hl.full([s], 1.0, **kw), test_host=False) + def test_specialize_reduce(self): @helion.kernel() def fn(