Skip to content

Commit

Permalink
[Gradient Compression] Directly let world_size = process_group.size()
Browse files Browse the repository at this point in the history
Address the comment on #49417 (comment)

Differential Revision: [D25673997](https://our.internmc.facebook.com/intern/diff/D25673997/)

ghstack-source-id: 119021459
Pull Request resolved: #49715
  • Loading branch information
wayi committed Dec 21, 2020
1 parent 4544f0f commit 4d551bc
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 21 deletions.
12 changes: 3 additions & 9 deletions torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def allreduce_hook(
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = process_group.size()

tensor = bucket.get_tensors()[0]
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()
Expand All @@ -46,9 +44,7 @@ def fp16_compress_hook(
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = process_group.size()

compressed_tensor = bucket.get_tensors()[0].to(torch.float16)

Expand Down Expand Up @@ -100,9 +96,7 @@ def _allgather_then_aggregate_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = process_group.size()

tensor = bucket.get_tensors()[0]
fut = dist.all_gather(
Expand Down
8 changes: 2 additions & 6 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def powerSGD_hook(
"""
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = group_to_use.size()

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]
Expand Down Expand Up @@ -363,9 +361,7 @@ def batched_powerSGD_hook(
"""
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = process_group.size()

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def quantization_pertensor_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = group_to_use.size()

tensor = bucket.get_tensors()[0]

Expand Down Expand Up @@ -144,9 +142,7 @@ def quantization_perchannel_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
world_size = process_group.size()

tensor = bucket.get_tensors()[0]

Expand Down

0 comments on commit 4d551bc

Please sign in to comment.