Skip to content

Conversation

karthickai
Copy link
Contributor

@karthickai karthickai commented Sep 19, 2025

Stacked PRs:


[Benchmark] Add low mem dropout example

karthickai added a commit that referenced this pull request Sep 19, 2025
stack-info: PR: #641, branch: karthickai/stack/1
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 19, 2025
karthickai added a commit that referenced this pull request Sep 19, 2025
stack-info: PR: #641, branch: karthickai/stack/1
@karthickai karthickai requested a review from yf225 September 19, 2025 22:42
@karthickai
Copy link
Contributor Author

karthickai commented Sep 19, 2025

Speedup and Accuracy

HELION_USE_DEFAULT_CONFIG=0 python benchmarks/run.py --kernel low_mem_dropout --metrics accuracy,speedup

  x_val    triton_dropout-speedup    triton_dropout-accuracy    torch_compile_dropout-speedup    torch_compile_dropout-accuracy    seeded_dropout-speedup    seeded_dropout-accuracy    helion_low_mem_dropout_tritonbench-speedup    helion_low_mem_dropout_tritonbench-accuracy
-------  ------------------------  -------------------------  -------------------------------  --------------------------------  ------------------------  -------------------------  --------------------------------------------  ---------------------------------------------
     32                   1.99543                          1                          1.99543                                 1                   2.17413                          0                                       2.13171                                              0
    128                   1.68981                          1                          1.89119                                 1                   1.74641                          0                                       1.88144                                              0
    512                   2.16744                          1                          2.18779                                 1                   2.08036                          0                                       2.51892                                              0
   2048                   2                                1                          2.31088                                 1                   1.93913                          0                                       2.31088                                              0
   8192                   2.09302                          1                          2.21675                                 1                   1.98238                          0                                       2.0362                                               0
  32768                   2.47264                          1                          2.4604                                  1                   2.05372                          0                                       2.47264                                              0
 131072                   2.18884                          1                          2.26667                                 1                   1.96911                          0                                       2.10744                                              0
 524288                   1.85586                          1                          1.91331                                 1                   1.7913                           0                                       2.20714                                              0
average                   2.05788                          1                          2.1553                                  1                   1.96707                          0                                       2.2083 

Args:
p (float): dropout probability
x (torch.Tensor): input tensor
x_keep (torch.Tensor): mask tensor indicating which elements to keep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if the tritonbench Triton kernel also takes this as input? Ideally we want to run x_keep = torch.rand_like(x) > p within the Helion kernel's hl.tile device loop; however if tritonbench Triton kernel is not doing that either, we can stay with the current design.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triton bench has two variants _triton_dropout and _seeded_triton_dropout, I referred _triton_dropout variant which take x_keep as arg.


for tidx in hl.tile(n):
xi = x_flat[tidx].to(torch.float32)
mi = m_flat[tidx].to(torch.float32) > 0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is wrong:

  1. We should be generating a random number here not reading an input. What makes "low mem" dropout "low mem" is you don't take the probability as an input, you generate it from a seed. So you use O(1) reads rather than O(n).
  2. We may need to add some hl.random ops to make this possible
  3. Low mem dropout is mainly interesting when you include the backwards, since you can use the same seed for forwards and backwards.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - maybe we should model after _seeded_triton_dropout, and also we could try to use torch.rand_like support (PR).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @jansel, I agree with your points. I've updated the kernel to generate randomness inside the kernel using torch.rand_like per tile. Thanks @yf225 for your suggestion.

@helion.kernel()
def low_mem_dropout(p: float, x: torch.Tensor) -> torch.Tensor:
    """
    Applies dropout on x using p
    Args:
        p (float): dropout probability
        x (torch.Tensor): input tensor
    Returns:
        Output tensor
    """
    scale = 1.0 / (1.0 - p)
    # flatten to 1D so we can use tile
    n = x.numel()
    x_flat = x.view(-1)
    out_flat = torch.empty_like(x_flat)

    for tidx in hl.tile(n):
        xi = x_flat[tidx].to(torch.float32)
        r = torch.rand_like(xi, dtype=torch.float32)
        keep = r > p
        yscaled = xi * scale
        zeros = xi - xi
        yi = torch.where(keep, yscaled, zeros)
        out_flat[tidx] = yi.to(x.dtype)
    return out_flat.view_as(x)

After the change the kernel's output not matching with eager, even with manual_seed I believe this is expected, could you kindly advice should I adjust the testing approach

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triton/Helion random will not match eager mode, this is expected. What you do need to do is make sure the randomness matches between forwards and backwards.

I worry torch.rand_like will make this hard since it doesn't accept a seed arg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @jansel, I've updated the test to verify that the dropout mask from fwd matches the mask in bwd. since torch.rand_like doesn't take seed, I reseed with torch.manual_seed before each kernel.

karthickai added a commit that referenced this pull request Sep 21, 2025
stack-info: PR: #641, branch: karthickai/stack/1
karthickai added a commit that referenced this pull request Sep 21, 2025
stack-info: PR: #641, branch: karthickai/stack/1
karthickai added a commit that referenced this pull request Sep 21, 2025
stack-info: PR: #641, branch: karthickai/stack/1
karthickai added a commit that referenced this pull request Sep 21, 2025
stack-info: PR: #641, branch: karthickai/stack/1
karthickai added a commit that referenced this pull request Sep 21, 2025
stack-info: PR: #641, branch: karthickai/stack/1
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below

Comment on lines 121 to 127
torch.manual_seed(123)
y, fwd_mask = low_mem_dropout(p, x)

# need to set seed again else we can't reproduce
torch.manual_seed(123)
grad_y = torch.ones_like(x)
grad_x, bwd_mask = low_mem_dropout_bwd(p, grad_y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still isn't right. The forward is returning a fwd_mask with is the signature of non-low-mem dropout. The point of low-mem dropout is to not store the mask in memory. If you have the mask, then the backward is just fwd_mask*grad_y*scale.

Also, needing to call torch.manual_seed(123) is kind of clunky since it mutates global state and results in extra kernel launches.

I'd suggest either:

  1. Just implement regular (not low-mem) dropout and make the backward use the fwd_mask
  2. Add a seeded random op so we can do low mem dropout properly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again, I will try to implement the seeded random op, once that ready I'll update the low mem dropout example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the low_mem_dropout with the newly created hl.rand op, which takes a seed arg. Now it's not storing mask in a memory.

karthickai added a commit that referenced this pull request Sep 24, 2025
stack-info: PR: #641, branch: karthickai/stack/1
karthickai added a commit that referenced this pull request Sep 24, 2025
stack-info: PR: #641, branch: karthickai/stack/1
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < n
xi = tl.load(x_flat + indices_0 * x_flat_stride_0, mask_0, other=0)
rand = tl.rand(seed, tl.arange(0, _BLOCK_SIZE_0).reshape([_BLOCK_SIZE_0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this will result in the same RNG used for every tile.

Tile 1 will be:

  • seed, range(0, BLOCK_SIZE)
    Tile 2 will be:
  • seed, range(0, BLOCK_SIZE)

So the same elements will get dropped each tile. I think this needs to be the index.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the excellent catch! I'll update hl.rand to use the index instead of tl.arange(0, BLOCK_SIZE)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated hl.rand to generate unique rng for each tile (#685). I'll update this once it's merged.

karthickai added a commit that referenced this pull request Oct 7, 2025
stack-info: PR: #641, branch: karthickai/stack/1
@karthickai karthickai requested a review from jansel October 7, 2025 21:26
)
)

def test_low_mem_dropout(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test backwards (not just fwd) and assert that the same elements are dropped out in bwd as fwd (and different elements are dopped out if you change the seed).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I've updated the test case with dropout mask checking.

karthickai added a commit that referenced this pull request Oct 8, 2025
stack-info: PR: #641, branch: karthickai/stack/1
stack-info: PR: #641, branch: karthickai/stack/1
@karthickai karthickai merged commit abe8339 into main Oct 8, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants