Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyG's DataLoader hangs when I use it in a subprocess created by the "fork" start method #3565

Closed
Enolerobotti opened this issue Nov 25, 2021 · 6 comments 路 Fixed by #3566
Closed

Comments

@Enolerobotti
Copy link
Contributor

馃悰 Bug

There are 2 identical completely independent functions which use torch_geometric.loader.DataLoader. If we run function 1 in the main process and then function 2 in the subprocess created by the "fork" start method the program hangs between these 2 lines. Note, that for standard torch.utils.data.DataLoader all works well. If we use another start method all works well too. Also, the program works If we run only function 2 in the subprocess created by any start method (w/o calling the function 1 first). And, the most interesting, the program runs well if we move tensors to the GPU.
I attach the example code for the broken case above.

To Reproduce

Steps to reproduce the behavior:

import queue
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from multiprocessing import Queue, get_context


def worker_function(_queue: Queue = None):
    data = [Data(x=torch.rand(size=(100, 100)), edge_index=torch.randint(100, size=(0, 100))) for _ in range(10)]
    dataloader = DataLoader(data, num_workers=0, batch_size=len(data)//2)
    dataloader_iter = enumerate(dataloader)
    idx, batch = next(dataloader_iter)
    if _queue is not None:
        _queue.put(idx)
    else:
        return idx


def completely_independent_function(_queue: Queue = None):
    # Actually, it is a copy of the function above
    data = [Data(x=torch.rand(size=(100, 100)), edge_index=torch.randint(100, size=(0, 100))) for _ in range(10)]
    dataloader = DataLoader(data, num_workers=0, batch_size=len(data)//2)
    dataloader_iter = enumerate(dataloader)
    idx, batch = next(dataloader_iter)
    if _queue is not None:
        _queue.put(idx)
    else:
        return idx


def main(context: str):
    ctx = get_context(context)
    q = ctx.Queue()
    p = ctx.Process(target=worker_function, args=(q,))
    p.start()
    result = q.get(timeout=10)
    p.join()
    return result


if __name__ == '__main__':
    print("Start")
    i = completely_independent_function()
    print("Finished completely_independent_function in the main process")
    for method in ['fork', 'forkserver', 'spawn']:
        try:
            r = main(method)
            assert i == r
            print(f"The '{method}' method succeed")
        except queue.Empty:
            print(f"The '{method}' hangs due to unknown reason")

Expected behavior

The queue is not Empty or the timeout value can be omitted

Environment

  • OS: Ubuntu 20.04
  • Python version: 3.9.7
  • PyTorch version: 1.9.0 and 1.10.0
  • CUDA/cuDNN version: 10.2 (actually I run on CPU!!!)
  • GCC version: 9.3.0 (not compiled, installed from PYPI)
  • Any other relevant information:
  • torch-geometric version 2.0.2 and 2.0.1
  • torch_scatter version 2.0.9 and 2.0.8
  • torch_sparse version 0.6.12
  • torch_cluster version 1.5.9
  • torch_spline_conv version 1.2.1

Additional context

I faced this issue running unittests in the library I develop. There are several unit tests. Some of them check parallel runs. When I run a single unit test for the fork start method, it was passed. However, it was suspended when I run multiple tests using python -m unittest discover. I used Pytorch Lightning for model training and the lines containing next(enumerate(dataloader)) are coming from there. So, I could not find a good workaround except to use another start methods (which are more expensive)

@rusty1s
Copy link
Member

rusty1s commented Nov 25, 2021

That's super interesting. It's indeed caused by the torch.repeat_interleave call. I need to investigate that. Completely removing the DataLoader and just calling repeat_interleave in worker_function leads to the same error.

@rusty1s
Copy link
Member

rusty1s commented Nov 25, 2021

As this seems to be a PyTorch issue, I'm not sure if it's worth fixing on our end. WDYT?

@rusty1s rusty1s linked a pull request Nov 25, 2021 that will close this issue
@rusty1s
Copy link
Member

rusty1s commented Nov 25, 2021

Created a PR nonetheless. Please let me know what you think.

@Enolerobotti
Copy link
Contributor Author

Hi @rusty1s!

Thank you for your PR! This indeed seems to be Pytorch issue. However, it is good to avoid of using repeat_interleave because users can have different versions of Pytorch and can experience this error.

Cheers, Artem.

@Enolerobotti
Copy link
Contributor Author

BTW, I can confirm that the PR fixes this error

@Enolerobotti
Copy link
Contributor Author

I raised the issue on the Pytorch repo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants