diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 268bfa31c..a1a77afbf 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -611,7 +611,7 @@ def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int | N len(config.reduction_loops) <= self.reduction_loop or config.reduction_loops[self.reduction_loop] is None ): - return next_power_of_2(block_size_info.size_hint()) + return max(1, next_power_of_2(block_size_info.size_hint())) return config.reduction_loops[self.reduction_loop] diff --git a/helion/_compiler/reduction_strategy.py b/helion/_compiler/reduction_strategy.py index e4856fac8..37ef84099 100644 --- a/helion/_compiler/reduction_strategy.py +++ b/helion/_compiler/reduction_strategy.py @@ -98,6 +98,15 @@ def call_reduction_function( return f"triton_helpers.prod({input_name}, {dim})" raise NotImplementedError(f"Unsupported reduction type: {reduction_type}") + def _index_init_expr(self, block_size_var: str, dtype: str, block_idx: int) -> str: + env = CompileEnvironment.current() + size = env.block_sizes[block_idx].size + if isinstance(size, int) and size == 0: + return f"tl.zeros([0], {dtype})" + if isinstance(size, torch.SymInt) and env.known_equal(size, 0): + return f"tl.zeros([0], {dtype})" + return f"tl.arange(0, {block_size_var}).to({dtype})" + def call_argmin_argmax( self, input_name: str, @@ -187,7 +196,7 @@ def codegen_preamble(self, state: CodegenState) -> None: ) state.codegen.host_statements.append(stmt) state.add_statement( - f"{index_var} = tl.arange(0, {block_size_var}).to({env.triton_index_type()})" + f"{index_var} = {self._index_init_expr(block_size_var, env.triton_index_type(), block_idx)}" ) if mask_var is not None: state.add_statement( @@ -213,6 +222,15 @@ def codegen_reduction( fake_input: torch.Tensor, fake_output: torch.Tensor, ) -> ast.AST: + env = CompileEnvironment.current() + numel = env.block_sizes[self.block_index].numel + if isinstance(numel, sympy.Integer) and numel == 0: + default = ir.Reduction.default_accumulator(reduction_type, fake_input.dtype) + assert isinstance(default, (float, int, bool)) + shape = self.fn.tile_strategy.shape_str([*fake_output.size()]) + return expr_from_string( + f"tl.full({shape}, {constant_repr(default)}, {triton_type(fake_output.dtype)})" + ) expr = self.call_reduction_function( input_name, reduction_type, @@ -260,7 +278,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: ) body: list[ast.AST] = [ statement_from_string( - f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({env.triton_index_type()})" + f"{index_var} = {offset_var} + {self._index_init_expr(f'({block_size_var})', env.triton_index_type(), block_index)}" ), ] if (mask_var := self._mask_var) is not None: @@ -395,6 +413,21 @@ def codegen_reduction( ) -> ast.AST: default = ir.Reduction.default_accumulator(reduction_type, fake_input.dtype) assert isinstance(default, (float, int, bool)) + env = CompileEnvironment.current() + dim_size = fake_input.size(dim) + is_zero_dim = False + if ( + isinstance(dim_size, int) + and dim_size == 0 + or isinstance(dim_size, torch.SymInt) + and env.known_equal(dim_size, 0) + ): + is_zero_dim = True + if is_zero_dim: + shape = self.fn.tile_strategy.shape_str([*fake_output.size()]) + return expr_from_string( + f"tl.full({shape}, {constant_repr(default)}, {triton_type(fake_output.dtype)})" + ) expr = self.call_reduction_function( input_name, reduction_type, diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index bf7cf1270..29b07676f 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -364,9 +364,12 @@ def __init__( super().__init__([block_id]) self.size_hint = size_hint self.min_size: int = min_size + bounded_hint = max(size_hint, 1) self.max_size: int = ( - next_power_of_2(size_hint) if max_size is None else max_size + next_power_of_2(bounded_hint) if max_size is None else max_size ) + if self.max_size < self.min_size: + self.max_size = self.min_size assert self.min_size <= self.max_size def __repr__(self) -> str: @@ -388,11 +391,12 @@ def update_min(self, value: int) -> None: self.max_size = self.min_size def update_max(self, value: int) -> None: - self.max_size = assert_integer_power_of_two(min(value, self.max_size)) + clamped = max(value, 1) + self.max_size = assert_integer_power_of_two(min(clamped, self.max_size)) def update_hint(self, value: int) -> None: self.size_hint = value - self.update_max(next_power_of_2(value)) + self.update_max(next_power_of_2(max(value, 1))) def _fragment(self, base: ConfigSpec) -> BlockSizeFragment: total_ndim = len(base.block_sizes) diff --git a/test/test_zero_size.py b/test/test_zero_size.py new file mode 100644 index 000000000..ff32b177f --- /dev/null +++ b/test/test_zero_size.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import RefEagerTestBase +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +class TestZeroSizeTensors(RefEagerTestBase, TestCase): + def test_pointwise_zero_rows(self) -> None: + @helion.kernel(autotune_effort="none") + def pointwise_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.size(0)): + out[tile, :] = x[tile, :] + y[tile, :] + return out + + x = torch.randn([0, 32], device=DEVICE) + y = torch.randn([0, 32], device=DEVICE) + _, result = code_and_output(pointwise_add, (x, y)) + torch.testing.assert_close(result, x + y) + + def test_reduce_zero_inner_dim(self) -> None: + @helion.kernel(autotune_effort="none") + def row_sums(x: torch.Tensor) -> torch.Tensor: + out = torch.empty([x.size(0)], dtype=x.dtype, device=x.device) + for tile in hl.tile(x.size(0)): + rows = x[tile, :] + out[tile] = torch.sum(rows, dim=1) + return out + + x = torch.randn([5, 0], device=DEVICE) + _, result = code_and_output(row_sums, (x,)) + torch.testing.assert_close(result, torch.sum(x, dim=1)) + + def test_local_zero_width_allocation(self) -> None: + @helion.kernel(autotune_effort="none") + def zero_width(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.size(0)): + scratch = hl.zeros([tile, 0], dtype=x.dtype) + out[tile, :] = scratch + return out + + x = torch.empty([4, 0], device=DEVICE) + _, result = code_and_output(zero_width, (x,)) + torch.testing.assert_close(result, torch.zeros_like(x))