Permalink
Browse files

Tensor fusion for allgather (#732)

  • Loading branch information...
abditag2 authored and alsrgv committed Jan 11, 2019
1 parent e5a4de0 commit 3db450d9d9f1f4d8cf715e4e878e3f8dfdfcd5f6
Showing with 424 additions and 122 deletions.
  1. +1 −1 docs/gpus.md
  2. +12 −1 horovod/common/mpi_message.cc
  3. +2 −0 horovod/common/mpi_message.h
  4. +304 −120 horovod/common/operations.cc
  5. +105 −0 test/test_tensorflow.py
@@ -74,7 +74,7 @@ $ HOROVOD_GPU_ALLREDUCE=MPI HOROVOD_GPU_ALLGATHER=MPI HOROVOD_GPU_BROADCAST=MPI
```

**Note**: Allgather allocates an output tensor which is proportionate to the number of processes participating in the
training. If you find yourself running out of GPU memory, you can force allreduce to happen on CPU by passing
training. If you find yourself running out of GPU memory, you can force allgather to happen on CPU by passing
`device_sparse='/cpu:0'` to `hvd.DistributedOptimizer`:

```python
@@ -299,8 +299,19 @@ void MPIResponse::add_tensor_sizes(int64_t value) {
tensor_sizes_.push_back(value);
}

void MPIResponse::add_allgather_response(
horovod::common::MPIResponse response) {
assert(response_type() == MPIResponse::ResponseType::ALLGATHER);
assert(response.tensor_names().size() == 1);
assert(response.devices() == devices());
add_tensor_names(response.tensor_names()[0]);
for (auto size: response.tensor_sizes()){
add_tensor_sizes(size);
}
}

void MPIResponse_ParseFromWire(MPIResponse& response,
const wire::MPIResponse* obj) {
const wire::MPIResponse* obj) {
response.set_response_type((MPIResponse::ResponseType)obj->response_type());
for (const auto& tensor_name_obj : *obj->tensor_names()) {
response.add_tensor_names(tensor_name_obj->str());
@@ -142,6 +142,8 @@ class MPIResponse {
const std::vector<int64_t>& tensor_sizes() const;
void set_tensor_sizes(const std::vector<int64_t>& value);
void add_tensor_sizes(int64_t value);
// To fuse multiple allgather responses
void add_allgather_response(MPIResponse response);

static void ParseFromString(MPIResponse& response, const std::string& input);
static void SerializeToString(MPIResponse& response, std::string& output);
Oops, something went wrong.

0 comments on commit 3db450d

Please sign in to comment.