Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

non zero depreciated fix #2314

Closed

Conversation

java-abhinav07
Copy link

@java-abhinav07 java-abhinav07 commented Jun 11, 2020

To fix non zero deprecation warning. Decreases time of execution of frcnn, keypoint_rcnn etc models significantly on account of the time taken by warning.
Fixing #2154

@java-abhinav07 java-abhinav07 changed the title Abhinav/non zero bug non zero depreciated fix Jun 11, 2020
@java-abhinav07 java-abhinav07 changed the title non zero depreciated fix WIP: non zero depreciated fix Jun 11, 2020
@java-abhinav07 java-abhinav07 changed the title WIP: non zero depreciated fix non zero depreciated fix Jun 12, 2020
@fmassa
Copy link
Member

fmassa commented Jun 12, 2020

Hi,

Thanks for the PR!

There are a few issues with torchvision CI, let me try to fix it and then get back to you, some of the errors are unrelated to your PR

@java-abhinav07
Copy link
Author

Yes, absolutely I was a little concerned it might be a code bug. Well, thanks for the clarification!

@fmassa
Copy link
Member

fmassa commented Jun 17, 2020

Hum, looks like torchscript and ONNX might not support the as_tuple keyword in nonzero.

I think a potentially better fix is to replace torch.nonzero(tensor) with tensor.nonzero(), as it fixes the deprecated overload issue, although this sounds a bit weird. cc @gchanan is this expected that torch.nonzero(tensor) raises a warning but but tensor.nonzero() doesn't, with message

/Users/fmassa/anaconda3/bin/ipython:1: UserWarning: This overload of nonzero is deprecated:
	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple) (Triggered internally at  /Users/distiller/project/conda/conda-bld/pytorch_1592377720291/work/torch/csrc/utils/python_arg_parser.cpp:761.)
  #!/Users/fmassa/anaconda3/bin/python

cc @eellison and @neginraoof FYI about torchscript and ONNX apparent limitations on nonzero

@fmassa
Copy link
Member

fmassa commented Jun 17, 2020

So, according to @gchanan, the difference between torch.nonzero(tensor) and tensor.nonzero() is a bug.

This means that merging this PR is currently blocked by torchscript and ONNX changes.

@fmassa
Copy link
Member

fmassa commented Jun 17, 2020

@eellison just proposed to use torch.where(condition) instead of torch.nonzero(condition, as_tuple=True), this should fix the torchscript issues (let's hope ONNX will support it)

@java-abhinav07
Copy link
Author

java-abhinav07 commented Jun 18, 2020

Yes, completely agree with @gchanan tensor.nonzero surprising doesn't throw a warning. I'll try replacing it with torch.where. However, I'm not so sure how to verify if ONNX will support it? Maybe trying to export it once would help?

@fmassa
Copy link
Member

fmassa commented Jun 19, 2020

@java-abhinav07 if ONNX tests pass with the change, then it torch.where is supported :-)

@java-abhinav07
Copy link
Author

java-abhinav07 commented Jun 19, 2020

Great, I'll commit new changes by tomorrow! I found this open issue though pytorch/pytorch#27959

@java-abhinav07
Copy link
Author

java-abhinav07 commented Jun 20, 2020

#1916 Facing this issue locally as well (travis-ci check doesn't pass because of this), could you give some insight into why this is happening @fmassa , possibly version mismatch (I'm using a cpu device)?
Heres the environment information:

PyTorch version: 1.5.0
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.18.1
[pip3] torch==1.5.0
[conda] Could not collect

@fmassa
Copy link
Member

fmassa commented Jun 22, 2020

@java-abhinav07 looks like ONNX doesn't support torch.where as you pointed out, so it doesn't look like there is anything we can do as of now, and we won't be able to merge this PR until this issue with ONNX is fixed in PyTorch. cc @neginraoof could you look into adding support for this overload of torch.where to ONNX?

Also, @java-abhinav07 why did you use torch.stack(torch.where( instead of just torch.where? Using stack adds unnecessary overhead, and you can just use the same implementation as before with indexing, right?.

@java-abhinav07
Copy link
Author

java-abhinav07 commented Jun 23, 2020

@fmassa no solely using torch.where doesn't solve the issue for multi-dimensional tensors. torch.where outputs a tuple of tensors identifying non zero values at indices of each column dimension. Those tensors need to be stacked together to replicate the functionality of torch.nonzero with as_tuple=False.

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.nonzero(t).squeeze(1)
/pytorch/torch/csrc/utils/python_arg_parser.cpp:756: UserWarning: This overload of nonzero is deprecated:
	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple)
tensor([[0, 0],
        [0, 1],
        [1, 0],
        [1, 1]])
>>> torch.where(t>0)
(tensor([0, 0, 1, 1]), tensor([0, 1, 0, 1]))
>>> torch.stack(torch.where(t>0), dim=1)
tensor([[0, 0],
        [0, 1],
        [1, 0],
        [1, 1]])

@fmassa
Copy link
Member

fmassa commented Jun 23, 2020

@java-abhinav07 most usages of nonzero() are on a 1d tensor, and we use just after a .squeeze(1) to remove the added dimension. This means that most of the time, this stack is not necessary.

In [2]: a = torch.rand(10)

In [3]: a
Out[3]:
tensor([0.4069, 0.6907, 0.0880, 0.4860, 0.0721, 0.3742, 0.2973, 0.7450, 0.3434,
        0.9473])

In [4]: (a > 0.5).nonzero()
Out[4]:
tensor([[1],
        [7],
        [9]])

In [5]: (a > 0.5).nonzero().squeeze(1)
Out[5]: tensor([1, 7, 9])

In [6]: torch.where(a > 0.5)[0]
Out[6]: tensor([1, 7, 9])

@java-abhinav07
Copy link
Author

@fmassa I see, yes if we are using a 1D tensor for subsequent nonzero computations then there is no need for stack. If we are sure it's always a 1D tensor then I'll rectify to simply torch.where(tensor>0)[0] however, I believe it makes sense to probably do tensor.size() and make sure it indeed is 1D since the code won't throw any error if we use the method you suggested by simply taking dim 0 nonzero indices.

@fmassa
Copy link
Member

fmassa commented Jul 7, 2020

@java-abhinav07 sorry for the delay in replying.

Most places are sure to have a 1d tensor all the time, I think there is only one occurence that this is not the case I think.

But given that ONNX doesn't support this overload yet, we can't merge this PR for now unfortunately

@neginraoof
Copy link
Contributor

neginraoof commented Jul 10, 2020

As you mentioned, ONNX export is missing for std::vector<Tensor> where(const Tensor & condition)
But there's also another problem with ONNX export after this change. The issue is that ONNX Greater does not support bool input type. Currently exporter does not take care of the required input type casts for this op, and this results in a failure when running the exported model code in ONNX Runtime.

Export of where op is possible using a combination of NonZero and Unbind ops.

We need to fix both issues to be able to fix Faster_RCNN tests with there updates.

@java-abhinav07
Copy link
Author

We can wait for ONNX support to be fixed, or open another PR later (@fmassa feel free to close the pr accordingly). @neginraoof if opening an onnx export issue helps we could create one. Either way, yeah, it makes sense, understood!

@neginraoof
Copy link
Contributor

I'm going to send out a PR for export of where op and necessary export fixes.
I'll follow up on this thread.

@neginraoof
Copy link
Contributor

PR to address export issue: pytorch/pytorch#41544

@fmassa
Copy link
Member

fmassa commented Sep 24, 2020

Subsumed by #2705. Thanks for all the work @java-abhinav07 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants