Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random hangs and failures when sending tensors that are split using torch.split in a JoinableQueue #95606

Closed
RobertoLat opened this issue Feb 27, 2023 · 13 comments
Labels
high priority module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@RobertoLat
Copy link

RobertoLat commented Feb 27, 2023

馃悰 Random hangs and failures when sending tensors that are split using torch.split in a JoinableQueue

Splitting tensors using torch.split and sending them to processes using a JoinableQueue seems to cause random errors and hangs in 2.0.0.dev20230130+cu116, while works perfectly fine on 1.9.1+cu102

I tried to make the code to reproduce as small as I could. The key ingredients are torch.split and JoinableQueue.

The following script hangs on my CUDA machine using PyTorch 2.0 while it completes successfully on PyTorch 1.9.

import os
import sys
import tempfile

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def setup(rank: int, world_size: int) -> None:
    backend = 'nccl' if torch.cuda.is_available() else 'gloo'
    dist.init_process_group(backend, init_method='tcp://{}'.format('127.0.0.1:23456'), rank=rank, world_size=world_size)


def cleanup() -> None:
    dist.destroy_process_group()


def demo_basic(rank: int, queue: mp.JoinableQueue, world_size: int) -> None:
    setup(rank, world_size)
    device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'

    while True:
        batch = queue.get()
        batch = batch.to(device)

        try:
            negative_in_batch = batch.lt(0).all().item()
            if negative_in_batch:
                print("Found negative in batch", sys.stderr)

        finally:
            queue.task_done()


def split_batch(batch: torch.Tensor, world_size: int) -> torch.Tensor:
    return torch.split(batch, batch.shape[0] // world_size) # if I use torch.split(batch, batch.shape[0] // world_size).clone() instead no error is observed


def run_demo(world_size: int) -> None:
    print(torch.__version__, file=sys.stderr)
    num_batches = 10000
    batch_size = 64

    ctx = mp.get_context('spawn')
    queues = [ctx.JoinableQueue() for _ in range(world_size)]
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    processes = [ctx.Process(target=demo_basic, args=(i, queues[i], world_size)) for i in range(world_size)]

    for p in processes:
        p.start()

    for i in range(num_batches):
        large_batch = torch.randint(100000, size=(batch_size,))
        batches = split_batch(large_batch, world_size) # if I remove this line and send the large batch instead no error is observed
        print(f'queuing batch {i}', file=sys.stderr)

        for batch, queue in zip(batches, queues):
            queue.put(batch)

        for q in queues:
            q.join()

    for p in processes:
        p.terminate()


def main() -> None:
    run_demo(4)


if __name__ == '__main__':
    main()

On CPU the behaviour is more random I sometimes observe the following error after some runtime:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/opt/conda/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 358, in reduce_storage
    metadata = storage._share_filename_cpu_()
RuntimeError: Trying to resize storage that is not resizable

while sometime the code runs successfully.

I verified that the code runs fine in 1.9.1+cu102 in both CPU and GPU but don't know about other versions.

Versions

Cuda environment:
PyTorch version: 2.0.0.dev20230130+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.27

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.14.301-224.520.amzn2.x86_64-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 79
Model name: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
Stepping: 1
CPU MHz: 2700.202
CPU max MHz: 3000.0000
CPU min MHz: 1200.0000
BogoMIPS: 4600.03
Hypervisor vendor: Xen
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 46080K
NUMA node0 CPU(s): 0-31
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtmrdseed adx xsaveopt

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] numpyro==0.6.0
[pip3] pytorch-triton==2.0.0+0d7e753227
[pip3] torch==2.0.0.dev20230130+cu116
[pip3] torchaudio==2.0.0.dev20230130+cu116
[pip3] torchvision==0.15.0.dev20230130+cu116
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.24.1 pypi_0 pypi
[conda] numpyro 0.6.0 pypi_0 pypi
[conda] pytorch-triton 2.0.0+0d7e753227 pypi_0 pypi
[conda] torch 2.0.0.dev20230130+cu116 pypi_0 pypi
[conda] torchaudio 2.0.0.dev20230130+cu116 pypi_0 pypi
[conda] torchvision 0.15.0.dev20230130+cu116 pypi_0 pypi

CPU environment:
PyTorch version: 2.0.0.dev20230130+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.27

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.15.49-linuxkit-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 6
On-line CPU(s) list: 0-5
Thread(s) per core: 1
Core(s) per socket: 6
Socket(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 158
Model name: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
Stepping: 10
CPU MHz: 2591.608
BogoMIPS: 5183.21
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 12288K
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch fsgsbase bmi1 avx2 smep bmi2 erms rdseed adx smap clflushopt xsaveopt xsavec xgetbv1 arat

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] numpyro==0.6.0
[pip3] pytorch-triton==2.0.0+0d7e753227
[pip3] torch==2.0.0.dev20230130+cu116
[pip3] torchaudio==2.0.0.dev20230130+cu116
[pip3] torchvision==0.15.0.dev20230130+cu116
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.24.1 pypi_0 pypi
[conda] numpyro 0.6.0 pypi_0 pypi
[conda] pytorch-triton 2.0.0+0d7e753227 pypi_0 pypi
[conda] torch 2.0.0.dev20230130+cu116 pypi_0 pypi
[conda] torchaudio 2.0.0.dev20230130+cu116 pypi_0 pypi
[conda] torchvision 0.15.0.dev20230130+cu116 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @VitalyFedyunin @ejguan

@vadimkantorov
Copy link
Contributor

Might be related: #95278: outputs from torch.split are view onto the giant storage. So worth checking if any copying of the full storage is happening... If the case, might be worth opening a feature request for something like split_copy / chunk_copy or introducing (copy = True) argument so that individual chunks can be copied without holding onto original storage.

@RobertoLat
Copy link
Author

Maybe I don't fully understand but why would it hang even if it was copying the entire tensor? The tensor is tiny (64 doubles).
Is there a change of behaviour for torch.split since version 1.9. It works fine with 1.9.

@mikaylagawarecki mikaylagawarecki added module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 1, 2023
@imurray
Copy link

imurray commented Mar 2, 2023

Inserting large_batch.share_memory() after creating large_batch fixes the problem on CPU for me.

queue.put(batch) causes batch to be shared (but asynchronously, not immediately), and because batch is a view into large_batch it actually shares all of the memory for large_batch. It seems that attempting to share large_batch's memory more than once, concurrently, is unsafe and sometimes causes crashes.

Is it possible for Pytorch to make put-ing views of unshared memory into a queue safe, or should the documentation warn about this case?

@imurray
Copy link

imurray commented Mar 4, 2023

This issue was caused by PR #85389. That change improved performance in some cases, but this issue shows it has also caused crashes in other code that worked in PyTorch 1.12 and earlier. (Mentioning @ezhang887 and @albanD who apparently discussed that PR.)

Reverting the change (as below) makes the code in this issue work again (tested on current git HEAD).

diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp
index bb66bfa3af5..cf7e40b97de 100644
--- a/torch/csrc/StorageSharing.cpp
+++ b/torch/csrc/StorageSharing.cpp
@@ -111,11 +111,7 @@ static PyObject* THPStorage_shareFilename(PyObject* _self, PyObject* noargs) {
         /*resizable=*/false));
 
     at::Storage _self_aten = torch::createStorage(_self);
-    {
-      // Copying into shared memory can be slow, so release the GIL
-      pybind11::gil_scoped_release no_gil;
-      at::storage_copy(new_storage, _self_aten);
-    }
+    at::storage_copy(new_storage, _self_aten);
 
     std::swap(*storage, *new_storage.unsafeGetStorageImpl());
     ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr());
@@ -200,11 +196,7 @@ static PyObject* THPStorage_shareFd(PyObject* _self, PyObject* noargs) {
   } else {
     at::Storage new_storage(at::new_shm_fd_storage(storage->nbytes()));
     at::Storage _self_aten = torch::createStorage(_self);
-    {
-      // Copying into shared memory can be slow, so release the GIL
-      pybind11::gil_scoped_release no_gil;
-      at::storage_copy(new_storage, _self_aten);
-    }
+    at::storage_copy(new_storage, _self_aten);
 
     std::swap(*storage, *new_storage.unsafeGetStorageImpl());
     ctx = at::MapAllocator::fromDataPtr(storage->data_ptr());

@albanD
Copy link
Collaborator

albanD commented Mar 10, 2023

Thanks for the detailed repro. I can reproduce locally but not sure what is the root cause tbh. I re-checked and storage_copy is safe to do without the GIL.
So there must be some weird interaction between view and this...

@cchan
Copy link
Contributor

cchan commented Mar 10, 2023

How are you checking that storage_copy is threadsafe? (Coincidentally our team hit the same issue recently too, lol)

@albanD
Copy link
Collaborator

albanD commented Mar 11, 2023

It is a pure C++ implementation that doesn't rely on any python object:

C10_EXPORT void storage_copy(
c10::Storage& dst,
const c10::Storage& src,
bool non_blocking) {
auto dst_options = c10::TensorOptions().device(dst.device()).dtype(at::kByte);
auto dst_t = at::empty({0}, {}, dst_options).set_(dst);
auto src_options = c10::TensorOptions().device(src.device()).dtype(at::kByte);
auto src_t = at::empty({0}, {}, src_options).set_(src);
dst_t.copy_(src_t, non_blocking);
}

@cchan
Copy link
Contributor

cchan commented Mar 11, 2023

I mean that even in pure C++, perhaps e.g. set_() is not a threadsafe operation? The GIL was implicitly locking access to this C++ code so perhaps this code was never threadsafe and we're only now seeing it because of the GIL release.

It's not obvious to me why it would be unsafe though.

@imurray
Copy link

imurray commented Mar 11, 2023

The issue is that queue.put(batch) shares the memory in batch if the memory is not already shared, and does that in a separate thread. If the GIL is released, then other python code that uses batch can run while the object is being moved, and the old memory for the object is being freed.

In the original code in this issue, the batch objects are a view into the same object. On each loop, the parent large_batch object is shared if it isn't already. When the second iteration of the loop comes along, the thread kicked off in the first iteration hasn't always finished, so we sometimes start to try to share the object again. Then multiple things could go wrong, a double-free looks likely (I haven't checked what actually happens).

Another way to trigger the issue is to send the same large_batch (which isn't actually that large) to multiple queues. Again, if it's not already shared then the multiple put calls will try to share it more than once.

However, considering views and the share code being called twice is probably a distraction. Any attempt to use batch while its memory is being shared could lead to crashes. The cpu-only code below is another example. Here there is no view, so all the batches are separate objects, but crashes still occur because the computation on batch (in the line after putting it into the queue) happens while the underlying memory is being moved around.

import os
import sys
import tempfile

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def demo_basic(rank: int, queue: mp.JoinableQueue, world_size: int) -> None:
    dist.init_process_group(backend='gloo', init_method='tcp://{}'.format('127.0.0.1:23456'), rank=rank, world_size=world_size)
    while True:
        batch = queue.get()
        assert(batch.sum() == batch.numel())
        queue.task_done()

def split_batch(batch: torch.Tensor, world_size: int) -> torch.Tensor:
    # The .clone() is so there are no views to confuse the issue:
    return [x.clone() for x in torch.split(batch, batch.shape[0] // world_size)]

def run_demo(world_size: int) -> None:
    print(torch.__version__, file=sys.stderr)
    num_batches = 10000
    large_batch_size = 64

    ctx = mp.get_context('spawn')
    queues = [ctx.JoinableQueue() for _ in range(world_size)]
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    processes = [ctx.Process(target=demo_basic, args=(i, queues[i], world_size)) for i in range(world_size)]

    for p in processes:
        p.start()

    for i in range(num_batches):
        large_batch = torch.ones(large_batch_size)
        batches = split_batch(large_batch, world_size)

        for batch, queue in zip(batches, queues):
            #batch.share_memory_() # This line would prevent the crashes
            queue.put(batch)
            assert(batch.sum() == batch.numel()) # trying to use batch when it might be in the middle of being moved.

        for q in queues:
            q.join()

    for p in processes:
        p.terminate()

if __name__ == '__main__':
    run_demo(4)
    print('Finished without crashing.')

@ezyang
Copy link
Contributor

ezyang commented Mar 11, 2023

I suggest we yank the optimization PR for now.

@albanD
Copy link
Collaborator

albanD commented Mar 13, 2023

@imurray

Note that assert(batch.sum() == batch.numel()) # trying to use batch when it might be in the middle of being moved. is unsafe even with the old code that was acquiring the GIL. the sum() call itself releases the GIL so the two can overlap even if the move didn't finish. Note that even if neither were releasing the GIL, there is still a very very small chance that the GIL will be passed from one thread to the other and crash in the same way.

The root problem here is that queue.put(batch) looks like a nice out of place ops but it is actually doing an inplace op on a Tensor in a separate threads.
In general, the Tensor object is not expected to be thread safe wrt inplace ops (like any most other python and c++ object). So modifying it inplace while you read it in a different thread without locking is expected to go bad. I'm not sure there is much we can do about this.
I tested this locally and even adding the GIL lock back, this will still fail.

The fact that calling share_memory_ explicitly fixes it is because the two calls are now from the same thread and so a naturally serialized.

While it is not fixing the entire problem, I propose #96664 to keep the benefits of parallel per-storage reduction without the issue raised at the top of this issue.

pytorchmergebot pushed a commit that referenced this issue Mar 14, 2023
To achieve this, I have a per-StorageImpl (was data_ptr in the previous version of this PR, but moved to StorageImpl to ensure stability of the key before/after sharing) lock created when we are about to share a storage and make sure that all other calls to share memory wait on this lock before moving forward.
This does NOT make this call generally thread safe as any call that is not sharing memory will race and lead to UB.

This makes ensures that the sample from @RobertoLat in #95606 works fine.
This does NOT fix the example from @imurray in that same issue as the call still race with the `.sum()` call. This race is expected and there is no easy way for us to make it work I'm afraid (see issue for more details).

Pull Request resolved: #96664
Approved by: https://github.com/colesbury
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Mar 23, 2023
To achieve this, I have a per-StorageImpl (was data_ptr in the previous version of this PR, but moved to StorageImpl to ensure stability of the key before/after sharing) lock created when we are about to share a storage and make sure that all other calls to share memory wait on this lock before moving forward.
This does NOT make this call generally thread safe as any call that is not sharing memory will race and lead to UB.

This makes ensures that the sample from @RobertoLat in pytorch/pytorch#95606 works fine.
This does NOT fix the example from @imurray in that same issue as the call still race with the `.sum()` call. This race is expected and there is no easy way for us to make it work I'm afraid (see issue for more details).

Pull Request resolved: pytorch/pytorch#96664
Approved by: https://github.com/colesbury
@byronyi
Copy link

byronyi commented Mar 23, 2023

Would #83623 be relevant, too?

cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Mar 27, 2023
To achieve this, I have a per-StorageImpl (was data_ptr in the previous version of this PR, but moved to StorageImpl to ensure stability of the key before/after sharing) lock created when we are about to share a storage and make sure that all other calls to share memory wait on this lock before moving forward.
This does NOT make this call generally thread safe as any call that is not sharing memory will race and lead to UB.

This makes ensures that the sample from @RobertoLat in pytorch/pytorch#95606 works fine.
This does NOT fix the example from @imurray in that same issue as the call still race with the `.sum()` call. This race is expected and there is no easy way for us to make it work I'm afraid (see issue for more details).

Pull Request resolved: pytorch/pytorch#96664
Approved by: https://github.com/colesbury
@albanD
Copy link
Collaborator

albanD commented Mar 27, 2023

@byronyi yes this is the PR that uncovered the issue by making race a lot more common.

#96664 is fixing the hangs (note that the mix of write/read cannot really be fixed though).

@albanD albanD closed this as completed Mar 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants