Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export of CumSum produces different data type #77842

Open
mergian opened this issue May 19, 2022 · 1 comment
Open

ONNX export of CumSum produces different data type #77842

mergian opened this issue May 19, 2022 · 1 comment
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mergian
Copy link
Contributor

mergian commented May 19, 2022

馃悰 Describe the bug

When running a model with torch.cumsum(..., dtype=None) where dtype is set to None, an integer data type gets promoted to torch.int64. However, if the same model gets exported to ONNX, the dtype does not get promoted. See following example:

import torch
import onnx

class Model(torch.nn.Module):
    def forward(self, x):
        return x.cumsum(0, dtype=None), x.cumsum(0, dtype=torch.int32), x.cumsum(0, dtype=torch.float)

input = torch.rand(1, 2, 3).to(torch.int32)
model = Model()
output = model(input)

torch.onnx.export(model, input, 'error.onnx', opset_version=11)

omodel = onnx.load('error.onnx')
lookup = {
    onnx.TensorProto.BOOL:      'onnx.TensorProto.BOOL',
    onnx.TensorProto.DOUBLE:    'onnx.TensorProto.DOUBLE',
    onnx.TensorProto.FLOAT16:   'onnx.TensorProto.FLOAT16',
    onnx.TensorProto.FLOAT:     'onnx.TensorProto.FLOAT',
    onnx.TensorProto.INT8:      'onnx.TensorProto.INT8',
    onnx.TensorProto.INT16:     'onnx.TensorProto.INT16',
    onnx.TensorProto.INT32:     'onnx.TensorProto.INT32',
    onnx.TensorProto.INT64:     'onnx.TensorProto.INT64',
    onnx.TensorProto.UINT8:     'onnx.TensorProto.UINT8',
    onnx.TensorProto.UINT16:    'onnx.TensorProto.UINT16',
    onnx.TensorProto.UINT32:    'onnx.TensorProto.UINT32',
    onnx.TensorProto.UINT64:    'onnx.TensorProto.UINT64'
}

print('PyTorch Output DTypes: {}'.format(tuple(o.dtype for o in output)))
print('ONNX Output DTypes: {}'.format(
    tuple(lookup.get(o.type.tensor_type.elem_type) for o in omodel.graph.output))
)

Output is:

PyTorch Output DTypes: (torch.int64, torch.int32, torch.float32)
ONNX Output DTypes: ('onnx.TensorProto.INT32', 'onnx.TensorProto.INT32', 'onnx.TensorProto.FLOAT')

As you can see, in the dtype=None case, PyTorch uses INT64 while ONNX uses INT32.

Versions

Collecting environment information...
PyTorch version: 1.11.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 10.3.0
Clang version: Could not collect
CMake version: version 3.23.1
Libc version: glibc-2.17

Python version: 3.7.12 (default, Feb 6 2022, 20:29:18) [GCC 10.2.1 20210130 (Red Hat 10.2.1-11)] (64-bit runtime)
Python platform: Linux-3.10.0-1160.59.1.el7.x86_64-x86_64-with-centos-7.9.2009-Core
Is CUDA available: False
CUDA runtime version: 11.3.109
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.11.0
[pip3] torchvision==0.12.0
[conda] Could not collect

@justinchuby justinchuby added the module: onnx Related to torch.onnx label May 20, 2022
@justinchuby justinchuby self-assigned this May 20, 2022
@anjali411 anjali411 added onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed onnx-triaged triaged by ONNX team labels May 23, 2022
@justinchuby justinchuby removed their assignment Jun 1, 2022
@justinchuby
Copy link
Collaborator

Thanks for reporting this issue! We have added it to our list for investigation.

@BowenBao BowenBao added the onnx-triaged triaged by ONNX team label Jun 7, 2022
pytorchmergebot pushed a commit that referenced this issue Aug 10, 2022
Currently we don't have a dtype check in verifying the consistency between PyTorch and ONNX outputs. As a result, some of dtype inconsistencies were found and reported: #77842 #77845

This is a POC.

Failed workflows:
- [linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)]
  - inconsistent shape
    - TestONNXRuntime_opset10.test_all (#79371)
    - TestONNXRuntime_opset10.test_any (#79371)
    - TestONNXRuntime_opset10.test_argmin_argmax (#79503)
    - TestONNXRuntime_opset10.test_hardshrink (#79695)
    - TestONNXRuntime_opset10.test_linalg_norm (#79506)
    - TestONNXRuntime_opset10.test_linalg_vector_norm (#79506)
    - TestONNXRuntime_opset10.test_prelu_scalar (#79846)
    - TestONNXRuntime_opset10.test_softshrink (#79695)
    - TestONNXRuntime_opset10.test_sum_empty_tensor (skipped)
    - TestONNXRuntime_opset10.test_tolist (skipped)
  - inconsistent dtype
    - test_arithmetic_prim_bool (skipped)
    - test_arithmeticOps_with_low_precision (skipped)
    - test_arithmetic_prim_float (skipped)
    - test_logical_and (#79339)
    - test_logical_or (#79339)
    - test_logical_xor (#79339)
    - test_pow (skipped)
    - test_primitive_input_floating (skipped)
    - test_quantize_per_tensor (#79690)
    - test_quantized_adaptive_avg_pool2d (#79690)
    - test_quantized_arithmetic (#79690)
    - test_quantized_arithmetic_qfunctional (#79690)
    - test_quantized_conv2d (#79690)
    - test_quantized_conv2d_relu (#79690)
    - test_quantized_flatten (#79690)
    - test_quantized_hardsigmoid (#79690)
    - test_quantized_hardswish (#79690)
    - test_quantized_linear (#79690)
    - test_quantized_sigmoid (#79690)
    - test_item (skipped)
    - test_full_like_value (skipped)
    - TestONNXRuntime_opset7.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset8.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset9.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset9_IRv4.test_div_rounding_mode (skipped)
    - test_outer (skipped)
    - test_symbolic_shape_inference_arange_2 (skipped)
Pull Request resolved: #79263
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
facebook-github-bot pushed a commit that referenced this issue Aug 10, 2022
Summary:
Currently we don't have a dtype check in verifying the consistency between PyTorch and ONNX outputs. As a result, some of dtype inconsistencies were found and reported: #77842 #77845

This is a POC.

Failed workflows:
- [linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)]
  - inconsistent shape
    - TestONNXRuntime_opset10.test_all (#79371)
    - TestONNXRuntime_opset10.test_any (#79371)
    - TestONNXRuntime_opset10.test_argmin_argmax (#79503)
    - TestONNXRuntime_opset10.test_hardshrink (#79695)
    - TestONNXRuntime_opset10.test_linalg_norm (#79506)
    - TestONNXRuntime_opset10.test_linalg_vector_norm (#79506)
    - TestONNXRuntime_opset10.test_prelu_scalar (#79846)
    - TestONNXRuntime_opset10.test_softshrink (#79695)
    - TestONNXRuntime_opset10.test_sum_empty_tensor (skipped)
    - TestONNXRuntime_opset10.test_tolist (skipped)
  - inconsistent dtype
    - test_arithmetic_prim_bool (skipped)
    - test_arithmeticOps_with_low_precision (skipped)
    - test_arithmetic_prim_float (skipped)
    - test_logical_and (#79339)
    - test_logical_or (#79339)
    - test_logical_xor (#79339)
    - test_pow (skipped)
    - test_primitive_input_floating (skipped)
    - test_quantize_per_tensor (#79690)
    - test_quantized_adaptive_avg_pool2d (#79690)
    - test_quantized_arithmetic (#79690)
    - test_quantized_arithmetic_qfunctional (#79690)
    - test_quantized_conv2d (#79690)
    - test_quantized_conv2d_relu (#79690)
    - test_quantized_flatten (#79690)
    - test_quantized_hardsigmoid (#79690)
    - test_quantized_hardswish (#79690)
    - test_quantized_linear (#79690)
    - test_quantized_sigmoid (#79690)
    - test_item (skipped)
    - test_full_like_value (skipped)
    - TestONNXRuntime_opset7.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset8.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset9.test_div_rounding_mode (skipped)
    - TestONNXRuntime_opset9_IRv4.test_div_rounding_mode (skipped)
    - test_outer (skipped)
    - test_symbolic_shape_inference_arange_2 (skipped)

Pull Request resolved: #79263
Approved by: https://github.com/justinchuby, https://github.com/BowenBao

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d9a7e93aaf3166e639ea413123bd6c38b9144adc

Reviewed By: seemethere

Differential Revision: D38585848

fbshipit-source-id: 9da98581ceec51142ae31d3f8a06f9f296a16b23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Inbox
Development

No branches or pull requests

4 participants