Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.all with dim Onnx: _all() takes 2 positional arguments but 4 were given #65817

Closed
simetin opened this issue Sep 29, 2021 · 2 comments
Closed
Assignees
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@simetin
Copy link

simetin commented Sep 29, 2021

馃悰 Bug

When I use the torch.all function with the dim parameter I get the following error during the Onnx translation: _all() takes 2 positional arguments but 4 were given

To Reproduce

import torch
from torch import nn

class TorchAll(nn.Module):
    def forward(self, tensor):
        tensor = torch.all(tensor, dim=1)
        return tensor

X = torch.ones((3, 300, 300), dtype=torch.int32)

torch.onnx.export(
    TorchAll(),
    (X), # Dummy input for shape
    "torch_all_model.onnx",
    opset_version=12,
    do_constant_folding=True,
)

Expected behavior

Expecting the translation to occur without error

Environment

PyTorch version: 1.8.1
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.27

Python version: 3.8.11 (default, Aug 3 2021, 15:09:35) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.15.0-143-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX 1070
GPU 1: GeForce GTX 1070

Nvidia driver version: 450.119.03
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.8.1
[pip3] torchaudio==0.8.0a0+e4e171a
[pip3] torchvision==0.9.1
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.3.0 h06a4308_520
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.0 py38h42c9631_2
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.20.3 py38hf144106_0
[conda] numpy-base 1.20.3 py38h74d4b33_0
[conda] pytorch 1.8.1 py3.8_cuda10.1_cudnn7.6.3_0 pytorch
[conda] torchaudio 0.8.1 py38 pytorch
[conda] torchvision 0.9.1 py38_cu101 pytorch

Additional context

Same error occur on my mac with PyTorch 1.9.1

cc @BowenBao @neginraoof

@VitalyFedyunin VitalyFedyunin added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 1, 2021
@shubhambhokare1 shubhambhokare1 self-assigned this Oct 4, 2021
@shubhambhokare1
Copy link
Collaborator

PR #66093 should fix this issue
@simetin let me know if this works

@simetin
Copy link
Author

simetin commented Oct 5, 2021

Yes this does fix the issue. Thank you

@simetin simetin closed this as completed Oct 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants