Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented NCCL Distributed Backend for PyTorch with new dist APIs #3435

Merged
merged 12 commits into from Nov 29, 2017

Conversation

teng-li
Copy link
Contributor

@teng-li teng-li commented Nov 2, 2017

This PR added a new PyTorch distributed backend called NcclDataChannel that uses NVIDIA NCCL2.0+ for collective operations across different GPUs on different nodes. The Nccl backend uses TCP socket as the rendevous for the initial broadcast of the unique NCCL ID.

The following new PyTorch APIs are added to perform collective operations on multiple GPU tensors, each of which resides on a different GPU.

    torch.distributed.all_reduce_multigpu(tensor_list, op=reduce_op.SUM, group=group.WORLD)
    torch.distributed.reduce_multigpu(tensor_list, dst, op=reduce_op.SUM, group=group.WORLD)
    torch.distributed.all_gather_multigpu(output_tensor_lists, input_tensor_list, group=group.WORLD)
    torch.distributed.broadcast_multigpu(tensor_list, src, group=group.WORLD)
    torch.distributed.destroy_process_group()

How to use?

torch.distributed.init_process_group("nccl")
# Optional
grp = torch.distributed.new_group()
torch.distributed.all_reduce_multigpu(tensor_list, group=grp)
# Clean shutdown to release all GPU resources
torch.distributed.destroy_process_group()

The exiting pytorch API such as torch.distributed.all_reduce works with "NCCL" backend as well.
Added some tiny new code paths for distributed data parallel model to use NCCL backend.

Testing:

All new functions are tested using my own test scripts.

On DGX1, 8 nodes with 4 Infinibands. I am able to run ResNet50 (that uses the modified distributed parallel model) for 140+ epochs without any issues.

TODO will be writing unit-test for the entire backend. Current marked as experimental only

::close(_masterListeningSocket);
_masterListeningSocket = -1;
}
int curDevice = 0;

This comment was marked as off-topic.

}

if (_ncclCommsAndEvents.find(groupId) != _ncclCommsAndEvents.end()) {
return std::make_pair(_ncclCommsAndEvents[groupId].first.get(),

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


if (input.size() <= 0) {
throw std::runtime_error("Input tensor sequence cannot be empty");
}

This comment was marked as off-topic.

}
if (input.size() != output.size()) {
throw std::runtime_error("Input tensor sequence should have the same size "
"as output tensor sequence");

This comment was marked as off-topic.

THCudaCheck(cudaSetDevice(curDevice));

// Move into the hash table
_ncclCommsAndEvents[groupId] = std::move(std::make_pair(std::move(comms),

This comment was marked as off-topic.

throw std::runtime_error("Expecting inputs on different GPU devices");
}

usedDevices.insert(input[i].get_device());

This comment was marked as off-topic.

@csarofeen
Copy link
Contributor

Distributed table needs to be modified to enforce ordering of all reduce calls. If the order is not the same across processes NCCL will hang.

@teng-li
Copy link
Contributor Author

teng-li commented Nov 3, 2017

@csarofeen Could you elaborate more on this?

@csarofeen
Copy link
Contributor

Since backward is multi-threaded the following could be possible:
processes one finishes params in bucket[0] then params in bucket[1]
processes two finishes params in bucket[1] then params in bucket[0]
The way the code is structured means that process one will try to all_reduce bucket[0] and process two will try to all_reduce bucket[1].
This will deadlock nccl.
Distributed should enforce that all_reduce call on bucket[i] cannot be issued before all_reduce call on bucket[i-1]

@teng-li
Copy link
Contributor Author

teng-li commented Nov 3, 2017

@csarofeen Thanks for the explanation, I will get a work-around on this to maintain the thread order.

@teng-li
Copy link
Contributor Author

teng-li commented Nov 8, 2017

@ngimel Addressed your comments,

@csarofeen I made distributed.py use a single reduction thread since all the NCCL calls are asynchronous, using multiple threads wouldn't provide much perf gain for us. Also, it would be pretty tricky to get each thread execute in a certain order. This simple one thread model works without NCCL deadlock issues as I tested. So each reduction bucket will write the reduction request to a single queue, which will be call NCCL backend in the order which the gradients become available. The single queue will be assigned to a dedicated single reduction thread to process the NCCL reduction one after another while maintaining the NCCL call order.

@ngimel
Copy link
Collaborator

ngimel commented Nov 8, 2017

You also have to use gpuGuard.setDevice instead of cudaSetDevice, it won't work otherwise (when you are creating gpuGuard without arguments, it sets original_device to -1, when you are destroying it without actually setting devices in between, it's a no-op).
Otherwise, LGTM

@teng-li
Copy link
Contributor Author

teng-li commented Nov 8, 2017

@ngimel Oh right, forgot to update it, now updated.

@csarofeen
Copy link
Contributor

@teng-li Sorry, I think I'm missing something. I don't understand how you guarantee the order of all_reduce calls will be the same across processes. Can you point me to the right section?

# nodes to avoid the deadlock. In other words, we will only maintain
# a single reduction thread for this purpose. This is OK since all
# distributed NCCL calls are asynchronous
self._reduction_queues = [queue.Queue()]

This comment was marked as off-topic.


if dist._backend == "nccl":
# NCCL backend, all buckets will share a single reduction queue
self._reduction_queues[0].put((dev_buckets, dev_events, event))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@teng-li
Copy link
Contributor Author

teng-li commented Nov 8, 2017

@apaszke Would like to hear your comments too.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think requiring users to call torch.distributed.destroy_process_group() is necessary. If NCCL needs this then register an atexit hendler.

@@ -574,6 +574,13 @@ def _join_and_reduce(self, fn):
class TestMPI(TestCase, _DistTestBase):
pass

elif BACKEND == 'nccl':
dist.init_process_group(init_method=INIT_METHOD, backend='nccl')
# TODO

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

PyObject* THDPModule_allReduceMultiGPU(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
PyObject* sequence = PyTuple_GET_ITEM(args, 0);

This comment was marked as off-topic.

std::vector<at::Tensor> descriptors;
std::vector<at::Tensor> raw_descriptors;
THDGroup group;
THDReduceOp op;

This comment was marked as off-topic.

This comment was marked as off-topic.

std::vector<at::Tensor> raw_descriptors;
THDGroup group;
THDReduceOp op;
int dst_rank;

This comment was marked as off-topic.

Py_ssize_t tmp_length;
std::size_t length;
std::vector<at::Tensor> descriptors;
std::vector<at::Tensor> raw_descriptors;

This comment was marked as off-topic.

This comment was marked as off-topic.


if dist._backend == "nccl":
# NCCL backend, all buckets will share a single reduction queue
self._reduction_queues[0].put((dev_buckets, dev_events, event))

This comment was marked as off-topic.

for reduction_stream in dev_r_streams:
default_stream.wait_stream(reduction_stream)
# We sync on the default streams for NCCL backend since all
# nccl reduction kernels goes to the default streams

This comment was marked as off-topic.

self._reduction_threads.append(threading.Thread(
target=self._reduction_thread_fn,
args=(reduction_queue, group_id, self.device_ids, reduction_streams, self._nccl_streams)))
if dist._backend == "nccl":

This comment was marked as off-topic.


if dist._backend == "nccl":
# NCCL backend, all buckets will share a single reduction queue
self._reduction_queues[0].put((dev_buckets, dev_events, event))

This comment was marked as off-topic.


# TODO: remove nccl.reduce with
# dist.all_reduce_multigpus
nccl.reduce(dev_coalesced, root=0, streams=default_streams)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@pietern
Copy link
Contributor

pietern commented Nov 10, 2017

@pytorchbot retest this please


// Use the socket to broadcast NCCL ID
void DataChannelNccl::broadcastUniqueNcclId(ncclUniqueId* srcNcclId,
ncclUniqueId* dstNcclId) {

This comment was marked as off-topic.


// Use the socket to broadcast NCCL ID
void DataChannelNccl::broadcastUniqueNcclId(ncclUniqueId* srcNcclId,
ncclUniqueId* dstNcclId) {

This comment was marked as off-topic.

gpuGuard.setDevice(devices[idx++]);
THCudaCheck(cudaEventSynchronize(event));
}
}

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

dist.barrier(grp)
dist.destroy_group(grp)
self.bcast_grp = dist.new_group()
self.reduce_grp = set()

This comment was marked as off-topic.

@teng-li
Copy link
Contributor Author

teng-li commented Nov 13, 2017

@apaszke addressed your comments:
(1) removed barrier support, but we still need these events for destroy functions, and later these events are useful too for benchmarking and other extension work, so I keep them in the backend code.
(2) added warning for users to use with care with NCCL2
(3) made the experimental function in distributed.py private in order to not confuse the user.

@csarofeen
Copy link
Contributor

@teng-li Can you create a small pytorch repro script for the deadlocking outside data parallel? If so I can take a look at it and send it over to the NCCL guys to check as well.

@teng-li
Copy link
Contributor Author

teng-li commented Nov 21, 2017

@csarofeen Will do. I will send you later.
@apaszke Made the simplest change to distributed.py with a single thread, which actually gives the best performance for ImageNet on DGX1 with very good scalability. Now the basic distributed data parallel model works with NCCL2 backend.

}

length = static_cast<std::size_t>(PySequence_Length(sequence));
THPUtils_assert(length >= 0, "couldn't obtain the length of %s",

This comment was marked as off-topic.

This comment was marked as off-topic.

}

length = static_cast<std::size_t>(PySequence_Length(sequence));
THPUtils_assert(length >= 0, "couldn't obtain the length of %s",

This comment was marked as off-topic.

This comment was marked as off-topic.

goto invalid_arguments;
}

length = static_cast<std::size_t>(PySequence_Length(sequence));

This comment was marked as off-topic.

This comment was marked as off-topic.

def init_process_group(backend, init_method='env://', **kwargs):
"""Initializes the distributed package.

Arguments:
backend (str): Name of the backend to use. Depending on build-time configuration
valid values include: ``tcp``, ``mpi`` and ``gloo``.
valid values include: ``tcp``, ``mpi``, and ``gloo```

This comment was marked as off-topic.

This comment was marked as off-topic.


_INITIALIZED_PG = 1
_INITIALIZED_MW = 2
_initialized = 0
_backend = ""

This comment was marked as off-topic.

This comment was marked as off-topic.

rank_type dst_rank, THDGroup group_id = THDGroupWORLD) = 0;
virtual void broadcast(at::Tensor& data, rank_type src_rank,
// Reduce multiple GPUs on a number of nodes
virtual void reduce(std::vector<at::Tensor>& data,

This comment was marked as off-topic.

This comment was marked as off-topic.

THDReduceOp operation,
rank_type dst_rank,
THDGroup group_id = THDGroupWORLD) = 0;
// Reduce multiple GPUs on a number of nodes

This comment was marked as off-topic.

This comment was marked as off-topic.

virtual void allGather(std::vector<at::Tensor>& output, at::Tensor& input,
// All gather multiple GPUs on a number of nodes
virtual void allGather(std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,

This comment was marked as off-topic.

This comment was marked as off-topic.

}
NCCL_CHECK(ncclGroupEnd());

cudaFreeMutexLock.unlock();

This comment was marked as off-topic.

This comment was marked as off-topic.

THDGroup newGroupId = static_cast<THDGroup>(_groups.size());

// Insert the current group
_groups.insert({

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DataChannel itself could be simplified, but that's not a blocker for this merge (only for public release). There are two things that need fixing - the raw string prefix in DDP docstring, and incorrect usage of PySequence_Fast.

goto invalid_arguments;
}

tmp_length = PySequence_Fast_GET_SIZE(sequence);

This comment was marked as off-topic.

This comment was marked as off-topic.

elif backend == "nccl":
_backend = dist_backend.NCCL
else:
raise RuntimeError("Invalid distributed backend name detected")

This comment was marked as off-topic.

This comment was marked as off-topic.

private:

std::pair<std::unique_ptr<std::vector<ncclComm_t>>,
std::unique_ptr<std::vector<cudaEvent_t>>> _commEventPair;

This comment was marked as off-topic.

* Note that the order of the device for the tensor list matters and
* each group can only be associated with one used device string
*/
std::unordered_map<THDGroup, std::string> _groupDevices;

This comment was marked as off-topic.

THDGroup newGroupId = static_cast<THDGroup>(_groups.size());

// Insert the current group
_groups.insert({

This comment was marked as off-topic.

@@ -22,7 +22,7 @@


class DistributedDataParallel(Module):
r"""Implements distributed data parallelism at the module level.
"""Implements distributed data parallelism at the module level.

This comment was marked as off-topic.

@@ -85,7 +80,11 @@ def init_process_group(backend, init_method='env://', **kwargs):
elif backend == "nccl":
_backend = dist_backend.NCCL
else:
raise RuntimeError("Invalid distributed backend name detected")
raise RuntimeError("Invalid distributed backend name: " + backend)

This comment was marked as off-topic.

@teng-li
Copy link
Contributor Author

teng-li commented Nov 29, 2017

@apaszke

Added two commits: The first one used the correct PySequence Fast for everyone in the module.cpp and added the "r" back. Also changed and tested and made sure that the invalid backend error can be thrown.

The second one removed the limitation that each group is bound to a given device sequence. But I don't feel comfortable to cache many NCCL communicators. So the current design is, for each process group, we only cache a single NCCL communicator (say with device sequence 0,1,2,3,4,5,6,7) If there is a new collective NCCL call with a new device sequence (say 1,0,2,3,4,5,6,7). We will destroy the existing cached NCCL communicator (with sequence 0,1,2,3,4,5,6,7), rebuild a new communicator and cache this one instead. This design is pretty clean and simple. I tested with different new device sequence and made sure it worked. This shouldn't be a release blocker for now

For the NCCL deadlock, it's not super clear to me why but it happens on several consecutive broadcast calls using the same communicator. If the communicator is rebuilt, it will work (as the current workaround).

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebuilding communicators sounds quite expensive, so I'm not sure if that's the best solution, but let's stick to that for now.

I still have a few comments for the Python bindings, but that's it. Right now they are leaking memory.

goto invalid_arguments;
}

sequence = PySequence_Fast(PyTuple_GET_ITEM(args, 0), nullptr);

This comment was marked as off-topic.

This comment was marked as off-topic.

}

tmp_length = PySequence_Fast_GET_SIZE(sequence);
THPUtils_assert(tmp_length >= 0, "couldn't obtain the length of %s",

This comment was marked as off-topic.

goto invalid_arguments;
}

sequence = PySequence_Fast(PyTuple_GET_ITEM(args, 0), nullptr);

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Nov 29, 2017

Also, when you make new changes, please add them as new commits at the end instead of squashing. It makes repeating the review much easier, because I explicitly see what changed. You can always clean up the history right before merge.

@soumith soumith merged commit 926ed2b into pytorch:master Nov 29, 2017
@seba-1511
Copy link

@teng-li This work looks quite good. I like the distributed interface much more than the NCCL one, even for single machine training.

Is there a way I can contact you privately ? I would like to start playing with this as soon as possible, so as to add it to the distributed tutorial.

@teng-li
Copy link
Contributor Author

teng-li commented Nov 30, 2017

@seba-1511 Thank you, sure, you can leave me your email and I will send you an email.

@teng-li teng-li deleted the nccl2 branch November 30, 2017 05:55
@seba-1511
Copy link

My email is: arnolds at usc dot edu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants