Skip to content

Type propagation error when summing tiled gradients (5632 vs u1 size mismatch) #726

@yf225

Description

@yf225

Repro:

from __future__ import annotations

import torch

import helion
import helion.language as hl


@helion.kernel
def _sum_feature_mismatch(grad_out: torch.Tensor) -> torch.Tensor:
    m, n = grad_out.shape  # grad_out: [m, n]
    n = hl.specialize(n)  # n: int (feature size)

    grad_block = torch.zeros(n, dtype=torch.float32, device=grad_out.device)  # grad_block: [n]

    for tile_m in hl.tile(m):  # tile_m: tile descriptor over batch dim
        dy_m = grad_out[tile_m, :].to(torch.float32)  # dy_m: [tile_m, n]
        grad_block += torch.sum(dy_m, dim=0)  # torch.sum(..., dim=0): [n] (symbolic)

    return grad_block  # [n]


def main() -> None:
    m, n = 4096, 5632
    device = "cuda"
    dtype = torch.float16

    grad_out = torch.randn((m, n), device=device, dtype=dtype)

    print(f"Running minimal mismatch repro with shape {(m, n)}…")
    _sum_feature_mismatch(grad_out)


if __name__ == "__main__":
    main()

Error:

Traceback (most recent call last):
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1911, in visit_BinOp
    _eval_binary(node.op, left_example, right_example),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1493, in _eval_binary
    return left + right  # pyright: ignore[reportOperatorIssue]
           ~~~~~^~~~~~~
  File "/home/willfeng/local/pytorch-nightly/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_tensor.py", line 1375, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_tensor.py", line 2102, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_tensor.py", line 1517, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_tensor.py", line 2625, in _dispatch_impl
    return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_impls.py", line 1211, in fast_binary_impl
    final_shape = infer_size(final_shape, shape)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_subclasses/fake_impls.py", line 1171, in infer_size
    torch._check(
  File "/home/willfeng/local/pytorch-nightly/torch/__init__.py", line 1702, in _check
    _check_with(RuntimeError, cond, message)
  File "/home/willfeng/local/pytorch-nightly/torch/__init__.py", line 1684, in _check_with
    raise error_type(message_evaluated)
RuntimeError: The size of tensor a (5632) must match the size of tensor b (u1) at non-singleton dimension 0)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/willfeng/local/helion2/layer_norm_bwd_symbolic_min_repro.py", line 38, in <module>
    main()
  File "/home/willfeng/local/helion2/layer_norm_bwd_symbolic_min_repro.py", line 34, in main
    _sum_feature_mismatch(grad_out)
  File "/home/willfeng/local/helion2/helion/runtime/kernel.py", line 285, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion2/helion/runtime/kernel.py", line 168, in bind
    bound_kernel = BoundKernel(self, args)
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion2/helion/runtime/kernel.py", line 351, in __init__
    self.host_function: HostFunction = HostFunction(
                                       ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion2/helion/_compiler/host_function.py", line 113, in __init__
    propagate_types(self)
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 2327, in propagate_types
    prop.visit(stmt)
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1625, in visit
    type_info = visitor(node)
                ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 2164, in visit_For
    body = self._loop_body(node.body)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 2128, in _loop_body
    self.visit(stmt)
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1625, in visit
    type_info = visitor(node)
                ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 2066, in visit_AugAssign
    type_info = self.visit(
                ^^^^^^^^^^^
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1625, in visit
    type_info = visitor(node)
                ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion2/helion/_compiler/type_propagation.py", line 1917, in visit_BinOp
    raise exc.TorchOpTracingError(e) from e
helion.exc.TorchOpTracingError: RuntimeError: The size of tensor a (5632) must match the size of tensor b (u1) at non-singleton dimension 0)
While processing:
  File "/home/willfeng/local/helion2/layer_norm_bwd_symbolic_min_repro.py", line 21, in _sum_feature_mismatch
    grad_block += torch.sum(dy_m, dim=0)  # torch.sum(..., dim=0): [n] (symbolic)

The kernel accumulates a tiled gradient sum into a pre-sized buffer, but the symbolic result of torch.sum(..., dim=0) keeps a symbolic length u1, so adding it to the specialized buffer of size 5632 fails with a size-mismatch.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions