-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Reference implementations for softmax, log_softmax, logsumexp #79423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 8893024 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
What are the errors with aten executor? |
For |
torch/_refs/__init__.py
Outdated
) -> TensorLikeType: | ||
result_dtype = dtype or a.dtype | ||
computation_dtype = utils.get_computation_dtype(a.dtype) | ||
a = prims.convert_element_type(a, computation_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this conditional on the conversion being required -- we actually have a function for this if you'd prefer, but it's in an awkward place (it could be moved to utils if you'd rather use it vs. a custom conditional)
pytorch/torch/_prims/wrappers.py
Line 20 in 38e717d
def _maybe_convert_to_dtype( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! _maybe_convert_to_dtype
is already used in this file, so we can keep it defined in the awkward place for now 🙂
torch/_refs/__init__.py
Outdated
computation_dtype = utils.get_computation_dtype(a.dtype) | ||
a = prims.convert_element_type(a, computation_dtype) | ||
a_max = amax(a, dim, keepdim=True) | ||
shifted = a - a_max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: comment for why the shift occurs would be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, the shift is actually not required here because stabilized logsumexp is used. I will remove it.
torch/_refs/__init__.py
Outdated
a_max = amax(a, dim, keepdim=True) | ||
shifted = a - a_max | ||
shifted_logsumexp = logsumexp(shifted, dim, keepdim=True) | ||
return prims.convert_element_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also make this conditional on the conversion being required
torch/_refs/__init__.py
Outdated
) | ||
|
||
|
||
def _squeeze_multiple(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unify this with prims.squeeze -- see
pytorch/torch/_prims/__init__.py
Line 1636 in 38e717d
squeeze = _make_prim( |
I think we only need one? But maybe this implementation is better for the prim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't notice that prims.squeeze
works for multiple specified dimensions. Both implementations are fine, I will not touch prims.squeeze
at this time.
torch/_refs/__init__.py
Outdated
@out_wrapper | ||
def logsumexp( | ||
a: TensorLikeType, | ||
dims: DimsType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, it should be dim to match the torch namespace. I think I was confused because it also accepts and works for several dimensions
dim (int or tuple of python:ints)
import torch
a = torch.ones(3, 3)
torch.logsumexp(a, (0, 1))
# tensor(3.1972)
keepdim: bool = False, | ||
) -> TensorLikeType: | ||
dims = utils.canonicalize_dims(a.ndim, dims) | ||
# ATen specifies int[1] type dims which expands integers to tuples of length 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice comment
torch/_refs/__init__.py
Outdated
a_max_squeezed = _squeeze_multiple(a_max, dims) if not keepdim else a_max | ||
result = log(sum(exp(a - a_max), dims, keepdim=keepdim)) + a_max_squeezed | ||
else: | ||
result = log(sum(exp(a), dims, keepdim=keepdim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment for what this case covers (integer and boolean dtypes)
dims = (dims,) | ||
if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): | ||
a_max = amax(a, dims, keepdim=True) | ||
a_max = where(abs(a_max) == float("inf"), 0.0, a_max) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really elegant code here
torch/_refs/__init__.py
Outdated
) -> TensorLikeType: | ||
result_dtype = dtype or a.dtype | ||
computation_dtype = utils.get_computation_dtype(a.dtype) | ||
a = prims.convert_element_type(a, computation_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comments here as with log_softmax re: conditional conversions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome, @IvanYashchuk! I made some inline comments for your review, nothing major, and the lint job needs to be fixed, but approving this for velocity because I'm sure you'll sort out the review and the jobs.
|
Forward AD test fails with new sample input that reduces over multiple dims
@pytorchbot merge -g |
@pytorchbot successfully started a merge job. Check the current status here |
Hey @IvanYashchuk. |
#79423) Summary: This PR adds references for: - `torch.softmax` - `torch.log_softmax` - `torch.logsumexp` Unfortunately, none of them currently pass `test_python_ref_executor` even with `"aten"` executor. Pull Request resolved: #79423 Approved by: https://github.com/mruberry Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/4fc7832d72926d65d773ae4e4ae0ed7fc573f0c7 Reviewed By: malfet Differential Revision: D37156829 fbshipit-source-id: 88a1ed3d42fda30b880d8a2fe48f385ebdb98d22
This PR adds references for:
torch.softmax
torch.log_softmax
torch.logsumexp
Unfortunately, none of them currently pass
test_python_ref_executor
even with"aten"
executor.