Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation fault (core dumped) #33

Closed
xiaojinglu opened this issue Aug 29, 2019 · 13 comments
Closed

Segmentation fault (core dumped) #33

xiaojinglu opened this issue Aug 29, 2019 · 13 comments
Labels

Comments

@xiaojinglu
Copy link

When I run python reddit.py in pytorch_geometric, which use neighbor_sampler in torch_cluster, it occur 'Segmentation fault (core dumped)'. My machine is Ubuntu 16.04.5, use python3.6 and cuda10.0

@rusty1s
Copy link
Owner

rusty1s commented Aug 29, 2019

Can you test if the test suite of torch-cluster works for you?

@xiaojinglu
Copy link
Author

xiaojinglu commented Sep 2, 2019

Can you test if the test suite of torch-cluster works for you?

I just run python setup.py install and then python setup.py test, all tests passed.

test/test_fps.py ....
test/test_graclus.py ............................
test/test_grid.py ........................................................
test/test_knn.py ........
test/test_nearest.py ....
test/test_radius.py ........
test/test_rw.py ..
test/test_sampler.py .

----------- coverage: platform linux, python 3.6.8-final-0 -----------
Name                        Stmts   Miss  Cover
-----------------------------------------------
torch_cluster/__init__.py      10      0   100%
torch_cluster/fps.py           10      0   100%
torch_cluster/graclus.py       10      0   100%
torch_cluster/grid.py           9      0   100%
torch_cluster/knn.py           37      0   100%
torch_cluster/nearest.py       21      0   100%
torch_cluster/radius.py        31      0   100%
torch_cluster/rw.py            11      0   100%
torch_cluster/sampler.py        8      0   100%
-----------------------------------------------
TOTAL                         147      0   100%


============================= 111 passed in 24.19s =============================

@rusty1s
Copy link
Owner

rusty1s commented Sep 2, 2019

Okay, I did not expect this. So installation succeed. The reddit.py example works for me, so I am not sure what may causes this issue. Is it possible to do some debugging on your side?

@Chandrahasd
Copy link

I am also facing the same issue.
My graph has 643474 nodes and 1222204 edges.
I am trying to use NeighborSampler to generate batches of size 128. I tried using full neighborhood (size=1.0) as well as subsampling (e.g. size=16).
In both cases, the code terminates with Segmentation Fault at this line of pytorch_cluster/torch_cluster/sampler.py file.

The example code reddit.py from Pytorch Geometric works fine for me.

@rusty1s
Copy link
Owner

rusty1s commented Sep 2, 2019

Can you send me your data.edge_index per mail? Curious to debug this!

@xiaojinglu
Copy link
Author

My code terminates with Segmentation Fault at this line too when I run python reddit.py

@Chandrahasd
Copy link

Can you send me your data.edge_index per mail? Curious to debug this!

@rusty1s, I shared the data.edge_index file over email.
Thank you for looking into this issue.

@rusty1s
Copy link
Owner

rusty1s commented Sep 3, 2019

I received your file and will look into it. An initial guess is that it may be due to isolated nodes.

@Chandrahasd
Copy link

I thought that too.
I did try adding the self-loops to avoid zero-degree nodes, but it wasn't enough to avoid the error.

@rusty1s
Copy link
Owner

rusty1s commented Sep 4, 2019

Works flawless for me:

import torch
from torch_geometric.read import read_txt_array
from torch_geometric.data import Data
from torch_geometric.data import NeighborSampler

edge_index = read_txt_array('edge_indices.txt', sep=' ', dtype=torch.long)
edge_index = edge_index.t().contiguous()
data = Data(edge_index=edge_index)
print(data)

loader = NeighborSampler(data, size=1.0, num_hops=2, batch_size=100,
                         shuffle=False, add_self_loops=False)

for data_flow in loader():
    print(data_flow)

Can you run the code and see if it works for you?

@Chandrahasd
Copy link

This code works.
The problem happens when either edge_index or data object or both are in GPU memory.

import torch
from torch_geometric.read import read_txt_array
from torch_geometric.data import Data
from torch_geometric.data import NeighborSampler

# Put edge_index in GPU memory
edge_index = read_txt_array('edge_indices.txt', sep=' ', dtype=torch.long).cuda()
edge_index = edge_index.t().contiguous()
# Put data object to GPU memory
data = Data(edge_index=edge_index).to('cuda')
print(data)

loader = NeighborSampler(data, size=1.0, num_hops=2, batch_size=100,
                         shuffle=False, add_self_loops=False)

for data_flow in loader():
    print(data_flow)

Output:

torch_geometric/data/data.py:177: UserWarning: The number of nodes in your data object can only be inferred by its edge indices, and hence may result in unexpected batch-wise behavior, e.g., in case there exists isolated nodes. Please consider explicitly setting the number of nodes for this data object by assigning it to data.num_nodes.
  warnings.warn(__num_nodes_warn_msg__.format('edge'))
Segmentation fault (core dumped)

@rusty1s
Copy link
Owner

rusty1s commented Sep 9, 2019

This makes sense, and we should add an assertion to prevent this. The NeighborSampler is a CPU-op, and hence expects data being on the CPU.

@github-actions
Copy link

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants