Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for other integer types by MessagePassing #5087

Open
DomInvivo opened this issue Jul 29, 2022 · 10 comments
Open

Support for other integer types by MessagePassing #5087

DomInvivo opened this issue Jul 29, 2022 · 10 comments

Comments

@DomInvivo
Copy link
Contributor

馃悰 Describe the bug

Why does PyG enforces edge_index to be of type long? Certain graphs can work properly with int16 such as molecules, which will rarely surpass the 32767, unless batch sizes larger than 1000 are used. They have an average of 20 atoms. Instead, we could simply ensure that there are no negative numbers (which happen when the edge index overflows), or that the maximum of a specific datatype is not reached.

Also, some hardwares do not support long, such as TPU and IPU, which are limited to int32.

assert edge_index.dtype == torch.long, \

Environment

  • PyG version: All
  • PyTorch version: All
  • OS: All
  • Python version: All
  • CUDA/cuDNN version: All
  • How you installed PyTorch and PyG (conda, pip, source): Any
  • Any other relevant information (e.g., version of torch-scatter): Any
@DomInvivo DomInvivo added the bug label Jul 29, 2022
@rusty1s
Copy link
Member

rusty1s commented Jul 30, 2022

This is mostly due to how PyTorch works, e.g., index_select only works with indices of dtype=torch.long. Not much we can do about it currently, sorry!

@rusty1s rusty1s added feature and removed bug labels Jul 30, 2022
@DomInvivo
Copy link
Contributor Author

Since Pytorch 1.8.0, index_select supports IntTensor as well as LongTensor, according to the docs.

@rusty1s
Copy link
Member

rusty1s commented Jul 30, 2022

Oh, my bad. I just tested with torch.short and was convinced it is only working for torch.long. I guess we can then start looking into supporting both torch.long and torch.int.

@hatemhelal
Copy link
Contributor

I can take a look at this: as a first attempt I'll try to relax the check for edge_index.dtype == torch.long and add some additional tests in test/nn/conv/test_message_passing.py and see what breaks.

@rusty1s
Copy link
Member

rusty1s commented Aug 15, 2022

Sounds good. Thanks! We also added support to make use of torch.scatter_reduce (see utils/scatter.py) which will help in this transition.

@EdisonLeeeee
Copy link
Contributor

I can take a look at this: as a first attempt I'll try to relax the check for edge_index.dtype == torch.long and add some additional tests in test/nn/conv/test_message_passing.py and see what breaks.

Hi @howardjp, I've tried and it breaks at torch_scatter.scatter_add.

I also tried to use torch.scatter_reduce as suggested by @rusty1s, it currently does not support IntTensor.

Finally, I also implemented scatter_reduce based on torch.scatter. Sadly, torch.scatter does not support IntTensor as well. This issue was mentioned in pytorch/pytorch#61819 and pytorch/pytorch#51323, which have not yet been resolved. It seems only index_select supports IntTensor.

@rusty1s
Copy link
Member

rusty1s commented Aug 18, 2022

Thanks for digging into this. It looks like we need to wait for PyTorch team to catch up, sorry :(

@hatemhelal
Copy link
Contributor

See #5281 for a proposal to relax the type assertion in the message passing interface. This patch effectively lets the execution backend decide which integer types to support when executing the aggregation step. That said, with the default CPU backend the error message changes to:

E           RuntimeError: scatter(): Expected dtype int64 for index

@vadimkantorov
Copy link

pytorch/pytorch#51323 seems fixed now, but not other nasty things like pytorch/pytorch#56975 or pytorch/pytorch#61819 ...

@robertparley
Copy link

robertparley commented Sep 6, 2023

Could I add try..except... code to aviod RuntimeError?
It will make redundant code. I don't know if such a modification is appropriate.
@vadimkantorov

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants