Skip to content

Commit

Permalink
[Gradient Compression] Make GradBucket class public (#53099)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #53099

Publish GradBucket APIs for publishing DDP communication hooks.

s/_GradBucket/GradBucket
ghstack-source-id: 123030921

Test Plan: waitforbuildbot

Reviewed By: rohan-varma

Differential Revision: D26721121

fbshipit-source-id: ee5f68e33095b9965b51937b86cdeb331fd2419a
  • Loading branch information
Yi Wang authored and facebook-github-bot committed Mar 4, 2021
1 parent b59075e commit 68b6249
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 27 deletions.
20 changes: 10 additions & 10 deletions test/distributed/test_c10d.py
Expand Up @@ -360,7 +360,7 @@ def _create_client(self, index, addr, port, world_size, messages):
client_store = dist.TCPStore(addr, port, world_size, timeout=timedelta(seconds=10))
self.assertEqual("value".encode(), client_store.get("key"))
client_store.set(f"new_key{index}", f"new_value{index}")
self.assertEqual(f"next_value{index}".encode(),
self.assertEqual(f"next_value{index}".encode(),
client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}"))
except Exception:
messages.put('Caught exception: \n{}exiting process with exit code: {}'
Expand Down Expand Up @@ -3057,7 +3057,7 @@ def test_accumulate_gradients_no_sync_allreduce_hook(self):
"""

def allreduce_hook(
process_group: object, bucket: dist._GradBucket
process_group: object, bucket: dist.GradBucket
) -> torch._C.Future:
tensors = [t / self.world_size for t in bucket.get_tensors()]
return process_group.allreduce(tensors).get_future()
Expand All @@ -3077,7 +3077,7 @@ def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
"""

def allreduce_with_then_hook(
process_group: object, bucket: dist._GradBucket
process_group: object, bucket: dist.GradBucket
) -> torch.futures.Future:
fut = process_group.allreduce(bucket.get_tensors()).get_future()

Expand Down Expand Up @@ -3727,7 +3727,7 @@ def _run_and_verify_hook(self, model, input, expected_grad):
[self.assertEqual(p.grad, expected_grad) for p in model.parameters()]

def _simple_hook(
self, state: object, bucket: dist._GradBucket
self, state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
fut = torch.futures.Future()
fut.set_result([torch.ones_like(t) for t in bucket.get_tensors()])
Expand Down Expand Up @@ -3782,7 +3782,7 @@ def _test_ddp_comm_hook_allreduce_hook_nccl(self, gradient_as_bucket_view=False)
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future:
def allreduce_hook(state: object, bucket: dist.GradBucket) -> torch._C.Future:
tensors = [t / self.world_size for t in bucket.get_tensors()]
return process_group.allreduce(tensors).get_future()

Expand Down Expand Up @@ -3930,7 +3930,7 @@ def test_ddp_comm_hook_allreduce_with_then_hook_nccl(self):
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

def allreduce_with_then_hook(
state: object, bucket: dist._GradBucket
state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
tensors = [t / self.world_size for t in bucket.get_tensors()]
fut = process_group.allreduce(tensors).get_future()
Expand Down Expand Up @@ -3972,7 +3972,7 @@ def test_ddp_invalid_comm_hook_init(self):
model.register_comm_hook(state=None, hook=1)

with self.assertRaisesRegex(
ValueError, "bucket annotation should be dist._GradBucket."
ValueError, "bucket annotation should be dist.GradBucket."
):

def comm_hook(state: object, bucket: int) -> torch.futures.Future:
Expand All @@ -3999,7 +3999,7 @@ def test_ddp_invalid_comm_hook_return_type(self):
"Communication hook: return annotation should be torch.futures.Future or torch._C.Future.",
):

def comm_hook(state: object, bucket: dist._GradBucket) -> int:
def comm_hook(state: object, bucket: dist.GradBucket) -> int:
return torch.futures.Future()

model.register_comm_hook(state=None, hook=comm_hook)
Expand All @@ -4009,7 +4009,7 @@ def comm_hook(state: object, bucket: dist._GradBucket) -> int:
"callback must return a torch.futures.Future or torch._C.Future object, but got",
):

def comm_hook(state: object, bucket: dist._GradBucket):
def comm_hook(state: object, bucket: dist.GradBucket):
return 1

model.register_comm_hook(state=None, hook=comm_hook)
Expand Down Expand Up @@ -4067,7 +4067,7 @@ def test_ddp_comm_hook_sparse_gradients(self):
# "get_future" API does not support gloo backend, see GH Issue #42048.
# Instead, we wait for an allreduce work, and write its result to a Future.
def allreduce_hook_gloo(
state: object, bucket: dist._GradBucket
state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
# Prepare allreduced grad bucket tensors by running an async work.
work = process_group.allreduce(bucket.get_tensors())
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/_distributed_c10d.pyi
Expand Up @@ -15,7 +15,7 @@ class BuiltinCommHookType(Enum):
def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
def _register_builtin_comm_hook(reducer: Reducer, comm_hook_type: BuiltinCommHookType): ...

class _GradBucket:
class GradBucket:
def __init__(self, tensors: List[Tensor]): ...
def get_index(self) -> int: ...
def get_tensors(self) -> List[Tensor]: ...
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -183,7 +183,7 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
py::arg("reducer"),
py::arg("comm_hook_type"));

shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket")
shared_ptr_class_<::c10d::GradBucket>(module, "GradBucket")
.def(
py::init<
size_t,
Expand Down Expand Up @@ -1231,7 +1231,7 @@ that adds a prefix to each key inserted to the store.
``get_future` API to retrieve a Future associated with the completion of
``allreduce`` work.
>>> def allreduce(state: object, bucket: dist._GradBucket): -> torch._C.Future
>>> def allreduce(state: object, bucket: dist.GradBucket): -> torch._C.Future
>>> tensors = [t / process_group.world_size for t in bucket.get_tensors()]
>>> work = process_group.allreduce(tensors)
>>> return work.get_future()
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/__init__.py
Expand Up @@ -63,8 +63,8 @@ def _get_debug_mode():
Reducer,
Logger,
BuiltinCommHookType,
GradBucket,
_DEFAULT_FIRST_BUCKET_BYTES,
_GradBucket,
_register_comm_hook,
_register_builtin_comm_hook,
_broadcast_coalesced,
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Expand Up @@ -17,7 +17,7 @@ def div_by_group_size(fut):


def allreduce_hook(
process_group: dist.ProcessGroup, bucket: dist._GradBucket
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future:
"""
This DDP communication hook just calls ``allreduce`` using ``GradBucket``
Expand All @@ -35,7 +35,7 @@ def allreduce_hook(


def fp16_compress_hook(
process_group: dist.ProcessGroup, bucket: dist._GradBucket
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future:
"""
This DDP communication hook implements a simple gradient compression
Expand Down
8 changes: 4 additions & 4 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Expand Up @@ -172,7 +172,7 @@ def maybe_increase_iter(self, bucket):


def powerSGD_hook(
state: PowerSGDState, bucket: dist._GradBucket
state: PowerSGDState, bucket: dist.GradBucket
) -> torch.futures.Future:
r"""
This DDP communication hook implements PowerSGD gradient compression
Expand Down Expand Up @@ -217,7 +217,7 @@ def powerSGD_hook(
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter``
and ``min_compression_rate``.
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode at this time,
only exactly one tensor is stored in this bucket.
Expand Down Expand Up @@ -440,7 +440,7 @@ def decompress(fut):


def batched_powerSGD_hook(
state: PowerSGDState, bucket: dist._GradBucket
state: PowerSGDState, bucket: dist.GradBucket
) -> torch.futures.Future:
r"""
This DDP communication hook implements a simplified PowerSGD gradient compression
Expand Down Expand Up @@ -484,7 +484,7 @@ def batched_powerSGD_hook(
Args:
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode at this time,
only exactly one tensor is stored in this bucket.
Expand Down
Expand Up @@ -43,7 +43,7 @@ def _get_allgather_out_list(all_gather_in_list, world_size):


def quantization_pertensor_hook(
process_group: dist.ProcessGroup, bucket: dist._GradBucket
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future:
"""
Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather``
Expand Down Expand Up @@ -116,7 +116,7 @@ def dequantize_and_aggregate(fut):


def quantization_perchannel_hook(
process_group: dist.ProcessGroup, bucket: dist._GradBucket, bucket_size=512
process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512
) -> torch.futures.Future:
"""
Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather``
Expand Down
10 changes: 5 additions & 5 deletions torch/nn/parallel/distributed.py
Expand Up @@ -1065,7 +1065,7 @@ def register_comm_hook(self, state: object, hook: callable):
It is locally stored by each worker
and shared by all the gradient tensors on the worker.
hook (callable): Averages gradient tensors across workers and defined as:
``hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future``:
``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future``:
This function is called once the bucket is ready. The
hook can perform whatever processing is needed and return
Expand Down Expand Up @@ -1107,7 +1107,7 @@ def register_comm_hook(self, state: object, hook: callable):
Example::
Below is an example of a noop hook that returns the same tensors.
>>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future
>>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future
>>> fut = torch.futures.Future()
>>> fut.set_result(bucket.get_tensors())
>>> return fut
Expand All @@ -1118,7 +1118,7 @@ def register_comm_hook(self, state: object, hook: callable):
Below is an example of a Parallel SGD algorithm where gradients are encoded before
allreduce, and then decoded after allreduce.
>>> def encode_and_decode(state: object, bucket: dist._GradBucket): -> torch.futures.Future
>>> def encode_and_decode(state: object, bucket: dist.GradBucket): -> torch.futures.Future
>>> tensors = [t / process_group.world_size for t in bucket.get_tensors()]
>>> encoded_tensors = encode(tensors) # encode gradients
>>> fut = process_group.allreduce(encoded_tensors).get_future()
Expand Down Expand Up @@ -1270,10 +1270,10 @@ def _check_comm_hook(self, hook):
sig = inspect.signature(hook)
if (
sig.parameters["bucket"].annotation != inspect._empty
and sig.parameters["bucket"].annotation != dist._GradBucket
and sig.parameters["bucket"].annotation != dist.GradBucket
):
raise ValueError(
"Communication hook: bucket annotation should be dist._GradBucket."
"Communication hook: bucket annotation should be dist.GradBucket."
)

if sig.return_annotation != inspect._empty and (
Expand Down

0 comments on commit 68b6249

Please sign in to comment.