-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] added aten.exponential_ decomp #91673
[Inductor] added aten.exponential_ decomp #91673
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91673
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1ba5577: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_decomp/decompositions.py
Outdated
@@ -1963,6 +1963,11 @@ def uniform( | |||
) | |||
|
|||
|
|||
@register_decomposition(aten.exponential_) | |||
def exponential_(self, rate=1, generator=None): | |||
return self.copy_(-1/rate * torch.log(1 - torch.rand_like(self))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't you have to use generator
in the torch.rand_like
call if it is not None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @soumith, looks like prim does not support generator
yet, I see there is a TODO comment https://github.com/pytorch/pytorch/blob/master/torch/_prims/__init__.py#L2669-L2670.
For now, I think we can just add assert generator is None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense.
See comment #90869 (comment) , #91673 (comment). Pull Request resolved: #91833 Approved by: https://github.com/jansel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function also has an out-of place variant and an out=
variant. A better way to implement it would be to implement the out-of-place variant and then generate the in-place via register_inplace
and the out=
via the out wrapper.
torch/_decomp/decompositions.py
Outdated
@register_decomposition(aten.exponential_) | ||
def exponential_(self, rate=1, generator=None): | ||
assert generator is None | ||
return self.copy_(-1 / rate * torch.log1p(-torch.rand_like(self))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return self.copy_(-1 / rate * torch.log1p(-torch.rand_like(self))) | |
return self.copy_(-1 / rate * torch.log(torch.rand_like(self))) |
If x ~ U(0,1), 1-x ~ U(0,1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can go to refs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for the other distributions, yep
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lezcano this won't work with triton (and generally with any fast log approximation) if compute is properly done in fp32, this is explained in the comments in eager code. Exponential distribution should not generate 0's because pdf at 0 is 0. Yet with half dtype fast log approximation would be truncated to 0:
In [28]: max_rand = torch.rand(10000000000, device="cuda").amax()
In [29]: def fn(x):
...: return x.log().half()
...:
In [30]: opt_fn = torch.compile(fn)
/scratch/ngimel/work/pytorch/torch/_dynamo/eval_frame.py:372: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
warnings.warn(
In [31]: fn(max_rand)
Out[31]: tensor(-5.9605e-08, device='cuda:0', dtype=torch.float16) #fine, eager log doesn't use fast approximation)
In [33]: opt_fn(max_rand)
Out[33]: tensor(-0., device='cuda:0', dtype=torch.float16) #fast log approx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I guess we'd have similar issues even if we cast it to float?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or this is device specific definition of exponential ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current formula you have implemented should do. This comment #91673 (review) and #91673 (comment) are still relevant though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently cpu incorrectly implements exponential in eager, but cuda exponential_
indeed doesn't contain zero, at least until large lambda would cause underflow:
In [5]: torch.empty(100000000, device="cuda").exponential_().min()
Out[5]: tensor(5.9605e-08, device='cuda:0')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because CPU implementation uses MKL (
vdRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, stream, len, |
vRngExponential
excludes zero or not. If it doesn't, we should fix the MKL-based implementation from PyTorch side. cc @CaoE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The link above says f(x)
is defined for x >= a
, where a
is a displacement parameter.
pytorch/aten/src/ATen/native/cpu/DistributionKernels.cpp
Lines 149 to 150 in b8057aa
vdRngExponential(VSL_RNG_METHOD_EXPONENTIAL_ICDF, stream, len, | |
(double *)(sample_ptr + begin), 0, 1./lambda); |
Here a
is 0, so probably that's why cpu generated 0.
torch/_decomp/decompositions.py
Outdated
@@ -1963,6 +1963,12 @@ def uniform( | |||
) | |||
|
|||
|
|||
@register_decomposition(aten.exponential_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about casting halfs to higher precision for intermediate computations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decorating it with pw_cast_for_opmath
should be good enough right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh thanks for the catch, pw_cast_for_opmath
should do. I'm checking if other ELEMENTWISE_TYPE_PROMOTION_KIND
than the default would be applicable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Also, could you add the relevant OpInfo test? |
…113195) Range of sampled random variable needs to be clarified for `torch.tensor.exponential_` whose supported interval is (0, inf) is different from [0, inf] of exponential distribution. Background: #37984 (comment), #48841 (comment), #91673 (comment) Pull Request resolved: #113195 Approved by: https://github.com/albanD
Fixes #91276
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire