Skip to content

Conversation

@gmagogsfm
Copy link
Contributor

Fixes a bug I discovered in vLLM Helion kernel authoring

InductorLowering was incorrectly expanding scalars (0-D tensors) with [None, None] to match the max ndim of all inputs. This created broadcast shape mismatches in generated Triton code like scale_val[None, None] when multiplying a 2D tensor by a scalar.

Fix: special case for 0-D tensors and expect Triton to handle scalar broadcasting

Example:

@helion.kernel(
    config=helion.Config(
        block_sizes=[2, 64],
        flatten_loops=[True],
        indexing=["pointer", "pointer", "tensor_descriptor"],
    )
)
def scalar_multiply(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    m, n = x.shape
    out = torch.empty_like(x)
    for tile_idx in hl.tile(out.shape):
        scale_val = hl.load(scale, [0])  # Load scalar from 1-element tensor
        out[tile_idx] = x[tile_idx] * scale_val
    return out

result = scalar_multiply(torch.randn([4, 128]), torch.tensor([2.0]))

Before

  scale_val = tl.load(...)
  v_0 = scale_val[None, None]  
  result = x_tile * v_0 

  Error: Cannot broadcast, rank mismatch: ['4', '128'], ['1', '1']

After

  scale_val = tl.load(...)
  result = x_tile * scale_val  # ✅ Scalar broadcasts correctly to [4, 128]

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 21, 2025
@gmagogsfm
Copy link
Contributor Author

@yf225 I think I hit a bug in Helion's handling of scalar tensor and attempted a fix. Could you check to see if this makes sense?

@gmagogsfm gmagogsfm force-pushed the fix-scalar-broadcast-bug branch from 35cdbc9 to 4a276af Compare November 21, 2025 17:05
Bug: InductorLowering was incorrectly expanding scalars (0-D tensors) with
[None, None] to match the max ndim of all inputs. This created broadcast
shape mismatches in generated Triton code like `scale_val[None, None]`
when multiplying a 2D tensor by a scalar.

Fix: Skip dimension expansion for 0-D tensors (fake_val.ndim > 0 check).
Triton naturally handles scalar broadcasting without explicit expansion,
following standard NumPy broadcasting rules.

Added regression test test_scalar_broadcast_2d() with a config known to
trigger the bug (block_sizes=[2, 64], flatten_loops=[True]).
@gmagogsfm gmagogsfm force-pushed the fix-scalar-broadcast-bug branch from 4a276af to ff8d4e7 Compare November 22, 2025 00:06
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm thanks! might need to fix lint

update: lint job seems broken on trunk, will try to see if I can repro and fix tomorrow

@gmagogsfm gmagogsfm merged commit 7acbc82 into pytorch:main Nov 22, 2025
15 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants