diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index d90033229..2ab50d429 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -1036,15 +1036,15 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[object, ...]: def visit_List(self, node: ast.List) -> list[object]: return [self.visit(x) for x in node.elts] - def visit_ListComp(self, node: ast.ListComp) -> tuple[object, ...]: - """Handle list comprehension unrolling similar to tuple unrolling.""" + def _visit_comprehension( + self, node: ast.ListComp | ast.GeneratorExp, name: str + ) -> tuple[object, ...]: + """Handle list comprehension or generator expression unrolling.""" assert isinstance(node, ExtendedAST) # Only handle simple cases with single generator and no if conditions if len(node.generators) != 1 or node.generators[0].ifs: - raise exc.StatementNotSupported( - "Complex list comprehensions are not supported" - ) + raise exc.StatementNotSupported(f"Complex {name}s are not supported") generator = node.generators[0] assert isinstance(generator.iter, ExtendedAST) @@ -1052,20 +1052,27 @@ def visit_ListComp(self, node: ast.ListComp) -> tuple[object, ...]: # Check if we're iterating over a sequence (similar to tuple unrolling) if isinstance(iter_type, SequenceType): - return self._handle_listcomp_unrolling(node) + return self._handle_comprehension_unrolling(node.elt, generator) # For non-sequence iterables, we could extend this later raise exc.StatementNotSupported( - "List comprehensions over non-sequence types are not supported" + f"{name.capitalize()}s over non-sequence types are not supported" ) - def _handle_listcomp_unrolling(self, node: ast.ListComp) -> tuple[object, ...]: - """Handle unrolling of list comprehensions over sequences.""" - generator = node.generators[0] + def visit_ListComp(self, node: ast.ListComp) -> tuple[object, ...]: + return self._visit_comprehension(node, "list comprehension") + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> tuple[object, ...]: + return self._visit_comprehension(node, "generator expression") + + def _handle_comprehension_unrolling( + self, elt: ast.expr, generator: ast.comprehension + ) -> tuple[object, ...]: + """Handle unrolling of comprehensions (list comp or generator exp) over sequences.""" def evaluate_expression() -> object: # Evaluate the comprehension expression - result = self.visit(node.elt) + result = self.visit(elt) # If the result is a SymInt that can be evaluated to a concrete value, do so if isinstance(result, torch.SymInt): try: diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index dfb112ae8..adf56182a 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -2429,21 +2429,26 @@ def _evaluate_comprehension( # Fallback to generic list type return SequenceType(self.origin(), [element_result_type]) - def visit_ListComp(self, node: ast.ListComp) -> TypeInfo: - """Type propagation for list comprehensions.""" + def _visit_comprehension( + self, node: ast.ListComp | ast.GeneratorExp, name: str + ) -> TypeInfo: + """Type propagation for list comprehensions and generator expressions.""" if len(node.generators) != 1: raise exc.StatementNotSupported( - "List comprehensions with multiple generators are not supported" + f"{name.capitalize()}s with multiple generators are not supported" ) - return self._evaluate_comprehension(node.generators[0], node.elt) + def visit_ListComp(self, node: ast.ListComp) -> TypeInfo: + return self._visit_comprehension(node, "list comprehension") + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> TypeInfo: + return self._visit_comprehension(node, "generator expression") + # TODO(jansel): need to implement these # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_SetComp: _VisitMethod = _not_supported # pyrefly: ignore [bad-assignment, bad-param-name-override] - visit_GeneratorExp: _VisitMethod = _not_supported - # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_DictComp: _VisitMethod = _not_supported # TODO(jansel): support closure functions defined on host diff --git a/test/test_unroll_tuples.expected b/test/test_unroll_tuples.expected index 1303d8a57..21780bc29 100644 --- a/test/test_unroll_tuples.expected +++ b/test/test_unroll_tuples.expected @@ -890,6 +890,152 @@ def kernel_static_range_with_start(x: torch.Tensor, *, _launcher=_default_launch # src[test_unroll_tuples.py:N]: return result return result +--- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kernel_tuple_comprehension(x, result, multipliers_item_0, multipliers_item_1, multipliers_item_2, _BLOCK_SIZE_0: tl.constexpr): + # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multiplier + load = tl.load(x + indices_0 * 1, None) + v_0 = tl.cast(multipliers_item_0, tl.float32) + v_1 = load * v_0 + v_2 = acc + v_1 + load_1 = tl.load(x + indices_0 * 1, None) + v_3 = tl.cast(multipliers_item_1, tl.float32) + v_4 = load_1 * v_3 + v_5 = v_2 + v_4 + load_2 = tl.load(x + indices_0 * 1, None) + v_6 = tl.cast(multipliers_item_2, tl.float32) + v_7 = load_2 * v_6 + v_8 = v_5 + v_7 + # src[test_unroll_tuples.py:N]: result[tile_idx] = acc + tl.store(result + indices_0 * 1, v_8, None) + +def kernel_tuple_comprehension(x: torch.Tensor, *, _launcher=_default_launcher): + """Test tuple comprehension with generator expression.""" + # src[test_unroll_tuples.py:N]: result = torch.zeros_like(x) + result = torch.zeros_like(x) + # src[test_unroll_tuples.py:N]: multipliers = tuple(m * 2 for m in (1, 2, 3)) + multipliers = tuple((m * 2 for m in (1, 2, 3))) + # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): + _BLOCK_SIZE_0 = 16 + # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): + # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # src[test_unroll_tuples.py:N]: for multiplier in multipliers: + # src[test_unroll_tuples.py:N-N]: ... + _launcher(_helion_kernel_tuple_comprehension, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], _BLOCK_SIZE_0, num_warps=4, num_stages=1) + # src[test_unroll_tuples.py:N]: return result + return result + +--- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension_with_static_range) +from __future__ import annotations + +import torch +import helion.language as hl +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kernel_tuple_comprehension_with_static_range(x, result, multipliers_item_0, multipliers_item_1, multipliers_item_2, multipliers_item_3, _BLOCK_SIZE_0: tl.constexpr): + # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[i] + load = tl.load(x + indices_0 * 1, None) + v_0 = tl.cast(multipliers_item_0, tl.float32) + v_1 = load * v_0 + v_2 = acc + v_1 + load_1 = tl.load(x + indices_0 * 1, None) + v_3 = tl.cast(multipliers_item_1, tl.float32) + v_4 = load_1 * v_3 + v_5 = v_2 + v_4 + load_2 = tl.load(x + indices_0 * 1, None) + v_6 = tl.cast(multipliers_item_2, tl.float32) + v_7 = load_2 * v_6 + v_8 = v_5 + v_7 + load_3 = tl.load(x + indices_0 * 1, None) + v_9 = tl.cast(multipliers_item_3, tl.float32) + v_10 = load_3 * v_9 + v_11 = v_8 + v_10 + # src[test_unroll_tuples.py:N]: result[tile_idx] = acc + tl.store(result + indices_0 * 1, v_11, None) + +def kernel_tuple_comprehension_with_static_range(x: torch.Tensor, N: hl.constexpr, *, _launcher=_default_launcher): + """Test tuple comprehension with static_range for indexing.""" + # src[test_unroll_tuples.py:N]: result = torch.zeros_like(x) + result = torch.zeros_like(x) + # src[test_unroll_tuples.py:N]: multipliers = tuple(i + 1 for i in range(N)) + multipliers = tuple((i + 1 for i in range(4))) + # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): + _BLOCK_SIZE_0 = 16 + # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): + # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # src[test_unroll_tuples.py:N]: for i in hl.static_range(N): + # src[test_unroll_tuples.py:N-N]: ... + _launcher(_helion_kernel_tuple_comprehension_with_static_range, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], multipliers[3], _BLOCK_SIZE_0, num_warps=4, num_stages=1) + # src[test_unroll_tuples.py:N]: return result + return result + +--- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension_with_tensors) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kernel_tuple_comprehension_with_tensors(scaled_item_0, scaled_item_1, scaled_item_2, result, _BLOCK_SIZE_0: tl.constexpr): + # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 18 + # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + # src[test_unroll_tuples.py:N]: acc += tensor[tile_idx] + load = tl.load(scaled_item_0 + indices_0 * 1, mask_0, other=0) + v_0 = acc + load + load_1 = tl.load(scaled_item_1 + indices_0 * 1, mask_0, other=0) + v_1 = v_0 + load_1 + load_2 = tl.load(scaled_item_2 + indices_0 * 1, mask_0, other=0) + v_2 = v_1 + load_2 + # src[test_unroll_tuples.py:N]: result[tile_idx] = acc + tl.store(result + indices_0 * 1, v_2, mask_0) + +def kernel_tuple_comprehension_with_tensors(tensors: tuple[torch.Tensor, torch.Tensor, torch.Tensor], *, _launcher=_default_launcher): + """Test tuple comprehension that transforms a tuple of tensors.""" + # src[test_unroll_tuples.py:N]: result = torch.zeros_like(tensors[0]) + result = torch.zeros_like(tensors[0]) + # src[test_unroll_tuples.py:N]: scales = (0.5, 1.0, 1.5) + scales = (0.5, 1.0, 1.5) + # src[test_unroll_tuples.py:N]: scaled = tuple(t * s for t, s in zip(tensors, scales, strict=False)) + scaled = tuple((t * s for t, s in zip(tensors, scales, strict=False))) + # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): + _BLOCK_SIZE_0 = 32 + # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)): + # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # src[test_unroll_tuples.py:N]: for tensor in scaled: + # src[test_unroll_tuples.py:N-N]: ... + _launcher(_helion_kernel_tuple_comprehension_with_tensors, (triton.cdiv(18, _BLOCK_SIZE_0),), scaled[0], scaled[1], scaled[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1) + # src[test_unroll_tuples.py:N]: return result + return result + --- assertExpectedJournal(TestUnrollTuples.test_tuple_with_scaling_factors) from __future__ import annotations diff --git a/test/test_unroll_tuples.py b/test/test_unroll_tuples.py index 9aef11fbf..d2a064e6e 100644 --- a/test/test_unroll_tuples.py +++ b/test/test_unroll_tuples.py @@ -227,6 +227,57 @@ def kernel_simple_list_comprehension( return result +@helion.kernel(autotune_effort="none") +def kernel_tuple_comprehension( + x: torch.Tensor, +) -> torch.Tensor: + """Test tuple comprehension with generator expression.""" + result = torch.zeros_like(x) + # Create tuple using generator expression + multipliers = tuple(m * 2 for m in (1, 2, 3)) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + for multiplier in multipliers: + acc += x[tile_idx] * multiplier + result[tile_idx] = acc + return result + + +@helion.kernel(autotune_effort="none") +def kernel_tuple_comprehension_with_static_range( + x: torch.Tensor, + N: hl.constexpr, +) -> torch.Tensor: + """Test tuple comprehension with static_range for indexing.""" + result = torch.zeros_like(x) + # Create tuple using generator expression with range + multipliers = tuple(i + 1 for i in range(N)) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + for i in hl.static_range(N): + acc += x[tile_idx] * multipliers[i] + result[tile_idx] = acc + return result + + +@helion.kernel(autotune_effort="none") +def kernel_tuple_comprehension_with_tensors( + tensors: tuple[torch.Tensor, torch.Tensor, torch.Tensor], +) -> torch.Tensor: + """Test tuple comprehension that transforms a tuple of tensors.""" + result = torch.zeros_like(tensors[0]) + # Create scaled versions using generator expression + scales = (0.5, 1.0, 1.5) + scaled = tuple(t * s for t, s in zip(tensors, scales, strict=False)) + + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + for tensor in scaled: + acc += tensor[tile_idx] + result[tile_idx] = acc + return result + + @helion.kernel(autotune_effort="none") def kernel_list_comprehension_with_function( x: torch.Tensor, @@ -623,6 +674,57 @@ def test_simple_list_comprehension(self): expected = x * 12 torch.testing.assert_close(result, expected) + def test_tuple_comprehension(self): + """Test tuple comprehension with generator expression.""" + size = (16,) + x = torch.randn(size, device=DEVICE) + + code, result = code_and_output(kernel_tuple_comprehension, (x,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be x * (2 + 4 + 6) = x * 12 + expected = x * 12 + torch.testing.assert_close(result, expected) + + def test_tuple_comprehension_with_static_range(self): + """Test tuple comprehension with static_range for indexing.""" + size = (16,) + x = torch.randn(size, device=DEVICE) + N = 4 + + code, result = code_and_output( + kernel_tuple_comprehension_with_static_range, (x, N) + ) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be x * (1 + 2 + 3 + 4) = x * 10 + expected = x * 10 + torch.testing.assert_close(result, expected) + + def test_tuple_comprehension_with_tensors(self): + """Test tuple comprehension that transforms a tuple of tensors.""" + size = (18,) + tensor1 = torch.randn(size, device=DEVICE) + tensor2 = torch.randn(size, device=DEVICE) + tensor3 = torch.randn(size, device=DEVICE) + + tensors = (tensor1, tensor2, tensor3) + + code, result = code_and_output( + kernel_tuple_comprehension_with_tensors, (tensors,) + ) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be tensor1*0.5 + tensor2*1.0 + tensor3*1.5 + expected = tensor1 * 0.5 + tensor2 * 1.0 + tensor3 * 1.5 + torch.testing.assert_close(result, expected) + def test_list_comprehension_with_function(self): """Test list comprehension with expressions.""" size = (14,)