Skip to content

Commit

Permalink
[Gradient Compression] Directly let world_size = group_to_use.size() (#…
Browse files Browse the repository at this point in the history
…49715)

Summary:
Pull Request resolved: #49715

Address the comment on #49417 (comment)
ghstack-source-id: 119049598

Test Plan: waitforbuildbot

Reviewed By: rohan-varma

Differential Revision: D25673997

fbshipit-source-id: 44eb2540e5a77331c34ba503285cbd0bd63c2c0a
  • Loading branch information
Yi Wang authored and facebook-github-bot committed Dec 23, 2020
1 parent 88c33ff commit 55b431b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 21 deletions.
12 changes: 3 additions & 9 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Expand Up @@ -18,9 +18,7 @@ def allreduce_hook(
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = group_to_use.size()

tensor = bucket.get_tensors()[0]
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()
Expand All @@ -46,9 +44,7 @@ def fp16_compress_hook(
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = group_to_use.size()

compressed_tensor = bucket.get_tensors()[0].to(torch.float16)

Expand Down Expand Up @@ -100,9 +96,7 @@ def _allgather_then_aggregate_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = group_to_use.size()

tensor = bucket.get_tensors()[0]
fut = dist.all_gather(
Expand Down
8 changes: 2 additions & 6 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Expand Up @@ -126,9 +126,7 @@ def powerSGD_hook(
"""
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = group_to_use.size()

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]
Expand Down Expand Up @@ -363,9 +361,7 @@ def batched_powerSGD_hook(
"""
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = group_to_use.size()

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]
Expand Down
Expand Up @@ -63,9 +63,7 @@ def quantization_pertensor_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = group_to_use.size()

tensor = bucket.get_tensors()[0]

Expand Down Expand Up @@ -144,9 +142,7 @@ def quantization_perchannel_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = group_to_use.size()

tensor = bucket.get_tensors()[0]

Expand Down

0 comments on commit 55b431b

Please sign in to comment.