diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 6e0ad4dd4..edf2f4faf 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -2071,14 +2071,25 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> TypeInfo: def visit_AugAssign(self, node: ast.AugAssign) -> TypeInfo: assert isinstance(node.target, ExtendedAST) - type_info = self.visit( - create( - ast.BinOp, - left=node.target.copy(ctx=ast.Load()), - op=node.op, - right=node.value, + try: + type_info = self.visit( + create( + ast.BinOp, + left=node.target.copy(ctx=ast.Load()), + op=node.op, + right=node.value, + ) ) - ) + except exc.TorchOpTracingError as e: + # Check if this is a shape mismatch when modifying a host variable in device loop + if ( + isinstance(node.target, ast.Name) + and (existing_type := self.scope.maybe_get(node.target.id)) is not None + and existing_type.origin.is_host() + and self.device_loop_depth > 0 + ): + raise exc.CannotModifyHostVariableOnDevice(node.target.id) from e + raise self._assign(node.target, type_info) return NoType(origin=self.origin()) diff --git a/helion/exc.py b/helion/exc.py index a394054ae..45a3a95e6 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -367,7 +367,7 @@ class NotAllowedInHelperFunction(BaseError): class CannotModifyHostVariableOnDevice(BaseError): - message = "Cannot modify host variable '{0}' inside `hl.tile` or `hl.grid` loop without subscript assignment. Use '{0}[tile] = ...' instead." + message = "Cannot modify host variable '{0}' inside `hl.tile` or `hl.grid` loop without subscript assignment. Use '{0}[tile] = ...' or '{0}[:] = ...' instead." class AtomicOnDeviceTensor(BaseError): diff --git a/test/test_errors.py b/test/test_errors.py index 18dae0f19..bfaf807b2 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -228,6 +228,25 @@ def bad_fn(x: torch.Tensor) -> torch.Tensor: with self.assertRaises(helion.exc.CannotReadDeviceVariableOnHost): code_and_output(bad_fn, (torch.randn(8, device=DEVICE),)) + def test_augmented_assign_without_subscript(self): + """Test that augmented assignment to host variable in device loop raises proper error.""" + + @helion.kernel() + def bad_fn(grad_out: torch.Tensor) -> torch.Tensor: + m, n = grad_out.shape + n = hl.specialize(n) + grad_block = torch.zeros(n, dtype=torch.float32, device=grad_out.device) + + for tile_m in hl.tile(m): + dy_m = grad_out[tile_m, :].to(torch.float32) + # Should use `grad_block[:] += ...` instead + grad_block += torch.sum(dy_m, dim=0) + + return grad_block + + with self.assertRaises(helion.exc.CannotModifyHostVariableOnDevice): + code_and_output(bad_fn, (torch.randn(4096, 5632, device=DEVICE),)) + def test_device_tensor_subscript(self): @helion.kernel() def bad_fn(x: torch.Tensor) -> torch.Tensor: