diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index 47399719b..5cf2e3c16 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -33,7 +33,7 @@ from ..runtime.config import PidTypeLiteral DEFAULT_NUM_WARPS = 4 -DEFAULT_NUM_STAGES = 2 +DEFAULT_NUM_STAGES = 1 VALID_KEYS: frozenset[str] = frozenset( [ "block_sizes", @@ -400,7 +400,7 @@ def _fragment(self, base: ConfigSpec) -> BlockSizeFragment: [next_power_of_2(spec.size_hint) for spec in base.reduction_loops] ) if total_ndim <= 1 and reduction_numel <= 1: - default = 1024 + default = 256 elif total_ndim <= 2 and reduction_numel <= 128: default = 32 elif reduction_numel <= 256: diff --git a/test/test_associative_scan.expected b/test/test_associative_scan.expected index eec03db2c..9f8eb8d4b 100644 --- a/test/test_associative_scan.expected +++ b/test/test_associative_scan.expected @@ -63,7 +63,7 @@ def cumulative_argmax_tuple_kernel(input_data: torch.Tensor, positions: torch.Te # src[test_associative_scan.py:N]: vals = input_data[tile_e, :] # src[test_associative_scan.py:N]: # Convert positions to float to match vals dtype, then broadcast to match vals shape # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_cumulative_argmax_tuple_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), input_data, positions, max_values, max_indices, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_cumulative_argmax_tuple_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), input_data, positions, max_values, max_indices, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return max_values, max_indices return (max_values, max_indices) @@ -110,7 +110,7 @@ def test_scan_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) - _launcher(_helion_test_scan_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_scan_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -153,7 +153,7 @@ def test_codegen_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) - _launcher(_helion_test_codegen_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_codegen_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -219,7 +219,7 @@ def cumulative_argmax_kernel(input_data: torch.Tensor, positions: torch.Tensor, # src[test_associative_scan.py:N]: vals = input_data[tile_e, :] # src[test_associative_scan.py:N]: # Convert positions to float to match vals dtype, then broadcast to match vals shape # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_cumulative_argmax_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), input_data, positions, max_values, max_indices, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_cumulative_argmax_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), input_data, positions, max_values, max_indices, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return max_values, max_indices return (max_values, max_indices) @@ -270,7 +270,7 @@ def test_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -321,7 +321,7 @@ def test_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -372,7 +372,7 @@ def test_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -423,7 +423,7 @@ def test_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -470,7 +470,7 @@ def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_size_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_size_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -522,7 +522,7 @@ def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_size_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_size_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -574,7 +574,7 @@ def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_size_kernel, (triton.cdiv(5, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_size_kernel, (triton.cdiv(5, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -623,7 +623,7 @@ def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_size_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_test_size_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -673,7 +673,7 @@ def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_size_kernel, (4,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_size_kernel, (4,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -723,7 +723,7 @@ def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_size_kernel, (8,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_size_kernel, (8,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -761,7 +761,7 @@ def test_single_element(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) - _launcher(_helion_test_single_element, (1,), x, result, num_warps=4, num_stages=2) + _launcher(_helion_test_single_element, (1,), x, result, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -803,7 +803,7 @@ def test_single_element(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) - _launcher(_helion_test_single_element, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_single_element, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -848,7 +848,7 @@ def test_helper_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: # Use the cumsum_helper function which internally calls hl.associative_scan # src[test_associative_scan.py:N]: result[i, :] = cumsum_helper(x[i, :]) - _launcher(_helion_test_helper_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_helper_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -890,7 +890,7 @@ def test_jit_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(jit_add_combine_fn, row_data, dim=1) - _launcher(_helion_test_jit_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_jit_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -935,7 +935,7 @@ def test_large_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) - _launcher(_helion_test_large_kernel, (32,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_large_kernel, (32,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -983,7 +983,7 @@ def test_max_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(max_combine_fn, row_data, dim=1) - _launcher(_helion_test_max_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_max_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1031,7 +1031,7 @@ def test_min_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(min_combine_fn, row_data, dim=1) - _launcher(_helion_test_min_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_min_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1090,7 +1090,7 @@ def test_multi_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: # Prefix sum # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_multi_kernel, (1,), x, sum_result, max_result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_multi_kernel, (1,), x, sum_result, max_result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return sum_result return sum_result @@ -1136,7 +1136,7 @@ def test_mul_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(mul_combine_fn, row_data, dim=1) - _launcher(_helion_test_mul_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_mul_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1183,7 +1183,7 @@ def test_reverse_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan( # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_reverse_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reverse_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1251,7 +1251,7 @@ def segmented_scan_kernel(indices: torch.Tensor, input_data: torch.Tensor, *, _l # src[test_associative_scan.py:N]: vals = input_data[tile_e, tile_f] # src[test_associative_scan.py:N]: # Convert indices to float to match vals dtype and broadcast to match shape # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_segmented_scan_kernel, (triton.cdiv(6, _BLOCK_SIZE_0) * triton.cdiv(3, _BLOCK_SIZE_1),), input_data, indices, output, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_segmented_scan_kernel, (triton.cdiv(6, _BLOCK_SIZE_0) * triton.cdiv(3, _BLOCK_SIZE_1),), input_data, indices, output, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return output return output @@ -1302,7 +1302,7 @@ def test_torch_hops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: # Use torch._higher_order_ops.associative_scan directly # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_torch_hops_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_torch_hops_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1367,7 +1367,7 @@ def test_segmented_kernel(indices: torch.Tensor, input_data: torch.Tensor, *, _l # src[test_associative_scan.py:N]: vals = input_data[tile_e, tile_f] # src[test_associative_scan.py:N]: # Broadcast indices to match vals shape for the scan # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_segmented_kernel, (triton.cdiv(4, _BLOCK_SIZE_0) * triton.cdiv(2, _BLOCK_SIZE_1),), input_data, indices, output, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_segmented_kernel, (triton.cdiv(4, _BLOCK_SIZE_0) * triton.cdiv(2, _BLOCK_SIZE_1),), input_data, indices, output, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return output return output @@ -1432,7 +1432,7 @@ def test_segmented_tuple_kernel(indices: torch.Tensor, input_data: torch.Tensor, # src[test_associative_scan.py:N]: vals = input_data[tile_e, tile_f] # src[test_associative_scan.py:N]: # Broadcast indices to match vals shape for the scan # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_segmented_tuple_kernel, (triton.cdiv(4, _BLOCK_SIZE_0) * triton.cdiv(2, _BLOCK_SIZE_1),), input_data, indices, output, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_segmented_tuple_kernel, (triton.cdiv(4, _BLOCK_SIZE_0) * triton.cdiv(2, _BLOCK_SIZE_1),), input_data, indices, output, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return output return output @@ -1477,7 +1477,7 @@ def test_type_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) - _launcher(_helion_test_type_kernel, (16,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_type_kernel, (16,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1521,7 +1521,7 @@ def test_cumprod_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = torch.cumprod(row_data, dim=1) - _launcher(_helion_test_cumprod_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumprod_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1565,7 +1565,7 @@ def test_cumprod_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.cumprod(row_data, dim=1) - _launcher(_helion_test_cumprod_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumprod_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1609,7 +1609,7 @@ def test_cumprod_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.cumprod(row_data, dim=1) - _launcher(_helion_test_cumprod_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumprod_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1653,7 +1653,7 @@ def test_cumprod_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.cumprod(row_data, dim=1) - _launcher(_helion_test_cumprod_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumprod_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1697,7 +1697,7 @@ def test_cumprod_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.cumprod(row_data, dim=1) - _launcher(_helion_test_cumprod_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumprod_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1737,7 +1737,7 @@ def test_cumprod_reverse_kernel(x: torch.Tensor, *, _launcher=_default_launcher) # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.cumprod(row_data, dim=1, reverse=True) - _launcher(_helion_test_cumprod_reverse_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumprod_reverse_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1781,7 +1781,7 @@ def test_cumsum_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = torch.cumsum(row_data, dim=1) - _launcher(_helion_test_cumsum_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumsum_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1837,7 +1837,7 @@ def test_mixed_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: # Cumulative sum # src[test_associative_scan.py:N-N]: ... - _launcher(_helion_test_mixed_kernel, (1,), x, sum_result, prod_result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_mixed_kernel, (1,), x, sum_result, prod_result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return sum_result return sum_result @@ -1881,7 +1881,7 @@ def test_cumsum_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = torch.cumsum(row_data, dim=1) - _launcher(_helion_test_cumsum_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumsum_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1925,7 +1925,7 @@ def test_cumsum_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = torch.cumsum(row_data, dim=1) - _launcher(_helion_test_cumsum_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumsum_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -1969,7 +1969,7 @@ def test_cumsum_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = torch.cumsum(row_data, dim=1) - _launcher(_helion_test_cumsum_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumsum_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -2013,7 +2013,7 @@ def test_cumsum_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = torch.cumsum(row_data, dim=1) - _launcher(_helion_test_cumsum_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumsum_dtype_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result @@ -2053,6 +2053,6 @@ def test_cumsum_reverse_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_associative_scan.py:N]: for i in hl.tile(x.size(0)): # src[test_associative_scan.py:N]: row_data = x[i, :] # src[test_associative_scan.py:N]: result[i, :] = hl.cumsum(row_data, dim=1, reverse=True) - _launcher(_helion_test_cumsum_reverse_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_cumsum_reverse_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_associative_scan.py:N]: return result return result diff --git a/test/test_atomic_ops.expected b/test/test_atomic_ops.expected index ccfc88fd3..dddea09aa 100644 --- a/test/test_atomic_ops.expected +++ b/test/test_atomic_ops.expected @@ -32,7 +32,7 @@ def atomic_add_2d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default _BLOCK_SIZE_1 = 8 # src[test_atomic_ops.py:N]: for i, j in hl.tile([y.size(0), y.size(1)]): # src[test_atomic_ops.py:N]: hl.atomic_add(x, [i, j], y[i, j]) - _launcher(_helion_atomic_add_2d_kernel, (triton.cdiv(3, _BLOCK_SIZE_0) * triton.cdiv(4, _BLOCK_SIZE_1),), y, x, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_atomic_add_2d_kernel, (triton.cdiv(3, _BLOCK_SIZE_0) * triton.cdiv(4, _BLOCK_SIZE_1),), y, x, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -77,7 +77,7 @@ def atomic_add_1d_tensor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_ # src[test_atomic_ops.py:N]: x_tile = x[tile_m, :].to(torch.float32) # src[test_atomic_ops.py:N]: y_tile = y[tile_m, :].to(torch.float32) # src[test_atomic_ops.py:N-N]: ... - _launcher(_helion_atomic_add_1d_tensor_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, y, z, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_atomic_add_1d_tensor_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, y, z, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return z return z @@ -108,7 +108,7 @@ def atomic_add_float_kernel(x: torch.Tensor, indices: torch.Tensor, *, _launcher # src[test_atomic_ops.py:N]: for i in hl.tile(indices.size(0)): # src[test_atomic_ops.py:N]: idx = indices[i] # src[test_atomic_ops.py:N]: hl.atomic_add(x, [idx], 2.0) - _launcher(_helion_atomic_add_float_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), indices, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_add_float_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), indices, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -140,7 +140,7 @@ def k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_atomic_ops.py:N]: for i in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: old = hl.atomic_add(x, [i], y[i]) # src[test_atomic_ops.py:N]: prev[i] = old - _launcher(_helion_k, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, prev, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_k, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, prev, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x, prev return (x, prev) @@ -168,7 +168,7 @@ def atomic_add_w_tile_attr(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 2 # src[test_atomic_ops.py:N]: for tile in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: hl.atomic_add(y, [tile.begin], 1) - _launcher(_helion_atomic_add_w_tile_attr, (triton.cdiv(20, _BLOCK_SIZE_0),), y, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_add_w_tile_attr, (triton.cdiv(20, _BLOCK_SIZE_0),), y, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return y return y @@ -195,7 +195,7 @@ def atomic_and_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _BLOCK_SIZE_0 = 8 # src[test_atomic_ops.py:N]: for i in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: hl.atomic_and(x, [i], y[i]) - _launcher(_helion_atomic_and_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_and_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -223,7 +223,7 @@ def atomic_cas_kernel(x: torch.Tensor, y: torch.Tensor, expect: torch.Tensor, *, _BLOCK_SIZE_0 = 4 # src[test_atomic_ops.py:N]: for i in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: hl.atomic_cas(x, [i], expect[i], y[i]) - _launcher(_helion_atomic_cas_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), expect, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_cas_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), expect, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -250,7 +250,7 @@ def atomic_max_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _BLOCK_SIZE_0 = 4 # src[test_atomic_ops.py:N]: for i in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: hl.atomic_max(x, [i], y[i]) - _launcher(_helion_atomic_max_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_max_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -277,7 +277,7 @@ def atomic_min_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _BLOCK_SIZE_0 = 4 # src[test_atomic_ops.py:N]: for i in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: hl.atomic_min(x, [i], y[i]) - _launcher(_helion_atomic_min_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_min_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -304,7 +304,7 @@ def atomic_or_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_lau _BLOCK_SIZE_0 = 8 # src[test_atomic_ops.py:N]: for i in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: hl.atomic_or(x, [i], y[i]) - _launcher(_helion_atomic_or_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_or_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -331,7 +331,7 @@ def atomic_xchg_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_l _BLOCK_SIZE_0 = 8 # src[test_atomic_ops.py:N]: for i in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: hl.atomic_xchg(x, [i], y[i]) - _launcher(_helion_atomic_xchg_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_xchg_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -358,7 +358,7 @@ def atomic_xor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _BLOCK_SIZE_0 = 8 # src[test_atomic_ops.py:N]: for i in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: hl.atomic_xor(x, [i], y[i]) - _launcher(_helion_atomic_xor_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_xor_kernel, (triton.cdiv(8, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -387,7 +387,7 @@ def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _BLOCK_SIZE_0 = 32 # src[test_atomic_ops.py:N]: for i in hl.tile(x.size(0)): # src[test_atomic_ops.py:N]: hl.atomic_add(x, [i], y[i]) - _launcher(_helion_atomic_add_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_add_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x @@ -419,6 +419,6 @@ def atomic_add_overlap_kernel(x: torch.Tensor, y: torch.Tensor, indices: torch.T # src[test_atomic_ops.py:N]: for i in hl.tile([y.size(0)]): # src[test_atomic_ops.py:N]: idx = indices[i] # src[test_atomic_ops.py:N]: hl.atomic_add(x, [idx], y[i]) - _launcher(_helion_atomic_add_overlap_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_atomic_add_overlap_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_atomic_ops.py:N]: return x return x diff --git a/test/test_autotuner.expected b/test/test_autotuner.expected index 65e84c583..67c9e5821 100644 --- a/test/test_autotuner.expected +++ b/test/test_autotuner.expected @@ -2,7 +2,7 @@ This file is automatically generated by assertExpectedJournal calls in test_auto Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. --- assertExpectedJournal(TestAutotuner.test_config_fragment0) -helion.Config(block_sizes=[16, 16, 16], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None]) +helion.Config(block_sizes=[16, 16, 16], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None]) helion.Config(block_sizes=[32, 128, 64], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, True]) helion.Config(block_sizes=[64, 16, 64], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=32, pid_type='persistent_blocked', range_flattens=[False, False], range_multi_buffers=[True, None], range_num_stages=[0, 0], range_unroll_factors=[4, 3], range_warp_specializes=[False, None]) helion.Config(block_sizes=[16, 64, 512], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[4], load_eviction_policies=['last', ''], loop_orders=[[1, 0]], num_stages=3, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[None, True], range_num_stages=[4, 0], range_unroll_factors=[4, 0], range_warp_specializes=[None, True]) @@ -14,7 +14,7 @@ helion.Config(block_sizes=[64, 64, 16], indexing=['tensor_descriptor', 'tensor_d helion.Config(block_sizes=[16, 16, 16], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[1], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=7, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[None, None], range_num_stages=[4, 0], range_unroll_factors=[1, 3], range_warp_specializes=[True, None]) --- assertExpectedJournal(TestAutotuner.test_config_fragment1) -helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) helion.Config(block_sizes=[4, 256, 256], flatten_loops=[False], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[4], load_eviction_policies=['', ''], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True]) helion.Config(block_sizes=[1, 64, 128], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['first', 'last'], loop_orders=[[1, 2, 0]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) helion.Config(block_sizes=[8, 1, 16], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['last', 'last'], loop_orders=[[2, 1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[True], range_unroll_factors=[3], range_warp_specializes=[None]) @@ -26,7 +26,7 @@ helion.Config(block_sizes=[4, 16, 16], flatten_loops=[False], indexing=['pointer helion.Config(block_sizes=[4, 1, 2], flatten_loops=[True], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=7, num_warps=32, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[1], range_warp_specializes=[True]) --- assertExpectedJournal(TestAutotuner.test_config_warp_specialize_unroll) -helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) helion.Config(block_sizes=[4, 256, 256], flatten_loops=[False], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[4], load_eviction_policies=['', ''], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) helion.Config(block_sizes=[1, 64, 128], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['first', 'last'], loop_orders=[[1, 2, 0]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) helion.Config(block_sizes=[8, 1, 16], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['last', 'last'], loop_orders=[[2, 1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) diff --git a/test/test_broadcasting.expected b/test/test_broadcasting.expected index 06157f9e8..4e2d8ad83 100644 --- a/test/test_broadcasting.expected +++ b/test/test_broadcasting.expected @@ -41,7 +41,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher): # src[test_broadcasting.py:N]: for tile0, tile1 in hl.tile(out0.size()): # src[test_broadcasting.py:N]: out0[tile0, tile1] = a[tile0, tile1] + b[tile0, None] # src[test_broadcasting.py:N]: out1[tile0, tile1] = a[tile0, tile1] + b[None, tile1] - _launcher(_helion_broadcast_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, b, out0, out1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_broadcast_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, b, out0, out1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_broadcasting.py:N]: return out0, out1 return (out0, out1) @@ -85,7 +85,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher): # src[test_broadcasting.py:N]: for tile0, tile1 in hl.tile(out0.size()): # src[test_broadcasting.py:N]: out0[tile0, tile1] = a[tile0, tile1] + b[tile0, None] # src[test_broadcasting.py:N]: out1[tile0, tile1] = a[tile0, tile1] + b[None, tile1] - _launcher(_helion_broadcast_fn, (triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(512, _BLOCK_SIZE_0),), a, b, out0, out1, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_broadcast_fn, (triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(512, _BLOCK_SIZE_0),), a, b, out0, out1, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_broadcasting.py:N]: return out0, out1 return (out0, out1) @@ -128,7 +128,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher): # src[test_broadcasting.py:N]: for tile0, tile1 in hl.tile(out0.size()): # src[test_broadcasting.py:N]: out0[tile0, tile1] = a[tile0, tile1] + b[tile0, None] # src[test_broadcasting.py:N]: out1[tile0, tile1] = a[tile0, tile1] + b[None, tile1] - _launcher(_helion_broadcast_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * 512,), a, b, out0, out1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_broadcast_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * 512,), a, b, out0, out1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_broadcasting.py:N]: return out0, out1 return (out0, out1) @@ -171,7 +171,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher): # src[test_broadcasting.py:N]: for tile0, tile1 in hl.tile(out0.size()): # src[test_broadcasting.py:N]: out0[tile0, tile1] = a[tile0, tile1] + b[tile0, None] # src[test_broadcasting.py:N]: out1[tile0, tile1] = a[tile0, tile1] + b[None, tile1] - _launcher(_helion_broadcast_fn, (512 * triton.cdiv(512, _BLOCK_SIZE_1),), a, b, out0, out1, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_broadcast_fn, (512 * triton.cdiv(512, _BLOCK_SIZE_1),), a, b, out0, out1, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_broadcasting.py:N]: return out0, out1 return (out0, out1) @@ -213,7 +213,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher): # src[test_broadcasting.py:N]: for tile0, tile1 in hl.tile(out0.size()): # src[test_broadcasting.py:N]: out0[tile0, tile1] = a[tile0, tile1] + b[tile0, None] # src[test_broadcasting.py:N]: out1[tile0, tile1] = a[tile0, tile1] + b[None, tile1] - _launcher(_helion_broadcast_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, b, out0, out1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_broadcast_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, b, out0, out1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_broadcasting.py:N]: return out0, out1 return (out0, out1) @@ -266,7 +266,7 @@ def fn(a, idx1, *, _launcher=_default_launcher): # src[test_broadcasting.py:N]: out0[tile0, tile1] = a[tile0, tile1] + a[tile0, 3, None] # src[test_broadcasting.py:N]: out1[tile0, tile1] = a[tile0, tile1] + a[idx0, tile1][None, :] # src[test_broadcasting.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, out0, out1, out2, idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, out0, out1, out2, idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_broadcasting.py:N]: return out0, out1, out2 return (out0, out1, out2) @@ -303,7 +303,7 @@ def fn(a, b, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_broadcasting.py:N]: for tile0, tile1 in hl.tile(a.size()): # src[test_broadcasting.py:N]: out[tile0, tile1] = a[tile0, tile1] + b[tile1] - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_broadcasting.py:N]: return out return out @@ -332,6 +332,6 @@ def fn(a, beta, *, _launcher=_default_launcher): # src[test_broadcasting.py:N]: for tile0 in hl.tile(a.shape[0]): # src[test_broadcasting.py:N]: b = a[tile0] # src[test_broadcasting.py:N]: a[tile0] = (1 - beta) * b - _launcher(_helion_fn, (triton.cdiv(1024, _BLOCK_SIZE_0),), a, beta, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(1024, _BLOCK_SIZE_0),), a, beta, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_broadcasting.py:N]: return a return a diff --git a/test/test_closures.expected b/test/test_closures.expected index e5cba058f..cbeb28467 100644 --- a/test/test_closures.expected +++ b/test/test_closures.expected @@ -45,7 +45,7 @@ def use_globals(a, *, _launcher=_default_launcher): # src[basic_kernels.py:N]: out[tile0, tile1] = ( # src[basic_kernels.py:N]: torch.sin(torch.add(a[tile0, tile1], global_tensor[None, tile1])) # src[basic_kernels.py:N-N]: ... - _launcher(_helion_use_globals, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, _source_module.global_tensor, out, _source_module.global_float, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_use_globals, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), a, _source_module.global_tensor, out, _source_module.global_float, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -76,10 +76,10 @@ def sin_func_arg(a, fn, *, _launcher=_default_launcher): # src[test_closures.py:N]: out = torch.empty_like(a) out = torch.empty_like(a) # src[test_closures.py:N]: for tile in hl.tile(a.size()): - _BLOCK_SIZE_0 = 512 + _BLOCK_SIZE_0 = 256 # src[test_closures.py:N]: for tile in hl.tile(a.size()): # src[test_closures.py:N]: out[tile] = fn(torch.sin(a[tile]), tile) - _launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, fn.__closure__[0].cell_contents, out, a.size(0), a.stride(0), fn.__closure__[0].cell_contents.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, fn.__closure__[0].cell_contents, out, a.size(0), a.stride(0), fn.__closure__[0].cell_contents.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_closures.py:N]: return out return out @@ -112,10 +112,10 @@ def sin_func_arg(a, fn, *, _launcher=_default_launcher): # src[test_closures.py:N]: out = torch.empty_like(a) out = torch.empty_like(a) # src[test_closures.py:N]: for tile in hl.tile(a.size()): - _BLOCK_SIZE_0 = 512 + _BLOCK_SIZE_0 = 256 # src[test_closures.py:N]: for tile in hl.tile(a.size()): # src[test_closures.py:N]: out[tile] = fn(torch.sin(a[tile]), tile) - _launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, _source_module.global_tensor, out, a.size(0), _source_module.global_tensor.stride(0), a.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, _source_module.global_tensor, out, a.size(0), _source_module.global_tensor.stride(0), a.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_closures.py:N]: return out return out @@ -147,10 +147,10 @@ def sin_func_arg(a, fn, *, _launcher=_default_launcher): # src[test_closures.py:N]: out = torch.empty_like(a) out = torch.empty_like(a) # src[test_closures.py:N]: for tile in hl.tile(a.size()): - _BLOCK_SIZE_0 = 512 + _BLOCK_SIZE_0 = 256 # src[test_closures.py:N]: for tile in hl.tile(a.size()): # src[test_closures.py:N]: out[tile] = fn(torch.sin(a[tile]), tile) - _launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, out, a.size(0), a.stride(0), out.stride(0), _global_source0.global_float, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, out, a.size(0), a.stride(0), out.stride(0), _global_source0.global_float, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_closures.py:N]: return out return out @@ -183,10 +183,10 @@ def sin_func_arg(a, fn, *, _launcher=_default_launcher): # src[test_closures.py:N]: out = torch.empty_like(a) out = torch.empty_like(a) # src[test_closures.py:N]: for tile in hl.tile(a.size()): - _BLOCK_SIZE_0 = 512 + _BLOCK_SIZE_0 = 256 # src[test_closures.py:N]: for tile in hl.tile(a.size()): # src[test_closures.py:N]: out[tile] = fn(torch.sin(a[tile]), tile) - _launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, fn.__closure__[0].cell_contents.__closure__[0].cell_contents, out, a.size(0), a.stride(0), fn.__closure__[0].cell_contents.__closure__[0].cell_contents.stride(0), out.stride(0), fn.__closure__[1].cell_contents, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_sin_func_arg, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, fn.__closure__[0].cell_contents.__closure__[0].cell_contents, out, a.size(0), a.stride(0), fn.__closure__[0].cell_contents.__closure__[0].cell_contents.stride(0), out.stride(0), fn.__closure__[1].cell_contents, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_closures.py:N]: return out return out @@ -215,9 +215,9 @@ def call_func_arg_on_host(a, alloc, *, _launcher=_default_launcher): # src[test_closures.py:N]: out = alloc(a) out = alloc(a) # src[test_closures.py:N]: for tile in hl.tile(a.size()): - _BLOCK_SIZE_0 = 512 + _BLOCK_SIZE_0 = 256 # src[test_closures.py:N]: for tile in hl.tile(a.size()): # src[test_closures.py:N]: out[tile] = a[tile].sin() - _launcher(_helion_call_func_arg_on_host, (triton.cdiv(512, _BLOCK_SIZE_0),), a, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_call_func_arg_on_host, (triton.cdiv(512, _BLOCK_SIZE_0),), a, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_closures.py:N]: return out return out diff --git a/test/test_constexpr.expected b/test/test_constexpr.expected index 05edc1f4d..4f2e0b247 100644 --- a/test/test_constexpr.expected +++ b/test/test_constexpr.expected @@ -35,7 +35,7 @@ def fn(x: torch.Tensor, v: hl.constexpr, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 32 # src[test_constexpr.py:N]: for tile in hl.tile(x.size()): # src[test_constexpr.py:N]: out[tile] = torch.sigmoid(x[tile] + v) - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_constexpr.py:N]: return out return out @@ -72,7 +72,7 @@ def fn(x: torch.Tensor, v: float, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 32 # src[test_constexpr.py:N]: for tile in hl.tile(x.size()): # src[test_constexpr.py:N]: out[tile] = torch.sigmoid(x[tile] + v) - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_constexpr.py:N]: return out return out @@ -111,7 +111,7 @@ def fn(x: torch.Tensor, s: hl.constexpr, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_constexpr.py:N]: for tile_b, tile_s in hl.tile([b, s]): # src[test_constexpr.py:N]: out[tile_b, tile_s] = x[tile_b].view(-1, 1).expand(tile_b, tile_s) - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(16, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(16, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_constexpr.py:N]: return out return out @@ -149,7 +149,7 @@ def fn(x: torch.Tensor, mode: str, *, _launcher=_default_launcher): # src[test_constexpr.py:N]: if mode == "add": # src[test_constexpr.py:N]: out[tile] = x[tile] + 1.0 # src[test_constexpr.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_constexpr.py:N]: return out return out @@ -187,7 +187,7 @@ def fn(x: torch.Tensor, mode: str, *, _launcher=_default_launcher): # src[test_constexpr.py:N]: if mode == "add": # src[test_constexpr.py:N]: out[tile] = x[tile] + 1.0 # src[test_constexpr.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_constexpr.py:N]: return out return out @@ -223,6 +223,6 @@ def fn(x: torch.Tensor, mode: str, *, _launcher=_default_launcher): # src[test_constexpr.py:N]: if mode == "add": # src[test_constexpr.py:N]: out[tile] = x[tile] + 1.0 # src[test_constexpr.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_constexpr.py:N]: return out return out diff --git a/test/test_control_flow.expected b/test/test_control_flow.expected index 33d9393d7..9df928683 100644 --- a/test/test_control_flow.expected +++ b/test/test_control_flow.expected @@ -33,7 +33,7 @@ def fn(x, *, _launcher=_default_launcher): # src[test_control_flow.py:N]: if 3 < v < 7: # src[test_control_flow.py:N]: out[tile] = torch.sigmoid(x[tile]) # src[test_control_flow.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_control_flow.py:N]: return out return out @@ -65,7 +65,7 @@ def fn(x, *, _launcher=_default_launcher): # src[test_control_flow.py:N]: if 3 < v < 7: # src[test_control_flow.py:N]: out[tile] = torch.sigmoid(x[tile]) # src[test_control_flow.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[test_control_flow.py:N]: return out return out @@ -141,7 +141,7 @@ def mul_relu_block_backward_kernel(x: torch.Tensor, y: torch.Tensor, dz: torch.T # src[test_control_flow.py:N]: # Get input tiles # src[test_control_flow.py:N]: x_tile = x[tile_i, tile_j] # src[test_control_flow.py:N-N]: ... - _launcher(_helion_mul_relu_block_backward_kernel, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, dz, dx, dy, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_mul_relu_block_backward_kernel, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, dz, dx, dy, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_control_flow.py:N]: if use_atomics: # src[test_control_flow.py:N]: return dx, dy if True: @@ -203,7 +203,7 @@ def fn(x, v, *, _launcher=_default_launcher): # src[test_control_flow.py:N]: if 3 < v < 7: # src[test_control_flow.py:N]: out[tile] = torch.sigmoid(x[tile]) # src[test_control_flow.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, v, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, v, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_control_flow.py:N]: return out return out @@ -251,6 +251,6 @@ def fn(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_control_flow.py:N]: # Since `y[idx]` is a scalar, comparing it against 0 will also create a scalar. # src[test_control_flow.py:N]: if y[idx] != 0: # src[test_control_flow.py:N-N]: ... - _launcher(_helion_fn, (4,), y, x, output, num_warps=4, num_stages=2) + _launcher(_helion_fn, (4,), y, x, output, num_warps=4, num_stages=1) # src[test_control_flow.py:N]: return output return output diff --git a/test/test_dot.expected b/test/test_dot.expected index 5f048a263..44c178f9a 100644 --- a/test/test_dot.expected +++ b/test/test_dot.expected @@ -54,7 +54,7 @@ def mm_small_dims(x: torch.Tensor, y: torch.Tensor, mm_func: Callable[[torch.Ten # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_mm_small_dims, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(7, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_mm_small_dims, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(7, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -111,7 +111,7 @@ def mm_reshape_k_2(x: torch.Tensor, y: torch.Tensor, mm_func: Callable[[torch.Te # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_dot.py:N]: # K is 2; don't tile it — slice with ':' # src[test_dot.py:N-N]: ... - _launcher(_helion_mm_reshape_k_2, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x_reshaped, y_reshaped, out, out.stride(0), out.stride(1), x_reshaped.stride(0), x_reshaped.stride(1), y_reshaped.stride(0), y_reshaped.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_mm_reshape_k_2, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x_reshaped, y_reshaped, out, out.stride(0), out.stride(1), x_reshaped.stride(0), x_reshaped.stride(1), y_reshaped.stride(0), y_reshaped.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -166,7 +166,7 @@ def mm_reshape_m_1(x: torch.Tensor, y: torch.Tensor, mm_func: Callable[[torch.Te # src[test_dot.py:N]: acc = hl.zeros([1, tile_n], dtype=torch.float32) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_mm_reshape_m_1, (triton.cdiv(n, _BLOCK_SIZE_0),), x_reshaped, y, out, out.stride(1), x_reshaped.stride(1), y.stride(0), y.stride(1), n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_mm_reshape_m_1, (triton.cdiv(n, _BLOCK_SIZE_0),), x_reshaped, y, out, out.stride(1), x_reshaped.stride(1), y.stride(0), y.stride(1), n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out.view(n) # Reshape back to vector return out.view(n) @@ -225,7 +225,7 @@ def mm_small_dims(x: torch.Tensor, y: torch.Tensor, mm_func: Callable[[torch.Ten # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_mm_small_dims, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(7, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_mm_small_dims, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(7, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -292,7 +292,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -359,7 +359,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -413,7 +413,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -467,7 +467,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -521,7 +521,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -575,7 +575,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -642,7 +642,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -709,7 +709,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -763,7 +763,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -817,7 +817,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -871,7 +871,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -925,7 +925,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -992,7 +992,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1059,7 +1059,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1113,7 +1113,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1167,7 +1167,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1221,7 +1221,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1275,7 +1275,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1342,7 +1342,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1409,7 +1409,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1463,7 +1463,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1517,7 +1517,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1571,7 +1571,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1625,7 +1625,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1692,7 +1692,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1759,7 +1759,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1813,7 +1813,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1867,7 +1867,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1921,7 +1921,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -1975,7 +1975,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -2042,7 +2042,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -2109,7 +2109,7 @@ def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defaul # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_no_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -2163,7 +2163,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -2217,7 +2217,7 @@ def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_dot_kernel_acc_arg, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -2278,7 +2278,7 @@ def mm_small_dims(x: torch.Tensor, y: torch.Tensor, mm_func: Callable[[torch.Ten # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_mm_small_dims, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(7, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_mm_small_dims, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(7, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out @@ -2339,6 +2339,6 @@ def mm_small_dims(x: torch.Tensor, y: torch.Tensor, mm_func: Callable[[torch.Ten # src[test_dot.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_dot.py:N]: for tile_k in hl.tile(k): # src[test_dot.py:N-N]: ... - _launcher(_helion_mm_small_dims, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(7, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_mm_small_dims, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(7, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_dot.py:N]: return out return out diff --git a/test/test_eviction_policy.expected b/test/test_eviction_policy.expected index a03a14c88..9a822c151 100644 --- a/test/test_eviction_policy.expected +++ b/test/test_eviction_policy.expected @@ -31,7 +31,7 @@ def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_eviction_policy.py:N]: val_x = hl.load(x, [tile]) # No eviction policy # src[test_eviction_policy.py:N]: val_y = hl.load(y, [tile]) # Should get evict_last # src[test_eviction_policy.py:N-N]: ... - _launcher(_helion_kernel_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -66,7 +66,7 @@ def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_eviction_policy.py:N]: val_x = hl.load(x, [tile]) # No eviction policy # src[test_eviction_policy.py:N]: val_y = hl.load(y, [tile]) # Should get evict_last # src[test_eviction_policy.py:N-N]: ... - _launcher(_helion_kernel_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -101,7 +101,7 @@ def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_eviction_policy.py:N]: val_x = hl.load(x, [tile]) # No eviction policy # src[test_eviction_policy.py:N]: val_y = hl.load(y, [tile]) # Should get evict_last # src[test_eviction_policy.py:N-N]: ... - _launcher(_helion_kernel_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -135,7 +135,7 @@ def kernel_with_override(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_eviction_policy.py:N]: # Explicit eviction_policy should override tunable # src[test_eviction_policy.py:N]: val_x = hl.load(x, [tile], eviction_policy="evict_last") # src[test_eviction_policy.py:N-N]: ... - _launcher(_helion_kernel_with_override, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_override, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -170,7 +170,7 @@ def kernel_with_override(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_eviction_policy.py:N]: # Explicit eviction_policy should override tunable # src[test_eviction_policy.py:N]: val_x = hl.load(x, [tile], eviction_policy="evict_last") # src[test_eviction_policy.py:N-N]: ... - _launcher(_helion_kernel_with_override, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_override, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -205,7 +205,7 @@ def kernel_with_override(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_eviction_policy.py:N]: # Explicit eviction_policy should override tunable # src[test_eviction_policy.py:N]: val_x = hl.load(x, [tile], eviction_policy="evict_last") # src[test_eviction_policy.py:N-N]: ... - _launcher(_helion_kernel_with_override, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_override, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -235,7 +235,7 @@ def copy_with_eviction(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_eviction_policy.py:N]: for tile in hl.tile(x.size(0)): # src[test_eviction_policy.py:N]: val = hl.load(x, [tile], eviction_policy="evict_last") # src[test_eviction_policy.py:N]: out[tile] = val - _launcher(_helion_copy_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_copy_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -266,7 +266,7 @@ def copy_with_eviction(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_eviction_policy.py:N]: for tile in hl.tile(x.size(0)): # src[test_eviction_policy.py:N]: val = hl.load(x, [tile], eviction_policy="evict_last") # src[test_eviction_policy.py:N]: out[tile] = val - _launcher(_helion_copy_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_copy_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -297,7 +297,7 @@ def copy_with_eviction(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_eviction_policy.py:N]: for tile in hl.tile(x.size(0)): # src[test_eviction_policy.py:N]: val = hl.load(x, [tile], eviction_policy="evict_last") # src[test_eviction_policy.py:N]: out[tile] = val - _launcher(_helion_copy_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_copy_with_eviction, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -334,7 +334,7 @@ def kernel_multiple_loads(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, # src[test_eviction_policy.py:N]: val_x = hl.load(x, [tile]) # evict_first # src[test_eviction_policy.py:N]: val_y = hl.load(y, [tile]) # evict_last # src[test_eviction_policy.py:N-N]: ... - _launcher(_helion_kernel_multiple_loads, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, z, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_multiple_loads, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, z, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -372,7 +372,7 @@ def kernel_multiple_loads(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, # src[test_eviction_policy.py:N]: val_x = hl.load(x, [tile]) # evict_first # src[test_eviction_policy.py:N]: val_y = hl.load(y, [tile]) # evict_last # src[test_eviction_policy.py:N-N]: ... - _launcher(_helion_kernel_multiple_loads, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, z, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_multiple_loads, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, z, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out @@ -410,6 +410,6 @@ def kernel_multiple_loads(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, # src[test_eviction_policy.py:N]: val_x = hl.load(x, [tile]) # evict_first # src[test_eviction_policy.py:N]: val_y = hl.load(y, [tile]) # evict_last # src[test_eviction_policy.py:N-N]: ... - _launcher(_helion_kernel_multiple_loads, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, z, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_multiple_loads, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, z, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_eviction_policy.py:N]: return out return out diff --git a/test/test_examples.expected b/test/test_examples.expected index 247e24655..30f3965f3 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -44,7 +44,7 @@ def add(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 128 # src[add.py:N]: for tile in hl.tile(out.size()): # src[add.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[add.py:N]: return out return out @@ -196,7 +196,7 @@ def addmm_bwd(grad_out: Tensor, bias: Tensor, mat1: Tensor, mat2: Tensor, alpha: # src[matmul.py:N]: acc2 = hl.zeros([tile_k2, tile_n2], dtype=torch.float32) # src[matmul.py:N]: for tile_m2 in hl.tile(m): # src[matmul.py:N-N]: ... - _launcher(_helion_addmm_bwd, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1) + triton.cdiv(128, _BLOCK_SIZE_2) * triton.cdiv(128, _BLOCK_SIZE_3) + triton.cdiv(128, _BLOCK_SIZE_5) * triton.cdiv(128, _BLOCK_SIZE_6),), grad_out, grad_input, mat2, grad_mat1, mat1, grad_mat2, beta, alpha, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, _BLOCK_SIZE_6, _BLOCK_SIZE_7, num_warps=4, num_stages=2) + _launcher(_helion_addmm_bwd, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1) + triton.cdiv(128, _BLOCK_SIZE_2) * triton.cdiv(128, _BLOCK_SIZE_3) + triton.cdiv(128, _BLOCK_SIZE_5) * triton.cdiv(128, _BLOCK_SIZE_6),), grad_out, grad_input, mat2, grad_mat1, mat1, grad_mat2, beta, alpha, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, _BLOCK_SIZE_6, _BLOCK_SIZE_7, num_warps=4, num_stages=1) # src[matmul.py:N]: return grad_input, grad_mat1, grad_mat2 return (grad_input, grad_mat1, grad_mat2) @@ -461,7 +461,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la # src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) # src[attention.py:N]: l_i = torch.full_like(m_i, 1.0) # src[attention.py:N-N]: ... - _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[attention.py:N]: return out.view(q_in.size()) return out.view(q_in.size()) @@ -736,7 +736,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la # src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) # src[attention.py:N]: l_i = torch.full_like(m_i, 1.0) # src[attention.py:N-N]: ... - _launcher(_helion_attention, (32 * triton.cdiv(512, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_attention, (32 * triton.cdiv(512, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[attention.py:N]: return out.view(q_in.size()) return out.view(q_in.size()) @@ -803,7 +803,7 @@ def _bf16xint16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher): # src[bf16xint16_gemm.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[bf16xint16_gemm.py:N]: for tile_k in hl.tile(K): # src[bf16xint16_gemm.py:N-N]: ... - _launcher(_helion__bf16xint16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion__bf16xint16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[bf16xint16_gemm.py:N]: return out return out @@ -870,7 +870,7 @@ def _int16xbf16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher): # src[bf16xint16_gemm.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[bf16xint16_gemm.py:N]: for tile_k in hl.tile(K): # src[bf16xint16_gemm.py:N-N]: ... - _launcher(_helion__int16xbf16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion__int16xbf16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[bf16xint16_gemm.py:N]: return out return out @@ -949,7 +949,7 @@ def bmm(A: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher): # src[bmm.py:N]: acc = hl.zeros([tile_b, tile_m, tile_n], dtype=torch.float32) # src[bmm.py:N]: for tile_k in hl.tile(k): # src[bmm.py:N-N]: ... - _launcher(_helion_bmm, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), A, B, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_bmm, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), A, B, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[bmm.py:N]: return out return out @@ -1026,7 +1026,7 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch # src[concatenate.py:N]: # Most masking is automatic in helion, but tile1 spans both x and y we need to do some manual masking # src[concatenate.py:N]: x_part = hl.load( # src[concatenate.py:N-N]: ... - _launcher(_helion_concat2d_dim1, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(1012, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_concat2d_dim1, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(1012, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[concatenate.py:N]: return out return out @@ -1104,7 +1104,7 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch # src[concatenate.py:N]: # Most masking is automatic in helion, but tile1 spans both x and y we need to do some manual masking # src[concatenate.py:N]: x_part = hl.load( # src[concatenate.py:N-N]: ... - _launcher(_helion_concat2d_dim1, (triton.cdiv(222, _BLOCK_SIZE_0) * triton.cdiv(251, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_concat2d_dim1, (triton.cdiv(222, _BLOCK_SIZE_0) * triton.cdiv(251, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[concatenate.py:N]: return out return out @@ -1185,7 +1185,7 @@ def cross_entropy(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_defa # src[cross_entropy.py:N]: # Get data for this tile # src[cross_entropy.py:N]: labels_tile = labels[tile_n] # [tile_size] # src[cross_entropy.py:N-N]: ... - _launcher(_helion_cross_entropy, (128,), labels, logits_flat, logits, losses, _RDIM_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_cross_entropy, (128,), labels, logits_flat, logits, losses, _RDIM_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[cross_entropy.py:N]: return losses.mean() return losses.mean() @@ -1236,7 +1236,7 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc _BLOCK_SIZE_1 = 64 # src[embedding.py:N]: for tile_b, tile_e in hl.tile([x_flat.size(0), embedding_dim]): # src[embedding.py:N]: out[tile_b, tile_e] = weight[x_flat[tile_b], tile_e] - _launcher(_helion_embedding, (triton.cdiv(1024, _BLOCK_SIZE_0), triton.cdiv(256, _BLOCK_SIZE_1)), x_flat, weight, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_embedding, (triton.cdiv(1024, _BLOCK_SIZE_0), triton.cdiv(256, _BLOCK_SIZE_1)), x_flat, weight, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[embedding.py:N]: return out.view(*x.size(), embedding_dim) return out.view(*x.size(), embedding_dim) @@ -1288,7 +1288,7 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc _BLOCK_SIZE_1 = 256 # src[embedding.py:N]: for tile_b, tile_e in hl.tile([x_flat.size(0), embedding_dim]): # src[embedding.py:N]: out[tile_b, tile_e] = weight[x_flat[tile_b], tile_e] - _launcher(_helion_embedding, (1024 * triton.cdiv(256, _BLOCK_SIZE_1),), x_flat, weight, out, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_embedding, (1024 * triton.cdiv(256, _BLOCK_SIZE_1),), x_flat, weight, out, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[embedding.py:N]: return out.view(*x.size(), embedding_dim) return out.view(*x.size(), embedding_dim) @@ -1510,7 +1510,7 @@ def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, batc # src[fp8_attention.py:N]: # Calculate batch and head indices # src[fp8_attention.py:N]: b = bh // heads # src[fp8_attention.py:N-N]: ... - _launcher(_helion_fp8_attention_kernel, (8,), q, k, v, out, out.stride(0), heads, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_fp8_attention_kernel, (8,), q, k, v, out, out.stride(0), heads, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[fp8_attention.py:N]: return out return out @@ -1672,7 +1672,7 @@ def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float, # src[fused_linear_jsd.py:N]: student_prob = torch.log_softmax(student_logits[batch, :] / temperature, dim=-1) # src[fused_linear_jsd.py:N]: teacher_prob = torch.log_softmax(teacher_logits[batch, :] / temperature, dim=-1) # src[fused_linear_jsd.py:N-N]: ... - _launcher(_helion_fused_linear_jsd_kernel, (triton.cdiv(64, _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fused_linear_jsd_kernel, (triton.cdiv(64, _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[fused_linear_jsd.py:N]: return (loss / student_logits.shape[0]).sum() return (loss / student_logits.shape[0]).sum() @@ -1871,7 +1871,7 @@ def grouped_gemm_jagged(A_packed: torch.Tensor, B: torch.Tensor, group_offsets: # src[grouped_gemm.py:N]: start = group_offsets[g] # src[grouped_gemm.py:N]: end = group_offsets[g + 1] # src[grouped_gemm.py:N-N]: ... - _launcher(_helion_grouped_gemm_jagged, (G,), group_offsets, A_packed, B, out, A_packed.stride(0), A_packed.stride(1), B.stride(0), B.stride(1), group_offsets.stride(0), out.stride(0), out.stride(1), N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_grouped_gemm_jagged, (G,), group_offsets, A_packed, B, out, A_packed.stride(0), A_packed.stride(1), B.stride(0), B.stride(1), group_offsets.stride(0), out.stride(0), out.stride(1), N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[grouped_gemm.py:N]: return out return out @@ -2098,7 +2098,7 @@ def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, grou # src[grouped_gemm.py:N]: # Persistent thread pattern: each worker processes tiles across all groups # src[grouped_gemm.py:N]: # using strided/interleaved assignment for load balancing. # src[grouped_gemm.py:N-N]: ... - _launcher(_helion_grouped_gemm_jagged_persistent, (num_workers,), group_offsets, A_packed, B, out, A_packed.stride(0), A_packed.stride(1), B.stride(0), B.stride(1), group_offsets.stride(0), out.stride(0), out.stride(1), num_workers, G, N, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_5, num_warps=4, num_stages=2) + _launcher(_helion_grouped_gemm_jagged_persistent, (num_workers,), group_offsets, A_packed, B, out, A_packed.stride(0), A_packed.stride(1), B.stride(0), B.stride(1), group_offsets.stride(0), out.stride(0), out.stride(1), num_workers, G, N, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_5, num_warps=4, num_stages=1) # src[grouped_gemm.py:N]: return out return out @@ -2321,7 +2321,7 @@ def jagged_dense_add_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch. # src[jagged_dense_add.py:N]: starts = x_offsets[tile0] # src[jagged_dense_add.py:N]: ends = x_offsets[tile0.index + 1] # src[jagged_dense_add.py:N-N]: ... - _launcher(_helion_jagged_dense_add_2d, (triton.cdiv(500, _BLOCK_SIZE_0),), x_offsets, x_data, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_jagged_dense_add_2d, (triton.cdiv(500, _BLOCK_SIZE_0),), x_offsets, x_data, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[jagged_dense_add.py:N]: return out return out @@ -2478,7 +2478,7 @@ def _helion_jagged_attention_kernel(max_seq_len: int, alpha: float, q: torch.Ten # src[jagged_hstu_attn.py:N]: [num_batches, num_heads, max_seq_len], block_size=[1, 1, None] # src[jagged_hstu_attn.py:N]: ): # src[jagged_hstu_attn.py:N-N]: ... - _launcher(_helion__helion_jagged_attention_kernel, (4 * 8 * triton.cdiv(max_seq_len, _BLOCK_SIZE_2),), seq_offsets, q, k, v, out, max_seq_len, alpha, scale, _BLOCK_SIZE_2, _RDIM_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=2) + _launcher(_helion__helion_jagged_attention_kernel, (4 * 8 * triton.cdiv(max_seq_len, _BLOCK_SIZE_2),), seq_offsets, q, k, v, out, max_seq_len, alpha, scale, _BLOCK_SIZE_2, _RDIM_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1) # src[jagged_hstu_attn.py:N]: return out return out @@ -2656,7 +2656,7 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI v_34 = var_acc / v_33 # src[jagged_layer_norm.py:N]: rstd = torch.rsqrt(variance + eps) v_35 = v_34 + eps - v_36 = libdevice.rsqrt(v_35) + v_36 = tl.rsqrt(v_35) # src[jagged_layer_norm.py:N]: for tile_m in hl.tile(M): # src[jagged_layer_norm.py:N]: for tile_k in hl.tile(0, max_seq_len): # src[jagged_layer_norm.py:N]: # Compute indices into x_values @@ -2796,7 +2796,7 @@ def jagged_layer_norm_kernel(x_values: torch.Tensor, x_offsets: torch.Tensor, ep # src[jagged_layer_norm.py:N]: # Get sequence boundaries for this tile # src[jagged_layer_norm.py:N]: starts = x_offsets[tile_b] # src[jagged_layer_norm.py:N-N]: ... - _launcher(_helion_jagged_layer_norm_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x_offsets, x_flat, out_flat, eps, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, _BLOCK_SIZE_6, num_warps=4, num_stages=2) + _launcher(_helion_jagged_layer_norm_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x_offsets, x_flat, out_flat, eps, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, _BLOCK_SIZE_6, num_warps=4, num_stages=1) # src[jagged_layer_norm.py:N]: return out.reshape(total_L, M) return out.reshape(total_L, M) @@ -2944,7 +2944,7 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_ # src[jagged_mean.py:N]: starts = x_offsets[tile_b] # src[jagged_mean.py:N]: ends = x_offsets[tile_b.index + 1] # src[jagged_mean.py:N-N]: ... - _launcher(_helion_jagged_mean_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_jagged_mean_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[jagged_mean.py:N]: return out return out @@ -3165,7 +3165,7 @@ def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _lau # src[jagged_softmax.py:N]: starts = x_offsets[tile_b] # src[jagged_softmax.py:N]: ends = x_offsets[tile_b.index + 1] # src[jagged_softmax.py:N-N]: ... - _launcher(_helion_jagged_softmax_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x_offsets, x_flat, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_jagged_softmax_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x_offsets, x_flat, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[jagged_softmax.py:N]: return out.reshape(N, M) return out.reshape(N, M) @@ -3286,7 +3286,7 @@ def jagged_sum_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launche # src[jagged_sum.py:N]: starts = x_offsets[tile_b] # src[jagged_sum.py:N]: ends = x_offsets[tile_b.index + 1] # src[jagged_sum.py:N-N]: ... - _launcher(_helion_jagged_sum_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x_offsets, x_flat, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_jagged_sum_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x_offsets, x_flat, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[jagged_sum.py:N]: return out return out @@ -4136,12 +4136,12 @@ def low_mem_dropout(p: float, x: torch.Tensor, seed: int, *, _launcher=_default_ # src[low_mem_dropout.py:N]: out_flat = torch.empty_like(x_flat) out_flat = torch.empty_like(x_flat) # src[low_mem_dropout.py:N]: for tidx in hl.tile(n): - _BLOCK_SIZE_0 = 1024 + _BLOCK_SIZE_0 = 256 # src[low_mem_dropout.py:N]: for tidx in hl.tile(n): # src[low_mem_dropout.py:N]: xi = x_flat[tidx].to(torch.float32) # src[low_mem_dropout.py:N]: r = hl.rand([tidx], seed=seed) # src[low_mem_dropout.py:N-N]: ... - _launcher(_helion_low_mem_dropout, (triton.cdiv(n, _BLOCK_SIZE_0),), x_flat, out_flat, out_flat.stride(0), x_flat.stride(0), n, seed, p, scale, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_low_mem_dropout, (triton.cdiv(n, _BLOCK_SIZE_0),), x_flat, out_flat, out_flat.stride(0), x_flat.stride(0), n, seed, p, scale, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[low_mem_dropout.py:N]: return out_flat.view_as(x) return out_flat.view_as(x) @@ -4215,7 +4215,7 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[matmul.py:N]: for tile_k in hl.tile(k): # src[matmul.py:N-N]: ... - _launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[matmul.py:N]: return out return out @@ -4342,7 +4342,7 @@ def matmul_bwd(grad_out: Tensor, mat1: Tensor, mat2: Tensor, *, _launcher=_defau # src[matmul.py:N]: acc2 = hl.zeros([tile_k2, tile_n2], dtype=torch.float32) # src[matmul.py:N]: for tile_m2 in hl.tile(m): # src[matmul.py:N-N]: ... - _launcher(_helion_matmul_bwd, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1) + triton.cdiv(128, _BLOCK_SIZE_3) * triton.cdiv(128, _BLOCK_SIZE_4),), grad_out, mat2, grad_mat1, mat1, grad_mat2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=2) + _launcher(_helion_matmul_bwd, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1) + triton.cdiv(128, _BLOCK_SIZE_3) * triton.cdiv(128, _BLOCK_SIZE_4),), grad_out, mat2, grad_mat1, mat1, grad_mat2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=1) # src[matmul.py:N]: return grad_mat1, grad_mat2 return (grad_mat1, grad_mat2) @@ -4396,7 +4396,7 @@ def _helion_matmul_layernorm(x, y, weight, bias, out, bias_stride_0, out_stride_ # src[matmul_layernorm.py:N]: normalized = centered * torch.rsqrt(var + eps) v_7 = 1e-05 v_8 = v_6 + v_7 - v_9 = libdevice.rsqrt(v_8) + v_9 = tl.rsqrt(v_8) v_10 = v_3 * v_9 # src[matmul_layernorm.py:N]: acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) load_2 = tl.load(weight + indices_5 * weight_stride_0, mask_2, other=0) @@ -4448,7 +4448,7 @@ def matmul_layernorm(x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bia # src[matmul_layernorm.py:N]: acc = hl.zeros([tile_m, n], dtype=torch.float32) # src[matmul_layernorm.py:N]: for tile_k in hl.tile(k): # src[matmul_layernorm.py:N-N]: ... - _launcher(_helion_matmul_layernorm, (triton.cdiv(m, _BLOCK_SIZE_0),), x, y, weight, bias, out, bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, k, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_matmul_layernorm, (triton.cdiv(m, _BLOCK_SIZE_0),), x, y, weight, bias, out, bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, k, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[matmul_layernorm.py:N]: return out return out @@ -4498,7 +4498,7 @@ def _helion_matmul_layernorm(x, y, weight, bias, out, _BLOCK_SIZE_0: tl.constexp # src[matmul_layernorm.py:N]: normalized = centered * torch.rsqrt(var + eps) v_7 = 1e-05 v_8 = v_6 + v_7 - v_9 = libdevice.rsqrt(v_8) + v_9 = tl.rsqrt(v_8) v_10 = v_3 * v_9 # src[matmul_layernorm.py:N]: acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) load_2 = tl.load(weight + indices_5 * 1, mask_2, other=0) @@ -4550,7 +4550,7 @@ def matmul_layernorm(x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bia # src[matmul_layernorm.py:N]: acc = hl.zeros([tile_m, n], dtype=torch.float32) # src[matmul_layernorm.py:N]: for tile_k in hl.tile(k): # src[matmul_layernorm.py:N-N]: ... - _launcher(_helion_matmul_layernorm, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, weight, bias, out, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_matmul_layernorm, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, weight, bias, out, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[matmul_layernorm.py:N]: return out return out @@ -4643,7 +4643,7 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.T # src[matmul_split_k.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[matmul_split_k.py:N]: for inner_k in hl.tile(outer_k.begin, outer_k.end): # src[matmul_split_k.py:N-N]: ... - _launcher(_helion_matmul_split_k, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_matmul_split_k, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[matmul_split_k.py:N]: return out return out @@ -4777,7 +4777,7 @@ def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch. # src[moe_matmul_ogs.py:N]: start = expert_token_offsets[e_idx] # src[moe_matmul_ogs.py:N]: num_tokens = expert_token_counts[e_idx] # src[moe_matmul_ogs.py:N-N]: ... - _launcher(_helion_moe_matmul_ogs, (E,), expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A.stride(0), A.stride(1), C.stride(0), C.stride(1), W.stride(0), W.stride(1), W.stride(2), expert_token_counts.stride(0), expert_token_offsets.stride(0), sorted_to_orig_token_idx.stride(0), max_T_per_expert, N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_moe_matmul_ogs, (E,), expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A.stride(0), A.stride(1), C.stride(0), C.stride(1), W.stride(0), W.stride(1), W.stride(2), expert_token_counts.stride(0), expert_token_offsets.stride(0), sorted_to_orig_token_idx.stride(0), max_T_per_expert, N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[moe_matmul_ogs.py:N]: return C return C @@ -4914,7 +4914,7 @@ def _helion_rms_norm_fwd(x, weight, out, inv_rms, eps, _BLOCK_SIZE_0: tl.constex v_3 = mean_x_squared_extra / v_2.to(tl.float32) # src[rms_norm.py:N]: inv_rms_tile = torch.rsqrt(mean_x_squared + eps) v_4 = v_3 + eps - v_5 = libdevice.rsqrt(v_4) + v_5 = tl.rsqrt(v_4) # src[rms_norm.py:N]: normalized = x_tile * inv_rms_tile[:, None] subscript = v_5[:, None] v_6 = v_0 * subscript @@ -4959,7 +4959,7 @@ def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _la # src[rms_norm.py:N]: for tile_m in hl.tile(m): # src[rms_norm.py:N]: x_tile = x[tile_m, :].to(torch.float32) # src[rms_norm.py:N-N]: ... - _launcher(_helion_rms_norm_fwd, (triton.cdiv(128, _BLOCK_SIZE_0),), x, weight, out, inv_rms, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_rms_norm_fwd, (triton.cdiv(128, _BLOCK_SIZE_0),), x, weight, out, inv_rms, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[rms_norm.py:N]: return out, inv_rms.reshape(-1, 1) return (out, inv_rms.reshape(-1, 1)) @@ -5072,7 +5072,7 @@ def segmented_reduction_helion(indices: torch.Tensor, input_data: torch.Tensor, # src[segment_reduction.py:N]: vals = input_data[tile_e, tile_f] # src[segment_reduction.py:N]: idxs = indices[tile_e] # src[segment_reduction.py:N-N]: ... - _launcher(_helion_segmented_reduction_helion, (triton.cdiv(1000, _BLOCK_SIZE_0) * triton.cdiv(32, _BLOCK_SIZE_1),), input_data, indices, output, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_segmented_reduction_helion, (triton.cdiv(1000, _BLOCK_SIZE_0) * triton.cdiv(32, _BLOCK_SIZE_1),), input_data, indices, output, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[segment_reduction.py:N]: return output return output @@ -5210,7 +5210,7 @@ def softmax_bwd(grad_output: torch.Tensor, softmax_output: torch.Tensor, *, _lau # src[softmax.py:N]: sum_per_row = hl.zeros([tile_m], dtype=torch.float32) # src[softmax.py:N]: for tile_n in hl.tile(n): # src[softmax.py:N-N]: ... - _launcher(_helion_softmax_bwd, (triton.cdiv(2048, _BLOCK_SIZE_0),), softmax_output, grad_output, grad_input, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_softmax_bwd, (triton.cdiv(2048, _BLOCK_SIZE_0),), softmax_output, grad_output, grad_input, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[softmax.py:N]: return grad_input return grad_input @@ -5418,7 +5418,7 @@ def softmax_two_pass(x: torch.Tensor, *, _launcher=_default_launcher): # src[softmax.py:N]: mi = hl.full([tile_m], float("-inf"), dtype=torch.float32) # src[softmax.py:N]: di = hl.zeros([tile_m], dtype=torch.float32) # src[softmax.py:N-N]: ... - _launcher(_helion_softmax_two_pass, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_softmax_two_pass, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[softmax.py:N]: return out return out @@ -5512,7 +5512,7 @@ def softmax_two_pass(x: torch.Tensor, *, _launcher=_default_launcher): # src[softmax.py:N]: mi = hl.full([tile_m], float("-inf"), dtype=torch.float32) # src[softmax.py:N]: di = hl.zeros([tile_m], dtype=torch.float32) # src[softmax.py:N-N]: ... - _launcher(_helion_softmax_two_pass, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_softmax_two_pass, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[softmax.py:N]: return out return out @@ -5939,7 +5939,7 @@ def sum_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _REDUCTION_BLOCK_1 = 32768 # src[sum.py:N]: for tile_m in hl.tile(m): # src[sum.py:N]: out[tile_m] = x[tile_m, :].sum(-1) - _launcher(_helion_sum_kernel, (512,), x, out, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) + _launcher(_helion_sum_kernel, (512,), x, out, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=1) # src[sum.py:N]: return out return out @@ -6077,12 +6077,12 @@ def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor, *, _launcher=_default_launc # src[swiglu.py:N]: dx2_flat = dx2.view(-1) dx2_flat = dx2.view(-1) # src[swiglu.py:N]: for tile in hl.tile(x1.numel()): - _BLOCK_SIZE_0 = 1024 + _BLOCK_SIZE_0 = 256 # src[swiglu.py:N]: for tile in hl.tile(x1.numel()): # src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32) # src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32) # src[swiglu.py:N-N]: ... - _launcher(_helion_swiglu_bwd, (triton.cdiv(1024, _BLOCK_SIZE_0),), x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_swiglu_bwd, (triton.cdiv(1024, _BLOCK_SIZE_0),), x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[swiglu.py:N]: return dx1, dx2 return (dx1, dx2) @@ -6409,7 +6409,7 @@ def _helion_welford(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _BLO # src[welford.py:N]: rstd_tile = torch.rsqrt(acc_m2 / acc_cnt + eps) v_23 = acc_m2 / acc_cnt v_24 = v_23 + eps - v_25 = libdevice.rsqrt(v_24) + v_25 = tl.rsqrt(v_24) # src[welford.py:N]: mean_col = acc_mean[:, None] mean_col = acc_mean[:, None] # src[welford.py:N]: rstd_col = rstd_tile[:, None] @@ -6471,6 +6471,6 @@ def welford(weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: floa # src[welford.py:N]: acc_cnt = torch.zeros_like(x[tile_m, 0], dtype=torch.float32) # src[welford.py:N]: acc_mean = torch.zeros_like(acc_cnt) # src[welford.py:N-N]: ... - _launcher(_helion_welford, (triton.cdiv(128, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_welford, (triton.cdiv(128, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[welford.py:N]: return out return out diff --git a/test/test_generate_ast.expected b/test/test_generate_ast.expected index 0963b5afe..fcf2ae633 100644 --- a/test/test_generate_ast.expected +++ b/test/test_generate_ast.expected @@ -30,7 +30,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 1024 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(4096, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(4096, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -64,7 +64,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 1024 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(50000, _BLOCK_SIZE_0_1), 1, 1), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(50000, _BLOCK_SIZE_0_1), 1, 1), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -98,7 +98,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 1024 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(50000, _BLOCK_SIZE_0_1), 1, 1), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(50000, _BLOCK_SIZE_0_1), 1, 1), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -133,7 +133,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1_2 = 1024 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(500000, _BLOCK_SIZE_0_1_2), 1, 1), x, y, out, _BLOCK_SIZE_0_1_2, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(500000, _BLOCK_SIZE_0_1_2), 1, 1), x, y, out, _BLOCK_SIZE_0_1_2, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -168,7 +168,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1_2 = 1024 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(500000, _BLOCK_SIZE_0_1_2), 1, 1), x, y, out, _BLOCK_SIZE_0_1_2, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(500000, _BLOCK_SIZE_0_1_2), 1, 1), x, y, out, _BLOCK_SIZE_0_1_2, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -212,7 +212,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_2 = 16 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(100, _BLOCK_SIZE_0), triton.cdiv(500, _BLOCK_SIZE_1), triton.cdiv(10, _BLOCK_SIZE_2)), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(100, _BLOCK_SIZE_0), triton.cdiv(500, _BLOCK_SIZE_1), triton.cdiv(10, _BLOCK_SIZE_2)), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -255,7 +255,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_2 = 32 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(512, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(512, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -298,7 +298,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 8 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(512, _BLOCK_SIZE_2) * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(512, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(512, _BLOCK_SIZE_2) * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(512, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -340,7 +340,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_2 = 32 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (512 * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(512, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_add, (512 * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(512, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -419,7 +419,7 @@ def hl_full_usage(x: torch.Tensor, *, _launcher=_default_launcher): # src[basic_kernels.py:N]: tmp = hl.full(tile, 1, dtype=x.dtype) # src[basic_kernels.py:N]: tmp += x[tile] # src[basic_kernels.py:N-N]: ... - _launcher(_helion_hl_full_usage, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_hl_full_usage, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -457,7 +457,7 @@ def hl_zeros_usage(x: torch.Tensor, *, _launcher=_default_launcher): # src[basic_kernels.py:N]: tmp = hl.zeros(tile, dtype=x.dtype) # src[basic_kernels.py:N]: tmp += x[tile] # src[basic_kernels.py:N-N]: ... - _launcher(_helion_hl_zeros_usage, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=2) + _launcher(_helion_hl_zeros_usage, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -500,7 +500,7 @@ def hl_zeros_usage(x: torch.Tensor, *, _launcher=_default_launcher): # src[basic_kernels.py:N]: tmp = hl.zeros(tile, dtype=x.dtype) # src[basic_kernels.py:N]: tmp += x[tile] # src[basic_kernels.py:N-N]: ... - _launcher(_helion_hl_zeros_usage, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_hl_zeros_usage, (triton.cdiv(512, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -531,7 +531,7 @@ def inplace_mul(x, c, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 128 # src[basic_kernels.py:N]: for tile in hl.tile(x.size()): # src[basic_kernels.py:N]: x[tile] *= c - _launcher(_helion_inplace_mul, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, c, _BLOCK_SIZE_0_1, num_warps=4, num_stages=2) + _launcher(_helion_inplace_mul, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, c, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return x return x @@ -566,6 +566,6 @@ def torch_ops_pointwise(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 128 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = torch.sigmoid(torch.add(torch.sin(x[tile]), torch.cos(y[tile]))) - _launcher(_helion_torch_ops_pointwise, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_torch_ops_pointwise, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out diff --git a/test/test_graph_module.expected b/test/test_graph_module.expected index af77170f2..f207ffea1 100644 --- a/test/test_graph_module.expected +++ b/test/test_graph_module.expected @@ -29,10 +29,10 @@ def apply_graph_module(func_m, x, *, _launcher=_default_launcher): # src[test_graph_module.py:N]: out = torch.empty_like(x) out = torch.empty_like(x) # src[test_graph_module.py:N]: for tile in hl.tile(out.size()): - _BLOCK_SIZE_0 = 1024 + _BLOCK_SIZE_0 = 256 # src[test_graph_module.py:N]: for tile in hl.tile(out.size()): # src[test_graph_module.py:N]: out[tile] = func_m(x[tile]) - _launcher(_helion_apply_graph_module, (triton.cdiv(1000, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_apply_graph_module, (triton.cdiv(1000, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_graph_module.py:N]: return out return out @@ -68,9 +68,9 @@ def apply_graph_module(func_m, x, *, _launcher=_default_launcher): # src[test_graph_module.py:N]: out = torch.empty_like(x) out = torch.empty_like(x) # src[test_graph_module.py:N]: for tile in hl.tile(out.size()): - _BLOCK_SIZE_0 = 512 + _BLOCK_SIZE_0 = 256 # src[test_graph_module.py:N]: for tile in hl.tile(out.size()): # src[test_graph_module.py:N]: out[tile] = func_m(x[tile]) - _launcher(_helion_apply_graph_module, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_apply_graph_module, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_graph_module.py:N]: return out return out diff --git a/test/test_grid.expected b/test/test_grid.expected index e3f00eca1..eb7e6238e 100644 --- a/test/test_grid.expected +++ b/test/test_grid.expected @@ -63,7 +63,7 @@ def grid_1d(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_grid.py:N]: for tile_m, tile_n in hl.tile([m, n]): # src[test_grid.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_grid.py:N-N]: ... - _launcher(_helion_grid_1d, (8,), x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_grid_1d, (8,), x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -129,7 +129,7 @@ def grid_1d(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_grid.py:N]: for tile_m, tile_n in hl.tile([m, n]): # src[test_grid.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_grid.py:N-N]: ... - _launcher(_helion_grid_1d, (8,), x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_grid_1d, (8,), x, y, out, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -205,7 +205,7 @@ def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_lau # src[test_grid.py:N]: for tile_m, tile_n in hl.tile([m, n]): # src[test_grid.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_grid.py:N-N]: ... - _launcher(_helion_grid_2d_idx_list, (3 * 4,), x, y, out, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_4, num_warps=4, num_stages=2) + _launcher(_helion_grid_2d_idx_list, (3 * 4,), x, y, out, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_4, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -291,7 +291,7 @@ def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_lau # src[test_grid.py:N]: for tile_m, tile_n in hl.tile([m, n]): # src[test_grid.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_grid.py:N-N]: ... - _launcher(_helion_grid_2d_idx_list, (3 * 4,), x, y, out, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_4, num_warps=4, num_stages=2) + _launcher(_helion_grid_2d_idx_list, (3 * 4,), x, y, out, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_4, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -369,7 +369,7 @@ def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_l # src[test_grid.py:N]: for j in hl.grid(bj): # src[test_grid.py:N]: for tile_m, tile_n in hl.tile([m, n]): # src[test_grid.py:N-N]: ... - _launcher(_helion_grid_2d_idx_nested, (3,), x, y, out, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_4, num_warps=4, num_stages=2) + _launcher(_helion_grid_2d_idx_nested, (3,), x, y, out, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_4, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -400,7 +400,7 @@ def grid_begin_end(x: torch.Tensor, *, _launcher=_default_launcher): out = torch.zeros_like(x) # src[test_grid.py:N]: for i in hl.grid(2, n - 2): # grid(begin, end) # src[test_grid.py:N]: out[i] = x[i] * 2 - _launcher(_helion_grid_begin_end, (12,), x, out, num_warps=4, num_stages=2) + _launcher(_helion_grid_begin_end, (12,), x, out, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -432,7 +432,7 @@ def grid_begin_end_step(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 2 # src[test_grid.py:N]: for i in hl.grid(0, n, 2): # grid(begin, end, step) # src[test_grid.py:N]: out[i] = x[i] * 2 - _launcher(_helion_grid_begin_end_step, (triton.cdiv(16, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_grid_begin_end_step, (triton.cdiv(16, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -464,7 +464,7 @@ def grid_end_step_kwarg(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 2 # src[test_grid.py:N]: for i in hl.grid(n, step=2): # grid(end, step=step) # src[test_grid.py:N]: out[i] = x[i] * 2 - _launcher(_helion_grid_end_step_kwarg, (triton.cdiv(16, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_grid_end_step_kwarg, (triton.cdiv(16, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -503,7 +503,7 @@ def grid_multidim_begin_end(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_grid.py:N]: [1, 1], [m - 1, n - 1] # src[test_grid.py:N]: ): # multidimensional grid(begin, end) # src[test_grid.py:N-N]: ... - _launcher(_helion_grid_multidim_begin_end, (6 * 6,), x, out, num_warps=4, num_stages=2) + _launcher(_helion_grid_multidim_begin_end, (6 * 6,), x, out, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -545,7 +545,7 @@ def grid_multidim_begin_end_step(x: torch.Tensor, *, _launcher=_default_launcher # src[test_grid.py:N]: [0, 0], [m, n], [2, 3] # src[test_grid.py:N]: ): # multidimensional grid(begin, end, step) # src[test_grid.py:N-N]: ... - _launcher(_helion_grid_multidim_begin_end_step, (triton.cdiv(8, _BLOCK_SIZE_0) * triton.cdiv(9, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_grid_multidim_begin_end_step, (triton.cdiv(8, _BLOCK_SIZE_0) * triton.cdiv(9, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -588,7 +588,7 @@ def range_step_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_grid.py:N]: for tile_batch in hl.tile(batch): # src[test_grid.py:N]: for i in range(1, 10, 2): # range(begin, end, step) # src[test_grid.py:N]: out[tile_batch] += x[tile_batch] / i - _launcher(_helion_range_step_kernel, (triton.cdiv(6, _BLOCK_SIZE_0),), out, x, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_range_step_kernel, (triton.cdiv(6, _BLOCK_SIZE_0),), out, x, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out @@ -620,6 +620,6 @@ def tile_begin_end(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 4 # src[test_grid.py:N]: for tile in hl.tile(2, 10): # tile(begin, end) - simple range [2, 10) # src[test_grid.py:N]: out[tile] = x[tile] * 2 - _launcher(_helion_tile_begin_end, (triton.cdiv(8, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_tile_begin_end, (triton.cdiv(8, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_grid.py:N]: return out return out diff --git a/test/test_indexing.expected b/test/test_indexing.expected index 5ff927e27..202c85971 100644 --- a/test/test_indexing.expected +++ b/test/test_indexing.expected @@ -26,7 +26,7 @@ def arange(length: int, device: torch.device, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 32 # src[test_indexing.py:N]: for tile in hl.tile(length): # src[test_indexing.py:N]: out[tile] = tile.index - _launcher(_helion_arange, (triton.cdiv(length, _BLOCK_SIZE_0),), out, length, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_arange, (triton.cdiv(length, _BLOCK_SIZE_0),), out, length, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -61,7 +61,7 @@ def arange_block_size_mul(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: indices = hl.arange( # src[test_indexing.py:N]: tile.begin * 2, tile.begin * 2 + tile.block_size * 2 # src[test_indexing.py:N-N]: ... - _launcher(_helion_arange_block_size_mul, (triton.cdiv(64, _BLOCK_SIZE_0),), out, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_arange_block_size_mul, (triton.cdiv(64, _BLOCK_SIZE_0),), out, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -95,7 +95,7 @@ def arange_three_args_step(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: # Test the exact pattern requested: torch.arange(start, end, step=2, device=x.device) # src[test_indexing.py:N]: start_idx = tile.begin * 2 # src[test_indexing.py:N-N]: ... - _launcher(_helion_arange_three_args_step, (triton.cdiv(32, _BLOCK_SIZE_0),), out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_arange_three_args_step, (triton.cdiv(32, _BLOCK_SIZE_0),), out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -150,7 +150,7 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor, # src[test_indexing.py:N]: # bias1 has shape [1, d1, d2], bias2 has shape [d0, 1, d2] # src[test_indexing.py:N]: out[tile_l, tile_m, tile_n] = ( # src[test_indexing.py:N-N]: ... - _launcher(_helion_broadcast_add_3d, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1) * triton.cdiv(32, _BLOCK_SIZE_2),), x, bias1, bias2, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_broadcast_add_3d, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1) * triton.cdiv(32, _BLOCK_SIZE_2),), x, bias1, bias2, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -208,7 +208,7 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor, # src[test_indexing.py:N]: # bias1 has shape [1, d1, d2], bias2 has shape [d0, 1, d2] # src[test_indexing.py:N]: out[tile_l, tile_m, tile_n] = ( # src[test_indexing.py:N-N]: ... - _launcher(_helion_broadcast_add_3d, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1) * triton.cdiv(32, _BLOCK_SIZE_2),), x, bias1, bias2, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_broadcast_add_3d, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1) * triton.cdiv(32, _BLOCK_SIZE_2),), x, bias1, bias2, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -281,7 +281,7 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor, # src[test_indexing.py:N]: # bias1 has shape [1, d1, d2], bias2 has shape [d0, 1, d2] # src[test_indexing.py:N]: out[tile_l, tile_m, tile_n] = ( # src[test_indexing.py:N-N]: ... - _launcher(_helion_broadcast_add_3d, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1) * triton.cdiv(32, _BLOCK_SIZE_2),), x, bias1, bias2, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_broadcast_add_3d, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1) * triton.cdiv(32, _BLOCK_SIZE_2),), x, bias1, bias2, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -323,7 +323,7 @@ def masked_load(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 16 # src[test_indexing.py:N]: for tile in hl.tile(out.size(0)): # src[test_indexing.py:N]: out[tile] = hl.load(x, [tile], extra_mask=(tile.index % 2) == 0) - _launcher(_helion_masked_load, (triton.cdiv(200, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_masked_load, (triton.cdiv(200, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -365,7 +365,7 @@ def masked_store(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 16 # src[test_indexing.py:N]: for tile in hl.tile(out.size(0)): # src[test_indexing.py:N]: hl.store(out, [tile], x[tile], extra_mask=(tile.index % 2) == 0) - _launcher(_helion_masked_store, (triton.cdiv(200, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_masked_store, (triton.cdiv(200, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -399,7 +399,7 @@ def pairwise_add(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 32 # src[test_indexing.py:N]: for tile in hl.tile(out.size(0)): # src[test_indexing.py:N]: out[tile] = x[tile] + x[tile.index + 1] - _launcher(_helion_pairwise_add, (triton.cdiv(499, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_pairwise_add, (triton.cdiv(499, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -441,7 +441,7 @@ def pairwise_add_variants(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: left = x[1 + tile.index] # src[test_indexing.py:N]: right = x[tile.index + 1 + 2] # src[test_indexing.py:N-N]: ... - _launcher(_helion_pairwise_add_variants, (triton.cdiv(253, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_pairwise_add_variants, (triton.cdiv(253, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -483,7 +483,7 @@ def load_store_kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_la # src[test_indexing.py:N]: # 2 loads # src[test_indexing.py:N]: val_a = a[tile_m, tile_n] # src[test_indexing.py:N-N]: ... - _launcher(_helion_load_store_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_load_store_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -525,7 +525,7 @@ def load_store_kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_la # src[test_indexing.py:N]: # 2 loads # src[test_indexing.py:N]: val_a = a[tile_m, tile_n] # src[test_indexing.py:N-N]: ... - _launcher(_helion_load_store_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_load_store_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -565,7 +565,7 @@ def load_store_kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_la # src[test_indexing.py:N]: # 2 loads # src[test_indexing.py:N]: val_a = a[tile_m, tile_n] # src[test_indexing.py:N-N]: ... - _launcher(_helion_load_store_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_load_store_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -605,7 +605,7 @@ def load_store_kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_la # src[test_indexing.py:N]: # 2 loads # src[test_indexing.py:N]: val_a = a[tile_m, tile_n] # src[test_indexing.py:N-N]: ... - _launcher(_helion_load_store_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_load_store_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -650,7 +650,7 @@ def multi_load_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, _lau # src[test_indexing.py:N]: val_a = a[tile_m, tile_n] # src[test_indexing.py:N]: val_b = b[tile_m, tile_n] # src[test_indexing.py:N-N]: ... - _launcher(_helion_multi_load_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, c, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_multi_load_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, b, c, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -695,7 +695,7 @@ def many_loads_kernel(a: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: v1 = a[tile_m, tile_n] # src[test_indexing.py:N]: v2 = a[tile_m, tile_n] # src[test_indexing.py:N-N]: ... - _launcher(_helion_many_loads_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_many_loads_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -740,7 +740,7 @@ def many_loads_kernel(a: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: v1 = a[tile_m, tile_n] # src[test_indexing.py:N]: v2 = a[tile_m, tile_n] # src[test_indexing.py:N-N]: ... - _launcher(_helion_many_loads_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_many_loads_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -785,7 +785,7 @@ def many_loads_kernel(a: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: v1 = a[tile_m, tile_n] # src[test_indexing.py:N]: v2 = a[tile_m, tile_n] # src[test_indexing.py:N-N]: ... - _launcher(_helion_many_loads_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_many_loads_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), a, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -821,7 +821,7 @@ def arange_block_size_mul(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: indices_start = tile.begin * 2 # src[test_indexing.py:N]: indices_end = indices_start + tile.block_size * 2 # src[test_indexing.py:N-N]: ... - _launcher(_helion_arange_block_size_mul, (triton.cdiv(64, _BLOCK_SIZE_0),), ones, out, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_arange_block_size_mul, (triton.cdiv(64, _BLOCK_SIZE_0),), ones, out, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -851,7 +851,7 @@ def fn(n: int, device: torch.device, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 64 # src[test_indexing.py:N]: for tile in hl.tile(n, block_size=64): # src[test_indexing.py:N]: out[tile] = tile.count - _launcher(_helion_fn, (triton.cdiv(n, _BLOCK_SIZE_0),), out, n, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(n, _BLOCK_SIZE_0),), out, n, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -876,7 +876,7 @@ def fn(begin: int, end: int, device: torch.device, *, _launcher=_default_launche _BLOCK_SIZE_0 = 32 # src[test_indexing.py:N]: for tile in hl.tile(begin, end, block_size=32): # src[test_indexing.py:N]: out[0] = tile.count - _launcher(_helion_fn, (triton.cdiv(end + -1 * begin, _BLOCK_SIZE_0),), out, begin, end, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(end + -1 * begin, _BLOCK_SIZE_0),), out, begin, end, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -910,7 +910,7 @@ def tile_offset_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: # Use tile + offset pattern # src[test_indexing.py:N]: tile_offset = tile + 10 # src[test_indexing.py:N-N]: ... - _launcher(_helion_tile_offset_kernel, (triton.cdiv(190, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_tile_offset_kernel, (triton.cdiv(190, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -1045,7 +1045,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la # src[test_indexing.py:N]: m_i = hl.zeros([tile_m]) - float("inf") # src[test_indexing.py:N]: l_i = hl.zeros([tile_m]) + 1.0 # src[test_indexing.py:N-N]: ... - _launcher(_helion_attention, (triton.cdiv(8192, _BLOCK_SIZE_0),), q, k, v, lse, o, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_attention, (triton.cdiv(8192, _BLOCK_SIZE_0),), q, k, v, lse, o, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return o.reshape(B, H, M, Dv), lse.reshape(B, H, M) return (o.reshape(B, H, M, Dv), lse.reshape(B, H, M)) @@ -1080,7 +1080,7 @@ def tile_offset_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: # Use tile + offset pattern # src[test_indexing.py:N]: tile_offset = tile + 10 # src[test_indexing.py:N-N]: ... - _launcher(_helion_tile_offset_kernel, (triton.cdiv(190, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_tile_offset_kernel, (triton.cdiv(190, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out @@ -1127,6 +1127,6 @@ def tile_offset_2d_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_indexing.py:N]: # Use tile + offset pattern # src[test_indexing.py:N]: tile_offset = tile_m + 10 # src[test_indexing.py:N-N]: ... - _launcher(_helion_tile_offset_2d_kernel, (triton.cdiv(118, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_tile_offset_2d_kernel, (triton.cdiv(118, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_indexing.py:N]: return out return out diff --git a/test/test_inline_asm_elementwise.expected b/test/test_inline_asm_elementwise.expected index 5655f4821..e2ccca1ab 100644 --- a/test/test_inline_asm_elementwise.expected +++ b/test/test_inline_asm_elementwise.expected @@ -34,7 +34,7 @@ def kernel_basic(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_inline_asm_elementwise.py:N]: # Simple compilation test # src[test_inline_asm_elementwise.py:N]: result_val = hl.inline_asm_elementwise( # src[test_inline_asm_elementwise.py:N-N]: ... - _launcher(_helion_kernel_basic, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_basic, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_inline_asm_elementwise.py:N]: return result return result @@ -69,7 +69,7 @@ def kernel_empty_args(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_inline_asm_elementwise.py:N]: # Empty args should work - generates output with context shape # src[test_inline_asm_elementwise.py:N]: result_val = hl.inline_asm_elementwise( # src[test_inline_asm_elementwise.py:N-N]: ... - _launcher(_helion_kernel_empty_args, (triton.cdiv(16, _BLOCK_SIZE_0),), result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_empty_args, (triton.cdiv(16, _BLOCK_SIZE_0),), result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_inline_asm_elementwise.py:N]: return result return result @@ -113,7 +113,7 @@ def kernel_multiple_outputs(a: torch.Tensor, b: torch.Tensor, *, _launcher=_defa # src[test_inline_asm_elementwise.py:N]: val_a = a[tile] # src[test_inline_asm_elementwise.py:N]: val_b = b[tile] # src[test_inline_asm_elementwise.py:N-N]: ... - _launcher(_helion_kernel_multiple_outputs, (triton.cdiv(64, _BLOCK_SIZE_0),), a, b, result_c, result_d, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_multiple_outputs, (triton.cdiv(64, _BLOCK_SIZE_0),), a, b, result_c, result_d, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_inline_asm_elementwise.py:N]: return result_c, result_d return (result_c, result_d) @@ -146,12 +146,12 @@ def kernel_packed_asm(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_inline_asm_elementwise.py:N]: result = torch.empty_like(x) result = torch.empty_like(x) # src[test_inline_asm_elementwise.py:N]: for tile in hl.tile(x.shape): - _BLOCK_SIZE_0 = 512 + _BLOCK_SIZE_0 = 256 # src[test_inline_asm_elementwise.py:N]: for tile in hl.tile(x.shape): # src[test_inline_asm_elementwise.py:N]: val = x[tile] # src[test_inline_asm_elementwise.py:N]: # Shift 4x8bit values together, pack=4 # src[test_inline_asm_elementwise.py:N-N]: ... - _launcher(_helion_kernel_packed_asm, (triton.cdiv(512, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_packed_asm, (triton.cdiv(512, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_inline_asm_elementwise.py:N]: return result return result @@ -192,7 +192,7 @@ def kernel_shift_asm(x: torch.Tensor, y: torch.Tensor, n: int, *, _launcher=_def # src[test_inline_asm_elementwise.py:N]: val_x = x[tile] # src[test_inline_asm_elementwise.py:N]: val_y = y[tile] # src[test_inline_asm_elementwise.py:N-N]: ... - _launcher(_helion_kernel_shift_asm, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, result, n, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_shift_asm, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, result, n, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_inline_asm_elementwise.py:N]: return result return result @@ -229,6 +229,6 @@ def kernel_simple_asm(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_inline_asm_elementwise.py:N]: val = x[tile] # src[test_inline_asm_elementwise.py:N]: # Simple mov instruction - copy input to output # src[test_inline_asm_elementwise.py:N-N]: ... - _launcher(_helion_kernel_simple_asm, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_simple_asm, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_inline_asm_elementwise.py:N]: return result return result diff --git a/test/test_inline_triton.expected b/test/test_inline_triton.expected index da0e6a74d..450eae4a0 100644 --- a/test/test_inline_triton.expected +++ b/test/test_inline_triton.expected @@ -39,7 +39,7 @@ def kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_inline_triton.py:N]: x_val = x[tile] # src[test_inline_triton.py:N]: y_val = y[tile] # src[test_inline_triton.py:N-N]: ... - _launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_inline_triton.py:N]: return out return out @@ -93,7 +93,7 @@ def kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher): # src[test_inline_triton.py:N]: a_val = a[tile] # src[test_inline_triton.py:N]: b_val = b[tile] # src[test_inline_triton.py:N-N]: ... - _launcher(_helion_kernel, (triton.cdiv(64, _BLOCK_SIZE_0),), a, b, sum_out, diff_out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel, (triton.cdiv(64, _BLOCK_SIZE_0),), a, b, sum_out, diff_out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_inline_triton.py:N]: return sum_out, diff_out return (sum_out, diff_out) @@ -136,6 +136,6 @@ def kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_inline_triton.py:N]: x_val = x[tile] # src[test_inline_triton.py:N]: y_val = y[tile] # src[test_inline_triton.py:N-N]: ... - _launcher(_helion_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_inline_triton.py:N]: return out return out diff --git a/test/test_loops.expected b/test/test_loops.expected index d158c39d5..d56cc6b25 100644 --- a/test/test_loops.expected +++ b/test/test_loops.expected @@ -50,7 +50,7 @@ def device_loop_3d(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: for tile_b, tile_c, tile_d in hl.tile([b, c, d]): # src[test_loops.py:N]: out[tile_a, tile_b, tile_c, tile_d] = torch.sin( # src[test_loops.py:N-N]: ... - _launcher(_helion_device_loop_3d, (128,), x, out, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_device_loop_3d, (128,), x, out, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -104,7 +104,7 @@ def device_loop_3d(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: for tile_b, tile_c, tile_d in hl.tile([b, c, d]): # src[test_loops.py:N]: out[tile_a, tile_b, tile_c, tile_d] = torch.sin( # src[test_loops.py:N-N]: ... - _launcher(_helion_device_loop_3d, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_device_loop_3d, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -156,7 +156,7 @@ def device_loop_3d(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: for tile_b, tile_c, tile_d in hl.tile([b, c, d]): # src[test_loops.py:N]: out[tile_a, tile_b, tile_c, tile_d] = torch.sin( # src[test_loops.py:N-N]: ... - _launcher(_helion_device_loop_3d, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1_2_3, num_warps=4, num_stages=2) + _launcher(_helion_device_loop_3d, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1_2_3, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -208,7 +208,7 @@ def device_loop_3d(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: for tile_b, tile_c, tile_d in hl.tile([b, c, d]): # src[test_loops.py:N]: out[tile_a, tile_b, tile_c, tile_d] = torch.sin( # src[test_loops.py:N-N]: ... - _launcher(_helion_device_loop_3d, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_device_loop_3d, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -292,7 +292,7 @@ def chebyshev_kernel(x: torch.Tensor, w: torch.Tensor, *, _launcher=_default_lau # src[test_loops.py:N]: in_x = x[b_tile, c_tile] # src[test_loops.py:N]: T0 = hl.full((b_tile, c_tile), 1.0, x.dtype) # src[test_loops.py:N-N]: ... - _launcher(_helion_chebyshev_kernel, (triton.cdiv(123, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_chebyshev_kernel, (triton.cdiv(123, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -340,7 +340,7 @@ def fn(x: torch.Tensor, end: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: acc = hl.zeros([tile0, bs]) # src[test_loops.py:N]: for tile1 in hl.tile(end[0], block_size=bs): # src[test_loops.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_1),), end, x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_1),), end, x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -388,7 +388,7 @@ def fn(x: torch.Tensor, end: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: acc = hl.zeros([tile0]) # src[test_loops.py:N]: for tile1 in hl.tile(end[0]): # src[test_loops.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0),), end, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0),), end, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -447,7 +447,7 @@ def fn(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor, *, _launcher=_de # src[test_loops.py:N]: acc = hl.zeros([tile0], dtype=x.dtype) # src[test_loops.py:N]: for tile1, tile2 in hl.tile([end0[0], end1[0]]): # src[test_loops.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(32, _BLOCK_SIZE_0),), end0, end1, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(32, _BLOCK_SIZE_0),), end0, end1, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -496,7 +496,7 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor, *, _launcher=_de # src[test_loops.py:N]: acc = hl.zeros([tile0, bs]) # src[test_loops.py:N]: for tile1 in hl.tile(begin[0], end[0], block_size=bs): # src[test_loops.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_1),), begin, end, x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_1),), begin, end, x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -545,7 +545,7 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor, *, _launcher=_de # src[test_loops.py:N]: acc = hl.zeros([tile0]) # src[test_loops.py:N]: for (tile1,) in hl.tile([begin[0]], [end[0]]): # src[test_loops.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0),), begin, end, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0),), begin, end, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -587,7 +587,7 @@ def kernel_with_dynamic_fill(x: torch.Tensor, fill_value: torch.Tensor, *, _laun # src[test_loops.py:N]: # Use scalar tensor as fill value # src[test_loops.py:N]: filled = hl.full((b_tile, c_tile), fill_value[0], x.dtype) # src[test_loops.py:N-N]: ... - _launcher(_helion_kernel_with_dynamic_fill, (triton.cdiv(4, _BLOCK_SIZE_0) * triton.cdiv(8, _BLOCK_SIZE_1),), fill_value, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_dynamic_fill, (triton.cdiv(4, _BLOCK_SIZE_0) * triton.cdiv(8, _BLOCK_SIZE_1),), fill_value, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -629,7 +629,7 @@ def add_3d_kernel_l2(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_lau result = x.new_empty(x.size()) # src[test_loops.py:N]: for tile in hl.grid(x.size()): # src[test_loops.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_3d_kernel_l2, (16 * 32 * 64,), x, y, result, num_warps=4, num_stages=2) + _launcher(_helion_add_3d_kernel_l2, (16 * 32 * 64,), x, y, result, num_warps=4, num_stages=1) # src[test_loops.py:N]: return result return result @@ -674,7 +674,7 @@ def add_4d_kernel_l2(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_lau result = x.new_empty(x.size()) # src[test_loops.py:N]: for tile in hl.grid(x.size()): # src[test_loops.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_4d_kernel_l2, (8 * 16 * 32 * 64,), x, y, result, num_warps=4, num_stages=2) + _launcher(_helion_add_4d_kernel_l2, (8 * 16 * 32 * 64,), x, y, result, num_warps=4, num_stages=1) # src[test_loops.py:N]: return result return result @@ -716,7 +716,7 @@ def add_3d_kernel_reordered(x: torch.Tensor, y: torch.Tensor, *, _launcher=_defa result = x.new_empty(x.size()) # src[test_loops.py:N]: for tile in hl.grid(x.size()): # src[test_loops.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_3d_kernel_reordered, (32 * 16 * 8,), x, y, result, num_warps=4, num_stages=2) + _launcher(_helion_add_3d_kernel_reordered, (32 * 16 * 8,), x, y, result, num_warps=4, num_stages=1) # src[test_loops.py:N]: return result return result @@ -756,7 +756,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_loops.py:N]: for tile0, tile1 in hl.tile(x.size(), block_size=[bs0, bs1]): # src[test_loops.py:N]: out[tile0, tile1] = x[tile0, tile1] + 1 - _launcher(_helion_fn, (triton.cdiv(2048, _BLOCK_SIZE_0) * triton.cdiv(2048, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(2048, _BLOCK_SIZE_0) * triton.cdiv(2048, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -788,7 +788,7 @@ def fn(x: torch.Tensor, block_size: int, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = block_size # src[test_loops.py:N]: for tile_a in hl.tile(a, block_size=block_size): # src[test_loops.py:N]: out[tile_a] = torch.sin(x[tile_a]) - _launcher(_helion_fn, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -831,7 +831,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: for tile_a, tile_b in hl.tile([a, b], block_size=[4, 8]): # src[test_loops.py:N]: for tile_c in hl.tile(c, block_size=16): # src[test_loops.py:N]: out[tile_a, tile_b, tile_c] = torch.sin(x[tile_a, tile_b, tile_c]) - _launcher(_helion_fn, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -875,7 +875,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: out[tile] = x[tile] # src[test_loops.py:N]: for i in [1, 2, 3]: # src[test_loops.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(4, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(4, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -919,7 +919,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: out[tile] = x[tile] # src[test_loops.py:N]: for i in (a, b, c): # src[test_loops.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(4, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(4, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -986,7 +986,7 @@ def addToBoth(a, b, c, *, _launcher=_default_launcher): _BLOCK_SIZE_2 = 8 # src[test_loops.py:N]: for tile in hl.tile(x2.size()): # src[test_loops.py:N]: x2[tile] += c2 - _launcher(_helion_addToBoth, (triton.cdiv(5, _BLOCK_SIZE_0) + triton.cdiv(5, _BLOCK_SIZE_1) + triton.cdiv(5, _BLOCK_SIZE_2),), x0, x1, x2, c0, c1, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_addToBoth, (triton.cdiv(5, _BLOCK_SIZE_0) + triton.cdiv(5, _BLOCK_SIZE_1) + triton.cdiv(5, _BLOCK_SIZE_2),), x0, x1, x2, c0, c1, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_loops.py:N]: return x0, x1, x2 return (x0, x1, x2) @@ -1085,7 +1085,7 @@ def addToBoth(a, b, c, *, _launcher=_default_launcher): # src[test_loops.py:N]: for tile_n in hl.tile(c_n): # src[test_loops.py:N]: for tile_m in hl.tile(c_m): # src[test_loops.py:N]: x2[tile_n, tile_m] += c2 - _launcher(_helion_addToBoth, (triton.cdiv(5, _BLOCK_SIZE_0) + triton.cdiv(5, _BLOCK_SIZE_2) + triton.cdiv(5, _BLOCK_SIZE_4),), x0, x1, x2, c0, c1, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=2) + _launcher(_helion_addToBoth, (triton.cdiv(5, _BLOCK_SIZE_0) + triton.cdiv(5, _BLOCK_SIZE_2) + triton.cdiv(5, _BLOCK_SIZE_4),), x0, x1, x2, c0, c1, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=1) # src[test_loops.py:N]: return x0, x1, x2 return (x0, x1, x2) @@ -1176,7 +1176,7 @@ def addToBoth(a, b, c, *, _launcher=_default_launcher): _BLOCK_SIZE_5 = 16 # src[test_loops.py:N]: for tile_n, tile_m in hl.tile([c_n, c_m]): # src[test_loops.py:N]: x2[tile_n, tile_m] += c2 - _launcher(_helion_addToBoth, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(10, _BLOCK_SIZE_1) + triton.cdiv(5, _BLOCK_SIZE_2) * triton.cdiv(10, _BLOCK_SIZE_3) + triton.cdiv(5, _BLOCK_SIZE_4) * triton.cdiv(10, _BLOCK_SIZE_5),), x0, x1, x2, c0, c1, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=2) + _launcher(_helion_addToBoth, (triton.cdiv(5, _BLOCK_SIZE_0) * triton.cdiv(10, _BLOCK_SIZE_1) + triton.cdiv(5, _BLOCK_SIZE_2) * triton.cdiv(10, _BLOCK_SIZE_3) + triton.cdiv(5, _BLOCK_SIZE_4) * triton.cdiv(10, _BLOCK_SIZE_5),), x0, x1, x2, c0, c1, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=1) # src[test_loops.py:N]: return x0, x1, x2 return (x0, x1, x2) @@ -1276,7 +1276,7 @@ def nested_loop_accumulator(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: # Initialize accumulator for this batch # src[test_loops.py:N]: acc = hl.zeros([tile_b], dtype=torch.float32) # src[test_loops.py:N-N]: ... - _launcher(_helion_nested_loop_accumulator, (2,), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=2) + _launcher(_helion_nested_loop_accumulator, (2,), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -1318,7 +1318,7 @@ def pointwise_device_loop(x: torch.Tensor, *, _launcher=_default_launcher): # src[basic_kernels.py:N]: for tile_n in hl.tile(n): # src[basic_kernels.py:N]: for tile_m in hl.tile(m): # src[basic_kernels.py:N]: out[tile_n, tile_m] = torch.sigmoid(x[tile_n, tile_m] + 1) - _launcher(_helion_pointwise_device_loop, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_pointwise_device_loop, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -1360,7 +1360,7 @@ def nested_loop_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: # Inner loop becomes device loop with tl.range # src[test_loops.py:N]: for tile_inner in hl.tile(x.size(1)): # src[test_loops.py:N-N]: ... - _launcher(_helion_nested_loop_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_nested_loop_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -1430,7 +1430,7 @@ def kernel_fixed_block_size(y_pred: torch.Tensor, y_true: torch.Tensor, *, _laun # src[test_loops.py:N]: for tile_v in hl.tile(V_local, block_size=block_size_n): # src[test_loops.py:N-N]: ... _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_0) - _launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _RDIM_SIZE_3, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _RDIM_SIZE_3, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_loops.py:N]: return torch.sum(loss) / BT return torch.sum(loss) / BT @@ -1464,7 +1464,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 64 # src[test_loops.py:N]: for tile0, tile1 in hl.tile(x.size(), block_size=[bs0, bs1]): # src[test_loops.py:N]: out[tile0, tile1] = x[tile0, tile1] + 1 - _launcher(_helion_fn, (triton.cdiv(2048, _BLOCK_SIZE_1) * triton.cdiv(2048, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(2048, _BLOCK_SIZE_1) * triton.cdiv(2048, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -1528,7 +1528,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: for tile_n in hl.tile(n): # src[test_loops.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_loops.py:N-N]: ... - _launcher(_helion_matmul, (triton.cdiv(256, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul, (triton.cdiv(256, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -1628,7 +1628,7 @@ def three_pass_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: # Pass 1: Compute sum # src[test_loops.py:N]: sum_val = hl.zeros([tile_b], dtype=torch.float32) # src[test_loops.py:N-N]: ... - _launcher(_helion_three_pass_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_three_pass_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out @@ -1695,6 +1695,6 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_loops.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_loops.py:N]: for tile_k in hl.tile(k): # src[test_loops.py:N-N]: ... - _launcher(_helion_matmul, (_NUM_SM,), x, y, out, _NUM_SM, _BLOCK_SIZE_1, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul, (_NUM_SM,), x, y, out, _NUM_SM, _BLOCK_SIZE_1, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_loops.py:N]: return out return out diff --git a/test/test_masking.expected b/test/test_masking.expected index e82c2fab0..15f7c473c 100644 --- a/test/test_masking.expected +++ b/test/test_masking.expected @@ -55,7 +55,7 @@ def fn(x, *, _launcher=_default_launcher): # src[test_masking.py:N]: acc = hl.zeros([tile_m, block_n]) # src[test_masking.py:N]: for _ in hl.tile(n, block_size=block_n): # src[test_masking.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(100, _BLOCK_SIZE_1),), out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(100, _BLOCK_SIZE_1),), out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_masking.py:N]: return out return out @@ -118,7 +118,7 @@ def add1mm(x, y, *, _launcher=_default_launcher): # src[test_masking.py:N]: acc = hl.zeros([tile_m, tile_n]) # src[test_masking.py:N]: for tile_k in hl.tile(k): # src[test_masking.py:N-N]: ... - _launcher(_helion_add1mm, (triton.cdiv(100, _BLOCK_SIZE_0) * triton.cdiv(100, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_add1mm, (triton.cdiv(100, _BLOCK_SIZE_0) * triton.cdiv(100, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_masking.py:N]: return out return out @@ -163,6 +163,6 @@ def fn(x, *, _launcher=_default_launcher): # src[test_masking.py:N]: acc = hl.zeros([tile_m, block_size_n]) # src[test_masking.py:N]: for tile_n in hl.tile(0, n, block_size_n): # src[test_masking.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(100, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(100, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_masking.py:N]: return out return out diff --git a/test/test_matmul.expected b/test/test_matmul.expected index 982de1658..0284e8237 100644 --- a/test/test_matmul.expected +++ b/test/test_matmul.expected @@ -60,7 +60,7 @@ def matmul_without_addmm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_matmul.py:N]: for tile_k in hl.tile(k): # src[test_matmul.py:N-N]: ... - _launcher(_helion_matmul_without_addmm, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_without_addmm, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_matmul.py:N]: return out return out @@ -128,7 +128,7 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[matmul.py:N]: for tile_k in hl.tile(k): # src[matmul.py:N-N]: ... - _launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_1) * triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_1) * triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[matmul.py:N]: return out return out @@ -190,7 +190,7 @@ def matmul_with_addmm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_matmul.py:N]: for tile_k in hl.tile(k): # src[test_matmul.py:N-N]: ... - _launcher(_helion_matmul_with_addmm, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_with_addmm, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_matmul.py:N]: return out return out @@ -261,7 +261,7 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[matmul.py:N]: for tile_k in hl.tile(k): # src[matmul.py:N-N]: ... - _launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[matmul.py:N]: return out return out @@ -344,7 +344,7 @@ def matmul_with_packed_b(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, *, _ # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([M, N]): # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=A.dtype) # src[test_matmul.py:N-N]: ... - _launcher(_helion_matmul_with_packed_b, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_matmul_with_packed_b, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=1) --- assertExpectedJournal(TestMatmul.test_matmul_split_k) from __future__ import annotations @@ -403,7 +403,7 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launc # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_matmul.py:N]: for inner_k in hl.tile(outer_k.begin, outer_k.end): # src[test_matmul.py:N-N]: ... - _launcher(_helion_matmul_split_k, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(32, _BLOCK_SIZE_1) * triton.cdiv(2000, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_matmul_split_k, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(32, _BLOCK_SIZE_1) * triton.cdiv(2000, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[test_matmul.py:N]: return out return out @@ -468,7 +468,7 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_matmul.py:N]: for tile_k in hl.tile(k): # src[test_matmul.py:N-N]: ... - _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_matmul.py:N]: return out return out @@ -533,7 +533,7 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_matmul.py:N]: for tile_k in hl.tile(k): # src[test_matmul.py:N-N]: ... - _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_matmul.py:N]: return out return out @@ -599,7 +599,7 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_matmul.py:N]: for tile_k in hl.tile(k): # src[test_matmul.py:N-N]: ... - _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_matmul.py:N]: return out return out @@ -666,7 +666,7 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_matmul.py:N]: for tile_k in hl.tile(k): # src[test_matmul.py:N-N]: ... - _launcher(_helion_matmul_static_shapes, (triton.cdiv(127, _BLOCK_SIZE_0) * triton.cdiv(127, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_static_shapes, (triton.cdiv(127, _BLOCK_SIZE_0) * triton.cdiv(127, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_matmul.py:N]: return out return out @@ -749,6 +749,6 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[matmul.py:N]: for tile_k in hl.tile(k): # src[matmul.py:N-N]: ... - _launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[matmul.py:N]: return out return out diff --git a/test/test_misc.expected b/test/test_misc.expected index 7731a571e..281159b24 100644 --- a/test/test_misc.expected +++ b/test/test_misc.expected @@ -52,7 +52,7 @@ def kernel(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass, *, _launcher=_de # src[test_misc.py:N]: for tile in hl.tile(a0.size()): # src[test_misc.py:N]: o0[tile] = a0[tile] + b0[tile] + c0[tile] + d0[tile] # src[test_misc.py:N]: o1[tile] = a1[tile] + b1[tile] + c1[tile] + d1[tile] - _launcher(_helion_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), a0, b0, c0, d0, o0, a1, b1, c1, d1, o1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel, (triton.cdiv(4, _BLOCK_SIZE_0),), a0, b0, c0, d0, o0, a1, b1, c1, d1, o1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return [o0, o1] return [o0, o1] @@ -83,7 +83,7 @@ def copy_kernel(a: torch.Tensor, *, _launcher=_default_launcher): # src[test_misc.py:N]: t1 = tile # src[test_misc.py:N]: t2 = tile # src[test_misc.py:N-N]: ... - _launcher(_helion_copy_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), a, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_copy_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), a, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out return out @@ -116,7 +116,7 @@ def kernel_with_scalar_item(x: torch.Tensor, scalar_tensor: torch.Tensor, *, _la _BLOCK_SIZE_0 = 128 # src[test_misc.py:N]: for tile in hl.tile(x.shape): # src[test_misc.py:N]: result[tile] = x[tile] + scalar_val - _launcher(_helion_kernel_with_scalar_item, (triton.cdiv(100, _BLOCK_SIZE_0),), x, result, scalar_val, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_scalar_item, (triton.cdiv(100, _BLOCK_SIZE_0),), x, result, scalar_val, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return result return result @@ -151,7 +151,7 @@ def kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 16 # src[test_misc.py:N]: for tile in hl.tile(a.size()): # src[test_misc.py:N]: out[tile] = a[tile] + b[tile] - _launcher(_helion_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * 1,), a, out, a.size(0), a.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * 1,), a, out, a.size(0), a.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out return out @@ -185,7 +185,7 @@ def kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 16 # src[test_misc.py:N]: for tile in hl.tile(a.size()): # src[test_misc.py:N]: out[tile] = a[tile] + b[tile] - _launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0) * 1,), a, b, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0) * 1,), a, b, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out return out @@ -237,7 +237,7 @@ def test_tile_block_size_usage(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_misc.py:N]: # This should not cause a compilation error when tile.block_size is used # src[test_misc.py:N]: # in expressions that generate .to() calls # src[test_misc.py:N-N]: ... - _launcher(_helion_test_tile_block_size_usage, (triton.cdiv(32, _BLOCK_SIZE_0),), out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_test_tile_block_size_usage, (triton.cdiv(32, _BLOCK_SIZE_0),), out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out return out @@ -284,7 +284,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_misc.py:N]: acc = x.new_zeros([tile_m, block_size_n]) # src[test_misc.py:N]: for tile_n in hl.tile(n, block_size=block_size_n): # src[test_misc.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out return out @@ -331,7 +331,7 @@ def add(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 16 # src[add.py:N]: for tile in hl.tile(out.size()): # src[add.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(16, _BLOCK_SIZE_0) * 1,), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_add, (triton.cdiv(16, _BLOCK_SIZE_0) * 1,), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[add.py:N]: return out return out @@ -383,7 +383,7 @@ def kernel(t: torch.Tensor, i: int, s: str, b: bool, f: float, zero_dim_t: torch # src[test_misc.py:N]: for tile in hl.tile(t.size()): # src[test_misc.py:N]: if b and len(s) > 2: # src[test_misc.py:N]: out[tile] = t[tile] + i + f - _launcher(_helion_kernel, (triton.cdiv(t.size(0), _BLOCK_SIZE_0) * 1,), t, out, t.size(0), out.stride(0), t.stride(0), b, i, f, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel, (triton.cdiv(t.size(0), _BLOCK_SIZE_0) * 1,), t, out, t.size(0), out.stride(0), t.stride(0), b, i, f, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out, zero_dim_t return (out, zero_dim_t) @@ -438,7 +438,7 @@ def kernel(t: torch.Tensor, i: int, s: str, b: bool, f: float, zero_dim_t: torch # src[test_misc.py:N]: for tile in hl.tile(t.size()): # src[test_misc.py:N]: if b and len(s) > 2: # src[test_misc.py:N]: out[tile] = t[tile] + i + f - _launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0) * 1,), t, out, b, i, f, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0) * 1,), t, out, b, i, f, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out, zero_dim_t return (out, zero_dim_t) @@ -494,7 +494,7 @@ def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 8 # src[test_misc.py:N]: for tile in hl.tile(out.size()): # src[test_misc.py:N]: out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2] - _launcher(_helion_tuple_literal_index_kernel, (triton.cdiv(8, _BLOCK_SIZE_0) * triton.cdiv(30, _BLOCK_SIZE_1),), inp_tuple[0], inp_tuple[1], out, inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_tuple_literal_index_kernel, (triton.cdiv(8, _BLOCK_SIZE_0) * triton.cdiv(30, _BLOCK_SIZE_1),), inp_tuple[0], inp_tuple[1], out, inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out return outfrom __future__ import annotations @@ -528,7 +528,7 @@ def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 8 # src[test_misc.py:N]: for tile in hl.tile(out.size()): # src[test_misc.py:N]: out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2] - _launcher(_helion_tuple_literal_index_kernel, (triton.cdiv(8, _BLOCK_SIZE_0) * triton.cdiv(30, _BLOCK_SIZE_1),), inp_tuple[0], inp_tuple[1], out, inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_tuple_literal_index_kernel, (triton.cdiv(8, _BLOCK_SIZE_0) * triton.cdiv(30, _BLOCK_SIZE_1),), inp_tuple[0], inp_tuple[1], out, inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out return out @@ -577,7 +577,7 @@ def tuple_literal_index_kernel(inp_tuple, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 8 # src[test_misc.py:N]: for tile in hl.tile(out.size()): # src[test_misc.py:N]: out[tile] = (inp_tuple[0][tile] + inp_tuple[1][tile]) * inp_tuple[2] - _launcher(_helion_tuple_literal_index_kernel, (triton.cdiv(8, _BLOCK_SIZE_0) * triton.cdiv(30, _BLOCK_SIZE_1),), inp_tuple[0], inp_tuple[1], out, inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_tuple_literal_index_kernel, (triton.cdiv(8, _BLOCK_SIZE_0) * triton.cdiv(30, _BLOCK_SIZE_1),), inp_tuple[0], inp_tuple[1], out, inp_tuple[2], _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out return out @@ -613,6 +613,6 @@ def tuple_unpack_kernel(inp_tuple, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 4 # src[test_misc.py:N]: for tile in hl.tile(out.size(0)): # src[test_misc.py:N]: out[tile] = a[tile] + b[tile] + x - _launcher(_helion_tuple_unpack_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), a, b, out, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_tuple_unpack_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), a, b, out, x, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_misc.py:N]: return out return out diff --git a/test/test_persistent_kernels.expected b/test/test_persistent_kernels.expected index cb689304c..efb27a4dd 100644 --- a/test/test_persistent_kernels.expected +++ b/test/test_persistent_kernels.expected @@ -70,7 +70,7 @@ def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _BLOCK_SIZE_3 = 8 # src[test_persistent_kernels.py:N]: for tile2 in hl.tile(y.size(), block_size=[16, 8]): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] * 3 - _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -140,7 +140,7 @@ def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _BLOCK_SIZE_3 = 8 # src[test_persistent_kernels.py:N]: for tile2 in hl.tile(y.size(), block_size=[16, 8]): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] * 3 - _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -180,7 +180,7 @@ def vector_add_1d(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch _BLOCK_SIZE_0 = 128 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[128]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_vector_add_1d, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_vector_add_1d, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -217,7 +217,7 @@ def vector_add_1d(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch _BLOCK_SIZE_0 = 128 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[128]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_vector_add_1d, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_vector_add_1d, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -248,7 +248,7 @@ def vector_add_1d(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch _BLOCK_SIZE_0 = 128 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[128]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_vector_add_1d, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, y, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_vector_add_1d, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, y, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -292,7 +292,7 @@ def add_3d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile in hl.grid(x.size()): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_3d_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_add_3d_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -326,7 +326,7 @@ def add_3d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch result = x.new_empty(x.size()) # src[test_persistent_kernels.py:N]: for tile in hl.grid(x.size()): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_3d_kernel, (32 * 64 * 48,), x, y, result, num_warps=4, num_stages=2) + _launcher(_helion_add_3d_kernel, (32 * 64 * 48,), x, y, result, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -394,7 +394,7 @@ def matmul_kernel(A: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launch # src[test_persistent_kernels.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_persistent_kernels.py:N]: for tile_k in hl.tile(K): # src[test_persistent_kernels.py:N-N]: ... - _launcher(_helion_matmul_kernel, (_NUM_SM,), A, B, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_kernel, (_NUM_SM,), A, B, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -451,7 +451,7 @@ def matmul_kernel(A: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launch # src[test_persistent_kernels.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_persistent_kernels.py:N]: for tile_k in hl.tile(K): # src[test_persistent_kernels.py:N-N]: ... - _launcher(_helion_matmul_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(96, _BLOCK_SIZE_1),), A, B, result, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(96, _BLOCK_SIZE_1),), A, B, result, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -492,7 +492,7 @@ def add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher) _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile in hl.grid(x.size()): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_add_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -539,7 +539,7 @@ def add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher) _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile in hl.grid(x.size()): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_add_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -576,7 +576,7 @@ def add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher) result = x.new_empty(x.size()) # src[test_persistent_kernels.py:N]: for tile in hl.grid(x.size()): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_kernel, (64 * 128,), x, y, result, num_warps=4, num_stages=2) + _launcher(_helion_add_kernel, (64 * 128,), x, y, result, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -617,7 +617,7 @@ def add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher) _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile in hl.grid(x.size()): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_add_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -653,7 +653,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(96, _BLOCK_SIZE_1),), x, result, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(96, _BLOCK_SIZE_1),), x, result, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -698,7 +698,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -740,7 +740,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -781,7 +781,7 @@ def add_3d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile in hl.grid(x.size()): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_3d_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_add_3d_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -815,7 +815,7 @@ def add_3d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch result = x.new_empty(x.size()) # src[test_persistent_kernels.py:N]: for tile in hl.grid(x.size()): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_3d_kernel, (32 * 64 * 48,), x, y, result, num_warps=4, num_stages=2) + _launcher(_helion_add_3d_kernel, (32 * 64 * 48,), x, y, result, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -880,7 +880,7 @@ def matmul_kernel(A: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launch # src[test_persistent_kernels.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_persistent_kernels.py:N]: for tile_k in hl.tile(K): # src[test_persistent_kernels.py:N-N]: ... - _launcher(_helion_matmul_kernel, (_NUM_SM,), A, B, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_kernel, (_NUM_SM,), A, B, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -937,7 +937,7 @@ def matmul_kernel(A: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launch # src[test_persistent_kernels.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) # src[test_persistent_kernels.py:N]: for tile_k in hl.tile(K): # src[test_persistent_kernels.py:N-N]: ... - _launcher(_helion_matmul_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(96, _BLOCK_SIZE_1),), A, B, result, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_matmul_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(96, _BLOCK_SIZE_1),), A, B, result, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -1041,7 +1041,7 @@ def multi_loop_l2_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default _BLOCK_SIZE_5 = 16 # src[test_persistent_kernels.py:N]: for tile3 in hl.tile(y.size(), block_size=[16, 16]): # src[test_persistent_kernels.py:N]: result3[tile3] = y[tile3] + 2.0 - _launcher(_helion_multi_loop_l2_kernel, (_NUM_SM,), x, result1, y, result2, result3, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=2) + _launcher(_helion_multi_loop_l2_kernel, (_NUM_SM,), x, result1, y, result2, result3, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2, result3 return (result1, result2, result3) @@ -1127,7 +1127,7 @@ def multi_loop_l2_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default _BLOCK_SIZE_5 = 16 # src[test_persistent_kernels.py:N]: for tile3 in hl.tile(y.size(), block_size=[16, 16]): # src[test_persistent_kernels.py:N]: result3[tile3] = y[tile3] + 2.0 - _launcher(_helion_multi_loop_l2_kernel, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) + triton.cdiv(32, _BLOCK_SIZE_2) * triton.cdiv(64, _BLOCK_SIZE_3) + triton.cdiv(32, _BLOCK_SIZE_4) * triton.cdiv(64, _BLOCK_SIZE_5),), x, result1, y, result2, result3, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=2) + _launcher(_helion_multi_loop_l2_kernel, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) + triton.cdiv(32, _BLOCK_SIZE_2) * triton.cdiv(64, _BLOCK_SIZE_3) + triton.cdiv(32, _BLOCK_SIZE_4) * triton.cdiv(64, _BLOCK_SIZE_5),), x, result1, y, result2, result3, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2, result3 return (result1, result2, result3) @@ -1193,7 +1193,7 @@ def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _BLOCK_SIZE_3 = 16 # src[test_persistent_kernels.py:N]: for tile2 in hl.tile(y.size(), block_size=[16, 16]): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] + 1.0 - _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -1253,7 +1253,7 @@ def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _BLOCK_SIZE_3 = 16 # src[test_persistent_kernels.py:N]: for tile2 in hl.tile(y.size(), block_size=[16, 16]): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] + 1.0 - _launcher(_helion_multi_loop_kernel, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) + triton.cdiv(32, _BLOCK_SIZE_2) * triton.cdiv(64, _BLOCK_SIZE_3),), x, result1, y, result2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_multi_loop_kernel, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) + triton.cdiv(32, _BLOCK_SIZE_2) * triton.cdiv(64, _BLOCK_SIZE_3),), x, result1, y, result2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -1291,7 +1291,7 @@ def add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher) _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile in hl.grid(x.size()): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_add_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_add_kernel, (_NUM_SM,), x, y, result, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -1339,7 +1339,7 @@ def single_loop_l2_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[16, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] * 2.0 - _launcher(_helion_single_loop_l2_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_single_loop_l2_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -1381,7 +1381,7 @@ def single_loop_l2_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[16, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] * 2.0 - _launcher(_helion_single_loop_l2_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, result, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_single_loop_l2_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, result, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -1441,7 +1441,7 @@ def complex_shared_kernel(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile2 in hl.grid(y.size()): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] * z[tile2] - _launcher(_helion_complex_shared_kernel, (_NUM_SM,), x, y, result1, z, result2, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_complex_shared_kernel, (_NUM_SM,), x, y, result1, z, result2, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -1498,7 +1498,7 @@ def complex_shared_kernel(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile2 in hl.grid(y.size()): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] * z[tile2] - _launcher(_helion_complex_shared_kernel, (_NUM_SM,), x, y, result1, z, result2, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_complex_shared_kernel, (_NUM_SM,), x, y, result1, z, result2, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -1548,7 +1548,7 @@ def complex_shared_kernel(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, result2 = y.new_empty(y.size()) # src[test_persistent_kernels.py:N]: for tile2 in hl.grid(y.size()): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] * z[tile2] - _launcher(_helion_complex_shared_kernel, (6 * 8 + 6 * 8,), x, y, result1, z, result2, num_warps=4, num_stages=2) + _launcher(_helion_complex_shared_kernel, (6 * 8 + 6 * 8,), x, y, result1, z, result2, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -1595,7 +1595,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -1639,7 +1639,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -1686,7 +1686,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -1730,7 +1730,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -1777,7 +1777,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -1837,7 +1837,7 @@ def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile2 in hl.grid(y.size()): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] * 3 - _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -1894,7 +1894,7 @@ def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile2 in hl.grid(y.size()): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] * 3 - _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_multi_loop_kernel, (_NUM_SM,), x, result1, y, result2, _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -1944,7 +1944,7 @@ def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_la result2 = y.new_empty(y.size()) # src[test_persistent_kernels.py:N]: for tile2 in hl.grid(y.size()): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] * 3 - _launcher(_helion_multi_loop_kernel, (8 * 12 + 8 * 12,), x, result1, y, result2, num_warps=4, num_stages=2) + _launcher(_helion_multi_loop_kernel, (8 * 12 + 8 * 12,), x, result1, y, result2, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -1997,7 +1997,7 @@ def tensor_descriptor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_def _BLOCK_SIZE_1 = 32 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 32]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_tensor_descriptor_kernel, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_tensor_descriptor_kernel, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -2047,7 +2047,7 @@ def tensor_descriptor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_def _BLOCK_SIZE_1 = 32 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 32]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_tensor_descriptor_kernel, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_tensor_descriptor_kernel, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -2137,7 +2137,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -2179,7 +2179,7 @@ def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + 1 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -2239,7 +2239,7 @@ def multi_add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_lau _NUM_SM = helion.runtime.get_num_sm(x.device) # src[test_persistent_kernels.py:N]: for tile2 in hl.grid(y.size()): # src[test_persistent_kernels.py:N]: result2[tile2] = y[tile2] * 2.0 - _launcher(_helion_multi_add_kernel, (_NUM_SM,), x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, num_warps=4, num_stages=2) + _launcher(_helion_multi_add_kernel, (_NUM_SM,), x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result1, result2 return (result1, result2) @@ -2286,7 +2286,7 @@ def simple_add(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher) _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_simple_add, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_simple_add, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result @@ -2330,6 +2330,6 @@ def simple_add(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher) _BLOCK_SIZE_1 = 16 # src[test_persistent_kernels.py:N]: for tile in hl.tile(x.size(), block_size=[32, 16]): # src[test_persistent_kernels.py:N]: result[tile] = x[tile] + y[tile] - _launcher(_helion_simple_add, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_simple_add, (_NUM_SM,), x, y, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_persistent_kernels.py:N]: return result return result diff --git a/test/test_random.expected b/test/test_random.expected index 0e52ce3c5..67f949c76 100644 --- a/test/test_random.expected +++ b/test/test_random.expected @@ -29,7 +29,7 @@ def rand_kernel_tiled_1d(x: torch.Tensor, seed: int, *, _launcher=_default_launc _BLOCK_SIZE_0 = 128 # src[test_random.py:N]: for tile_m in hl.tile(m): # src[test_random.py:N]: output[tile_m] = hl.rand([tile_m], seed=seed) - _launcher(_helion_rand_kernel_tiled_1d, (triton.cdiv(m, _BLOCK_SIZE_0),), output, output.stride(0), m, seed, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_rand_kernel_tiled_1d, (triton.cdiv(m, _BLOCK_SIZE_0),), output, output.stride(0), m, seed, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_random.py:N]: return output return output @@ -67,7 +67,7 @@ def rand_kernel_tiled_2d(x: torch.Tensor, seed: int, *, _launcher=_default_launc _BLOCK_SIZE_1 = 32 # src[test_random.py:N]: for tile_m, tile_n in hl.tile([m, n]): # src[test_random.py:N]: output[tile_m, tile_n] = hl.rand([tile_m, tile_n], seed=seed) - _launcher(_helion_rand_kernel_tiled_2d, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), output, output.stride(0), output.stride(1), m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_rand_kernel_tiled_2d, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), output, output.stride(0), output.stride(1), m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_random.py:N]: return output return output @@ -115,7 +115,7 @@ def rand_kernel_tiled_3d(x: torch.Tensor, seed: int, *, _launcher=_default_launc # src[test_random.py:N]: output[tile_b, tile_m, tile_n] = hl.rand( # src[test_random.py:N]: [tile_b, tile_m, tile_n], seed=seed # src[test_random.py:N-N]: ... - _launcher(_helion_rand_kernel_tiled_3d, (triton.cdiv(b, _BLOCK_SIZE_0) * triton.cdiv(m, _BLOCK_SIZE_1) * triton.cdiv(n, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), b, m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_rand_kernel_tiled_3d, (triton.cdiv(b, _BLOCK_SIZE_0) * triton.cdiv(m, _BLOCK_SIZE_1) * triton.cdiv(n, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), b, m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_random.py:N]: return output return output @@ -163,7 +163,7 @@ def rand_kernel_normal_order(x: torch.Tensor, seed: int, *, _launcher=_default_l # src[test_random.py:N]: output[tile_m, tile_n, tile_k] = hl.rand( # src[test_random.py:N]: [tile_m, tile_n, tile_k], seed=seed # src[test_random.py:N-N]: ... - _launcher(_helion_rand_kernel_normal_order, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1) * triton.cdiv(k, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), m, n, k, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_rand_kernel_normal_order, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1) * triton.cdiv(k, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), m, n, k, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_random.py:N]: return output return output @@ -211,7 +211,7 @@ def rand_kernel_mixed_order(x: torch.Tensor, seed: int, *, _launcher=_default_la # src[test_random.py:N]: output[tile_m, tile_n, tile_k] = hl.rand( # src[test_random.py:N]: [tile_m, tile_n, tile_k], seed=seed # src[test_random.py:N-N]: ... - _launcher(_helion_rand_kernel_mixed_order, (triton.cdiv(k, _BLOCK_SIZE_0) * triton.cdiv(m, _BLOCK_SIZE_1) * triton.cdiv(n, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), k, m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_rand_kernel_mixed_order, (triton.cdiv(k, _BLOCK_SIZE_0) * triton.cdiv(m, _BLOCK_SIZE_1) * triton.cdiv(n, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), k, m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) # src[test_random.py:N]: return output return output @@ -251,7 +251,7 @@ def rand_kernel_partial_tile(x: torch.Tensor, seed: int, *, _launcher=_default_l _RDIM_SIZE_2 = 8 # src[test_random.py:N]: for tile_m, tile_n in hl.tile([m, n]): # src[test_random.py:N]: output[tile_m, tile_n, :] = hl.rand([tile_m, tile_n, k], seed=seed) - _launcher(_helion_rand_kernel_partial_tile, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), output, output.stride(0), output.stride(1), output.stride(2), m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_rand_kernel_partial_tile, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), output, output.stride(0), output.stride(1), output.stride(2), m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=1) # src[test_random.py:N]: return output return output @@ -295,7 +295,7 @@ def rand_kernel_with_reduction(x: torch.Tensor, seed: int, *, _launcher=_default # src[test_random.py:N]: tile_values = x[tile_m, :] # src[test_random.py:N]: rand_values = hl.rand([tile_m], seed=seed) # src[test_random.py:N-N]: ... - _launcher(_helion_rand_kernel_with_reduction, (triton.cdiv(m, _BLOCK_SIZE_0),), x, output, output.stride(0), x.stride(0), x.stride(1), m, n, seed, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_rand_kernel_with_reduction, (triton.cdiv(m, _BLOCK_SIZE_0),), x, output, output.stride(0), x.stride(0), x.stride(1), m, n, seed, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_random.py:N]: return output return output @@ -345,6 +345,6 @@ def rand_kernel_with_reduction(x: torch.Tensor, seed: int, *, _launcher=_default # src[test_random.py:N]: tile_values = x[tile_m, :] # src[test_random.py:N]: rand_values = hl.rand([tile_m], seed=seed) # src[test_random.py:N-N]: ... - _launcher(_helion_rand_kernel_with_reduction, (triton.cdiv(m, _BLOCK_SIZE_0),), x, output, output.stride(0), x.stride(0), x.stride(1), m, seed, n, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) + _launcher(_helion_rand_kernel_with_reduction, (triton.cdiv(m, _BLOCK_SIZE_0),), x, output, output.stride(0), x.stride(0), x.stride(1), m, seed, n, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=1) # src[test_random.py:N]: return output return output diff --git a/test/test_reduce.expected b/test/test_reduce.expected index bc9a26556..d73f9d65c 100644 --- a/test/test_reduce.expected +++ b/test/test_reduce.expected @@ -59,7 +59,7 @@ def test_argmax_negative_kernel(values: torch.Tensor, indices: torch.Tensor, *, # src[test_reduce.py:N]: row_values = values[i, :] # Shape: [TILE_SIZE, seq_len] # src[test_reduce.py:N]: row_indices = indices[i, :] # Shape: [TILE_SIZE, seq_len] # src[test_reduce.py:N-N]: ... - _launcher(_helion_test_argmax_negative_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), values, indices, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_argmax_negative_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), values, indices, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -118,7 +118,7 @@ def test_argmax_unpacked_kernel(values: torch.Tensor, indices: torch.Tensor, *, # src[test_reduce.py:N]: row_values = values[i, :] # src[test_reduce.py:N]: row_indices = indices[i, :] # src[test_reduce.py:N-N]: ... - _launcher(_helion_test_argmax_unpacked_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), values, indices, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_argmax_unpacked_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), values, indices, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -165,7 +165,7 @@ def test_reduce_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_reduce.py:N]: for i in hl.tile(x.size(0)): # src[test_reduce.py:N]: row_data = x[i, :] # Shape: [TILE_SIZE, seq_len] # src[test_reduce.py:N]: result[i] = hl.reduce(add_combine_fn, row_data, dim=1) - _launcher(_helion_test_reduce_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -208,7 +208,7 @@ def test_reduce_codegen_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_reduce.py:N]: for i in hl.tile(x.size(0)): # src[test_reduce.py:N]: row_data = x[i, :] # src[test_reduce.py:N]: result[i] = hl.reduce(add_combine_fn, row_data, dim=1) - _launcher(_helion_test_reduce_codegen_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_codegen_kernel, (1,), x, result, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -254,7 +254,7 @@ def test_reduce_int_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_reduce.py:N]: for i in hl.tile(x.size(0)): # src[test_reduce.py:N]: row_data = x[i, :] # src[test_reduce.py:N]: result[i] = hl.reduce(add_combine_fn, row_data, dim=1) - _launcher(_helion_test_reduce_int_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_int_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -300,7 +300,7 @@ def test_reduce_jit_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_reduce.py:N]: for i in hl.tile(x.size(0)): # src[test_reduce.py:N]: row_data = x[i, :] # src[test_reduce.py:N]: result[i] = hl.reduce(jit_add_combine_fn, row_data, dim=1) - _launcher(_helion_test_reduce_jit_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_jit_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -348,7 +348,7 @@ def test_reduce_max_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_reduce.py:N]: for i in hl.tile(x.size(0)): # src[test_reduce.py:N]: row_data = x[i, :] # src[test_reduce.py:N]: result[i] = hl.reduce(max_combine_fn, row_data, dim=1) - _launcher(_helion_test_reduce_max_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_max_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -396,7 +396,7 @@ def test_reduce_min_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_reduce.py:N]: for i in hl.tile(x.size(0)): # src[test_reduce.py:N]: row_data = x[i, :] # src[test_reduce.py:N]: result[i] = hl.reduce(min_combine_fn, row_data, dim=1) - _launcher(_helion_test_reduce_min_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_min_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -445,7 +445,7 @@ def test_reduce_product_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_reduce.py:N]: for i in hl.tile(x.size(0)): # src[test_reduce.py:N]: row_data = x[i, :] # src[test_reduce.py:N]: result[i] = hl.reduce(mul_combine_fn, row_data, dim=1, other=1.0) - _launcher(_helion_test_reduce_product_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_product_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -501,7 +501,7 @@ def test_reduce_tuple_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_def # src[test_reduce.py:N]: row_x = x[i, :] # src[test_reduce.py:N]: row_y = y[i, :] # src[test_reduce.py:N-N]: ... - _launcher(_helion_test_reduce_tuple_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, y, result_x, result_y, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_tuple_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, y, result_x, result_y, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result_x, result_y return (result_x, result_y) @@ -561,7 +561,7 @@ def test_reduce_tuple_unpacked_kernel(x: torch.Tensor, y: torch.Tensor, *, _laun # src[test_reduce.py:N]: row_x = x[i, :] # src[test_reduce.py:N]: row_y = y[i, :] # src[test_reduce.py:N-N]: ... - _launcher(_helion_test_reduce_tuple_unpacked_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, y, result_x, result_y, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_tuple_unpacked_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, y, result_x, result_y, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result_x, result_y return (result_x, result_y) @@ -620,7 +620,7 @@ def test_tuple_oneline_kernel(values: torch.Tensor, indices: torch.Tensor, *, _l # src[test_reduce.py:N]: row_values = values[i, :] # Shape: [TILE_SIZE, seq_len] # src[test_reduce.py:N]: row_indices = indices[i, :] # Shape: [TILE_SIZE, seq_len] # src[test_reduce.py:N-N]: ... - _launcher(_helion_test_tuple_oneline_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), values, indices, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_tuple_oneline_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), values, indices, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -679,7 +679,7 @@ def test_tuple_twoline_kernel(values: torch.Tensor, indices: torch.Tensor, *, _l # src[test_reduce.py:N]: row_values = values[i, :] # Shape: [TILE_SIZE, seq_len] # src[test_reduce.py:N]: row_indices = indices[i, :] # Shape: [TILE_SIZE, seq_len] # src[test_reduce.py:N-N]: ... - _launcher(_helion_test_tuple_twoline_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), values, indices, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_tuple_twoline_kernel, (triton.cdiv(3, _BLOCK_SIZE_0),), values, indices, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result @@ -730,6 +730,6 @@ def test_reduce_keep_dims_kernel(x: torch.Tensor, *, _launcher=_default_launcher # src[test_reduce.py:N]: row_data = x[i, :] # src[test_reduce.py:N]: result[i, :] = hl.reduce( # src[test_reduce.py:N-N]: ... - _launcher(_helion_test_reduce_keep_dims_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_reduce_keep_dims_kernel, (triton.cdiv(2, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reduce.py:N]: return result return result diff --git a/test/test_reductions.expected b/test/test_reductions.expected index adf1c49dc..7d0107d6e 100644 --- a/test/test_reductions.expected +++ b/test/test_reductions.expected @@ -44,7 +44,7 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o _RDIM_SIZE_1 = 512 # src[test_reductions.py:N]: for tile_n in hl.tile(n): # src[test_reductions.py:N]: out[tile_n] = fn(x[tile_n, :], dim=-1) - _launcher(_helion_reduce_kernel, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out @@ -86,7 +86,7 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o _REDUCTION_BLOCK_1 = 16 # src[test_reductions.py:N]: for tile_n in hl.tile(n): # src[test_reductions.py:N]: out[tile_n] = fn(x[tile_n, :], dim=-1) - _launcher(_helion_reduce_kernel, (512,), x, out, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (512,), x, out, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out @@ -130,7 +130,7 @@ def _helion_layer_norm_fwd(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexp # src[test_reductions.py:N]: normalized = (acc - mean) * torch.rsqrt(var + eps) v_8 = acc - v_3 v_9 = v_7 + eps - v_10 = libdevice.rsqrt(v_9) + v_10 = tl.rsqrt(v_9) v_11 = v_8 * v_10 # src[test_reductions.py:N]: acc = normalized * (weight[:].to(torch.float32)) + ( load_1 = tl.load(weight + indices_1 * 1, None) @@ -159,7 +159,7 @@ def layer_norm_fwd(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep # src[test_reductions.py:N]: acc = x[tile_m, :].to(torch.float32) # src[test_reductions.py:N]: mean = hl.full([n], 0.0, acc.dtype) # src[test_reductions.py:N-N]: ... - _launcher(_helion_layer_norm_fwd, (triton.cdiv(2, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_layer_norm_fwd, (triton.cdiv(2, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out @@ -180,7 +180,7 @@ def _helion_rsqrt_fp16_kernel(x, result, _BLOCK_SIZE_0: tl.constexpr): # src[test_reductions.py:N]: result[tile] = torch.rsqrt(x[tile]) load = tl.load(x + indices_0 * 1, None) v_0 = tl.cast(load, tl.float32) - v_1 = libdevice.rsqrt(v_0) + v_1 = tl.rsqrt(v_0) v_2 = tl.cast(v_1, tl.float16) tl.store(result + indices_0 * 1, v_2, None) @@ -192,7 +192,7 @@ def rsqrt_fp16_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_reductions.py:N]: for tile in hl.tile(x.size(0)): # src[test_reductions.py:N]: # This should now work via fp32 fallback # src[test_reductions.py:N]: result[tile] = torch.rsqrt(x[tile]) - _launcher(_helion_rsqrt_fp16_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_rsqrt_fp16_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return result return result @@ -214,7 +214,7 @@ def _helion_multi_math_ops_fp16_kernel(x, result, _BLOCK_SIZE_0: tl.constexpr): # src[test_reductions.py:N]: result[tile, 0] = torch.rsqrt(x[tile]) load = tl.load(x + indices_0 * 1, None) v_0 = tl.cast(load, tl.float32) - v_1 = libdevice.rsqrt(v_0) + v_1 = tl.rsqrt(v_0) v_2 = tl.cast(v_1, tl.float16) tl.store(result + (indices_0 * 8 + 0 * 1), v_2, None) # src[test_reductions.py:N]: result[tile, 1] = torch.sqrt(x[tile]) @@ -269,7 +269,7 @@ def multi_math_ops_fp16_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_reductions.py:N]: # Test multiple operations that have confirmed fallbacks # src[test_reductions.py:N]: result[tile, 0] = torch.rsqrt(x[tile]) # src[test_reductions.py:N-N]: ... - _launcher(_helion_multi_math_ops_fp16_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_multi_math_ops_fp16_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return result return result @@ -307,7 +307,7 @@ def _helion_layer_norm_fwd_repro(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.co v_10 = x_part - v_9 v_11 = tl.cast(v_8, tl.float32) v_12 = v_11 + eps - v_13 = libdevice.rsqrt(v_12) + v_13 = tl.rsqrt(v_12) v_14 = tl.cast(v_10, tl.float32) v_15 = v_14 * v_13 # src[test_reductions.py:N]: out[tile_m, :] = normalized * (weight[:].to(torch.float32)) + ( @@ -338,7 +338,7 @@ def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tens # src[test_reductions.py:N]: x_part = x[tile_m, :] # src[test_reductions.py:N]: var, mean = torch.var_mean(x_part, dim=-1, keepdim=True, correction=0) # src[test_reductions.py:N-N]: ... - _launcher(_helion_layer_norm_fwd_repro, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_layer_norm_fwd_repro, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out @@ -389,7 +389,7 @@ def _helion_layer_norm_fwd_repro(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.co # src[test_reductions.py:N]: normalized = (x_part - mean) * torch.rsqrt(var.to(torch.float32) + eps) v_12 = tl.cast(v_10, tl.float32) v_13 = v_12 + eps - v_14 = libdevice.rsqrt(v_13) + v_14 = tl.rsqrt(v_13) # src[test_reductions.py:N]: x_part = x[tile_m, :] for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1): rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) @@ -429,7 +429,7 @@ def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tens # src[test_reductions.py:N]: x_part = x[tile_m, :] # src[test_reductions.py:N]: var, mean = torch.var_mean(x_part, dim=-1, keepdim=True, correction=0) # src[test_reductions.py:N-N]: ... - _launcher(_helion_layer_norm_fwd_repro, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) + _launcher(_helion_layer_norm_fwd_repro, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out @@ -639,7 +639,7 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o _RDIM_SIZE_1 = 512 # src[test_reductions.py:N]: for tile_n in hl.tile(n): # src[test_reductions.py:N]: out[tile_n] = fn(x[tile_n, :], dim=-1) - _launcher(_helion_reduce_kernel, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out @@ -687,7 +687,7 @@ def _helion_layer_norm_reduction(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.co v_9 = var_mean_extra_2 / v_8.to(tl.float32) # src[test_reductions.py:N]: normalized = (acc - mean) * torch.rsqrt(var + eps) v_10 = v_9 + eps - v_11 = libdevice.rsqrt(v_10) + v_11 = tl.rsqrt(v_10) # src[test_reductions.py:N]: acc = x[tile_m, :].to(torch.float32) for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1): rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) @@ -728,7 +728,7 @@ def layer_norm_reduction(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tens # src[test_reductions.py:N]: acc = x[tile_m, :].to(torch.float32) # src[test_reductions.py:N]: var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) # src[test_reductions.py:N-N]: ... - _launcher(_helion_layer_norm_reduction, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) + _launcher(_helion_layer_norm_reduction, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out @@ -764,7 +764,7 @@ def sum_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _RDIM_SIZE_1 = 512 # src[test_reductions.py:N]: for tile_n in hl.tile(n): # src[test_reductions.py:N]: out[tile_n] = x[tile_n, :].sum(-1) - _launcher(_helion_sum_kernel, (512,), x, out, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_sum_kernel, (512,), x, out, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out @@ -809,7 +809,7 @@ def sum_kernel_keepdims(x: torch.Tensor, *, _launcher=_default_launcher): _RDIM_SIZE_1 = 512 # src[test_reductions.py:N]: for tile_m in hl.tile(m): # src[test_reductions.py:N]: out[:, tile_m] = x[:, tile_m].sum(0, keepdim=True) - _launcher(_helion_sum_kernel_keepdims, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_sum_kernel_keepdims, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out @@ -851,6 +851,6 @@ def sum_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _REDUCTION_BLOCK_1 = 64 # src[test_reductions.py:N]: for tile_n in hl.tile(n): # src[test_reductions.py:N]: out[tile_n] = x[tile_n, :].sum(-1) - _launcher(_helion_sum_kernel, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) + _launcher(_helion_sum_kernel, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=1) # src[test_reductions.py:N]: return out return out diff --git a/test/test_register_tunable.expected b/test/test_register_tunable.expected index dca66f22b..fe76170b2 100644 --- a/test/test_register_tunable.expected +++ b/test/test_register_tunable.expected @@ -2,7 +2,7 @@ This file is automatically generated by assertExpectedJournal calls in test_regi Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. --- assertExpectedJournal(TestRegisterTunable.test_integer_fragment) -helion.Config(block_sizes=[128], indexing=['pointer', 'pointer'], load_eviction_policies=[''], multiplier=3, num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]) +helion.Config(block_sizes=[128], indexing=['pointer', 'pointer'], load_eviction_policies=[''], multiplier=3, num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]) --- assertExpectedJournal(TestRegisterTunable.test_integer_fragment) from __future__ import annotations @@ -37,7 +37,7 @@ def kernel_with_int_param(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 64 # src[test_register_tunable.py:N]: for tile_n in hl.tile([n]): # src[test_register_tunable.py:N]: out[tile_n] = x[tile_n] * multiplier - _launcher(_helion_kernel_with_int_param, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, multiplier, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_int_param, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, multiplier, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_register_tunable.py:N]: return out return out @@ -151,7 +151,7 @@ def kernel_with_tunable(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 2 * block_size # src[test_register_tunable.py:N]: for tile_n in hl.tile([n], block_size=[block_size * 2]): # src[test_register_tunable.py:N]: out[tile_n] = x[tile_n] * 2.0 - _launcher(_helion_kernel_with_tunable, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_with_tunable, (triton.cdiv(128, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_register_tunable.py:N]: return out return out @@ -189,6 +189,6 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 64 # src[test_register_tunable.py:N]: for tile in hl.tile(m, block_size=block_m): # src[test_register_tunable.py:N]: partial[tile.begin // block_m] = x[tile].sum() - _launcher(_helion_fn, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, partial, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, partial, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_register_tunable.py:N]: return partial.sum() return partial.sum() diff --git a/test/test_rng.expected b/test/test_rng.expected index 3f3fb1531..7fbf80627 100644 --- a/test/test_rng.expected +++ b/test/test_rng.expected @@ -73,7 +73,7 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_rng.py:N]: # Two independent rand operations # src[test_rng.py:N]: rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) # src[test_rng.py:N-N]: ... - _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=2) + _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1) # src[test_rng.py:N]: randn_sum = randn_a + randn_b + randn_c randn_sum = randn_a + randn_b + randn_c # src[test_rng.py:N]: return rand1, rand2, uniform, normal, randn_sum diff --git a/test/test_signal_wait.expected b/test/test_signal_wait.expected index a2931e374..bdc5bdb93 100644 --- a/test/test_signal_wait.expected +++ b/test/test_signal_wait.expected @@ -42,7 +42,7 @@ def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor, *, _launcher=_default_l # src[test_signal_wait.py:N]: for tile in hl.tile(N, block_size=N): # src[test_signal_wait.py:N]: hl.signal( # src[test_signal_wait.py:N-N]: ... - _launcher(_helion_gmem_multi_bar_sync_kernel, (4,), signal_pad, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_gmem_multi_bar_sync_kernel, (4,), signal_pad, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return signal_pad return signal_pad @@ -68,7 +68,7 @@ def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor, *, _launcher=_defaul n, = signal_pad.shape # src[test_signal_wait.py:N]: for i in hl.grid(n): # src[test_signal_wait.py:N]: hl.signal(signal_pad, [i], signal=1) - _launcher(_helion_gmem_signal_scalar_bar_kernel, (4,), signal_pad, num_warps=4, num_stages=2) + _launcher(_helion_gmem_signal_scalar_bar_kernel, (4,), signal_pad, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return signal_pad return signal_pad @@ -94,7 +94,7 @@ def gmem_signal_cas_kernel(signal_pad: torch.Tensor, *, _launcher=_default_launc n, = signal_pad.shape # src[test_signal_wait.py:N]: for i in hl.grid(n): # src[test_signal_wait.py:N]: hl.signal(signal_pad, [i], signal=1, wait_for=0) - _launcher(_helion_gmem_signal_cas_kernel, (4,), signal_pad, num_warps=4, num_stages=2) + _launcher(_helion_gmem_signal_cas_kernel, (4,), signal_pad, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return signal_pad return signal_pad @@ -123,7 +123,7 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor, *, _launcher=_defaul _BLOCK_SIZE_0 = 4 # src[test_signal_wait.py:N]: for tile in hl.tile(n): # src[test_signal_wait.py:N]: hl.signal(signal_pad, [tile], signal=1) - _launcher(_helion_gmem_signal_tensor_bar_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), signal_pad, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_gmem_signal_tensor_bar_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), signal_pad, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return signal_pad return signal_pad @@ -152,7 +152,7 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor, *, _launcher=_defaul _BLOCK_SIZE_0 = 4 # src[test_signal_wait.py:N]: for tile in hl.tile(n): # src[test_signal_wait.py:N]: hl.signal(signal_pad, [tile], wait_for=0, signal=1) - _launcher(_helion_gmem_signal_tensor_bar_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), signal_pad, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_gmem_signal_tensor_bar_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), signal_pad, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return signal_pad return signal_pad @@ -183,7 +183,7 @@ def gmem_signal_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Te # src[test_signal_wait.py:N]: ptr_tile = signal_pad_ptrs[:] # src[test_signal_wait.py:N]: stack_signal_pad = hl.stacktensor_like(example, ptr_tile) # src[test_signal_wait.py:N-N]: ... - _launcher(_helion_gmem_signal_pointers_kernel, (4,), signal_pad_ptrs, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_gmem_signal_pointers_kernel, (4,), signal_pad_ptrs, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return signal_pad_ptrs return signal_pad_ptrs @@ -225,7 +225,7 @@ def wait_for_2d_tile_kernel(signal_pad: torch.Tensor, x: torch.Tensor, *, _launc # src[test_signal_wait.py:N]: for tile_n, tile_m in hl.tile([n, m]): # src[test_signal_wait.py:N]: hl.wait(signal_pad, [tile_n.id, tile_m.id], signal=1) # src[test_signal_wait.py:N]: out[tile_n, tile_m] = x[tile_n, tile_m] - _launcher(_helion_wait_for_2d_tile_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), signal_pad, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_wait_for_2d_tile_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), signal_pad, x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return out return out @@ -256,7 +256,7 @@ def gmem_wait_kernel(signal_pad: torch.Tensor, *, _launcher=_default_launcher): # src[test_signal_wait.py:N]: for i in hl.grid(n): # src[test_signal_wait.py:N]: hl.wait(signal_pad, [i], signal=1) # src[test_signal_wait.py:N]: out[i] = i - _launcher(_helion_gmem_wait_kernel, (4,), signal_pad, out, num_warps=4, num_stages=2) + _launcher(_helion_gmem_wait_kernel, (4,), signal_pad, out, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return out return out @@ -295,7 +295,7 @@ def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor, *, _launcher=_default_l # src[test_signal_wait.py:N]: for tile in hl.tile(N, block_size=n): # src[test_signal_wait.py:N]: hl.wait(signal_pad, [tile], signal=1) # src[test_signal_wait.py:N]: out[tile.id] = tile.id - _launcher(_helion_gmem_wait_multi_bar_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), signal_pad, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_gmem_wait_multi_bar_kernel, (triton.cdiv(16, _BLOCK_SIZE_0),), signal_pad, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return out return out @@ -324,7 +324,7 @@ def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor, *, _launcher=_defau _BLOCK_SIZE_0 = 4 # src[test_signal_wait.py:N]: for tile in hl.tile(N, block_size=n): # src[test_signal_wait.py:N]: hl.wait(signal_pad, [tile], signal=1, update=2) - _launcher(_helion_gmem_wait_multi_bar_kernel_cas, (triton.cdiv(16, _BLOCK_SIZE_0),), signal_pad, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_gmem_wait_multi_bar_kernel_cas, (triton.cdiv(16, _BLOCK_SIZE_0),), signal_pad, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return signal_pad return signal_pad @@ -359,6 +359,6 @@ def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tens # src[test_signal_wait.py:N]: dev_tile = signal_pad_ptrs[:] # src[test_signal_wait.py:N]: stack_tensor = hl.stacktensor_like(example, dev_tile) # src[test_signal_wait.py:N-N]: ... - _launcher(_helion_gmem_wait_pointers_kernel, (4,), signal_pad_ptrs, out, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_gmem_wait_pointers_kernel, (4,), signal_pad_ptrs, out, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_signal_wait.py:N]: return out return out diff --git a/test/test_specialize.expected b/test/test_specialize.expected index d237546d5..c8b19be6f 100644 --- a/test/test_specialize.expected +++ b/test/test_specialize.expected @@ -38,7 +38,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_specialize.py:N]: acc = hl.zeros([tile, helion.next_power_of_2(x.size(1))]) # src[test_specialize.py:N]: acc += x[tile, :] + 1 # src[test_specialize.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out @@ -83,7 +83,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_specialize.py:N]: acc = hl.zeros([tile, helion.next_power_of_2(x.size(1))]) # src[test_specialize.py:N]: acc2 = hl.full([tile, helion.next_power_of_2(x.size(1))], 1.0) # src[test_specialize.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out @@ -135,7 +135,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_specialize.py:N]: acc = hl.full( # src[test_specialize.py:N]: [tile, helion.next_power_of_2(x.size(1))], # src[test_specialize.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out @@ -176,7 +176,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_specialize.py:N]: acc = hl.zeros([tile, helion.next_power_of_2(x.size(1))]) # src[test_specialize.py:N]: acc = acc + x[tile, :] + 1 # src[test_specialize.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out @@ -217,7 +217,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_specialize.py:N]: acc = hl.zeros([tile, helion.next_power_of_2(x.size(1))]) # src[test_specialize.py:N]: acc = x[tile, :] + acc + 1 # src[test_specialize.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out @@ -256,7 +256,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_specialize.py:N]: acc = hl.zeros([tile, x.size(1)]) # src[test_specialize.py:N]: acc += x[tile, :] + 1 # src[test_specialize.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(512, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out @@ -297,7 +297,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 32 # src[test_specialize.py:N]: for tile in hl.tile(x.size()): # src[test_specialize.py:N]: out[tile] = x[tile] * scale - _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0) * triton.cdiv(500, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0) * triton.cdiv(500, _BLOCK_SIZE_1),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out @@ -331,7 +331,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _RDIM_SIZE_1 = 512 # src[test_specialize.py:N]: for tile in hl.tile(x.size(0)): # src[test_specialize.py:N]: out[tile] = x[tile, :].sum(-1) - _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(500, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out @@ -365,7 +365,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 32 # src[test_specialize.py:N]: for tile in hl.tile(x.size()): # src[test_specialize.py:N]: out[tile] = x[tile] * scale - _launcher(_helion_fn, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out @@ -439,7 +439,7 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) @@ -513,7 +513,7 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) @@ -587,7 +587,7 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) @@ -661,7 +661,7 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) @@ -733,7 +733,7 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) @@ -805,7 +805,7 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) @@ -877,7 +877,7 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) @@ -949,7 +949,7 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) @@ -1023,7 +1023,7 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) @@ -1097,6 +1097,6 @@ def reduce_kernel(x: torch.Tensor, tensor_factory_fn, test_host, *, _launcher=_d # src[test_specialize.py:N]: # Device-side tensor creation should be padded to 64 # src[test_specialize.py:N]: grad_w_m = tensor_factory_fn(x, weight_shape, dtype=torch.float32) # src[test_specialize.py:N-N]: ... - _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (triton.cdiv(128, _BLOCK_SIZE_0),), x, grad_weight, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return grad_weight.sum(0).to(x.dtype) return grad_weight.sum(0).to(x.dtype) diff --git a/test/test_stack_tensor.expected b/test/test_stack_tensor.expected index f9f0b394b..7886a9217 100644 --- a/test/test_stack_tensor.expected +++ b/test/test_stack_tensor.expected @@ -36,7 +36,7 @@ def stack_load_kernel_2d(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, * # src[test_stack_tensor.py:N]: ptr_tile = dev_ptrs[:, :] # src[test_stack_tensor.py:N]: tensors = hl.stacktensor_like(example_tensor, ptr_tile) # src[test_stack_tensor.py:N-N]: ... - _launcher(_helion_stack_load_kernel_2d, (triton.cdiv(4, _BLOCK_SIZE_0),), dev_ptrs, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_stack_load_kernel_2d, (triton.cdiv(4, _BLOCK_SIZE_0),), dev_ptrs, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_stack_tensor.py:N]: return out return outfrom __future__ import annotations @@ -77,7 +77,7 @@ def stack_load_2d_looped(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, * # src[test_stack_tensor.py:N]: for i in range(M1): # src[test_stack_tensor.py:N]: ptr_tile = dev_ptrs[i, :] # src[test_stack_tensor.py:N-N]: ... - _launcher(_helion_stack_load_2d_looped, (triton.cdiv(4, _BLOCK_SIZE_0),), dev_ptrs, out, _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_stack_load_2d_looped, (triton.cdiv(4, _BLOCK_SIZE_0),), dev_ptrs, out, _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=1) # src[test_stack_tensor.py:N]: return out return out @@ -121,7 +121,7 @@ def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _ # src[test_stack_tensor.py:N]: ptr_tile = dev_ptrs[:] # src[test_stack_tensor.py:N]: tensors = hl.stacktensor_like(example_tensor, ptr_tile) # src[test_stack_tensor.py:N-N]: ... - _launcher(_helion_stack_load_kernel, (triton.cdiv(4, _BLOCK_SIZE_0) * triton.cdiv(4, _BLOCK_SIZE_1),), dev_ptrs, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=2) + _launcher(_helion_stack_load_kernel, (triton.cdiv(4, _BLOCK_SIZE_0) * triton.cdiv(4, _BLOCK_SIZE_1),), dev_ptrs, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=1) # src[test_stack_tensor.py:N]: return out return out @@ -158,7 +158,7 @@ def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _ # src[test_stack_tensor.py:N]: ptr_tile = dev_ptrs[:] # src[test_stack_tensor.py:N]: tensors = hl.stacktensor_like(example_tensor, ptr_tile) # src[test_stack_tensor.py:N-N]: ... - _launcher(_helion_stack_load_kernel, (4,), dev_ptrs, out, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_stack_load_kernel, (4,), dev_ptrs, out, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_stack_tensor.py:N]: return out return out @@ -211,7 +211,7 @@ def stack_load_w_mask(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _ # src[test_stack_tensor.py:N]: for stack_tile in hl.tile(M, block_size=4): # src[test_stack_tensor.py:N]: ptr_tile = dev_ptrs[stack_tile] # src[test_stack_tensor.py:N-N]: ... - _launcher(_helion_stack_load_w_mask, (triton.cdiv(15, _BLOCK_SIZE_0),), dev_ptrs, out, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_stack_load_w_mask, (triton.cdiv(15, _BLOCK_SIZE_0),), dev_ptrs, out, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_stack_tensor.py:N]: return out return out @@ -250,7 +250,7 @@ def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor: # src[test_stack_tensor.py:N]: ptr_tile = dev_ptrs[:] # src[test_stack_tensor.py:N]: tensors = hl.stacktensor_like(example_tensor, ptr_tile) # src[test_stack_tensor.py:N-N]: ... - _launcher(_helion_stack_store_kernel, (triton.cdiv(15, _BLOCK_SIZE_0),), dev_ptrs, x, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_stack_store_kernel, (triton.cdiv(15, _BLOCK_SIZE_0),), dev_ptrs, x, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) --- assertExpectedJournal(TestStackTensor.test_stack_store_grid) from __future__ import annotations @@ -281,7 +281,7 @@ def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor: # src[test_stack_tensor.py:N]: ptr_tile = dev_ptrs[:] # src[test_stack_tensor.py:N]: tensors = hl.stacktensor_like(example_tensor, ptr_tile) # src[test_stack_tensor.py:N-N]: ... - _launcher(_helion_stack_store_kernel, (16,), dev_ptrs, x, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_stack_store_kernel, (16,), dev_ptrs, x, _RDIM_SIZE_1, num_warps=4, num_stages=1) --- assertExpectedJournal(TestStackTensor.test_stack_store_scatter) from __future__ import annotations @@ -313,4 +313,4 @@ def stack_store_arange_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tens # src[test_stack_tensor.py:N]: ptr_tile = dev_ptrs[:] # src[test_stack_tensor.py:N]: tensors = hl.stacktensor_like(example_tensor, ptr_tile) # src[test_stack_tensor.py:N-N]: ... - _launcher(_helion_stack_store_arange_kernel, (15,), dev_ptrs, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_stack_store_arange_kernel, (15,), dev_ptrs, _RDIM_SIZE_1, num_warps=4, num_stages=1) diff --git a/test/test_tensor_descriptor.expected b/test/test_tensor_descriptor.expected index f279c6af3..33be7a54e 100644 --- a/test/test_tensor_descriptor.expected +++ b/test/test_tensor_descriptor.expected @@ -132,7 +132,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la # src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) # src[attention.py:N]: l_i = torch.full_like(m_i, 1.0) # src[attention.py:N-N]: ... - _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[attention.py:N]: return out.view(q_in.size()) return out.view(q_in.size()) @@ -268,6 +268,6 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la # src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32) # src[attention.py:N]: l_i = torch.full_like(m_i, 1.0) # src[attention.py:N-N]: ... - _launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=1) # src[attention.py:N]: return out.view(q_in.size()) return out.view(q_in.size()) diff --git a/test/test_unroll_tuples.expected b/test/test_unroll_tuples.expected index 5327ba330..29761b6db 100644 --- a/test/test_unroll_tuples.expected +++ b/test/test_unroll_tuples.expected @@ -37,7 +37,7 @@ def kernel_tuple_addition(a_shared_tuple: tuple[torch.Tensor, ...], *, _launcher # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_n], dtype=torch.float32, device=out.device) # src[test_unroll_tuples.py:N]: for a_tensor in a_shared_tuple: # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_tuple_addition, (triton.cdiv(32, _BLOCK_SIZE_0),), a_shared_tuple[0], a_shared_tuple[1], a_shared_tuple[2], out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_tuple_addition, (triton.cdiv(32, _BLOCK_SIZE_0),), a_shared_tuple[0], a_shared_tuple[1], a_shared_tuple[2], out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return out return out @@ -84,7 +84,7 @@ def kernel_constants_iteration(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: # Iterate over constants # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_constants_iteration, (triton.cdiv(24, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_constants_iteration, (triton.cdiv(24, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -137,7 +137,7 @@ def kernel_enumerate_constants(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: # Enumerate over constant values # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_enumerate_constants, (triton.cdiv(20, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_enumerate_constants, (triton.cdiv(20, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -184,7 +184,7 @@ def kernel_enumerate_iteration(tensors: tuple[torch.Tensor, torch.Tensor, torch. # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: # Iterate with enumerate to get index and tensor # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_enumerate_iteration, (triton.cdiv(24, _BLOCK_SIZE_0),), tensors[0], tensors[1], tensors[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_enumerate_iteration, (triton.cdiv(24, _BLOCK_SIZE_0),), tensors[0], tensors[1], tensors[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -227,7 +227,7 @@ def kernel_enumerate_with_start(tensors: tuple[torch.Tensor, torch.Tensor], *, _ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: # Enumerate starting from 5 # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_enumerate_with_start, (triton.cdiv(18, _BLOCK_SIZE_0),), tensors[0], tensors[1], result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_enumerate_with_start, (triton.cdiv(18, _BLOCK_SIZE_0),), tensors[0], tensors[1], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -335,7 +335,7 @@ def kernel_list_comprehension_host_and_device(x: torch.Tensor, *, _launcher=_def # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_list_comprehension_host_and_device, (triton.cdiv(26, _BLOCK_SIZE_0),), x, result, host_multipliers[0], host_multipliers[1], host_multipliers[2], host_multipliers[3], _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_list_comprehension_host_and_device, (triton.cdiv(26, _BLOCK_SIZE_0),), x, result, host_multipliers[0], host_multipliers[1], host_multipliers[2], host_multipliers[3], _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -384,7 +384,7 @@ def kernel_list_comprehension_with_function(x: torch.Tensor, *, _launcher=_defau # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: for value in squared_values: # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_list_comprehension_with_function, (triton.cdiv(14, _BLOCK_SIZE_0),), x, result, squared_values[0], squared_values[1], squared_values[2], _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_list_comprehension_with_function, (triton.cdiv(14, _BLOCK_SIZE_0),), x, result, squared_values[0], squared_values[1], squared_values[2], _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -427,7 +427,7 @@ def kernel_list_comprehension_with_tensors(tensors: tuple[torch.Tensor, torch.Te # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: for tensor in tensor_list: # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_list_comprehension_with_tensors, (triton.cdiv(18, _BLOCK_SIZE_0),), tensors[0], tensors[1], tensors[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_list_comprehension_with_tensors, (triton.cdiv(18, _BLOCK_SIZE_0),), tensors[0], tensors[1], tensors[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -472,7 +472,7 @@ def kernel_list_comprehension_with_tuple_unrolling(tensors: tuple[torch.Tensor, # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: # Iterate over the scaled tensors (both list comp and tuple unrolling) # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_list_comprehension_with_tuple_unrolling, (triton.cdiv(22, _BLOCK_SIZE_0),), scaled_tensors[0], scaled_tensors[1], scaled_tensors[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_list_comprehension_with_tuple_unrolling, (triton.cdiv(22, _BLOCK_SIZE_0),), scaled_tensors[0], scaled_tensors[1], scaled_tensors[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -519,7 +519,7 @@ def kernel_list_constants_iteration(x: torch.Tensor, *, _launcher=_default_launc # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: # Iterate over constants in a list # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_list_constants_iteration, (triton.cdiv(20, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_list_constants_iteration, (triton.cdiv(20, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -562,7 +562,7 @@ def kernel_mixed_constants_and_tensors(tensors: tuple[torch.Tensor, torch.Tensor # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_mixed_constants_and_tensors, (triton.cdiv(22, _BLOCK_SIZE_0),), tensors[0], tensors[1], result, constants[0], constants[1], _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_mixed_constants_and_tensors, (triton.cdiv(22, _BLOCK_SIZE_0),), tensors[0], tensors[1], result, constants[0], constants[1], _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -617,7 +617,7 @@ def kernel_nested_list_comprehension(x: torch.Tensor, *, _launcher=_default_laun # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: for i, j in pairs: # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_nested_list_comprehension, (triton.cdiv(12, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_nested_list_comprehension, (triton.cdiv(12, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -660,7 +660,7 @@ def kernel_nested_tuple_iteration(a_tuple: tuple[torch.Tensor, torch.Tensor], b_ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): # src[test_unroll_tuples.py:N]: temp = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_nested_tuple_iteration, (triton.cdiv(40, _BLOCK_SIZE_0),), a_tuple[0], a_tuple[1], b_tuple[0], b_tuple[1], result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_nested_tuple_iteration, (triton.cdiv(40, _BLOCK_SIZE_0),), a_tuple[0], a_tuple[1], b_tuple[0], b_tuple[1], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -708,7 +708,7 @@ def kernel_simple_list_comprehension(x: torch.Tensor, *, _launcher=_default_laun # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: for multiplier in multipliers: # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_simple_list_comprehension, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_simple_list_comprehension, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -744,7 +744,7 @@ def kernel_tuple_addition(a_shared_tuple: tuple[torch.Tensor, ...], *, _launcher # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_n], dtype=torch.float32, device=out.device) # src[test_unroll_tuples.py:N]: for a_tensor in a_shared_tuple: # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_tuple_addition, (triton.cdiv(16, _BLOCK_SIZE_0),), a_shared_tuple[0], out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_tuple_addition, (triton.cdiv(16, _BLOCK_SIZE_0),), a_shared_tuple[0], out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return out return out @@ -795,7 +795,7 @@ def kernel_static_range_iteration(x: torch.Tensor, *, _launcher=_default_launche # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: # Use static_range for unrolled loop # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_static_range_iteration, (triton.cdiv(28, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_static_range_iteration, (triton.cdiv(28, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -842,7 +842,7 @@ def kernel_static_range_with_start(x: torch.Tensor, *, _launcher=_default_launch # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: # Use static_range(start, end) # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_static_range_with_start, (triton.cdiv(18, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_static_range_with_start, (triton.cdiv(18, _BLOCK_SIZE_0),), x, result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result @@ -886,7 +886,7 @@ def kernel_tuple_with_scaling(tensor1: torch.Tensor, tensor2: torch.Tensor, tens # src[test_unroll_tuples.py:N]: temp = torch.zeros([tile_idx], dtype=torch.float32, device=output.device) # src[test_unroll_tuples.py:N]: for tensor, scale in zip(tensors, scales, strict=True): # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_tuple_with_scaling, (triton.cdiv(48, _BLOCK_SIZE_0),), tensor1, tensor2, tensor3, output, scale1, scale2, scale3, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_tuple_with_scaling, (triton.cdiv(48, _BLOCK_SIZE_0),), tensor1, tensor2, tensor3, output, scale1, scale2, scale3, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return output return output @@ -929,6 +929,6 @@ def kernel_zip_iteration(tensors_a: tuple[torch.Tensor, torch.Tensor], tensors_b # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) # src[test_unroll_tuples.py:N]: # Iterate over zip of tensors # src[test_unroll_tuples.py:N-N]: ... - _launcher(_helion_kernel_zip_iteration, (triton.cdiv(36, _BLOCK_SIZE_0),), tensors_a[0], tensors_b[0], tensors_a[1], tensors_b[1], result, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_kernel_zip_iteration, (triton.cdiv(36, _BLOCK_SIZE_0),), tensors_a[0], tensors_b[0], tensors_a[1], tensors_b[1], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_unroll_tuples.py:N]: return result return result diff --git a/test/test_views.expected b/test/test_views.expected index 7e10d6fd4..e9c86e941 100644 --- a/test/test_views.expected +++ b/test/test_views.expected @@ -49,7 +49,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_views.py:N]: for tile1, tile2 in hl.tile([x.size(1), x.size(2)]): # src[test_views.py:N-N]: ... _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1 * _BLOCK_SIZE_2) - _launcher(_helion_fn, (triton.cdiv(3, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(3, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_views.py:N]: return out return out @@ -96,7 +96,7 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_views.py:N]: values = x[tile_n, :] # src[test_views.py:N]: amax = torch.amax(values, dim=1).unsqueeze(1) # src[test_views.py:N-N]: ... - _launcher(_helion_softmax, (1024,), x, out, _RDIM_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_softmax, (1024,), x, out, _RDIM_SIZE_1, num_warps=4, num_stages=1) # src[test_views.py:N]: return out return out @@ -145,7 +145,7 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_views.py:N]: values = x[tile_n, :] # src[test_views.py:N]: amax = torch.amax(values, dim=1).view(tile_n, 1) # src[test_views.py:N-N]: ... - _launcher(_helion_softmax, (1024,), x, out, _RDIM_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + _launcher(_helion_softmax, (1024,), x, out, _RDIM_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) # src[test_views.py:N]: return out return out @@ -200,7 +200,7 @@ def fn(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[test_views.py:N]: out[tile_n, tile_m] = x[tile_n, tile_m] + y[tile_m, :].squeeze( # src[test_views.py:N]: 1 # src[test_views.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_fn, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_views.py:N]: return out return out @@ -267,7 +267,7 @@ def test_stack_dim0_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, # src[test_views.py:N]: for tile_n in hl.tile(N): # src[test_views.py:N]: a_tile = a[tile_m, tile_n] # src[test_views.py:N-N]: ... - _launcher(_helion_test_stack_dim0_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_stack_dim0_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_views.py:N]: return result return result @@ -334,6 +334,6 @@ def test_stack_non_power_of_2_kernel(a: torch.Tensor, b: torch.Tensor, c: torch. # src[test_views.py:N]: for tile_n in hl.tile(N): # src[test_views.py:N]: a_tile = a[tile_m, tile_n] # src[test_views.py:N-N]: ... - _launcher(_helion_test_stack_non_power_of_2_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_stack_non_power_of_2_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=1) # src[test_views.py:N]: return result return result