-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The result of gloo all_gather error #20421
Comments
@qijianan777 can you share a minimum repro for this problem. I tried the following, but cannot reproduce the error. import os, sys
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
def run(rank, size):
x = torch.ones(2, 2) * rank;
outputs = []
outputs.append(torch.zeros_like(x))
outputs.append(torch.zeros_like(x))
dist.all_gather(outputs, x)
print("rank ", rank, ": ", outputs)
sys.stdout.flush()
def init_processes(rank, size, fn, backend='gloo'):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
size = 2
processes = []
for rank in range(size):
p = Process(target=init_processes, args=(rank, size, run))
p.start()
processes.append(p)
for p in processes:
p.join() |
We use torch.chunk to split a Tensor to a and b, and gather the a and b.I will try to reproduce the error by you code. |
The code which reproduce the error import os, sys
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
def run(rank, size, x):
# The result is error when dim is 0
input = torch.chunk(x, size, dim=0)[rank]
# The result is right when dim is not 0
# input = torch.chunk(x, size, dim=1)[rank]
# The result is right when use clone
# input = torch.chunk(x, size, dim=0)[rank].clone()
outputs = []
outputs.append(torch.zeros_like(input))
outputs.append(torch.zeros_like(input))
dist.all_gather(outputs, input)
print("rank ", rank, ": ", outputs)
sys.stdout.flush()
def init_processes(rank, size, fn , x, backend='gloo'):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size, x)
if __name__ == "__main__":
size = 2
processes = []
x = torch.reshape(torch.tensor(range(4), dtype=torch.float32), [2, 2])
print(x)
for rank in range(size):
p = Process(target=init_processes, args=(rank, size, run, x))
p.start()
processes.append(p)
for p in processes:
p.join() |
@qijianan777 confirm that I can reproduce, and this is indeed a bug in Thanks for reporting, I will add a fix for it. |
I use gloo for Model Parallelism, when I use all_gather, the result is error.
There are two process, I expect the all_gather result is [Tensor1, Tensor2], but the result is actually [Tensor1, Tensor1].
The Tensor2 is like this
The result is like this, The senod Tensor in the result should be equal as the Tensor2
But before all_gather, I use torch.reshape(torch.tensor(range(xxxx), dtype=torch.float32), [16, 16, 16]) to create two Tensor and all_gather. The result is correctly.
The code is:
Environment:
macos
pytorch 1.0.1
pytorch-cpu 1.1.0
numpy 1.16.2
PS: We use torch.chunk to split the Tensor and the dim is 0, and all_gather the chunked tensor by gloo, the all_gather result is error.I think although I use contiguous() to make memory contiguous, but the it is not effective after I chunk the tensor at dim 0.
The text was updated successfully, but these errors were encountered: