Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unstable softmax trace when exporting to onnx #34585

Closed
Enolerobotti opened this issue Mar 11, 2020 · 3 comments
Closed

Unstable softmax trace when exporting to onnx #34585

Enolerobotti opened this issue Mar 11, 2020 · 3 comments
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Enolerobotti
Copy link

Enolerobotti commented Mar 11, 2020

🐛 Bug

The ONNX model with softmax which is exported with torch.onnx module predicts NaNs instead of expected numerical values. The actual output does not depend on mantissa. The actual result is unstable. Meaning that expected value 5.0390004e-32 actually may be converted either to NaN or to 5.0390004e-32. If for example 5.0390004e-32 was correctly converted to 5.0390004e-32 then it may happen that other element of the tensor, e.g. 1.0000000e+00, might be converted to NaN.

To Reproduce

Please consider the following code to reproduse

import torch.onnx
import torch.nn as nn
import onnxruntime as rt
import torch.nn.functional as F


class TraceCheck(nn.Module):
    def __init__(self):
        super(TraceCheck, self).__init__()

    def forward(self, x):
        return F.softmax(-100*x, -2)


if __name__ == '__main__':
    net = TraceCheck()
    x = torch.Tensor(torch.randn(5,5))
    pred = net.forward(x).detach().cpu().numpy()
    tmp_filename = 'test.onnx'
    torch.onnx.export(net, x, tmp_filename)
    sess = rt.InferenceSession(tmp_filename)
    pred_onnx = sess.run([sess.get_outputs()[-1].name], {sess.get_inputs()[-1].name: x.detach().cpu().numpy()})[0]
    print(pred)
    print(pred_onnx)

Example outputs are :
pred:

[[1.0000000e+00 0.0000000e+00 0.0000000e+00 1.0000000e+00 0.0000000e+00]
[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[0.0000000e+00 6.1350977e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[0.0000000e+00 3.8649023e-01 8.0113404e-19 0.0000000e+00 1.0000000e+00]
[1.1149160e-09 5.0298510e-32 1.0000000e+00 0.0000000e+00 0.0000000e+00]]

pred_onnx:

[[ nan 0.0000000e+00 0.0000000e+00 nan 0.0000000e+00]
[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[0.0000000e+00 6.1350977e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[0.0000000e+00 3.8649023e-01 nan 0.0000000e+00 nan]
[ nan 5.0298498e-32 nan 0.0000000e+00 0.0000000e+00]]

We see that pred != pred_onnx

Expected behavior

pred == pred_onnx

Environment

  • PyTorch Version : 1.4.0
  • OS: Windows 10 Pro 64 bit
  • How you installed PyTorch (conda, pip, source): current repo (git clone ...)
  • Build command you used (if compiling from source): setup.py build/install
  • Python version: 3.6
  • CUDA/cuDNN version: None
  • GPU models and configuration: None
  • Any other relevant information: opset 11

Additional context

cc @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof

@ngimel ngimel added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 11, 2020
@spandantiwari
Copy link

@Enolerobotti - Do you see anything wrong with the exported ONNX model? Let us know if you do. If not, then most likely this could be an implementation issue in ORT. Could you please report this in onnxruntime repo?

@Enolerobotti
Copy link
Author

@spandantiwari yes, I see something wrong with the exported graph.

graph(%0 : Float(5, 5)):
  %1 : Float() = onnx::Constant[value={-100}]()
  %2 : Float(5, 5) = onnx::Mul(%0, %1)
  %3 : Tensor = onnx::Exp(%2)
  %4 : Tensor = onnx::ReduceSum[axes=[0]](%3)
  %5 : Float(5, 5) = onnx::Div(%3, %4) # c:\anaconda3\envs\python36b\lib\site-packages\torch\nn\functional.py:1231:0
  return (%5)

I guess I get NaN during the division Div(%3, %4). So, this bug is not come from the onnxruntime.

I've found a workaround. The workaround is replace negative axis -2 in the softmax with x.dim()-1. After the replacement I get correct results. Now the graph is

graph(%0 : Float(5, 5)):
  %1 : Float() = onnx::Constant[value={-100}]()
  %2 : Float(5, 5) = onnx::Mul(%0, %1)
  %3 : Float(5, 5) = onnx::Softmax[axis=1](%2) # c:\anaconda3\envs\python36b\lib\site-packages\torch\nn\functional.py:1231:0
  return (%3)

Unfortunately, this workaround is not useful for me because actually I try to export a third-party net

@p12tic
Copy link
Contributor

p12tic commented Apr 27, 2020

The problem here is that the semantics of softmax are different in Pytorch and ONNX and Pytorch can use ONNX softmax only in certain situations. #37326 improves the situation.

cc @spandantiwari

ShawnZhong pushed a commit to ShawnZhong/pytorch that referenced this issue May 5, 2020
…pytorch#37326)

Summary:
Fixes pytorch#34585.

This PR improves the workaround for the problem of different semantics between ONNX softmax and Pytorch softmax.

In Pytorch the `dim` parameter specifies over which dimension normalize the values.  ONNX on the other hand always coerces the input into a 2D tensor and the `axis` parameter specifies which dimensions represent rows and columns of the resulting tensor. As a result, only when we are normalizing the last dimension (`dim == ndim - 1`) semantics are the same.

Previously this was handled by recognizing the `dim == ndim - 1` case and using `softmax` for that. All other cases used a fallback path of explicit invocations of exp, reducesum and div operators to compute the result. Unfortunately, this results in numeric errors when input values are large: the result of exp will produce infinity on both numerator and denumerator and the division of that will result in NaN.

This can be improved by transposing the input tensor so that we can reuse ONNX softmax.

Similar approach has been applied to `logsoftmax` function in pytorch#30433.
Pull Request resolved: pytorch#37326

Reviewed By: hl475

Differential Revision: D21389712

Pulled By: houseroad

fbshipit-source-id: 554fd1b98231a28984c30c7e7abd3c0643386ff7
bharatr21 pushed a commit to bharatr21/pytorch that referenced this issue May 5, 2020
…pytorch#37326)

Summary:
Fixes pytorch#34585.

This PR improves the workaround for the problem of different semantics between ONNX softmax and Pytorch softmax.

In Pytorch the `dim` parameter specifies over which dimension normalize the values.  ONNX on the other hand always coerces the input into a 2D tensor and the `axis` parameter specifies which dimensions represent rows and columns of the resulting tensor. As a result, only when we are normalizing the last dimension (`dim == ndim - 1`) semantics are the same.

Previously this was handled by recognizing the `dim == ndim - 1` case and using `softmax` for that. All other cases used a fallback path of explicit invocations of exp, reducesum and div operators to compute the result. Unfortunately, this results in numeric errors when input values are large: the result of exp will produce infinity on both numerator and denumerator and the division of that will result in NaN.

This can be improved by transposing the input tensor so that we can reuse ONNX softmax.

Similar approach has been applied to `logsoftmax` function in pytorch#30433.
Pull Request resolved: pytorch#37326

Reviewed By: hl475

Differential Revision: D21389712

Pulled By: houseroad

fbshipit-source-id: 554fd1b98231a28984c30c7e7abd3c0643386ff7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants