-
Notifications
You must be signed in to change notification settings - Fork 24.6k
OpInfo for nn.functional.softmax #62077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OpInfo for nn.functional.softmax #62077
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit bd9a1d3 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
This looks pretty good @krshrimali; I have one suggestion (inline) for how to tweak testing the "dtype" kwarg. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few updates:
- This PR now uses sample inputs function of
log_softmax
forsoftmax
. - Added alias for
softmax
. - Code clean-up for some OpInfos, when params are passed with default values (which isn't needed).
- Skip removal for
test_jit_alias_remapping
test oflog_softmax
.
cc: @mruberry @zou3519 (sorry for the ping over the weekend, please review whenever you find time)
Update: I'm taking a look at the XLA error, if it's something non-trivial - I'll probably remove the scalar input and add an issue related to this error. Also, this is only reproducible on XLA (tested on Google Colab) with scalar tensors: # Works on CPU
>>> torch.log_softmax(torch.rand((), device='cpu'), dim=0)
tensor(0.)
# Fails on XLA
>>> torch.log_softmax(torch.rand((), device='xla'), dim=0)
ERROR (please see the error below since it's too verbose) Error: Error on XLA
RuntimeError Traceback (most recent call last)
<ipython-input-4-470607feefbc> in <module>()
----> 1 torch.log_softmax(torch.rand((), device='xla'), dim=0)
RuntimeError: torch_xla/csrc/helpers.cpp:97 : Check failed: min_shape_dim <= dim && dim <= max_shape_dim
*** Begin stack trace ***
tensorflow::CurrentStackTrace()
torch_xla::XlaHelpers::GetCanonicalDimensionIndex(long, long)
torch_xla::XLATensor::log_softmax(torch_xla::XLATensor const&, long, c10::optional<c10::ScalarType>)
torch_xla::AtenXlaType::_log_softmax(at::Tensor const&, long, bool)
c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, long, bool), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, long, bool> >, at::Tensor (at::Tensor const&, long, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long, bool)
at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, long, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, long, bool)> const&, c10::DispatchKeySet, at::Tensor const&, long, bool) const
at::redispatch::_log_softmax(c10::DispatchKeySet, at::Tensor const&, long, bool)
at::_log_softmax(at::Tensor const&, long, bool)
at::native::log_softmax(at::Tensor const&, long, c10::optional<c10::ScalarType>)
at::Tensor::log_softmax(long, c10::optional<c10::ScalarType>) const
_PyMethodDef_RawFastCallKeywords
_PyCFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
_PyMethodDef_RawFastCallKeywords
_PyCFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyObject_Call_Prepend
_PyObject_FastCallKeywords
_PyMethodDef_RawFastCallDict
PyCFunction_Call
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
_PyMethodDef_RawFastCallKeywords
_PyCFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_Py_UnixMain
__libc_start_main
_start
*** End stack trace ***
Value out of range (expected to be in range of [0, -1], but got 0) Also, the error doesn't look message correct: since Will update once I've finished building XLA locally to test the macros used. Thanks! |
Updates: PyTorch on 0d tensors with This PR should be ready for review, hopefully, tests should pass. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two nits inline, otherwise LGTM! Thanks @krshrimali.
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
def generator(): | ||
for shape, args, kwargs in cases: | ||
yield SampleInput(make_arg(shape), args=args, kwargs=kwargs) | ||
cases = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pathological case, but could we add ((), (0,))
to the list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment @zou3519! I had added this before but the test fails on the XLA device: #62077 (comment). PyTorch on 0d tensors with dim=0 doesn't throw an error but on XLA, it does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, sorry for not seeing that! The action items you proposed (file an issue, leave that case out of the OpInfo) sgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @zou3519 ! Post discussion with @mruberry, we thought that having a separate case when the device type isn't XLA so that we don't skip this input for CPU/CUDA devices. I've also filed an issue here: pytorch/xla#3061.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good. I added some comments about some more cases for completeness, after that we should be good to go
…hrimali/pytorch into opinfo/nn/functional/softmax
…ntries now" This reverts commit 8ac9c63.
Lint is failing: https://github.com/pytorch/pytorch/pull/62077/checks?check_run_id=3191463113
|
Thanks, @zou3519 for the pointer, I've fixed it now. Hopefully, the tests should pass :) |
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
for shape, args, kwargs in cases: | ||
yield SampleInput(make_arg(shape), args=args, kwargs=kwargs) | ||
# PyTorch on XLA throws an error when passed with dim argument for 0d tensor. | ||
# See https://github.com/pytorch/xla/issues/3061 for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice comment
return list(generator()) | ||
return [ | ||
SampleInput(make_arg(shape), args=dim, kwargs=dict(dtype=torch.float64) if with_dtype else None) | ||
for shape, dim in cases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style nit: put the for loop first for readability (doesn't have to be changed in this PR)
@@ -3669,19 +3669,27 @@ def sample_inputs_to_sparse(op_info, device, dtype, requires_grad, **kwargs): | |||
SampleInput(make_arg((S, S)), args=(1,), output_process_fn_grad=lambda x: x.to_dense()),) | |||
|
|||
|
|||
def sample_inputs_log_softmax(op_info, device, dtype, requires_grad, with_dtype=False, **kwargs): | |||
# Used for both log_softmax and softmax | |||
def sample_inputs_softmax_variant(op_info, device, dtype, requires_grad, with_dtype=False, **kwargs): | |||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"with_dtype" should be kwarg-only
@@ -6372,6 +6380,24 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): | |||
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), | |||
supports_forward_ad=True, | |||
sample_inputs_func=sample_inputs_max_min_binary,), | |||
# `softmax` supports different dtypes based on whether `dtype` argument, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice skip removal. Overall looks good. I made a few comments inline to consider in future PRs, no changes needed for this one
This PR:
softmax
andnn.functional.softmax
(alias).test_jit_alias_remapping
test oflog_softmax
.Please see pytorch/functorch#78 and #54261.
cc: @mruberry @zou3519 @pmeier