-
Notifications
You must be signed in to change notification settings - Fork 36
[Benchmark] Add low mem dropout example #641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
""" | ||
Low mem dropout Example | ||
================ | ||
|
||
This example demonstrates how to implement a Low mem dropout using Helion. | ||
""" | ||
|
||
# %% | ||
# Imports | ||
# ------- | ||
from __future__ import annotations | ||
|
||
from typing import Callable | ||
|
||
import torch | ||
|
||
import helion | ||
import helion.language as hl | ||
|
||
|
||
# %% | ||
# Low mem dropout forward implementations | ||
# ------------------- | ||
@helion.kernel() | ||
def low_mem_dropout(p: float, x: torch.Tensor, seed: int) -> torch.Tensor: | ||
""" | ||
Applies dropout on x using p | ||
Args: | ||
p (float): dropout probability | ||
x (torch.Tensor): input tensor | ||
Returns: | ||
Output tensor | ||
""" | ||
scale = 1.0 / (1.0 - p) | ||
# flatten to 1D so we can use tile | ||
n = x.numel() | ||
x_flat = x.view(-1) | ||
out_flat = torch.empty_like(x_flat) | ||
for tidx in hl.tile(n): | ||
xi = x_flat[tidx].to(torch.float32) | ||
r = hl.rand([tidx], seed=seed) | ||
keep = r > p | ||
yscaled = xi * scale | ||
yi = torch.where(keep, yscaled, 0.0) | ||
out_flat[tidx] = yi.to(x.dtype) | ||
return out_flat.view_as(x) | ||
|
||
|
||
# %% | ||
# Low mem dropout backward implementation | ||
# ------------------- | ||
@helion.kernel() | ||
def low_mem_dropout_bwd(p: float, grad_y: torch.Tensor, seed: int) -> torch.Tensor: | ||
""" | ||
For low mem dropout we are applying randomness inside both fwd and bwd | ||
technically dropout bwd is same as fwd | ||
Args: | ||
p (float): Dropout probability | ||
grad_y (torch.Tensor): Gradient tensor | ||
Returns: | ||
Output tensor | ||
""" | ||
scale = 1.0 / (1.0 - p) | ||
n = grad_y.numel() | ||
grad_y_flat = grad_y.view(-1) | ||
out_flat = torch.empty_like(grad_y_flat) | ||
for tidx in hl.tile(n): | ||
gi = grad_y_flat[tidx].to(torch.float32) | ||
r = hl.rand([tidx], seed=seed) | ||
keep = r > p | ||
g_scaled = gi * scale | ||
gxi = torch.where(keep, g_scaled, 0.0) | ||
out_flat[tidx] = gxi.to(grad_y.dtype) | ||
return out_flat.view_as(grad_y) | ||
|
||
|
||
# %% | ||
# TritonBench Wrapper | ||
# ------------------- | ||
def low_mem_dropout_tritonbench(tb_op: object, p: float, x: torch.Tensor) -> Callable: | ||
""" | ||
Wrapper for TritonBench compatibility. | ||
|
||
Args: | ||
tb_op: TritonBench operator instance | ||
p (float): dropout probability | ||
x (torch.Tensor): Input tensor | ||
|
||
Returns: | ||
Callable: A function that performs the low_mem_dropout. | ||
""" | ||
|
||
def _inner() -> torch.Tensor: | ||
return low_mem_dropout(p, x, seed=123) | ||
|
||
return _inner | ||
|
||
|
||
# %% | ||
# Verification Function | ||
# ------------------- | ||
def check(p: float, size: int) -> None: | ||
""" | ||
Verify the low mem dropout kernel implementation against PyTorch's native dropout implementation. | ||
|
||
Args: | ||
p (float): dropout probability | ||
size (int): input tensor size | ||
""" | ||
x = torch.randn(size=(size,)).cuda() | ||
seed = 123 | ||
|
||
out = low_mem_dropout(p, x, seed) | ||
grad_y = torch.ones_like(x) | ||
grad_x = low_mem_dropout_bwd(p, grad_y, seed) | ||
mask_fwd = out != 0 | ||
mask_bwd = grad_x != 0 | ||
assert torch.equal(mask_fwd, mask_bwd) | ||
|
||
|
||
# %% | ||
# Main Function | ||
# ----------- | ||
def main() -> None: | ||
""" | ||
Main entry point that runs the low mem dropout kernel verification with different tensor sizes. | ||
Tests with two configurations: | ||
- p=0.25, s=8192 | ||
- p=0.25, s=32768 | ||
""" | ||
check(0.25, 8192) | ||
check(0.25, 32768) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test backwards (not just fwd) and assert that the same elements are dropped out in bwd as fwd (and different elements are dopped out if you change the seed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I've updated the test case with dropout mask checking.