Skip to content

Missing documentation for device mesh on DDP #159836

@conceptofmind

Description

@conceptofmind

📚 The doc issue

The current documentation does not have any definition or reference on how we should be using device mesh in DDP: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html

Is this the correct way to initialize the device mesh?

from torch.distributed.device_mesh import init_device_mesh

def init_dist():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    assert torch.cuda.is_available()
    device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
    torch.cuda.set_device(device)
    dist.init_process_group(backend="nccl", device_id=device)
    dist.barrier()
    master_process = (rank == 0)
    ddp_mesh = init_device_mesh(device_type=device, mesh_shape=(world_size,))
    return ddp_mesh, device, master_process

ddp_mesh, device, master_proc = init_dist()

model = MyPyTorchModel().to(device)
model = DDP(model, device_mesh=ddp_mesh)
model = torch.compile(model, dynamic=False)

Suggest a potential alternative/fix

An example and relevant documentation on how to use DDP with device mesh.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @svekars @sekyondaMeta @AlannaBurke

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: ddpIssues/PRs related distributed data parallel trainingmodule: docsRelated to our documentation, both in docs/ and docblocksoncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions