Skip to content

Commit

Permalink
Skip the extra copy operation in broadcast_object_list if tensor_list…
Browse files Browse the repository at this point in the history
… has only one element (#107509)

The `broadcast_object_list` function can easily broadcast the state_dict of models/optimizers. However, the `torch.cat` operation performed within `broadcast_object_list` consumes an additional double amount of memory space. This means that only objects with a maximum memory occupancy of half the device capacity can be broadcasted. This PR improves usability by skipping the `torch.cat` operation on object_lists with only a single element.

Before (30G tensor):
<img width="607" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/c0c67931-0851-4f27-81c1-0119c6cd2944">

After (46G tensor):
<img width="600" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/90cd1536-be7c-43f4-82ef-257234afcfa5">

Test Code:
```python
if __name__ == "__main__":
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())

    fake_tensor = torch.randn(30 * 1024 * 1024 * 1024 // 4)

    if dist.get_rank() == 0:
        state_dict = {"fake_tensor": fake_tensor}
    else:
        state_dict = {}
    object_list = [state_dict]
    dist.broadcast_object_list(object_list, src=0)
    print("Rank: ", dist.get_rank(), " Broadcasted Object: ", object_list[0].keys())
    dist.barrier()
```
Pull Request resolved: #107509
Approved by: https://github.com/awgu
  • Loading branch information
Codle authored and pytorchmergebot committed Aug 23, 2023
1 parent ecde622 commit 42738c5
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2603,8 +2603,13 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
broadcast(object_sizes_tensor, src=src, group=group)

# Concatenate and broadcast serialized object tensors
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
# has only one element, we can skip the copy.
if my_rank == src:
object_tensor = torch.cat(tensor_list)
if len(tensor_list) == 1:
object_tensor = tensor_list[0]
else:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.empty( # type: ignore[call-overload]
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
Expand Down

0 comments on commit 42738c5

Please sign in to comment.