-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Description
🐛 Bug
The Tensor.argsort()
is not supported in onnx export. While the Tensor.sort()
is supported by topK.
To Reproduce
Steps to reproduce the behavior:
- run the scripts
import torch
from torch import nn
import numpy as np
class Demo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
v, inds = x.sort(descending=True)
# inds = x.argsort(descending=True)
return inds
if __name__ == "__main__":
input_tensor = torch.range(20, 80)
demo = Demo()
out = demo(input_tensor)
torch.onnx.export(demo, input_tensor, "debug.onnx", verbose=True,
input_names=['data'],
opset_version=11,
do_constant_folding=True,
dynamic_axes={'data':{0:'batch'}})
Expected behavior
File "/home/god/python36/lib/python3.6/site-packages/torch/onnx/symbolic_registry.py", line 91, in get_registered_op
return _registry[(domain, version)][opname]
KeyError: 'argsort'
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
Collecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.15.3
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: GeForce GTX 1080
Nvidia driver version: 430.50
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.3
Versions of relevant libraries:
[pip3] numpy==1.17.2
[pip3] torch==1.4.0
[pip3] torchvision==0.5.0
[conda] Could not collect
Additional context
export result of Tensor.sort()
input_tensor = torch.range(20, 80)
graph(%data : Float(61)):
%1 : Float(61), %2 : Long(61) = onnx::TopK[axis=-1, k=61](%data) # onnx.py:10:0
return (%2)
cc @suo @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof