-
Notifications
You must be signed in to change notification settings - Fork 36
[Benchmark] Add low mem dropout example #641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #641, branch: karthickai/stack/1
89f5048
to
5992548
Compare
stack-info: PR: #641, branch: karthickai/stack/1
5992548
to
c8b0148
Compare
Speedup and Accuracy HELION_USE_DEFAULT_CONFIG=0 python benchmarks/run.py --kernel low_mem_dropout --metrics accuracy,speedup
x_val triton_dropout-speedup triton_dropout-accuracy torch_compile_dropout-speedup torch_compile_dropout-accuracy seeded_dropout-speedup seeded_dropout-accuracy helion_low_mem_dropout_tritonbench-speedup helion_low_mem_dropout_tritonbench-accuracy
------- ------------------------ ------------------------- ------------------------------- -------------------------------- ------------------------ ------------------------- -------------------------------------------- ---------------------------------------------
32 1.99543 1 1.99543 1 2.17413 0 2.13171 0
128 1.68981 1 1.89119 1 1.74641 0 1.88144 0
512 2.16744 1 2.18779 1 2.08036 0 2.51892 0
2048 2 1 2.31088 1 1.93913 0 2.31088 0
8192 2.09302 1 2.21675 1 1.98238 0 2.0362 0
32768 2.47264 1 2.4604 1 2.05372 0 2.47264 0
131072 2.18884 1 2.26667 1 1.96911 0 2.10744 0
524288 1.85586 1 1.91331 1 1.7913 0 2.20714 0
average 2.05788 1 2.1553 1 1.96707 0 2.2083 |
examples/low_mem_dropout.py
Outdated
Args: | ||
p (float): dropout probability | ||
x (torch.Tensor): input tensor | ||
x_keep (torch.Tensor): mask tensor indicating which elements to keep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if the tritonbench Triton kernel also takes this as input? Ideally we want to run x_keep = torch.rand_like(x) > p
within the Helion kernel's hl.tile
device loop; however if tritonbench Triton kernel is not doing that either, we can stay with the current design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Triton bench has two variants _triton_dropout and _seeded_triton_dropout, I referred _triton_dropout variant which take x_keep
as arg.
examples/low_mem_dropout.py
Outdated
|
||
for tidx in hl.tile(n): | ||
xi = x_flat[tidx].to(torch.float32) | ||
mi = m_flat[tidx].to(torch.float32) > 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is wrong:
- We should be generating a random number here not reading an input. What makes "low mem" dropout "low mem" is you don't take the probability as an input, you generate it from a seed. So you use O(1) reads rather than O(n).
- We may need to add some
hl.random
ops to make this possible - Low mem dropout is mainly interesting when you include the backwards, since you can use the same seed for forwards and backwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed - maybe we should model after _seeded_triton_dropout, and also we could try to use torch.rand_like
support (PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @jansel, I agree with your points. I've updated the kernel to generate randomness inside the kernel using torch.rand_like
per tile. Thanks @yf225 for your suggestion.
@helion.kernel()
def low_mem_dropout(p: float, x: torch.Tensor) -> torch.Tensor:
"""
Applies dropout on x using p
Args:
p (float): dropout probability
x (torch.Tensor): input tensor
Returns:
Output tensor
"""
scale = 1.0 / (1.0 - p)
# flatten to 1D so we can use tile
n = x.numel()
x_flat = x.view(-1)
out_flat = torch.empty_like(x_flat)
for tidx in hl.tile(n):
xi = x_flat[tidx].to(torch.float32)
r = torch.rand_like(xi, dtype=torch.float32)
keep = r > p
yscaled = xi * scale
zeros = xi - xi
yi = torch.where(keep, yscaled, zeros)
out_flat[tidx] = yi.to(x.dtype)
return out_flat.view_as(x)
After the change the kernel's output not matching with eager, even with manual_seed
I believe this is expected, could you kindly advice should I adjust the testing approach
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Triton/Helion random will not match eager mode, this is expected. What you do need to do is make sure the randomness matches between forwards and backwards.
I worry torch.rand_like
will make this hard since it doesn't accept a seed arg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @jansel, I've updated the test to verify that the dropout mask from fwd matches the mask in bwd. since torch.rand_like
doesn't take seed, I reseed with torch.manual_seed
before each kernel.
stack-info: PR: #641, branch: karthickai/stack/1
c8b0148
to
4ce0b23
Compare
stack-info: PR: #641, branch: karthickai/stack/1
4ce0b23
to
29bd548
Compare
stack-info: PR: #641, branch: karthickai/stack/1
29bd548
to
17ff8d7
Compare
stack-info: PR: #641, branch: karthickai/stack/1
17ff8d7
to
bcd53b3
Compare
stack-info: PR: #641, branch: karthickai/stack/1
bcd53b3
to
9164ec2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below
examples/low_mem_dropout.py
Outdated
torch.manual_seed(123) | ||
y, fwd_mask = low_mem_dropout(p, x) | ||
|
||
# need to set seed again else we can't reproduce | ||
torch.manual_seed(123) | ||
grad_y = torch.ones_like(x) | ||
grad_x, bwd_mask = low_mem_dropout_bwd(p, grad_y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still isn't right. The forward is returning a fwd_mask
with is the signature of non-low-mem dropout. The point of low-mem dropout is to not store the mask in memory. If you have the mask, then the backward is just fwd_mask*grad_y*scale
.
Also, needing to call torch.manual_seed(123)
is kind of clunky since it mutates global state and results in extra kernel launches.
I'd suggest either:
- Just implement regular (not low-mem) dropout and make the backward use the fwd_mask
- Add a seeded random op so we can do low mem dropout properly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again, I will try to implement the seeded random op, once that ready I'll update the low mem dropout example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the low_mem_dropout
with the newly created hl.rand
op, which takes a seed arg. Now it's not storing mask in a memory.
stack-info: PR: #641, branch: karthickai/stack/1
9164ec2
to
b7cbc36
Compare
stack-info: PR: #641, branch: karthickai/stack/1
b7cbc36
to
4458226
Compare
test/test_examples.expected
Outdated
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) | ||
mask_0 = indices_0 < n | ||
xi = tl.load(x_flat + indices_0 * x_flat_stride_0, mask_0, other=0) | ||
rand = tl.rand(seed, tl.arange(0, _BLOCK_SIZE_0).reshape([_BLOCK_SIZE_0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this will result in the same RNG used for every tile.
Tile 1 will be:
- seed, range(0, BLOCK_SIZE)
Tile 2 will be: - seed, range(0, BLOCK_SIZE)
So the same elements will get dropped each tile. I think this needs to be the index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the excellent catch! I'll update hl.rand
to use the index instead of tl.arange(0, BLOCK_SIZE)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated hl.rand
to generate unique rng for each tile (#685). I'll update this once it's merged.
stack-info: PR: #641, branch: karthickai/stack/1
4458226
to
41e4926
Compare
) | ||
) | ||
|
||
def test_low_mem_dropout(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test backwards (not just fwd) and assert that the same elements are dropped out in bwd as fwd (and different elements are dopped out if you change the seed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I've updated the test case with dropout mask checking.
stack-info: PR: #641, branch: karthickai/stack/1
41e4926
to
822afac
Compare
stack-info: PR: #641, branch: karthickai/stack/1
822afac
to
0f4844a
Compare
Stacked PRs:
[Benchmark] Add low mem dropout example