Permalink
Browse files

Parallelize hierarchical allreduce algorithm (#411)

* Change Hierarchical Allreduce algorithm into NCCLReduceScatter - MPIAllreduce - NCCLAllgather pattern to parallelize inter-node reduction and improve Hierarchical Allreduce performance

* Remove NCCL_REDUCE AND NCCL_BCAST definitions from header file, update timeline.md appropriately

* Offset buffers for ReduceScatter and Allgather to allow in-place operations; pad buffers before hierarchical allreduce to make data size a multiple of # local ranks

* Do hybrid hierarchical allreduce: First do NCCLReduceScatter - MPIAllreduce - NCCLAllgather for the part of data divisble by hvd.local_rank(), then do NCCLReduce - MPIAllreduce - NCCLBcast for the remainder

* Fix formatting and variable names

* Fix bug in offsetting pointer before operating on remainder data in hierarchical allreduce

* Add synchronization before operating on remainder data to make timeline work properly

* Make hierarchical allreduce steps pipelined: (NCCL ReduceScatter + NCCL Reduce) / single MPI Allreduce / (NCCL Allgather + NCCL Bcast)

* Clean up comments

* Add support for heterogeneous clusters; for homogeneous case allocate fusion buffer size divisible by local_size

* Round up tensor fusion threshold; round up even when the environment var is not set; do synchronous memory copy to host buffer to produce correct timeline

* For hierarchical allreduce make sure num_elements is divisible by 64 for improved performance

* Define fusion buffer atomic unit (64) as a constant

* Free the local_sizes buffer after using, during initialization
  • Loading branch information...
karakusc authored and alsrgv committed Sep 6, 2018
1 parent bdbd056 commit 9166c1af27c7ff27c92d05dbb69754dcb9c8f02e
Showing with 208 additions and 54 deletions.
  1. +2 −2 docs/timeline.md
  2. +198 −50 horovod/common/operations.cc
  3. +8 −2 horovod/common/operations.h
@@ -39,5 +39,5 @@ workers were early and which were late.
* *NCCL_ALLREDUCE*, *MPI_ALLREDUCE*, *MPI_ALLGATHER*, or *MPI_BCAST* indicate time taken to do the actual operation on GPU
(or CPU) and highlights whether the operation was performed using NCCL or pure MPI.
* In case of `HOROVOD_HIERARCHICAL_ALLREDUCE=1`, *NCCL_ALLREDUCE* will become a sequence of *NCCL_REDUCE*,
*MEMCPY_IN_HOST_BUFFER*, *MPI_ALLREDUCE*, *MEMCPY_OUT_HOST_BUFFER*, *NCCL_BCAST*.
* In case of `HOROVOD_HIERARCHICAL_ALLREDUCE=1`, *NCCL_ALLREDUCE* will become a sequence or a subsequence of *NCCL_REDUCESCATTER*,
*NCCL_REDUCE*, *MEMCPY_IN_HOST_BUFFER*, *MPI_ALLREDUCE*, *MEMCPY_OUT_HOST_BUFFER*, *NCCL_ALLGATHER*, *NCCL_BCAST*.
@@ -176,6 +176,7 @@ struct HorovodGlobalState {
int local_size = 1;
int cross_size = 1;
bool mpi_threads_supported = false;
bool is_homogeneous = false;
std::vector<int> ranks;
// COMM_WORLD ranks of processes running on this node.
@@ -258,11 +259,11 @@ const Status SHUT_DOWN_ERROR = Status::Aborted(
#define OP_ERROR(entries, error_message) \
{ \
for (auto& e : (entries)) { \
timeline.End(e.tensor_name, nullptr); \
e.callback(Status::UnknownError(error_message)); \
} \
return; \
for (auto& e : (entries)) { \
timeline.End(e.tensor_name, nullptr); \
e.callback(Status::UnknownError(error_message)); \
} \
return; \
}
// Store the MPIRequest for a name, and return whether the total count of
@@ -920,13 +921,15 @@ void PerformOperation(TensorTable& tensor_table, MPIResponse response) {
// Initialize DDL
auto ddl_options = std::getenv("DDL_OPTIONS");
if (ddl_options == nullptr) {
OP_ERROR(entries, "DDL_OPTIONS env variable needs to be set to use DDL.")
OP_ERROR(entries,
"DDL_OPTIONS env variable needs to be set to use DDL.")
}
DDL_CHECK(entries, "ddl_init", ddl_init(ddl_options))
horovod_global.ddl_initialized = true;
horovod_global.ddl_local_device_id = first_entry.device;
} else if (horovod_global.ddl_local_device_id != first_entry.device) {
OP_ERROR(entries, "DDL does not support more than one GPU device per process.")
OP_ERROR(entries,
"DDL does not support more than one GPU device per process.")
}
#endif
@@ -959,7 +962,9 @@ void PerformOperation(TensorTable& tensor_table, MPIResponse response) {
cudaMemcpyDeviceToDevice, stream))
offset += e.tensor->size();
}
buffer_len = (size_t)offset;
if (timeline.Initialized() || horovod_global.ddl_initialized) {
RECORD_EVENT(entries, event_queue, MEMCPY_IN_FUSION_BUFFER, stream)
}
@@ -971,23 +976,25 @@ void PerformOperation(TensorTable& tensor_table, MPIResponse response) {
for (auto& e : entries) {
num_elements += e.tensor->shape().num_elements();
}
} else {
fused_input_data = first_entry.tensor->data();
buffer_data = (void*)first_entry.output->data();
num_elements = first_entry.tensor->shape().num_elements();
buffer_len = (size_t)first_entry.output->size();
if (horovod_global.ddl_initialized) {
// Copy input buffer content to output buffer
// because DDL only supports in-place allreduce
CUDA_CHECK(entries, "cudaMemcpyAsync",
cudaMemcpyAsync(buffer_data, fused_input_data,
buffer_len,
cudaMemcpyAsync(buffer_data, fused_input_data, buffer_len,
cudaMemcpyDeviceToDevice, stream))
RECORD_EVENT(entries, event_queue, MEMCPY_IN_FUSION_BUFFER, stream)
}
}
void* host_buffer = nullptr;
#if HOROVOD_GPU_ALLREDUCE == 'D'
// Synchronize.
WAIT_FOR_EVENTS(entries, timeline, event_queue)
@@ -998,68 +1005,169 @@ void PerformOperation(TensorTable& tensor_table, MPIResponse response) {
OP_ERROR(entries, ex.what())
}
DDL_CHECK(entries, "ddl_allreduce",
ddl_allreduce(buffer_data,
(size_t)num_elements,
ddl_data_type,
ddl_allreduce(buffer_data, (size_t)num_elements, ddl_data_type,
DDL_OP_SUM))
#else
if (horovod_global.hierarchical_allreduce) {
NCCL_CHECK(entries, "ncclReduce",
ncclReduce(fused_input_data, buffer_data,
(size_t)num_elements,
GetNCCLDataType(first_entry.tensor), ncclSum, 0,
nccl_comm, stream))
if (timeline.Initialized()) {
RECORD_EVENT(entries, event_queue, NCCL_REDUCE, stream)
int element_size = buffer_len / num_elements;
// If cluster is homogeneous and we are using fusion buffer, include
// dummy elements from the buffer (if necessary) to make sure the data
// is divisible by local_size. This is always possible since we
// set the fusion buffer size divisible by local_size.
if (horovod_global.is_homogeneous && entries.size() > 1) {
// Making sure the number of elements is divisible by
// FUSION_BUFFER_ATOMIC_UNIT for improved performance
int div = horovod_global.local_size * FUSION_BUFFER_ATOMIC_UNIT;
num_elements = ((num_elements + div - 1) / div) * div;
buffer_len = num_elements * element_size;
}
if (horovod_global.local_rank == 0) {
// Split the elements into two groups: num_elements_per_rank*local_size,
// and num_elements_remaining. Cross-node reduction for the first group
// is done by all local_rank's in parallel, while for the second group
// it it is only done by the root_rank. If the cluster is not
// homogeneous first group is zero, and root_rank is 0.
// Homogeneous case:
// For the part of data divisible by local_size, perform NCCL
// ReduceScatter - Parallelized MPI Allreduce - NCCL Allgather. For the
// non-divisible part (if any), do NCCL Reduce (at rank local_size-1),
// MPI Allreduce (across rank (local_size-1)'s), and NCCL Bcast
int64_t num_elements_per_rank =
horovod_global.is_homogeneous
? num_elements / horovod_global.local_size
: 0;
size_t buffer_len_per_rank = element_size * num_elements_per_rank;
void* buffer_data_at_rank_offset =
(uint8_t*)buffer_data +
buffer_len_per_rank * horovod_global.local_rank;
int64_t num_elements_remaining =
horovod_global.is_homogeneous
? num_elements % horovod_global.local_size
: num_elements;
size_t buffer_len_remaining = element_size * num_elements_remaining;
void* buffer_data_remainder =
(uint8_t*)buffer_data +
buffer_len_per_rank * horovod_global.local_size;
void* fused_input_data_remainder =
(uint8_t*)fused_input_data +
buffer_len_per_rank * horovod_global.local_size;
int root_rank =
horovod_global.is_homogeneous ? horovod_global.local_size - 1 : 0;
bool is_root_rank = horovod_global.local_rank == root_rank;
int64_t total_num_elements =
is_root_rank ? num_elements_per_rank + num_elements_remaining
: num_elements_per_rank;
int64_t total_buffer_len =
is_root_rank ? buffer_len_per_rank + buffer_len_remaining
: buffer_len_per_rank;
if (num_elements_per_rank > 0) {
NCCL_CHECK(entries, "ncclReduceScatter",
ncclReduceScatter(fused_input_data,
buffer_data_at_rank_offset,
(size_t)num_elements_per_rank,
GetNCCLDataType(first_entry.tensor),
ncclSum, nccl_comm, stream))
if (timeline.Initialized()) {
RECORD_EVENT(entries, event_queue, NCCL_REDUCESCATTER, stream)
}
}
if (num_elements_remaining > 0) {
// Reduce the remaining data at local_size-1 to append to
// existing buffer
NCCL_CHECK(entries, "ncclReduce",
ncclReduce(fused_input_data_remainder,
buffer_data_remainder,
(size_t)num_elements_remaining,
GetNCCLDataType(first_entry.tensor), ncclSum,
root_rank, nccl_comm, stream))
if (timeline.Initialized()) {
RECORD_EVENT(entries, event_queue, NCCL_REDUCE, stream)
}
}
if (horovod_global.is_homogeneous || is_root_rank) {
// cudaHostAlloc is significantly slower than malloc. Pre-allocating
// a buffer is not safe since the tensor can be arbitrarily large.
host_buffer = malloc(buffer_len);
CUDA_CHECK(entries, "cudaMemcpyAsync",
cudaMemcpyAsync(host_buffer, buffer_data, buffer_len,
cudaMemcpyDeviceToHost, stream))
// This event must be recorded for the subsequent synchronize.
RECORD_EVENT(entries, event_queue, MEMCPY_IN_HOST_BUFFER, stream)
host_buffer = malloc(total_buffer_len);
// Synchronize.
WAIT_FOR_EVENTS(entries, timeline, event_queue)
// According to https://docs.nvidia.com/cuda/cuda-runtime-api/
// api-sync-behavior.html#api-sync-behavior__memcpy-async,
// cudaMemcpyAsync is synchronous with respect to the host, so we
// memcpy (effectively) synchronously to generate an accurate timeline
ACTIVITY_START_ALL(entries, timeline, MEMCPY_IN_HOST_BUFFER)
CUDA_CHECK(entries, "cudaMemcpyAsync",
cudaMemcpyAsync(host_buffer, buffer_data_at_rank_offset,
total_buffer_len, cudaMemcpyDeviceToHost,
stream))
ACTIVITY_END_ALL(entries, timeline)
ACTIVITY_START_ALL(entries, timeline, MPI_ALLREDUCE)
MPI_CHECK(entries, "MPI_Allreduce",
MPI_Allreduce(MPI_IN_PLACE, host_buffer, (int)num_elements,
MPI_Allreduce(MPI_IN_PLACE, host_buffer,
(int)total_num_elements,
GetMPIDataType(first_entry.tensor), MPI_SUM,
horovod_global.cross_comm))
ACTIVITY_END_ALL(entries, timeline)
ACTIVITY_START_ALL(entries, timeline, MEMCPY_OUT_HOST_BUFFER)
CUDA_CHECK(entries, "cudaMemcpyAsync",
cudaMemcpyAsync(buffer_data, host_buffer, buffer_len,
cudaMemcpyHostToDevice, stream))
cudaMemcpyAsync(buffer_data_at_rank_offset, host_buffer,
total_buffer_len, cudaMemcpyHostToDevice,
stream))
ACTIVITY_END_ALL(entries, timeline)
}
if (num_elements_per_rank > 0) {
NCCL_CHECK(entries, "ncclAllGather",
ncclAllGather(buffer_data_at_rank_offset, buffer_data,
(size_t)num_elements_per_rank,
GetNCCLDataType(first_entry.tensor),
nccl_comm, stream))
if (timeline.Initialized()) {
RECORD_EVENT(entries, event_queue, MEMCPY_OUT_HOST_BUFFER, stream)
RECORD_EVENT(entries, event_queue, NCCL_ALLGATHER, stream)
}
}
if (num_elements_remaining > 0) {
NCCL_CHECK(entries, "ncclBcast",
ncclBcast(buffer_data_remainder,
(size_t)num_elements_remaining,
GetNCCLDataType(first_entry.tensor), root_rank,
nccl_comm, stream))
NCCL_CHECK(entries, "ncclBcast",
ncclBcast(buffer_data, (size_t)num_elements,
GetNCCLDataType(first_entry.tensor), 0, nccl_comm,
stream))
if (timeline.Initialized()) {
RECORD_EVENT(entries, event_queue, NCCL_BCAST, stream)
if (timeline.Initialized()) {
RECORD_EVENT(entries, event_queue, NCCL_BCAST, stream)
}
}
} else {
NCCL_CHECK(entries, "ncclAllReduce",
ncclAllReduce(fused_input_data, buffer_data,
(size_t)num_elements,
GetNCCLDataType(first_entry.tensor), ncclSum,
nccl_comm, stream))
if (timeline.Initialized()) {
RECORD_EVENT(entries, event_queue, NCCL_ALLREDUCE, stream)
}
}
#endif
if (timeline.Initialized()) {
RECORD_EVENT(entries, event_queue, NCCL_ALLREDUCE, stream)
}
if (entries.size() > 1) {
// Copy memory out of the fusion buffer.
@@ -1250,7 +1358,8 @@ void CheckForStalledTensors(HorovodGlobalState& state) {
<< " seconds. ";
std::cerr << "This may indicate that different ranks are trying to "
"submit different tensors or that only subset of ranks is "
"submitting tensors, which will cause deadlock. " << std::endl;
"submitting tensors, which will cause deadlock. "
<< std::endl;
std::cerr << "Stalled ops:" << std::endl;
preamble = true;
}
@@ -1260,7 +1369,7 @@ void CheckForStalledTensors(HorovodGlobalState& state) {
bool missing_preamble = false;
for (auto msg_iter = messages.begin(); msg_iter != messages.end();
msg_iter++) {
ready_ranks.insert(msg_iter->request_rank());
ready_ranks.insert(msg_iter->request_rank());
}
for (int32_t rank = 0; rank < state.size; rank++) {
if (ready_ranks.find(rank) == ready_ranks.end()) {
@@ -1375,6 +1484,22 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, local_comm_ranks.data(), 1,
MPI_INT, local_comm);
// Determine if cluster is homogeneous, i.e., if every node has the same
// local_size
auto local_sizes = new int[size];
MPI_Allgather(&local_size, 1, MPI_INT, local_sizes, 1, MPI_INT,
state.mpi_comm);
bool is_homogeneous = true;
for (int i = 0; i < size; i++) {
if (local_sizes[i] != local_size) {
is_homogeneous = false;
break;
}
}
delete[] local_sizes;
state.is_homogeneous = is_homogeneous;
// Set up cross-communicator in case of hierarchical allreduce.
MPI_Comm cross_comm;
MPI_Comm_split(state.mpi_comm, local_rank, rank, &cross_comm);
@@ -1399,13 +1524,6 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
state.timeline.Initialize(std::string(horovod_timeline));
}
// Override Tensor Fusion threshold, if it's set.
auto horovod_fusion_threshold = std::getenv(HOROVOD_FUSION_THRESHOLD);
if (horovod_fusion_threshold != nullptr) {
state.tensor_fusion_threshold =
std::strtol(horovod_fusion_threshold, nullptr, 10);
}
// Override the cycle time.
auto horovod_cycle_time = std::getenv(HOROVOD_CYCLE_TIME);
if (horovod_cycle_time != nullptr) {
@@ -1425,10 +1543,40 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
std::getenv(HOROVOD_HIERARCHICAL_ALLREDUCE);
if (horovod_hierarchical_allreduce != nullptr &&
std::strtol(horovod_hierarchical_allreduce, nullptr, 10) > 0 &&
cross_size > 1) {
(size != local_size)) {
state.hierarchical_allreduce = true;
}
// Issue warning if hierarchical allreduce is enabled in heterogeneous cluster
if (is_coordinator && state.hierarchical_allreduce && !state.is_homogeneous) {
std::cerr
<< "WARNING: Using different number of ranks per node might hurt "
"performance of hierarchical allreduce. Consider assigning the same "
"number of ranks to each node or disabling hierarchical allreduce."
<< std::endl;
}
// Override Tensor Fusion threshold, if it's set.
auto horovod_fusion_threshold = std::getenv("HOROVOD_FUSION_THRESHOLD");
int64_t proposed_fusion_threshold = (horovod_fusion_threshold != nullptr) ?
std::strtol(horovod_fusion_threshold, nullptr, 10) :
state.tensor_fusion_threshold;
// If the cluster is homogeneous and hierarchical allreduce is enabled,
// adjust buffer size to make sure it is divisible by local_size to improve
// performance.
if (state.is_homogeneous && state.hierarchical_allreduce) {
// Assume the worst-case data type float64, since if it is divisible with
// float64, it will be divisible for other types too.
// Ensuring that fusion buffer can hold a number of elements divisible by
// FUSION_BUFFER_ATOMIC_UNIT for performance
int64_t div = state.local_size * sizeof(MPI_DOUBLE) * FUSION_BUFFER_ATOMIC_UNIT;
state.tensor_fusion_threshold = ((proposed_fusion_threshold+div-1) / div) * div;
} else {
state.tensor_fusion_threshold = proposed_fusion_threshold;
}
// Initialize the tensor count table. No tensors are available yet.
if (is_coordinator) {
state.message_table = std::unique_ptr<MessageTable>(new MessageTable());
Oops, something went wrong.

0 comments on commit 9166c1a

Please sign in to comment.