Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The result of gloo all_gather error #20421

Closed
qijianan777 opened this issue May 13, 2019 · 4 comments
Closed

The result of gloo all_gather error #20421

qijianan777 opened this issue May 13, 2019 · 4 comments
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@qijianan777
Copy link

qijianan777 commented May 13, 2019

I use gloo for Model Parallelism, when I use all_gather, the result is error.

There are two process, I expect the all_gather result is [Tensor1, Tensor2], but the result is actually [Tensor1, Tensor1].
The Tensor2 is like this
image
The result is like this, The senod Tensor in the result should be equal as the Tensor2
image
But before all_gather, I use torch.reshape(torch.tensor(range(xxxx), dtype=torch.float32), [16, 16, 16]) to create two Tensor and all_gather. The result is correctly.
The code is:

for _ in range(stage.get_devices_num()):
        gather_tensor.append(torch.zeros_like(in_slice))
dist.all_gather(gather_tensor, in_slice.contiguous(), group=group)

Environment:
macos
pytorch 1.0.1
pytorch-cpu 1.1.0
numpy 1.16.2

PS: We use torch.chunk to split the Tensor and the dim is 0, and all_gather the chunked tensor by gloo, the all_gather result is error.I think although I use contiguous() to make memory contiguous, but the it is not effective after I chunk the tensor at dim 0.

@jeffreyksmithjr jeffreyksmithjr added oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 13, 2019
@mrshenli
Copy link
Contributor

mrshenli commented Jun 3, 2019

@qijianan777 can you share a minimum repro for this problem. I tried the following, but cannot reproduce the error.

import os, sys
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

def run(rank, size):
    x = torch.ones(2, 2) * rank;
    outputs = []
    outputs.append(torch.zeros_like(x))
    outputs.append(torch.zeros_like(x))
    dist.all_gather(outputs, x)

    print("rank ", rank, ": ", outputs)
    sys.stdout.flush()

def init_processes(rank, size, fn, backend='gloo'):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


if __name__ == "__main__":
    size = 2
    processes = []
    for rank in range(size):
        p = Process(target=init_processes, args=(rank, size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

@mrshenli mrshenli self-assigned this Jun 3, 2019
@mrshenli mrshenli added awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user and removed awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it labels Jun 3, 2019
@qijianan777
Copy link
Author

We use torch.chunk to split a Tensor to a and b, and gather the a and b.I will try to reproduce the error by you code.

@qijianan777
Copy link
Author

qijianan777 commented Jun 6, 2019

The code which reproduce the error

import os, sys
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

def run(rank, size, x):
    # The result is error when dim is 0
    input = torch.chunk(x, size, dim=0)[rank]
    # The result is right when dim is not 0
    # input = torch.chunk(x, size, dim=1)[rank]
    # The result is right when use clone
    # input = torch.chunk(x, size, dim=0)[rank].clone()
    outputs = []
    outputs.append(torch.zeros_like(input))
    outputs.append(torch.zeros_like(input))
    dist.all_gather(outputs, input)

    print("rank ", rank, ": ", outputs)
    sys.stdout.flush()

def init_processes(rank, size, fn , x, backend='gloo'):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, x)


if __name__ == "__main__":
    size = 2
    processes = []
    x = torch.reshape(torch.tensor(range(4), dtype=torch.float32), [2, 2])
    print(x)
    for rank in range(size):
        p = Process(target=init_processes, args=(rank, size, run, x))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

@mrshenli
Copy link
Contributor

mrshenli commented Jun 6, 2019

@qijianan777 confirm that I can reproduce, and this is indeed a bug in ProcessGroupGloo. More specifically, when you do chunk on dim=0, the result tensors share the same underlying storage and are contiguous, but with different offset. So, when we do flat the tensors here, it will do nothing. Later, when retrieving data pointer of the tensors, it will return the same ptr value (this is the bug). As a result, both processes are reading the first 2 elements.

Thanks for reporting, I will add a fix for it.

@mrshenli mrshenli removed the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Jun 6, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants