Skip to content

Commit

Permalink
[Gradient Compression] Add an index field to GradBucket for PowerSGD
Browse files Browse the repository at this point in the history
Add an index field to GradBucekt, so error_dict is keyed by this index instead of the hashcode of input tensor.

Howevever, sometimes the buckets can be rebuilt in the forward pass. In this case, the shape of the bucket with the same index will not be consistent with the one in the previous iteration, and hence the error tensor will be re--initialized as a zero tensor of the new shape.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D25288496](https://our.internmc.facebook.com/intern/diff/D25288496/)

ghstack-source-id: 117719173
Pull Request resolved: #48757
  • Loading branch information
wayi committed Dec 3, 2020
1 parent 90a3049 commit a9e30a4
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 14 deletions.
17 changes: 13 additions & 4 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -163,7 +163,8 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
}

auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
auto m = torch_C_m.def_submodule("_distributed_c10d", "distributed c10d bindings");
auto m =
torch_C_m.def_submodule("_distributed_c10d", "distributed c10d bindings");

auto module = py::handle(m).cast<py::module>();

Expand All @@ -184,14 +185,20 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket")
.def(
py::init<
size_t,
const std::vector<Tensor>&,
const std::vector<size_t>&,
const std::vector<size_t>&,
const std::vector<c10::IntArrayRef>&>(),
py::arg("index"),
py::arg("tensors"),
py::arg("offsets"),
py::arg("lengths"),
py::arg("sizes_list"))
.def(
"get_index",
&::c10d::GradBucket::getIndex,
py::call_guard<py::gil_scoped_release>())
.def(
"get_tensors",
&::c10d::GradBucket::getTensors,
Expand Down Expand Up @@ -1095,7 +1102,8 @@ that adds a prefix to each key inserted to the store.
&::c10d::ProcessGroup::Work::wait,
py::arg("timeout") = kNoTimeout,
py::call_guard<py::gil_scoped_release>())
.def("get_future",
.def(
"get_future",
[](::c10d::ProcessGroup::Work& work)
-> std::shared_ptr<jit::PythonFutureWrapper> {
return std::make_shared<jit::PythonFutureWrapper>(work.getFuture());
Expand Down Expand Up @@ -1261,7 +1269,6 @@ static const auto ProcessGroupWorkTorchBind =
})
.def("result", &::c10d::ProcessGroup::Work::result);


// TODO: Support argument names in Python API.
static const auto ProcessGroupTorchBind =
torch::class_<::c10d::ProcessGroup>("dist_c10d", "ProcessGroup")
Expand Down Expand Up @@ -1558,7 +1565,9 @@ static const auto DistributedC10dFrontendTorchBind =
c10::make_intrusive<::c10d::DistributedC10d>();
return c10d_frontend_singleton;
}))
.def("new_process_group_helper", &::c10d::DistributedC10d::newProcessGroupHelper);
.def(
"new_process_group_helper",
&::c10d::DistributedC10d::newProcessGroupHelper);

} // namespace

Expand Down
29 changes: 21 additions & 8 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@@ -1,3 +1,4 @@
import logging
import math

import numpy as np
Expand Down Expand Up @@ -63,10 +64,7 @@ def __init__(
# there will be differences between the gradients that are never synchronized.
self.rng = np.random.RandomState(random_seed)
# Since there is only a single state instance for all the input buckets,
# need to maintain a dictionary that maps each bucket to the local error.
# TODO(wayi): Currently the key is the (hashcode of) input tensor, which may change across steps,
# since the bucket can be rebuilt in the forward pass (to save peak memory usage).
# Need to add an index field to the input bucket of comm hook.
# need to maintain a dictionary that maps each bucket index to the local error.
self.error_dict = {}


Expand Down Expand Up @@ -127,11 +125,26 @@ def powerSGD_hook(
input_tensor[total_length:padded_total_length].fill_(0)

# Incorporate the error from the previous state into the gradients.
bucket_index = bucket.get_index()
if state.use_error_feedback:
if input_tensor in state.error_dict:
input_tensor.add_(state.error_dict[input_tensor])
# The buckets can be rebuilt during training.
# In this case, the error tensor shape will not be aligned with the input tensor,
# and the error will be re-initialized as zeros.
if (
bucket_index in state.error_dict
and state.error_dict[bucket_index].shape[0] == padded_total_length
):
input_tensor.add_(state.error_dict[bucket_index])
else:
state.error_dict[input_tensor] = torch.zeros(padded_total_length, device=device)
logging.info(
"A zero tensor of length {} that represents local error is created.".format(
padded_total_length
)
)
state.error_dict[bucket_index] = torch.zeros(
padded_total_length, device=device
)

# Keep a copy of the input tensor,
# so that we can compute the local error caused by compression later,
# by comparing this copy and the input tensor updated after decompression.
Expand Down Expand Up @@ -181,7 +194,7 @@ def decompress(fut):

if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[input_tensor] = input_tensor_cp - input_tensor
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
ret = input_tensor.resize_(total_length)
return [ret]

Expand Down
11 changes: 10 additions & 1 deletion torch/lib/c10d/comm.hpp
Expand Up @@ -20,15 +20,22 @@ void broadcast_coalesced(
class GradBucket {
public:
explicit GradBucket(
size_t index,
const std::vector<at::Tensor>& tensors,
const std::vector<size_t>& offsets = {},
const std::vector<size_t>& lengths = {},
const std::vector<c10::IntArrayRef>& sizes_vec = {})
: tensors_(tensors),
: index_(index),
tensors_(tensors),
offsets_(offsets),
lengths_(lengths),
sizes_vec_(sizes_vec) {}

// Returns the index of the bucket, which is unique across all the buckets.
size_t getIndex() const {
return index_;
}

// Each tensor in the list that getTensors returns refers to the replica on
// each device. There will be multiple replicas only in the case of single
// process multiple device mode. In the single process single device mode,
Expand All @@ -37,6 +44,7 @@ class GradBucket {
return tensors_;
}

// Returns a mutable tensor vector compared with the above method.
std::vector<at::Tensor>& getTensorsRef() {
return tensors_;
}
Expand All @@ -58,6 +66,7 @@ class GradBucket {
}

private:
size_t index_;
std::vector<at::Tensor> tensors_;

// Per-variable info in tensors_[0].
Expand Down
3 changes: 2 additions & 1 deletion torch/lib/c10d/reducer.cpp
Expand Up @@ -2,11 +2,11 @@

#include <functional>

#include <c10d/comm.hpp>
#include <c10/core/DeviceGuard.h>
#include <c10/core/StreamGuard.h>
#include <c10/util/Exception.h>
#include <c10/util/hash.h>
#include <c10d/comm.hpp>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
Expand Down Expand Up @@ -713,6 +713,7 @@ void Reducer::mark_bucket_ready(size_t bucket_index) {
bucket.work = process_group_->allreduce(tensors);
} else {
GradBucket grad_bucket(
next_bucket_,
tensors,
// Since currently we do not support single-process multiple-device
// mode, we can assume only one replica in the bucket.
Expand Down

0 comments on commit a9e30a4

Please sign in to comment.