## Why DeviceMesh?

Setting up distributed communicators such as `NCCL` (Nvidia Collective Communication Library) communicators, for distributed training can pose a significant challenge.

User may need to manually set up and manage NCCL communicators (e.g. ProcessGroup) for each parallelism solution. This can cause a lot of headache if the workload is composed of different parallelisms.

Moreover, due to the manual nature it is susceptible to errors.

`DeviceMesh` can simplify this process, making it more manageable and less prone to errors.

### What is a DeviceMesh?

Its a higher level abstraction that manages the `ProcessGroup`

- Easy setup of intra & inter nodes
- No worry about the rank setup for each subgroup

### Use

DeviceMesh is useful when working with multiple dimensional parallelism (i.e. 3D parallel), here parallelism composability is required.
- Parallelism solution that requires both communication across host and within each host.

E.g. If there are 2 host machines and each host machine has 2 GPUs.

Over here, without DeviceMesh the user will have to manually setup the NCCL communicators, cuda devices on each process before applying any parallelism, which is complicated and cause manual errors.

In [1]:
import os

import torch
import torch.distributed as dist

In [2]:
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = 'localhost'  # Replace with your master address if distributed across machines
os.environ['MASTER_PORT'] = '12345'

Run the below script using torchrun

```bash
torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py
```

In [3]:
# CHECK THIS OUT IN KAGGLE NOTEBOOK, 2 GPUS
# WILL FAIL IN COLAB UNLESS THE SHARD RANK LIST IS SET FOR 1 DEVICE/GPU

import os

import torch
import torch.distributed as dist

# Understand world topology
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
print(f"Running example on {rank=} in a world with {world_size=}")

# Create process groups to manage 2-D like parallel pattern
dist.init_process_group("nccl")
torch.cuda.set_device(rank)

# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
# and assign the correct shard group to each rank
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices))
shard_groups = (
    dist.new_group(shard_rank_lists[0]),
    dist.new_group(shard_rank_lists[1]),
)
current_shard_group = (
    shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1]
)

# Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7))
# and assign the correct replicate group to each rank
current_replicate_group = None
shard_factor = len(shard_rank_lists[0])
for i in range(num_node_devices // 2):
    replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
    replicate_group = dist.new_group(replicate_group_ranks)
    if rank in replicate_group_ranks:
        current_replicate_group = replicate_group

Running example on rank=0 in a world with world_size=1


TypeError: __init__(): incompatible constructor arguments. The following argument types are supported:
    1. torch._C._distributed_c10d.ProcessGroup(arg0: int, arg1: int)
    2. torch._C._distributed_c10d.ProcessGroup(arg0: torch._C._distributed_c10d.Store, arg1: int, arg2: int, arg3: c10d::ProcessGroup::Options)

Invoked with: <torch.distributed.distributed_c10d.PrefixStore object at 0x78955ad69bb0>, None, 0, <torch._C._distributed_c10d.ProcessGroup.Options object at 0x789559921bb0>

Run the below script using torchrun as follows, add the code to a file with specified filename


```bash
torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py
```

In [5]:
## THIS WON't WORK ON KAGGLE, SOME PYTORCH VERSION ISSUE
## `device_mesh` not found error
## WORKS WELL WITH COLAB, JUST KEEP THE MESH SIZE FOR 1 GPU i.e (1,1)
## IF INCASE OF 8 GPU SPLIT IN 2 SHARDS THEN IT WILL BE (2, 4)

from torch.distributed.device_mesh import init_device_mesh
mesh_2d = init_device_mesh("cuda", (1,1), mesh_dim_names=("replicate", "shard"))

# Users can access the underlying process group thru `get_group` API.
replicate_group = mesh_2d.get_group(mesh_dim="replicate")
shard_group = mesh_2d.get_group(mesh_dim="shard")

In [6]:
replicate_group

<torch.distributed.distributed_c10d.ProcessGroup at 0x7b9e74970c70>

In [7]:
shard_group

<torch.distributed.distributed_c10d.ProcessGroup at 0x7b9e749727f0>

## Using DeviceMesh with HSDP

HSDP - Hybrid Sharding Data Parallel

Its a 2D strategy to perform FSDP within a host and DDP across hosts


```bash
torchrun --nproc_per_node=8 hsdp.py
```


In [4]:
import torch
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

In [None]:
class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

# HSDP: MeshShape(2, 4)
mesh_2d = init_device_mesh("cuda", (1, 1))
model = FSDP(
    ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD
)

## NOT A GOOD EXAMPLE, AS SINGLE MACHINE SINGLE GPU
## (2, 4) ALSO FOR A GOOD USE SHOULD BE I THINK 2 HOSTS WITH 4 GPUS EACH
## INSTEAD OF A SINGLE MACHINE WITH 8 GPUS AS THE DDP WILL BE WITHIN HOST ONLY