-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Labels
Description
Repro:
from __future__ import annotations
import torch
import helion
import helion.language as hl
@helion.kernel
def _sum_feature_mismatch(grad_out: torch.Tensor) -> torch.Tensor:
m, n = grad_out.shape # grad_out: [m, n]
n = hl.specialize(n) # n: int (feature size)
grad_block = torch.zeros(n, dtype=torch.float32, device=grad_out.device) # grad_block: [n]
for tile_m in hl.tile(m): # tile_m: tile descriptor over batch dim
dy_m = grad_out[tile_m, :].to(torch.float32) # dy_m: [tile_m, n]
grad_block += torch.sum(dy_m, dim=0) # torch.sum(..., dim=0): [n] (symbolic)
return grad_block # [n]
def main() -> None:
m, n = 4096, 5632
device = "cuda"
dtype = torch.float16
grad_out = torch.randn((m, n), device=device, dtype=dtype)
print(f"Running minimal mismatch repro with shape {(m, n)}…")
_sum_feature_mismatch(grad_out)
if __name__ == "__main__":
main()
Error:
Traceback (most recent call last):
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1911, in visit_BinOp
_eval_binary(node.op, left_example, right_example),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1493, in _eval_binary
return left + right # pyright: ignore[reportOperatorIssue]
~~~~~^~~~~~~
File "/home/willfeng/local/pytorch-nightly/torch/utils/_stats.py", line 28, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_tensor.py", line 1375, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_tensor.py", line 2102, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_tensor.py", line 1517, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_tensor.py", line 2625, in _dispatch_impl
return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_impls.py", line 1211, in fast_binary_impl
final_shape = infer_size(final_shape, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_impls.py", line 1171, in infer_size
torch._check(
File "/home/willfeng/local/pytorch-nightly/torch/__init__.py", line 1702, in _check
_check_with(RuntimeError, cond, message)
File "/home/willfeng/local/pytorch-nightly/torch/__init__.py", line 1684, in _check_with
raise error_type(message_evaluated)
RuntimeError: The size of tensor a (5632) must match the size of tensor b (u1) at non-singleton dimension 0)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/willfeng/local/helion2/layer_norm_bwd_symbolic_min_repro.py", line 38, in <module>
main()
File "/home/willfeng/local/helion2/layer_norm_bwd_symbolic_min_repro.py", line 34, in main
_sum_feature_mismatch(grad_out)
File "/home/willfeng/local/helion2/helion/runtime/kernel.py", line 285, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion2/helion/runtime/kernel.py", line 168, in bind
bound_kernel = BoundKernel(self, args)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion2/helion/runtime/kernel.py", line 351, in __init__
self.host_function: HostFunction = HostFunction(
^^^^^^^^^^^^^
File "/home/willfeng/local/helion2/helion/_compiler/host_function.py", line 113, in __init__
propagate_types(self)
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 2327, in propagate_types
prop.visit(stmt)
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1625, in visit
type_info = visitor(node)
^^^^^^^^^^^^^
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 2164, in visit_For
body = self._loop_body(node.body)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 2128, in _loop_body
self.visit(stmt)
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1625, in visit
type_info = visitor(node)
^^^^^^^^^^^^^
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 2066, in visit_AugAssign
type_info = self.visit(
^^^^^^^^^^^
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1625, in visit
type_info = visitor(node)
^^^^^^^^^^^^^
File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1917, in visit_BinOp
raise exc.TorchOpTracingError(e) from e
helion.exc.TorchOpTracingError: RuntimeError: The size of tensor a (5632) must match the size of tensor b (u1) at non-singleton dimension 0)
While processing:
File "/home/willfeng/local/helion2/layer_norm_bwd_symbolic_min_repro.py", line 21, in _sum_feature_mismatch
grad_block += torch.sum(dy_m, dim=0) # torch.sum(..., dim=0): [n] (symbolic)
The kernel accumulates a tiled gradient sum into a pre-sized buffer, but the symbolic result of torch.sum(..., dim=0)
keeps a symbolic length u1
, so adding it to the specialized buffer of size 5632 fails with a size-mismatch.