Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Fix numerical errors in softmax when dim is not last dimension #37326

Closed

Conversation

@p12tic
Copy link
Contributor

@p12tic p12tic commented Apr 26, 2020

Fixes #34585.

This PR improves the workaround for the problem of different semantics between ONNX softmax and Pytorch softmax.

In Pytorch the dim parameter specifies over which dimension normalize the values. ONNX on the other hand always coerces the input into a 2D tensor and the axis parameter specifies which dimensions represent rows and columns of the resulting tensor. As a result, only when we are normalizing the last dimension (dim == ndim - 1) semantics are the same.

Previously this was handled by recognizing the dim == ndim - 1 case and using softmax for that. All other cases used a fallback path of explicit invocations of exp, reducesum and div operators to compute the result. Unfortunately, this results in numeric errors when input values are large: the result of exp will produce infinity on both numerator and denumerator and the division of that will result in NaN.

This can be improved by transposing the input tensor so that we can reuse ONNX softmax.

Similar approach has been applied to logsoftmax function in #30433.

@dr-ci
Copy link

@dr-ci dr-ci bot commented Apr 27, 2020

💊 Build failures summary and remediations

As of commit 2a9c810 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 5 times.

Loading

@mrshenli mrshenli requested a review from houseroad Apr 27, 2020
@p12tic p12tic force-pushed the onnx-softmax-fix-numerical-errors branch from 63c4108 to ea8f934 Apr 28, 2020
Copy link
Collaborator

@BowenBao BowenBao left a comment

Thanks @p12tic for the fix! Overall looks good, please see inline comments for some additional updates.

Loading

test/onnx/test_pytorch_onnx_onnxruntime.py Show resolved Hide resolved
Loading
torch/onnx/symbolic_opset9.py Show resolved Hide resolved
Loading
@BowenBao
Copy link
Collaborator

@BowenBao BowenBao commented May 1, 2020

@p12tic could you take a look at the comments and update your PR? Thanks!

Loading

@p12tic
Copy link
Contributor Author

@p12tic p12tic commented May 2, 2020

@BowenBao Yes, I will do so. This week my priorities were elsewhere, thus the delay.

Loading

@p12tic p12tic force-pushed the onnx-softmax-fix-numerical-errors branch from ea8f934 to 2a9c810 May 3, 2020
@p12tic
Copy link
Contributor Author

@p12tic p12tic commented May 3, 2020

@BowenBao I've applied your suggestions to the PR. Thanks for suggesting them to the PR instead of doing it yourself, I've got to learn what max normalization in softmax is :-)

Loading

Copy link
Collaborator

@BowenBao BowenBao left a comment

@p12tic Thanks for the fix! LGTM

Loading

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Loading

@facebook-github-bot
Copy link
Contributor

@facebook-github-bot facebook-github-bot commented May 5, 2020

@houseroad merged this pull request in d16c823.

Loading

ShawnZhong added a commit to ShawnZhong/pytorch that referenced this issue May 5, 2020
…pytorch#37326)

Summary:
Fixes pytorch#34585.

This PR improves the workaround for the problem of different semantics between ONNX softmax and Pytorch softmax.

In Pytorch the `dim` parameter specifies over which dimension normalize the values.  ONNX on the other hand always coerces the input into a 2D tensor and the `axis` parameter specifies which dimensions represent rows and columns of the resulting tensor. As a result, only when we are normalizing the last dimension (`dim == ndim - 1`) semantics are the same.

Previously this was handled by recognizing the `dim == ndim - 1` case and using `softmax` for that. All other cases used a fallback path of explicit invocations of exp, reducesum and div operators to compute the result. Unfortunately, this results in numeric errors when input values are large: the result of exp will produce infinity on both numerator and denumerator and the division of that will result in NaN.

This can be improved by transposing the input tensor so that we can reuse ONNX softmax.

Similar approach has been applied to `logsoftmax` function in pytorch#30433.
Pull Request resolved: pytorch#37326

Reviewed By: hl475

Differential Revision: D21389712

Pulled By: houseroad

fbshipit-source-id: 554fd1b98231a28984c30c7e7abd3c0643386ff7
Bharat123rox added a commit to Bharat123rox/pytorch that referenced this issue May 5, 2020
…pytorch#37326)

Summary:
Fixes pytorch#34585.

This PR improves the workaround for the problem of different semantics between ONNX softmax and Pytorch softmax.

In Pytorch the `dim` parameter specifies over which dimension normalize the values.  ONNX on the other hand always coerces the input into a 2D tensor and the `axis` parameter specifies which dimensions represent rows and columns of the resulting tensor. As a result, only when we are normalizing the last dimension (`dim == ndim - 1`) semantics are the same.

Previously this was handled by recognizing the `dim == ndim - 1` case and using `softmax` for that. All other cases used a fallback path of explicit invocations of exp, reducesum and div operators to compute the result. Unfortunately, this results in numeric errors when input values are large: the result of exp will produce infinity on both numerator and denumerator and the division of that will result in NaN.

This can be improved by transposing the input tensor so that we can reuse ONNX softmax.

Similar approach has been applied to `logsoftmax` function in pytorch#30433.
Pull Request resolved: pytorch#37326

Reviewed By: hl475

Differential Revision: D21389712

Pulled By: houseroad

fbshipit-source-id: 554fd1b98231a28984c30c7e7abd3c0643386ff7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

6 participants