Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: scatter_add_cuda_kernel does not have a deterministic implementation #3175

Closed
monk1337 opened this issue Sep 18, 2021 · 25 comments

Comments

@monk1337
Copy link
Contributor

monk1337 commented Sep 18, 2021

I am trying to use GCN and GAT from library and getting this error :

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch_scatter/scatter.py", line 21, in softmax
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
        return out.scatter_add_(dim, index, src)
               ~~~~~~~~~~~~~~~~ <--- HERE
    else:
        return out.scatter_add_(dim, index, src)
RuntimeError: scatter_add_cuda_kernel does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.

I tried to add torch.use_deterministic_algorithms(True) but not working
Same code is working well on CPU.
How I can avoid this error?

@rusty1s
Copy link
Member

rusty1s commented Sep 19, 2021

scatter_* is not a deterministic operation, and using torch.use_deterministic_algorithms(True) will result in an error. If you want to make use of deterministic operations in PyG, you have to use the SparseTensor class as an alternative to edge_index, see here.

@monk1337
Copy link
Contributor Author

@rusty1s Could you shed a little more light on this? The result will be the same as I can see from example then what the difference and issue is? And is there any accuracy difference between them?

@rusty1s
Copy link
Member

rusty1s commented Sep 20, 2021

Sure. The difference between those two approaches is that, for scatter, the order of aggregation is not deterministic since internally scatter is implemented by making use of atomic operations. This may lead to slightly different outputs induced by floating point precision, e.g., 3 + 2 + 1 = 5.000001 while 1 + 2 + 3 = 4.9999999. In contrast, the order of aggregation in SparseTensor is always deterministic and is performed based on the ordering of node indices. In practice, either operation works fine, in particular because graphs do not usually obey a fixed size ordering. The final accuracy is usually the same.

@monk1337
Copy link
Contributor Author

Thank you very much for explanation :)

@monk1337
Copy link
Contributor Author

Sorry for reopening the issue again; I tried the above solution with the GCNConv method, and it works fine, but using the SparseTensor method with GATConv throwing this error

RuntimeError: index_add_cuda_ does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.

@monk1337 monk1337 reopened this Sep 21, 2021
@rusty1s
Copy link
Member

rusty1s commented Sep 21, 2021

Can you show me a script to reproduce? I have no problems running the following script:

import torch
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

torch.use_deterministic_algorithms(True)

data = Planetoid('/tmp/dawdh', name='Cora', transform=T.ToSparseTensor())[0]
data = data.cuda()

conv = GATConv(data.num_node_features, 20, heads=4)
conv = conv.cuda()

out = conv(data.x, data.adj_t)
out.mean().backward()

@aabbas90
Copy link

aabbas90 commented Feb 22, 2022

Hi @rusty1s

As per this comment on PyTorch repo, deterministic scatter operation is now available for 1D tensors. Thus I think for heads = 1 using edge_index instead of SparseTensor should work in the script you gave. i.e.,

import torch
from torch_geometric.nn import TransformerConv
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

torch.use_deterministic_algorithms(True)

data = Planetoid('/tmp/dawdh', name='Cora', transform=None)[0] #T.ToSparseTensor())[0]
data = data.cuda()

conv = TransformerConv(data.num_node_features, 20, heads=4)
conv = conv.cuda()

out = conv(data.x, data.edge_index)
out.mean().backward()

however it does not work. The problem is only due to presence of a singleton dimension in out here:

out_sum = scatter(out, index, dim, dim_size=N, reduce='sum')

Changing this line instead to the following does work:

out_sum = scatter(out.squeeze(-1), index, dim, dim_size=N, reduce='sum').unsqueeze(-1)
However, such change might also be required at other places such as here:

return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,

TLDR: It might be possible to start supporting deterministic scatter operations for 1D tensors using edge_index.

@rusty1s
Copy link
Member

rusty1s commented Feb 23, 2022

Thanks for the reference @aabbas90. In my understanding, deterministic scatter is automatically supported for 1D tensors in case you set torch.use_deterministic_algorithms(True), right?

Notably, I'm afraid that operating on 1D tensors will actually never be the case in PyG (as this requires that the number of features needs to be one).

@aabbas90
Copy link

aabbas90 commented Feb 23, 2022

Yes deterministic scatter is automatically supported for 1D tensors. You are right though as most of the times multi-dimensional features would need to be used. The only exception I found is when we are calculating attention weights for transformer related architectures as they are 1D (and heads = 1) e.g. here:

out_sum = scatter(out, index, dim, dim_size=N, reduce='sum')

But I guess it does not help much since other ops would still be non-deterministic.

@rusty1s
Copy link
Member

rusty1s commented Feb 23, 2022

Ah, I see. In that case we would need to try to squeeze the tensor into a one-dimensional, got it. In that case, it might be better to add this functionality directly inside torch-scatter. Are you interested in contributing this?

@aabbas90
Copy link

aabbas90 commented Feb 23, 2022

Thanks, yes I was partly interested in contributing but cannot currently build torch-scatter successfully from source. So I am afraid I do not have enough capacity to go down this rabbit hole.

@rusty1s
Copy link
Member

rusty1s commented Feb 25, 2022

Please let me know the issues when installing torch-scatter from source. You can probably also try to simply compile the CPU version via FORCE_ONLY_CPU=1 python setup.py develop.

@aabbas90
Copy link

The error I am getting is:

$ python
>>> import torch_scatter
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/projects/pytorch_scatter/torch_scatter/__init__.py", line 16, in <module>
    torch.ops.load_library(spec.origin)
  File "/anaconda3/envs/py_scatter/lib/python3.7/site-packages/torch/_ops.py", line 110, in load_library
    ctypes.CDLL(path)
  File "/anaconda3/envs/py_scatter/lib/python3.7/ctypes/__init__.py", line 364, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /projects/pytorch_scatter/torch_scatter/_scatter_cuda.so: undefined symbol: _ZNK2at6Tensor6deviceEv

To install from source I did the following inside a conda environment with python 3.7:

conda install pytorch cudatoolkit=10.2 -c pytorch
CPATH=/usr/lib/cuda-10.2/include/ python setup.py develop

where PATH=/anaconda3/envs/py_scatter/bin:/anaconda3/condabin:/software/gcc-7.5-build/bin:/usr/lib/cuda-10.2/bin:...

@rusty1s
Copy link
Member

rusty1s commented Feb 25, 2022

Can you uninstall and try to install CPU-only (as detailed above). Please report any installation logs as well.

@andrei-rusu
Copy link

Sorry for intervening in your discussion, but I have a question regarding the workaround posted here for the OP issue:
Is there any way one could bypass the non-deterministic issue of scatter by utilizing SparseTensors, while also retaining the ability to utilize edge attributes and self loops? It seems that this behaviour is currently missing, and unfortunately my GAT-based models seem to get a performance hit when I'm using SparseTensors and setting add_self_loops=False.
Thank you!

@rusty1s
Copy link
Member

rusty1s commented Mar 4, 2022

The usage of SparseTensor ensures deterministic behavior while its output and training performance should be identical to edge_index. There shouldn't be any significant performance hit. If there exists one, it is probably induced by add_self_loops=False. Aany reason to disable it at all?

@andrei-rusu
Copy link

andrei-rusu commented Mar 5, 2022

I use edge attributes. If add_self_loops is enabled, a NotImplementedError arises (e.g. check the GATConv forward implementation if edge_dim is set). In more recent versions, one cannot actually omit setting the edge_dim either due to an assert failure in message(), even if the dimensionality of the edge attributes is kept in check externally to the call (which is also a bit annoying to be fair), so having self loops becomes effectively locked for SparseTensors in GAT or GATv2Conv..

@rusty1s
Copy link
Member

rusty1s commented Mar 5, 2022

Can you clarify what you mean by one cannot omit setting edge_dim? I'm happy to fix it once I am sure about the issue.

Furthermore, note that you can always set add_self_loops in GATConv and apply self-loops to the SparseTensor in a pre-processing step (prior to any message passing), e.g., via transforms:

transform = T.Compose([T.AddSelfLoops(), T.ToSparseTensor()])

This should mimic the original behavior.

@andrei-rusu
Copy link

andrei-rusu commented Mar 5, 2022

OK. So, in gat_conv.py in the forward() method you have the following logic which returns an error for SparseTensors if add_self_loops is True and edge_dim is set:

if self.add_self_loops:
if isinstance(edge_index, Tensor):
# We only want to add self-loops for nodes that appear both as
# source and target nodes:
num_nodes = x_src.size(0)
if x_dst is not None:
num_nodes = min(num_nodes, x_dst.size(0))
num_nodes = min(size) if size is not None else num_nodes
edge_index, edge_attr = remove_self_loops(
edge_index, edge_attr)
edge_index, edge_attr = add_self_loops(
edge_index, edge_attr, fill_value=self.fill_value,
num_nodes=num_nodes)
elif isinstance(edge_index, SparseTensor):
if self.edge_dim is None:
edge_index = set_diag(edge_index)
else:
raise NotImplementedError(
"The usage of 'edge_attr' and 'add_self_loops' "
"simultaneously is currently not yet supported for "
"'edge_index' in a 'SparseTensor' form")

At the same time, in message() you have the assertion below, which fails if one does not set edge_dim but passes in edge attributes (self.lin_edge gets initialized only if edge_dim is not None). Worth noting that layers like GINEConv do not enforce a non-None edge_dim, which is advantageous if one wants to create his own routine of edge_attr projection.

if edge_attr is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
assert self.lin_edge is not None

Maybe the intention here was to only support passing edge attributes directly to the SparseTensor constructor (i.e. without support for passing edge_attr as an argument to forward())? Nonetheless, I've been having trouble with performing the required linear projections to match the dimensionality of x when passing the attributes through the SparseTensor constructor, so that also seems a limitation of this approach. At the same time, since I've written some Adapters to allow for the edge attributes to also be utilized as edge_weight or edge_type where the layer supports that, and I've also implemented the Linear projections of the edge_attr myself, utilizing the edge attributes from the SparseTensor requires some annoying code refactory on my part to support both modes as I only really need SparseTensors when running on CUDA with a set seed to ensure consistency.

@rusty1s
Copy link
Member

rusty1s commented Mar 5, 2022

you have the assertion below, which fails if one does not set edge_dim but passes in edge attributes (self.lin_edge gets initialized only if edge_dim is not None).

I think this is definitely intended. If you pass in edge_attr it shouldn't be ignored just because one didn't set edge_dim. The GINEConv layer does not have this constraint since it requires same input dimensionality across node and edge features (which comes with other disadvantages).

I guess the easiest workaround for your problem would be the addition of self-loops in SparseTensor for multi-dimensional edge features as well, right? We would need to add support for this in add_self_loops.

@andrei-rusu
Copy link

andrei-rusu commented Mar 5, 2022

Well, from my understanding GINEConv requires the edge_attr to have a dimensionality of in_channels. GATConv, on the other hand, requires edge_attr to have a dimensionality of heads * out_channels. I don't see why one wouldn't be able to create their own projections in order to ensure edge_attr has the correct size for each of these cases. I agree, however, that not enforcing edge_dim may be confusing to some (and the documentation is already explicit about this internal Linear layer), so I don't really have a problem with it as long as add_self_loops works fine.

And yep, the final solution should be what you said in the second paragraph! However, as per your previous comment, a temporary workaround could be to just add the self loops as part of the Transform or directly to the original edge_index & edge_attr tensors before creating the SparseTensor. Thanks!

@cxw-droid
Copy link

Hi, @rusty1s I try to make mutag_gin.py output a deterministic result. Following the above suggestions, I have changed the dataset to SparseTensor, edge_index to adj_t and changed line 53 to x = global_max_pool(x, batch), but I still got random result. I set the seed as follows:

seed = 2
torch.manual_seed(seed)  ##
# np.random.seed(seed)
# random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

Any help would be appreciated. Thanks.

@FeiGSSS
Copy link

FeiGSSS commented Nov 7, 2022

Hi, @rusty1s I try to make mutag_gin.py output a deterministic result. Following the above suggestions, I have changed the dataset to SparseTensor, edge_index to adj_t and changed line 53 to x = global_max_pool(x, batch), but I still got random result. I set the seed as follows:

seed = 2
torch.manual_seed(seed)  ##
# np.random.seed(seed)
# random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

Any help would be appreciated. Thanks.

I have encontered the same problem, have you figured it out?

@rusty1s
Copy link
Member

rusty1s commented Nov 10, 2022

You need to use torch_scatter.segment_csr(x, batch.ptr) instead of global_max_pool.

@akihironitta
Copy link
Member

Closing this issue as there seems no action item, but feel free to create a new issue or discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants