Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,36 +1036,43 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[object, ...]:
def visit_List(self, node: ast.List) -> list[object]:
return [self.visit(x) for x in node.elts]

def visit_ListComp(self, node: ast.ListComp) -> tuple[object, ...]:
"""Handle list comprehension unrolling similar to tuple unrolling."""
def _visit_comprehension(
self, node: ast.ListComp | ast.GeneratorExp, name: str
) -> tuple[object, ...]:
"""Handle list comprehension or generator expression unrolling."""
assert isinstance(node, ExtendedAST)

# Only handle simple cases with single generator and no if conditions
if len(node.generators) != 1 or node.generators[0].ifs:
raise exc.StatementNotSupported(
"Complex list comprehensions are not supported"
)
raise exc.StatementNotSupported(f"Complex {name}s are not supported")

generator = node.generators[0]
assert isinstance(generator.iter, ExtendedAST)
iter_type = generator.iter._type_info

# Check if we're iterating over a sequence (similar to tuple unrolling)
if isinstance(iter_type, SequenceType):
return self._handle_listcomp_unrolling(node)
return self._handle_comprehension_unrolling(node.elt, generator)

# For non-sequence iterables, we could extend this later
raise exc.StatementNotSupported(
"List comprehensions over non-sequence types are not supported"
f"{name.capitalize()}s over non-sequence types are not supported"
)

def _handle_listcomp_unrolling(self, node: ast.ListComp) -> tuple[object, ...]:
"""Handle unrolling of list comprehensions over sequences."""
generator = node.generators[0]
def visit_ListComp(self, node: ast.ListComp) -> tuple[object, ...]:
return self._visit_comprehension(node, "list comprehension")

def visit_GeneratorExp(self, node: ast.GeneratorExp) -> tuple[object, ...]:
return self._visit_comprehension(node, "generator expression")

def _handle_comprehension_unrolling(
self, elt: ast.expr, generator: ast.comprehension
) -> tuple[object, ...]:
"""Handle unrolling of comprehensions (list comp or generator exp) over sequences."""

def evaluate_expression() -> object:
# Evaluate the comprehension expression
result = self.visit(node.elt)
result = self.visit(elt)
# If the result is a SymInt that can be evaluated to a concrete value, do so
if isinstance(result, torch.SymInt):
try:
Expand Down
17 changes: 11 additions & 6 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,21 +2429,26 @@ def _evaluate_comprehension(
# Fallback to generic list type
return SequenceType(self.origin(), [element_result_type])

def visit_ListComp(self, node: ast.ListComp) -> TypeInfo:
"""Type propagation for list comprehensions."""
def _visit_comprehension(
self, node: ast.ListComp | ast.GeneratorExp, name: str
) -> TypeInfo:
"""Type propagation for list comprehensions and generator expressions."""
if len(node.generators) != 1:
raise exc.StatementNotSupported(
"List comprehensions with multiple generators are not supported"
f"{name.capitalize()}s with multiple generators are not supported"
)

return self._evaluate_comprehension(node.generators[0], node.elt)

def visit_ListComp(self, node: ast.ListComp) -> TypeInfo:
return self._visit_comprehension(node, "list comprehension")

def visit_GeneratorExp(self, node: ast.GeneratorExp) -> TypeInfo:
return self._visit_comprehension(node, "generator expression")

# TODO(jansel): need to implement these
# pyrefly: ignore [bad-assignment, bad-param-name-override]
visit_SetComp: _VisitMethod = _not_supported
# pyrefly: ignore [bad-assignment, bad-param-name-override]
visit_GeneratorExp: _VisitMethod = _not_supported
# pyrefly: ignore [bad-assignment, bad-param-name-override]
visit_DictComp: _VisitMethod = _not_supported

# TODO(jansel): support closure functions defined on host
Expand Down
146 changes: 146 additions & 0 deletions test/test_unroll_tuples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,152 @@ def kernel_static_range_with_start(x: torch.Tensor, *, _launcher=_default_launch
# src[test_unroll_tuples.py:N]: return result
return result

--- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel_tuple_comprehension(x, result, multipliers_item_0, multipliers_item_1, multipliers_item_2, _BLOCK_SIZE_0: tl.constexpr):
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multiplier
load = tl.load(x + indices_0 * 1, None)
v_0 = tl.cast(multipliers_item_0, tl.float32)
v_1 = load * v_0
v_2 = acc + v_1
load_1 = tl.load(x + indices_0 * 1, None)
v_3 = tl.cast(multipliers_item_1, tl.float32)
v_4 = load_1 * v_3
v_5 = v_2 + v_4
load_2 = tl.load(x + indices_0 * 1, None)
v_6 = tl.cast(multipliers_item_2, tl.float32)
v_7 = load_2 * v_6
v_8 = v_5 + v_7
# src[test_unroll_tuples.py:N]: result[tile_idx] = acc
tl.store(result + indices_0 * 1, v_8, None)

def kernel_tuple_comprehension(x: torch.Tensor, *, _launcher=_default_launcher):
"""Test tuple comprehension with generator expression."""
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
result = torch.zeros_like(x)
# src[test_unroll_tuples.py:N]: multipliers = tuple(m * 2 for m in (1, 2, 3))
multipliers = tuple((m * 2 for m in (1, 2, 3)))
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
_BLOCK_SIZE_0 = 16
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
# src[test_unroll_tuples.py:N]: for multiplier in multipliers:
# src[test_unroll_tuples.py:N-N]: ...
_launcher(_helion_kernel_tuple_comprehension, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], _BLOCK_SIZE_0, num_warps=4, num_stages=1)
# src[test_unroll_tuples.py:N]: return result
return result

--- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension_with_static_range)
from __future__ import annotations

import torch
import helion.language as hl
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel_tuple_comprehension_with_static_range(x, result, multipliers_item_0, multipliers_item_1, multipliers_item_2, multipliers_item_3, _BLOCK_SIZE_0: tl.constexpr):
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[i]
load = tl.load(x + indices_0 * 1, None)
v_0 = tl.cast(multipliers_item_0, tl.float32)
v_1 = load * v_0
v_2 = acc + v_1
load_1 = tl.load(x + indices_0 * 1, None)
v_3 = tl.cast(multipliers_item_1, tl.float32)
v_4 = load_1 * v_3
v_5 = v_2 + v_4
load_2 = tl.load(x + indices_0 * 1, None)
v_6 = tl.cast(multipliers_item_2, tl.float32)
v_7 = load_2 * v_6
v_8 = v_5 + v_7
load_3 = tl.load(x + indices_0 * 1, None)
v_9 = tl.cast(multipliers_item_3, tl.float32)
v_10 = load_3 * v_9
v_11 = v_8 + v_10
# src[test_unroll_tuples.py:N]: result[tile_idx] = acc
tl.store(result + indices_0 * 1, v_11, None)

def kernel_tuple_comprehension_with_static_range(x: torch.Tensor, N: hl.constexpr, *, _launcher=_default_launcher):
"""Test tuple comprehension with static_range for indexing."""
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
result = torch.zeros_like(x)
# src[test_unroll_tuples.py:N]: multipliers = tuple(i + 1 for i in range(N))
multipliers = tuple((i + 1 for i in range(4)))
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
_BLOCK_SIZE_0 = 16
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
# src[test_unroll_tuples.py:N]: for i in hl.static_range(N):
# src[test_unroll_tuples.py:N-N]: ...
_launcher(_helion_kernel_tuple_comprehension_with_static_range, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], multipliers[3], _BLOCK_SIZE_0, num_warps=4, num_stages=1)
# src[test_unroll_tuples.py:N]: return result
return result

--- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension_with_tensors)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel_tuple_comprehension_with_tensors(scaled_item_0, scaled_item_1, scaled_item_2, result, _BLOCK_SIZE_0: tl.constexpr):
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < 18
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
# src[test_unroll_tuples.py:N]: acc += tensor[tile_idx]
load = tl.load(scaled_item_0 + indices_0 * 1, mask_0, other=0)
v_0 = acc + load
load_1 = tl.load(scaled_item_1 + indices_0 * 1, mask_0, other=0)
v_1 = v_0 + load_1
load_2 = tl.load(scaled_item_2 + indices_0 * 1, mask_0, other=0)
v_2 = v_1 + load_2
# src[test_unroll_tuples.py:N]: result[tile_idx] = acc
tl.store(result + indices_0 * 1, v_2, mask_0)

def kernel_tuple_comprehension_with_tensors(tensors: tuple[torch.Tensor, torch.Tensor, torch.Tensor], *, _launcher=_default_launcher):
"""Test tuple comprehension that transforms a tuple of tensors."""
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(tensors[0])
result = torch.zeros_like(tensors[0])
# src[test_unroll_tuples.py:N]: scales = (0.5, 1.0, 1.5)
scales = (0.5, 1.0, 1.5)
# src[test_unroll_tuples.py:N]: scaled = tuple(t * s for t, s in zip(tensors, scales, strict=False))
scaled = tuple((t * s for t, s in zip(tensors, scales, strict=False)))
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
_BLOCK_SIZE_0 = 32
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
# src[test_unroll_tuples.py:N]: for tensor in scaled:
# src[test_unroll_tuples.py:N-N]: ...
_launcher(_helion_kernel_tuple_comprehension_with_tensors, (triton.cdiv(18, _BLOCK_SIZE_0),), scaled[0], scaled[1], scaled[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
# src[test_unroll_tuples.py:N]: return result
return result

--- assertExpectedJournal(TestUnrollTuples.test_tuple_with_scaling_factors)
from __future__ import annotations

Expand Down
102 changes: 102 additions & 0 deletions test/test_unroll_tuples.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,57 @@ def kernel_simple_list_comprehension(
return result


@helion.kernel(autotune_effort="none")
def kernel_tuple_comprehension(
x: torch.Tensor,
) -> torch.Tensor:
"""Test tuple comprehension with generator expression."""
result = torch.zeros_like(x)
# Create tuple using generator expression
multipliers = tuple(m * 2 for m in (1, 2, 3))
for tile_idx in hl.tile(result.size(0)):
acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
for multiplier in multipliers:
acc += x[tile_idx] * multiplier
result[tile_idx] = acc
return result


@helion.kernel(autotune_effort="none")
def kernel_tuple_comprehension_with_static_range(
x: torch.Tensor,
N: hl.constexpr,
) -> torch.Tensor:
"""Test tuple comprehension with static_range for indexing."""
result = torch.zeros_like(x)
# Create tuple using generator expression with range
multipliers = tuple(i + 1 for i in range(N))
for tile_idx in hl.tile(result.size(0)):
acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
for i in hl.static_range(N):
acc += x[tile_idx] * multipliers[i]
result[tile_idx] = acc
return result


@helion.kernel(autotune_effort="none")
def kernel_tuple_comprehension_with_tensors(
tensors: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
) -> torch.Tensor:
"""Test tuple comprehension that transforms a tuple of tensors."""
result = torch.zeros_like(tensors[0])
# Create scaled versions using generator expression
scales = (0.5, 1.0, 1.5)
scaled = tuple(t * s for t, s in zip(tensors, scales, strict=False))

for tile_idx in hl.tile(result.size(0)):
acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
for tensor in scaled:
acc += tensor[tile_idx]
result[tile_idx] = acc
return result


@helion.kernel(autotune_effort="none")
def kernel_list_comprehension_with_function(
x: torch.Tensor,
Expand Down Expand Up @@ -623,6 +674,57 @@ def test_simple_list_comprehension(self):
expected = x * 12
torch.testing.assert_close(result, expected)

def test_tuple_comprehension(self):
"""Test tuple comprehension with generator expression."""
size = (16,)
x = torch.randn(size, device=DEVICE)

code, result = code_and_output(kernel_tuple_comprehension, (x,))

# Validate generated code
self.assertExpectedJournal(code)

# Test correctness - should be x * (2 + 4 + 6) = x * 12
expected = x * 12
torch.testing.assert_close(result, expected)

def test_tuple_comprehension_with_static_range(self):
"""Test tuple comprehension with static_range for indexing."""
size = (16,)
x = torch.randn(size, device=DEVICE)
N = 4

code, result = code_and_output(
kernel_tuple_comprehension_with_static_range, (x, N)
)

# Validate generated code
self.assertExpectedJournal(code)

# Test correctness - should be x * (1 + 2 + 3 + 4) = x * 10
expected = x * 10
torch.testing.assert_close(result, expected)

def test_tuple_comprehension_with_tensors(self):
"""Test tuple comprehension that transforms a tuple of tensors."""
size = (18,)
tensor1 = torch.randn(size, device=DEVICE)
tensor2 = torch.randn(size, device=DEVICE)
tensor3 = torch.randn(size, device=DEVICE)

tensors = (tensor1, tensor2, tensor3)

code, result = code_and_output(
kernel_tuple_comprehension_with_tensors, (tensors,)
)

# Validate generated code
self.assertExpectedJournal(code)

# Test correctness - should be tensor1*0.5 + tensor2*1.0 + tensor3*1.5
expected = tensor1 * 0.5 + tensor2 * 1.0 + tensor3 * 1.5
torch.testing.assert_close(result, expected)

def test_list_comprehension_with_function(self):
"""Test list comprehension with expressions."""
size = (14,)
Expand Down
Loading