Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int | N
len(config.reduction_loops) <= self.reduction_loop
or config.reduction_loops[self.reduction_loop] is None
):
return next_power_of_2(block_size_info.size_hint())
return max(1, next_power_of_2(block_size_info.size_hint()))
return config.reduction_loops[self.reduction_loop]


Expand Down
37 changes: 35 additions & 2 deletions helion/_compiler/reduction_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ def call_reduction_function(
return f"triton_helpers.prod({input_name}, {dim})"
raise NotImplementedError(f"Unsupported reduction type: {reduction_type}")

def _index_init_expr(self, block_size_var: str, dtype: str, block_idx: int) -> str:
env = CompileEnvironment.current()
size = env.block_sizes[block_idx].size
if isinstance(size, int) and size == 0:
return f"tl.zeros([0], {dtype})"
if isinstance(size, torch.SymInt) and env.known_equal(size, 0):
return f"tl.zeros([0], {dtype})"
return f"tl.arange(0, {block_size_var}).to({dtype})"

def call_argmin_argmax(
self,
input_name: str,
Expand Down Expand Up @@ -187,7 +196,7 @@ def codegen_preamble(self, state: CodegenState) -> None:
)
state.codegen.host_statements.append(stmt)
state.add_statement(
f"{index_var} = tl.arange(0, {block_size_var}).to({env.triton_index_type()})"
f"{index_var} = {self._index_init_expr(block_size_var, env.triton_index_type(), block_idx)}"
)
if mask_var is not None:
state.add_statement(
Expand All @@ -213,6 +222,15 @@ def codegen_reduction(
fake_input: torch.Tensor,
fake_output: torch.Tensor,
) -> ast.AST:
env = CompileEnvironment.current()
numel = env.block_sizes[self.block_index].numel
if isinstance(numel, sympy.Integer) and numel == 0:
default = ir.Reduction.default_accumulator(reduction_type, fake_input.dtype)
assert isinstance(default, (float, int, bool))
shape = self.fn.tile_strategy.shape_str([*fake_output.size()])
return expr_from_string(
f"tl.full({shape}, {constant_repr(default)}, {triton_type(fake_output.dtype)})"
)
expr = self.call_reduction_function(
input_name,
reduction_type,
Expand Down Expand Up @@ -260,7 +278,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
)
body: list[ast.AST] = [
statement_from_string(
f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({env.triton_index_type()})"
f"{index_var} = {offset_var} + {self._index_init_expr(f'({block_size_var})', env.triton_index_type(), block_index)}"
),
]
if (mask_var := self._mask_var) is not None:
Expand Down Expand Up @@ -395,6 +413,21 @@ def codegen_reduction(
) -> ast.AST:
default = ir.Reduction.default_accumulator(reduction_type, fake_input.dtype)
assert isinstance(default, (float, int, bool))
env = CompileEnvironment.current()
dim_size = fake_input.size(dim)
is_zero_dim = False
if (
isinstance(dim_size, int)
and dim_size == 0
or isinstance(dim_size, torch.SymInt)
and env.known_equal(dim_size, 0)
):
is_zero_dim = True
if is_zero_dim:
shape = self.fn.tile_strategy.shape_str([*fake_output.size()])
return expr_from_string(
f"tl.full({shape}, {constant_repr(default)}, {triton_type(fake_output.dtype)})"
)
expr = self.call_reduction_function(
input_name,
reduction_type,
Expand Down
10 changes: 7 additions & 3 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,12 @@ def __init__(
super().__init__([block_id])
self.size_hint = size_hint
self.min_size: int = min_size
bounded_hint = max(size_hint, 1)
self.max_size: int = (
next_power_of_2(size_hint) if max_size is None else max_size
next_power_of_2(bounded_hint) if max_size is None else max_size
)
if self.max_size < self.min_size:
self.max_size = self.min_size
assert self.min_size <= self.max_size

def __repr__(self) -> str:
Expand All @@ -388,11 +391,12 @@ def update_min(self, value: int) -> None:
self.max_size = self.min_size

def update_max(self, value: int) -> None:
self.max_size = assert_integer_power_of_two(min(value, self.max_size))
clamped = max(value, 1)
self.max_size = assert_integer_power_of_two(min(clamped, self.max_size))

def update_hint(self, value: int) -> None:
self.size_hint = value
self.update_max(next_power_of_2(value))
self.update_max(next_power_of_2(max(value, 1)))

def _fragment(self, base: ConfigSpec) -> BlockSizeFragment:
total_ndim = len(base.block_sizes)
Expand Down
51 changes: 51 additions & 0 deletions test/test_zero_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

import torch

import helion
from helion._testing import DEVICE
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
import helion.language as hl


class TestZeroSizeTensors(RefEagerTestBase, TestCase):
def test_pointwise_zero_rows(self) -> None:
@helion.kernel(autotune_effort="none")
def pointwise_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size(0)):
out[tile, :] = x[tile, :] + y[tile, :]
return out

x = torch.randn([0, 32], device=DEVICE)
y = torch.randn([0, 32], device=DEVICE)
_, result = code_and_output(pointwise_add, (x, y))
torch.testing.assert_close(result, x + y)

def test_reduce_zero_inner_dim(self) -> None:
@helion.kernel(autotune_effort="none")
def row_sums(x: torch.Tensor) -> torch.Tensor:
out = torch.empty([x.size(0)], dtype=x.dtype, device=x.device)
for tile in hl.tile(x.size(0)):
rows = x[tile, :]
out[tile] = torch.sum(rows, dim=1)
return out

x = torch.randn([5, 0], device=DEVICE)
_, result = code_and_output(row_sums, (x,))
torch.testing.assert_close(result, torch.sum(x, dim=1))

def test_local_zero_width_allocation(self) -> None:
@helion.kernel(autotune_effort="none")
def zero_width(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size(0)):
scratch = hl.zeros([tile, 0], dtype=x.dtype)
out[tile, :] = scratch
return out

x = torch.empty([4, 0], device=DEVICE)
_, result = code_and_output(zero_width, (x,))
torch.testing.assert_close(result, torch.zeros_like(x))
Loading