Skip to content

Conversation

yf225
Copy link
Contributor

@yf225 yf225 commented Oct 3, 2025

We want to support usage pattern like:

@helion.kernel()
def reduce_kernel(x: torch.Tensor) -> torch.Tensor:
    m_block = hl.register_block_size(x.size(0))
    grad_weight = x.new_empty(
        [(x.size(0) + m_block - 1) // m_block, x.size(1)], dtype=torch.float32
    )
    weight_shape = hl.specialize(x.size(1))
    for mb_cta in hl.tile(x.size(0), block_size=m_block):
        grad_w_m = x.new_zeros(weight_shape, dtype=torch.float32)
        for mb in hl.tile(mb_cta.begin, mb_cta.end):
            grad_w_m += x[mb, :].to(torch.float32).sum(0)
        grad_weight[mb_cta.id, :] = grad_w_m
    return grad_weight.sum(0).to(x.dtype)

Particularly grad_w_m += x[mb, :].to(torch.float32).sum(0) is difficult to support, because the LHS grad_w_m is of shape x.size(1) which can be a non-power-of-2 value (e.g. 56), while the RHS x[mb, :].to(torch.float32).sum(0) is a the next power-of-2 value of x.size(1) (i.e. 64), resulting in a shape mismatch.

I explored many solutions, and the cleanest / simplest way is to bump n to next power-of-2 when it's used as device tensor shape value (e.g. when used in torch.zeros in device loop).

Alternative is we ask users to explicitly wrap weight_shape with helion.next_power_of_2(...), like grad_w_m = x.new_zeros(helion.next_power_of_2(weight_shape), dtype=torch.float32), but I feel that the UX friction is quite high if we do that.

Fixes #737.
Fixes #741.

cc. @v0i0 @mengluy0125

@yf225 yf225 requested review from jansel and oulgen October 3, 2025 21:43
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 3, 2025
@yf225 yf225 force-pushed the specialize_fix_v4 branch 2 times, most recently from b139ddb to 01fb3b0 Compare October 4, 2025 19:13
@yf225 yf225 changed the title Round to next power of 2 for hl.specialize'ed shape value used in device tensor creation Pad to next power of 2 for hl.specialize'ed shape value used in device tensor creation Oct 4, 2025
@yf225 yf225 marked this pull request as draft October 5, 2025 04:54
@yf225 yf225 force-pushed the specialize_fix_v4 branch 23 times, most recently from e214937 to 898e9ad Compare October 6, 2025 20:31
@yf225 yf225 force-pushed the specialize_fix_v4 branch 5 times, most recently from 1408a04 to 4bc5a08 Compare October 6, 2025 20:52
@yf225 yf225 marked this pull request as ready for review October 6, 2025 20:56
@yf225 yf225 force-pushed the specialize_fix_v4 branch from 4bc5a08 to 977ee7c Compare October 6, 2025 21:00
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to do this at the aten level rather than monkey-patching torch.

At the aten level the factory functions are normalized (so you need to handle fewer of them) and if you have multiple Python threads then monkey patching torch is not thread safe.

@yf225 yf225 force-pushed the specialize_fix_v4 branch 4 times, most recently from f726b74 to cda0e25 Compare October 7, 2025 20:57
from collections.abc import Generator


class _PadTensorFactoryMode(TorchDispatchMode):
Copy link
Contributor Author

@yf225 yf225 Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use TorchDispatchMode to intercept tensor factory ops at the aten level and avoid monkey-patching.

@yf225 yf225 requested a review from jansel October 7, 2025 21:01
@yf225 yf225 force-pushed the specialize_fix_v4 branch from cda0e25 to 24a2d6e Compare October 7, 2025 21:03
@yf225 yf225 force-pushed the specialize_fix_v4 branch from 24a2d6e to 1020dbc Compare October 8, 2025 04:51
@yf225 yf225 merged commit 55d6aa0 into main Oct 8, 2025
14 checks passed
@yf225 yf225 deleted the specialize_fix_v4 branch October 8, 2025 15:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

hl.specialize + torch.sum has size mismatch error Something funny is going on with non-pow2 reduction/accumulators in rms_norm-bwd

2 participants