-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unstable softmax trace when exporting to onnx #34585
Comments
@Enolerobotti - Do you see anything wrong with the exported ONNX model? Let us know if you do. If not, then most likely this could be an implementation issue in ORT. Could you please report this in onnxruntime repo? |
@spandantiwari yes, I see something wrong with the exported graph.
I guess I get NaN during the division Div(%3, %4). So, this bug is not come from the onnxruntime. I've found a workaround. The workaround is replace negative axis -2 in the softmax with x.dim()-1. After the replacement I get correct results. Now the graph is
Unfortunately, this workaround is not useful for me because actually I try to export a third-party net |
The problem here is that the semantics of softmax are different in Pytorch and ONNX and Pytorch can use ONNX softmax only in certain situations. #37326 improves the situation. |
…pytorch#37326) Summary: Fixes pytorch#34585. This PR improves the workaround for the problem of different semantics between ONNX softmax and Pytorch softmax. In Pytorch the `dim` parameter specifies over which dimension normalize the values. ONNX on the other hand always coerces the input into a 2D tensor and the `axis` parameter specifies which dimensions represent rows and columns of the resulting tensor. As a result, only when we are normalizing the last dimension (`dim == ndim - 1`) semantics are the same. Previously this was handled by recognizing the `dim == ndim - 1` case and using `softmax` for that. All other cases used a fallback path of explicit invocations of exp, reducesum and div operators to compute the result. Unfortunately, this results in numeric errors when input values are large: the result of exp will produce infinity on both numerator and denumerator and the division of that will result in NaN. This can be improved by transposing the input tensor so that we can reuse ONNX softmax. Similar approach has been applied to `logsoftmax` function in pytorch#30433. Pull Request resolved: pytorch#37326 Reviewed By: hl475 Differential Revision: D21389712 Pulled By: houseroad fbshipit-source-id: 554fd1b98231a28984c30c7e7abd3c0643386ff7
…pytorch#37326) Summary: Fixes pytorch#34585. This PR improves the workaround for the problem of different semantics between ONNX softmax and Pytorch softmax. In Pytorch the `dim` parameter specifies over which dimension normalize the values. ONNX on the other hand always coerces the input into a 2D tensor and the `axis` parameter specifies which dimensions represent rows and columns of the resulting tensor. As a result, only when we are normalizing the last dimension (`dim == ndim - 1`) semantics are the same. Previously this was handled by recognizing the `dim == ndim - 1` case and using `softmax` for that. All other cases used a fallback path of explicit invocations of exp, reducesum and div operators to compute the result. Unfortunately, this results in numeric errors when input values are large: the result of exp will produce infinity on both numerator and denumerator and the division of that will result in NaN. This can be improved by transposing the input tensor so that we can reuse ONNX softmax. Similar approach has been applied to `logsoftmax` function in pytorch#30433. Pull Request resolved: pytorch#37326 Reviewed By: hl475 Differential Revision: D21389712 Pulled By: houseroad fbshipit-source-id: 554fd1b98231a28984c30c7e7abd3c0643386ff7
🐛 Bug
The ONNX model with softmax which is exported with torch.onnx module predicts NaNs instead of expected numerical values. The actual output does not depend on mantissa. The actual result is unstable. Meaning that expected value 5.0390004e-32 actually may be converted either to NaN or to 5.0390004e-32. If for example 5.0390004e-32 was correctly converted to 5.0390004e-32 then it may happen that other element of the tensor, e.g. 1.0000000e+00, might be converted to NaN.
To Reproduce
Please consider the following code to reproduse
Example outputs are :
pred:
pred_onnx:
We see that pred != pred_onnx
Expected behavior
pred == pred_onnx
Environment
Additional context
cc @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof
The text was updated successfully, but these errors were encountered: