-
Notifications
You must be signed in to change notification settings - Fork 36
Add hl.rand op with seed arg lowering to tl.rand #652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #652, branch: karthickai/stack/2
04d0e9b
to
300de6b
Compare
stack-info: PR: #652, branch: karthickai/stack/2
300de6b
to
d82cc97
Compare
stack-info: PR: #652, branch: karthickai/stack/2
d82cc97
to
0d507e3
Compare
test/test_rng.py
Outdated
def test_hl_rand_3d(self): | ||
import helion | ||
|
||
@helion.kernel(ref_mode=helion.RefMode.EAGER) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe remove the ref_mode setting? The code as-is will explicitly test ref eager mode instead of normal Helion compile mode, but here I believe the intent is to test compile mode. (We have other harness to run tests in ref eager mode automatically so usually we don't need to worry about it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the catch! I added that line for debugging to check the ref implementation and forgot to remove it. I’ve updated it now.
test/test_rng.py
Outdated
self.assertTrue(torch.all(output < 1.0), "All values should be < 1") | ||
|
||
def test_hl_rand_3d(self): | ||
import helion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can likely remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed that line.
stack-info: PR: #652, branch: karthickai/stack/2
0d507e3
to
c112edb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @karthickai !
|
||
numel = " * ".join(shape_str.strip("[]").split(",")) | ||
seed_ast = state.ast_arg(1) | ||
offs_expr = f"tl.arange(0, {numel}).reshape({shape_str})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect.
- Every tile will get the same RNG values.
- The RNG values will depend on the tile size due to the reshape
Stacked PRs:
Add hl.rand op with seed arg lowering to tl.rand