Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,7 @@ def _fragment(self, base: ConfigSpec) -> BlockSizeFragment:
reduction_numel = _product(
[next_power_of_2(spec.size_hint) for spec in base.reduction_loops]
)
if total_ndim <= 1 and reduction_numel <= 1:
default = 256
elif total_ndim <= 2 and reduction_numel <= 128:
if total_ndim <= 2 and reduction_numel <= 128:
default = 32
elif reduction_numel <= 256:
default = 16
Expand Down
10 changes: 5 additions & 5 deletions test/test_closures.expected
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def sin_func_arg(a, fn, *, _launcher=_default_launcher):
# src[test_closures.py:N]: out = torch.empty_like(a)
out = torch.empty_like(a)
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
# src[test_closures.py:N]: out[tile] = fn(torch.sin(a[tile]), tile)
_launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, fn.__closure__[0].cell_contents, out, a.size(0), a.stride(0), fn.__closure__[0].cell_contents.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1)
Expand Down Expand Up @@ -112,7 +112,7 @@ def sin_func_arg(a, fn, *, _launcher=_default_launcher):
# src[test_closures.py:N]: out = torch.empty_like(a)
out = torch.empty_like(a)
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
# src[test_closures.py:N]: out[tile] = fn(torch.sin(a[tile]), tile)
_launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, _source_module.global_tensor, out, a.size(0), _source_module.global_tensor.stride(0), a.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1)
Expand Down Expand Up @@ -147,7 +147,7 @@ def sin_func_arg(a, fn, *, _launcher=_default_launcher):
# src[test_closures.py:N]: out = torch.empty_like(a)
out = torch.empty_like(a)
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
# src[test_closures.py:N]: out[tile] = fn(torch.sin(a[tile]), tile)
_launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, out, a.size(0), a.stride(0), out.stride(0), _global_source0.global_float, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
Expand Down Expand Up @@ -183,7 +183,7 @@ def sin_func_arg(a, fn, *, _launcher=_default_launcher):
# src[test_closures.py:N]: out = torch.empty_like(a)
out = torch.empty_like(a)
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
# src[test_closures.py:N]: out[tile] = fn(torch.sin(a[tile]), tile)
_launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, fn.__closure__[0].cell_contents.__closure__[0].cell_contents, out, a.size(0), a.stride(0), fn.__closure__[0].cell_contents.__closure__[0].cell_contents.stride(0), out.stride(0), fn.__closure__[1].cell_contents, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
Expand Down Expand Up @@ -215,7 +215,7 @@ def call_func_arg_on_host(a, alloc, *, _launcher=_default_launcher):
# src[test_closures.py:N]: out = alloc(a)
out = alloc(a)
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[test_closures.py:N]: for tile in hl.tile(a.size()):
# src[test_closures.py:N]: out[tile] = a[tile].sin()
_launcher(_helion_call_func_arg_on_host, (triton.cdiv(512, _BLOCK_SIZE_0),), a, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
Expand Down
4 changes: 2 additions & 2 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -4136,7 +4136,7 @@ def low_mem_dropout(p: float, x: torch.Tensor, seed: int, *, _launcher=_default_
# src[low_mem_dropout.py:N]: out_flat = torch.empty_like(x_flat)
out_flat = torch.empty_like(x_flat)
# src[low_mem_dropout.py:N]: for tidx in hl.tile(n):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[low_mem_dropout.py:N]: for tidx in hl.tile(n):
# src[low_mem_dropout.py:N]: xi = x_flat[tidx].to(torch.float32)
# src[low_mem_dropout.py:N]: r = hl.rand([tidx], seed=seed)
Expand Down Expand Up @@ -6077,7 +6077,7 @@ def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor, *, _launcher=_default_launc
# src[swiglu.py:N]: dx2_flat = dx2.view(-1)
dx2_flat = dx2.view(-1)
# src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
# src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32)
# src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32)
Expand Down
4 changes: 2 additions & 2 deletions test/test_graph_module.expected
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def apply_graph_module(func_m, x, *, _launcher=_default_launcher):
# src[test_graph_module.py:N]: out = torch.empty_like(x)
out = torch.empty_like(x)
# src[test_graph_module.py:N]: for tile in hl.tile(out.size()):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[test_graph_module.py:N]: for tile in hl.tile(out.size()):
# src[test_graph_module.py:N]: out[tile] = func_m(x[tile])
_launcher(_helion_apply_graph_module, (triton.cdiv(1000, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
Expand Down Expand Up @@ -68,7 +68,7 @@ def apply_graph_module(func_m, x, *, _launcher=_default_launcher):
# src[test_graph_module.py:N]: out = torch.empty_like(x)
out = torch.empty_like(x)
# src[test_graph_module.py:N]: for tile in hl.tile(out.size()):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[test_graph_module.py:N]: for tile in hl.tile(out.size()):
# src[test_graph_module.py:N]: out[tile] = func_m(x[tile])
_launcher(_helion_apply_graph_module, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
Expand Down
4 changes: 2 additions & 2 deletions test/test_indexing.expected
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def arange_block_size_mul(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_indexing.py:N]: out = torch.zeros([x.size(0) * 2], dtype=torch.int32, device=x.device)
out = torch.zeros([x.size(0) * 2], dtype=torch.int32, device=x.device)
# src[test_indexing.py:N]: for tile in hl.tile(x.size(0)):
_BLOCK_SIZE_0 = 64
_BLOCK_SIZE_0 = 32
# src[test_indexing.py:N]: for tile in hl.tile(x.size(0)):
# src[test_indexing.py:N]: indices = hl.arange(
# src[test_indexing.py:N]: tile.begin * 2, tile.begin * 2 + tile.block_size * 2
Expand Down Expand Up @@ -816,7 +816,7 @@ def arange_block_size_mul(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_indexing.py:N]: ones = torch.ones_like(out)
ones = torch.ones_like(out)
# src[test_indexing.py:N]: for tile in hl.tile(x.size(0)):
_BLOCK_SIZE_0 = 64
_BLOCK_SIZE_0 = 32
# src[test_indexing.py:N]: for tile in hl.tile(x.size(0)):
# src[test_indexing.py:N]: indices_start = tile.begin * 2
# src[test_indexing.py:N]: indices_end = indices_start + tile.block_size * 2
Expand Down
6 changes: 3 additions & 3 deletions test/test_inline_asm_elementwise.expected
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def kernel_multiple_outputs(a: torch.Tensor, b: torch.Tensor, *, _launcher=_defa
# src[test_inline_asm_elementwise.py:N]: result_d = torch.empty_like(a)
result_d = torch.empty_like(a)
# src[test_inline_asm_elementwise.py:N]: for tile in hl.tile(a.shape):
_BLOCK_SIZE_0 = 64
_BLOCK_SIZE_0 = 32
# src[test_inline_asm_elementwise.py:N]: for tile in hl.tile(a.shape):
# src[test_inline_asm_elementwise.py:N]: val_a = a[tile]
# src[test_inline_asm_elementwise.py:N]: val_b = b[tile]
Expand Down Expand Up @@ -146,7 +146,7 @@ def kernel_packed_asm(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_inline_asm_elementwise.py:N]: result = torch.empty_like(x)
result = torch.empty_like(x)
# src[test_inline_asm_elementwise.py:N]: for tile in hl.tile(x.shape):
_BLOCK_SIZE_0 = 256
_BLOCK_SIZE_0 = 32
# src[test_inline_asm_elementwise.py:N]: for tile in hl.tile(x.shape):
# src[test_inline_asm_elementwise.py:N]: val = x[tile]
# src[test_inline_asm_elementwise.py:N]: # Shift 4x8bit values together, pack=4
Expand Down Expand Up @@ -187,7 +187,7 @@ def kernel_shift_asm(x: torch.Tensor, y: torch.Tensor, n: int, *, _launcher=_def
# src[test_inline_asm_elementwise.py:N]: result = torch.empty_like(x)
result = torch.empty_like(x)
# src[test_inline_asm_elementwise.py:N]: for tile in hl.tile(x.shape):
_BLOCK_SIZE_0 = 128
_BLOCK_SIZE_0 = 32
# src[test_inline_asm_elementwise.py:N]: for tile in hl.tile(x.shape):
# src[test_inline_asm_elementwise.py:N]: val_x = x[tile]
# src[test_inline_asm_elementwise.py:N]: val_y = y[tile]
Expand Down
4 changes: 2 additions & 2 deletions test/test_inline_triton.expected
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher):
# src[test_inline_triton.py:N]: diff_out = torch.empty_like(a)
diff_out = torch.empty_like(a)
# src[test_inline_triton.py:N]: for tile in hl.tile(a.shape):
_BLOCK_SIZE_0 = 64
_BLOCK_SIZE_0 = 32
# src[test_inline_triton.py:N]: for tile in hl.tile(a.shape):
# src[test_inline_triton.py:N]: a_val = a[tile]
# src[test_inline_triton.py:N]: b_val = b[tile]
Expand Down Expand Up @@ -131,7 +131,7 @@ def kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
# src[test_inline_triton.py:N]: out = torch.empty_like(x)
out = torch.empty_like(x)
# src[test_inline_triton.py:N]: for tile in hl.tile(x.shape):
_BLOCK_SIZE_0 = 128
_BLOCK_SIZE_0 = 32
# src[test_inline_triton.py:N]: for tile in hl.tile(x.shape):
# src[test_inline_triton.py:N]: x_val = x[tile]
# src[test_inline_triton.py:N]: y_val = y[tile]
Expand Down
2 changes: 1 addition & 1 deletion test/test_misc.expected
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def kernel_with_scalar_item(x: torch.Tensor, scalar_tensor: torch.Tensor, *, _la
# src[test_misc.py:N]: scalar_val = scalar_tensor.item()
scalar_val = scalar_tensor.item()
# src[test_misc.py:N]: for tile in hl.tile(x.shape):
_BLOCK_SIZE_0 = 128
_BLOCK_SIZE_0 = 32
# src[test_misc.py:N]: for tile in hl.tile(x.shape):
# src[test_misc.py:N]: result[tile] = x[tile] + scalar_val
_launcher(_helion_kernel_with_scalar_item, (triton.cdiv(100, _BLOCK_SIZE_0),), x, result, scalar_val, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
Expand Down
2 changes: 1 addition & 1 deletion test/test_random.expected
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def rand_kernel_tiled_1d(x: torch.Tensor, seed: int, *, _launcher=_default_launc
# src[test_random.py:N]: (m,) = x.shape
m, = x.shape
# src[test_random.py:N]: for tile_m in hl.tile(m):
_BLOCK_SIZE_0 = 128
_BLOCK_SIZE_0 = 32
# src[test_random.py:N]: for tile_m in hl.tile(m):
# src[test_random.py:N]: output[tile_m] = hl.rand([tile_m], seed=seed)
_launcher(_helion_rand_kernel_tiled_1d, (triton.cdiv(m, _BLOCK_SIZE_0),), output, output.stride(0), m, seed, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
Expand Down
2 changes: 1 addition & 1 deletion test/test_register_tunable.expected
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ This file is automatically generated by assertExpectedJournal calls in test_regi
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestRegisterTunable.test_integer_fragment)
helion.Config(block_sizes=[128], indexing=['pointer', 'pointer'], load_eviction_policies=[''], multiplier=3, num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[])
helion.Config(block_sizes=[32], indexing=['pointer', 'pointer'], load_eviction_policies=[''], multiplier=3, num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[])

--- assertExpectedJournal(TestRegisterTunable.test_integer_fragment)
from __future__ import annotations
Expand Down
6 changes: 3 additions & 3 deletions test/test_unroll_tuples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def kernel_nested_tuple_iteration(a_tuple: tuple[torch.Tensor, torch.Tensor], b_
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(a_tuple[0])
result = torch.zeros_like(a_tuple[0])
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
_BLOCK_SIZE_0 = 64
_BLOCK_SIZE_0 = 32
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
# src[test_unroll_tuples.py:N]: temp = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
# src[test_unroll_tuples.py:N-N]: ...
Expand Down Expand Up @@ -881,7 +881,7 @@ def kernel_tuple_with_scaling(tensor1: torch.Tensor, tensor2: torch.Tensor, tens
# src[test_unroll_tuples.py:N]: output = torch.zeros_like(tensor1)
output = torch.zeros_like(tensor1)
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(output.size(0)):
_BLOCK_SIZE_0 = 64
_BLOCK_SIZE_0 = 32
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(output.size(0)):
# src[test_unroll_tuples.py:N]: temp = torch.zeros([tile_idx], dtype=torch.float32, device=output.device)
# src[test_unroll_tuples.py:N]: for tensor, scale in zip(tensors, scales, strict=True):
Expand Down Expand Up @@ -924,7 +924,7 @@ def kernel_zip_iteration(tensors_a: tuple[torch.Tensor, torch.Tensor], tensors_b
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(tensors_a[0])
result = torch.zeros_like(tensors_a[0])
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
_BLOCK_SIZE_0 = 64
_BLOCK_SIZE_0 = 32
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
# src[test_unroll_tuples.py:N]: # Iterate over zip of tensors
Expand Down
Loading