-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Open
Labels
module: ddpIssues/PRs related distributed data parallel trainingIssues/PRs related distributed data parallel trainingmodule: docsRelated to our documentation, both in docs/ and docblocksRelated to our documentation, both in docs/ and docblocksoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue
Description
📚 The doc issue
The current documentation does not have any definition or reference on how we should be using device mesh in DDP: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
Is this the correct way to initialize the device mesh?
from torch.distributed.device_mesh import init_device_mesh
def init_dist():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert torch.cuda.is_available()
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
master_process = (rank == 0)
ddp_mesh = init_device_mesh(device_type=device, mesh_shape=(world_size,))
return ddp_mesh, device, master_process
ddp_mesh, device, master_proc = init_dist()
model = MyPyTorchModel().to(device)
model = DDP(model, device_mesh=ddp_mesh)
model = torch.compile(model, dynamic=False)
Suggest a potential alternative/fix
An example and relevant documentation on how to use DDP with device mesh.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @svekars @sekyondaMeta @AlannaBurke
Metadata
Metadata
Assignees
Labels
module: ddpIssues/PRs related distributed data parallel trainingIssues/PRs related distributed data parallel trainingmodule: docsRelated to our documentation, both in docs/ and docblocksRelated to our documentation, both in docs/ and docblocksoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue