Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,14 +2071,25 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> TypeInfo:

def visit_AugAssign(self, node: ast.AugAssign) -> TypeInfo:
assert isinstance(node.target, ExtendedAST)
type_info = self.visit(
create(
ast.BinOp,
left=node.target.copy(ctx=ast.Load()),
op=node.op,
right=node.value,
try:
type_info = self.visit(
create(
ast.BinOp,
left=node.target.copy(ctx=ast.Load()),
op=node.op,
right=node.value,
)
)
)
except exc.TorchOpTracingError as e:
# Check if this is a shape mismatch when modifying a host variable in device loop
if (
isinstance(node.target, ast.Name)
and (existing_type := self.scope.maybe_get(node.target.id)) is not None
and existing_type.origin.is_host()
and self.device_loop_depth > 0
):
raise exc.CannotModifyHostVariableOnDevice(node.target.id) from e
raise
self._assign(node.target, type_info)
return NoType(origin=self.origin())

Expand Down
2 changes: 1 addition & 1 deletion helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class NotAllowedInHelperFunction(BaseError):


class CannotModifyHostVariableOnDevice(BaseError):
message = "Cannot modify host variable '{0}' inside `hl.tile` or `hl.grid` loop without subscript assignment. Use '{0}[tile] = ...' instead."
message = "Cannot modify host variable '{0}' inside `hl.tile` or `hl.grid` loop without subscript assignment. Use '{0}[tile] = ...' or '{0}[:] = ...' instead."


class AtomicOnDeviceTensor(BaseError):
Expand Down
19 changes: 19 additions & 0 deletions test/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,25 @@ def bad_fn(x: torch.Tensor) -> torch.Tensor:
with self.assertRaises(helion.exc.CannotReadDeviceVariableOnHost):
code_and_output(bad_fn, (torch.randn(8, device=DEVICE),))

def test_augmented_assign_without_subscript(self):
"""Test that augmented assignment to host variable in device loop raises proper error."""

@helion.kernel()
def bad_fn(grad_out: torch.Tensor) -> torch.Tensor:
m, n = grad_out.shape
n = hl.specialize(n)
grad_block = torch.zeros(n, dtype=torch.float32, device=grad_out.device)

for tile_m in hl.tile(m):
dy_m = grad_out[tile_m, :].to(torch.float32)
# Should use `grad_block[:] += ...` instead
grad_block += torch.sum(dy_m, dim=0)

return grad_block

with self.assertRaises(helion.exc.CannotModifyHostVariableOnDevice):
code_and_output(bad_fn, (torch.randn(4096, 5632, device=DEVICE),))

def test_device_tensor_subscript(self):
@helion.kernel()
def bad_fn(x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading