Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Skip the extra copy operation in broadcast_object_list if tensor_list…
… has only one element (#107509) The `broadcast_object_list` function can easily broadcast the state_dict of models/optimizers. However, the `torch.cat` operation performed within `broadcast_object_list` consumes an additional double amount of memory space. This means that only objects with a maximum memory occupancy of half the device capacity can be broadcasted. This PR improves usability by skipping the `torch.cat` operation on object_lists with only a single element. Before (30G tensor): <img width="607" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/c0c67931-0851-4f27-81c1-0119c6cd2944"> After (46G tensor): <img width="600" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/90cd1536-be7c-43f4-82ef-257234afcfa5"> Test Code: ```python if __name__ == "__main__": dist.init_process_group(backend='nccl') torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) fake_tensor = torch.randn(30 * 1024 * 1024 * 1024 // 4) if dist.get_rank() == 0: state_dict = {"fake_tensor": fake_tensor} else: state_dict = {} object_list = [state_dict] dist.broadcast_object_list(object_list, src=0) print("Rank: ", dist.get_rank(), " Broadcasted Object: ", object_list[0].keys()) dist.barrier() ``` Pull Request resolved: #107509 Approved by: https://github.com/awgu
- Loading branch information