Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Geometric RandomLinkSplit: Expected all tensors to be on the same device #3641

Closed
Sticksword opened this issue Dec 7, 2021 · 4 comments
Labels

Comments

@Sticksword
Copy link

馃悰 Bug

The RandomLinkSplit class is negative sampling with the device being the cpu. If the original data was on a cuda device, we would run into the error "Expected all tensors to be on the same device" when we operate on the two tensors (eg. torch.cat).

The culprit can be found around this line.

To Reproduce

import torch

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import negative_sampling

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
])

dataset = Planetoid(root='/tmp/Planetoid', name='Cora', transform=transform)

data = dataset[0]

transform = T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,)
train_data, val_data, test_data = transform(data) # throws error

Example pulled from https://github.com/pyg-team/pytorch_geometric/blob/master/examples/link_pred.py

Expected behavior

I expect the sampled edges to be on the same device as the original sampled tensor.

Environment

  • PyG version (torch_geometric.__version__): 2.0.2
  • PyTorch version: (torch.__version__): 1.9.0+cu111
  • OS (e.g., Linux): Linux
  • Python version (e.g., 3.9): 3.9
  • CUDA/cuDNN version: cuda
  • How you installed PyTorch and PyG (conda, pip, source): pip
  • Any other relevant information (e.g., version of torch-scatter):

Additional context

We can see from this line that the negative sampling method does not take a device parameter.

I originally opened a StackOverflow post and a user had suggested I raise this officially. I then took a closer look and it is indeed a bug.

Further inspecting how negative_sampling works, I have confirmed this buggy behavior of having the negative samples be on the cpu device.

from torch_geometric.datasets import Planetoid
from torch_geometric.utils import negative_sampling

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
])

dataset = Planetoid(root='/tmp/Planetoid', name='Cora', transform=transform)

data = dataset[0]

neg_edge_index = negative_sampling(data.edge_index, data.x.size(),
                                           num_neg_samples=500,
                                           method='sparse')

data.x.device, neg_edge_index.device

outputs

=> (device(type='cuda', index=0), device(type='cpu'))
@Sticksword Sticksword added the bug label Dec 7, 2021
@Sticksword
Copy link
Author

Sticksword commented Dec 7, 2021

Possible solutions

  • we can take the device of the sampled tensor and apply it to the negative sample tensor
  • we can add a parameter that allows for RandomLinkSplit to accept a device and then apply this device to the negative sample tensor

open to other suggestions too! this is my first time opening an issue so learning the ropes as I go

@saiden89
Copy link
Contributor

saiden89 commented Dec 7, 2021

Thanks for opening the issue.
As you have experienced, tensors must reside on the same device before performing certain operations, so the first solution, at least for me, is the one that feels more natural. Furthermore, allowing the user to specify a (possibly wrong) device has no real benefit, and may as well introduce bugs if the user inadvertently inputs the wrong device.

@rusty1s
Copy link
Member

rusty1s commented Dec 7, 2021

Thanks for spotting. This is now fixed in master, see 8630831.

@rusty1s rusty1s closed this as completed Dec 7, 2021
@Sticksword
Copy link
Author

awesome, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants