Skip to content

Conversation

IvanYashchuk
Copy link
Collaborator

This PR adds references for:

  • torch.softmax
  • torch.log_softmax
  • torch.logsumexp

Unfortunately, none of them currently pass test_python_ref_executor even with "aten" executor.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 13, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 8893024 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@ngimel
Copy link
Collaborator

ngimel commented Jun 13, 2022

What are the errors with aten executor?

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 13, 2022
@IvanYashchuk
Copy link
Collaborator Author

What are the errors with aten executor?

For logsumexp the generated Python code from the FX graph is invalid leading to the SyntaxError.
For log_softmax and softmax the error is due to #78923.

) -> TensorLikeType:
result_dtype = dtype or a.dtype
computation_dtype = utils.get_computation_dtype(a.dtype)
a = prims.convert_element_type(a, computation_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this conditional on the conversion being required -- we actually have a function for this if you'd prefer, but it's in an awkward place (it could be moved to utils if you'd rather use it vs. a custom conditional)

def _maybe_convert_to_dtype(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! _maybe_convert_to_dtype is already used in this file, so we can keep it defined in the awkward place for now 🙂

computation_dtype = utils.get_computation_dtype(a.dtype)
a = prims.convert_element_type(a, computation_dtype)
a_max = amax(a, dim, keepdim=True)
shifted = a - a_max
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: comment for why the shift occurs would be nice

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, the shift is actually not required here because stabilized logsumexp is used. I will remove it.

a_max = amax(a, dim, keepdim=True)
shifted = a - a_max
shifted_logsumexp = logsumexp(shifted, dim, keepdim=True)
return prims.convert_element_type(
Copy link
Collaborator

@mruberry mruberry Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also make this conditional on the conversion being required

)


def _squeeze_multiple(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unify this with prims.squeeze -- see

squeeze = _make_prim(

I think we only need one? But maybe this implementation is better for the prim?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice that prims.squeeze works for multiple specified dimensions. Both implementations are fine, I will not touch prims.squeeze at this time.

@out_wrapper
def logsumexp(
a: TensorLikeType,
dims: DimsType,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it should be dim to match the torch namespace. I think I was confused because it also accepts and works for several dimensions

dim (int or tuple of python:ints)

import torch
a = torch.ones(3, 3)
torch.logsumexp(a, (0, 1))
# tensor(3.1972)

keepdim: bool = False,
) -> TensorLikeType:
dims = utils.canonicalize_dims(a.ndim, dims)
# ATen specifies int[1] type dims which expands integers to tuples of length 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice comment

a_max_squeezed = _squeeze_multiple(a_max, dims) if not keepdim else a_max
result = log(sum(exp(a - a_max), dims, keepdim=keepdim)) + a_max_squeezed
else:
result = log(sum(exp(a), dims, keepdim=keepdim))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for what this case covers (integer and boolean dtypes)

dims = (dims,)
if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
a_max = amax(a, dims, keepdim=True)
a_max = where(abs(a_max) == float("inf"), 0.0, a_max)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really elegant code here

) -> TensorLikeType:
result_dtype = dtype or a.dtype
computation_dtype = utils.get_computation_dtype(a.dtype)
a = prims.convert_element_type(a, computation_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comments here as with log_softmax re: conditional conversions

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, @IvanYashchuk! I made some inline comments for your review, nothing major, and the lint job needs to be fixed, but approving this for velocity because I'm sure you'll sort out the review and the jobs.

@IvanYashchuk
Copy link
Collaborator Author

_maybe_convert_to_dtype is now used for optional conversions. I removed unnecessary shift from log_softmax and fixed mypy errors. In case when _maybe_convert_to_dtype is the returned variable mypy errors are silenced.

@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

Hey @IvanYashchuk.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jun 16, 2022
#79423)

Summary:
This PR adds references for:

- `torch.softmax`
- `torch.log_softmax`
- `torch.logsumexp`

Unfortunately, none of them currently pass `test_python_ref_executor` even with `"aten"` executor.

Pull Request resolved: #79423
Approved by: https://github.com/mruberry

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/4fc7832d72926d65d773ae4e4ae0ed7fc573f0c7

Reviewed By: malfet

Differential Revision: D37156829

fbshipit-source-id: 88a1ed3d42fda30b880d8a2fe48f385ebdb98d22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: primTorch open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants