# 集合通信实践

介绍：在分布式训练或推理中，集合通信（Collective Communication）是一项必备操作。理解常见集合通信的操作原理，是实践或优化分布式任务的基础。本练习将通过PyTorch的集合通信库，带大家了解常见操作的基本用法与原理。


相关文章：[分布式训练/推理基础：集合通信原理与实践](https://zhuanlan.zhihu.com/p/2006011081177457311)

Author: kaiyuan

Email: kaiyuanxie@yeah.net

# 1 用例说明

我们采用pytorch的通信库API来进行实践，API的相关介绍参考：[ Distributed communication](https://docs.pytorch.org/docs/stable/distributed.html)

建议用PyTorch的镜像运行用例，常用的镜像： nvcr.io/nvidia/pytorch:xxxx

```
docker pull nvcr.io/nvidia/pytorch:26.01-py3
```

启动示例：

```
docker run -itd --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
-v /data/nfs_87:/data/nfs_87 \
--name pytorch-dev nvcr.io/nvidia/pytorch:26.01-py3 bash
```

进入容器：

```
docker exec -it pytorch-dev bash
```

测试机器信息：
- NVIDIA A100-SXM4-80GB x 8
- NVIDIA-SMI 570.172.08
- Driver Version: 570.172.08
- CUDA Version: 13.1

在示例中，通过设置 world_size 参数来控制使用的GPU数量，此处设定world_size=4。运行该用例时，可能会出现如下告警信息，这是资源释放过程中产生的问题，不影响演示效果。

```
[rank0]:[W214 09:22:26.727076495 ProcessGroupNCCL.cpp:1565] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
```

# 2 聚合(Gather)


## 2.1 Row Gather

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 本地张量：形状[1, 5]，数值为 rank
    tensor = torch.ones(1, 5, device=rank) * rank

    # 只在目标进程（rank 0）创建接收列表
    if rank == 0:
        gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
    else:
        gather_list = None  # 非目标进程不需要接收列表

    # 执行 gather 操作，所有进程将 tensor 发送给 rank 0
    dist.gather(tensor, gather_list=gather_list, dst=0)

    # 在rank 0上处理收集到的数据
    if rank == 0:
        # 将列表沿dim=0拼接，得到[world_size, 5]
        gathered_tensor = torch.cat(gather_list, dim=0)
        print(f"Rank {rank} gathered tensor shape: {gathered_tensor.shape}")
        print("Gathered tensor:\n", gathered_tensor)
    else:
        print(f"Rank {rank} has sent its data, no local copy of gathered tensor.")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 3 has sent its data, no local copy of gathered tensor.
Rank 1 has sent its data, no local copy of gathered tensor.
Rank 2 has sent its data, no local copy of gathered tensor.
Rank 0 gathered tensor shape: torch.Size([4, 5])
Gathered tensor:
 tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]], device='cuda:0')



## 2.2 Column Gather

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    # 初始化进程组，使用NCCL后端
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 每个进程生成自己的本地张量：形状[5, 2]，值 = rank*10 + 列索引
    cols_per_rank = 2
    local_tensor = torch.zeros(5, cols_per_rank, device=rank, dtype=torch.float)
    for col in range(cols_per_rank):
        local_tensor[:, col] = rank * 10 + col  # 第0列全是 rank*10，第1列全是 rank*10+1

    print(f"Rank {rank} local tensor (shape {local_tensor.shape}):\n{local_tensor.cpu().numpy()}")

    # 只在目标进程（rank 0）上准备接收列表
    if rank == 0:
        # 接收列表包含 world_size 个空张量，每个形状与 local_tensor 相同，位于 rank 0 的设备上
        gather_list = [torch.empty_like(local_tensor, device=0) for _ in range(world_size)]
    else:
        gather_list = None  # 非目标进程不需要接收列表

    # 执行 gather 操作：所有进程将local_tensor发送给rank 0
    dist.gather(local_tensor, gather_list=gather_list, dst=0)

    # 在rank 0上处理收集到的数据
    if rank == 0:
        # 沿列维度（dim=1）拼接，得到形状[5, world_size * 2]的大张量
        gathered = torch.cat(gather_list, dim=1)
        print(f"\nRank {0} final gathered tensor (shape {gathered.shape}):\n{gathered.cpu().numpy()}")
    else:
        print(f"Rank {rank} has sent its data, no local copy of gathered tensor.\n")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 1 local tensor (shape torch.Size([5, 2])):
[[10. 11.]
 [10. 11.]
 [10. 11.]
 [10. 11.]
 [10. 11.]]
Rank 3 local tensor (shape torch.Size([5, 2])):
[[30. 31.]
 [30. 31.]
 [30. 31.]
 [30. 31.]
 [30. 31.]]
Rank 2 local tensor (shape torch.Size([5, 2])):
[[20. 21.]
 [20. 21.]
 [20. 21.]
 [20. 21.]
 [20. 21.]]
Rank 0 local tensor (shape torch.Size([5, 2])):
[[0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]]
Rank 1 has sent its data, no local copy of gathered tensor.

Rank 2 has sent its data, no local copy of gathered tensor.

Rank 3 has sent its data, no local copy of gathered tensor.


Rank 0 final gathered tensor (shape torch.Size([5, 8])):
[[ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]]



# 3 全聚合（All Gather）

## 3.1 Row allgather

In [None]:
import os
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 本地张量形状[1, 5]，放在当前GPU
    tensor_shard = torch.ones(1, 5, device=rank) * rank

    # 创建接收列表：每个元素形状为[1, 5]
    tensor_gather_list = [torch.empty_like(tensor_shard) for _ in range(world_size)]

    # 执行 all_gather
    dist.all_gather(tensor_gather_list, tensor_shard)

    # 方法1：使用torch.cat沿dim=0拼接，得到 [world_size, 5]
    gathered_tensor = torch.cat(tensor_gather_list, dim=0)  # 形状 (world_size, 5)
    print(f"rank {rank} after cat: {gathered_tensor.shape}")

    # 方法2：如果使用torch.stack，会得到[world_size, 1, 5]
    # stacked = torch.stack(tensor_gather_list, dim=0)  # (world_size, 1, 5)
    # gathered_tensor = stacked.squeeze(1)         # (world_size, 5)

    time.sleep(1)
    # 验证结果：rank 0 打印聚合后的张量
    print(f"Rank {rank} gathered tensor:\n", gathered_tensor)

def main():
    world_size = 4  # 假设 4 个进程
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


rank 1 after cat: torch.Size([4, 5])
rank 0 after cat: torch.Size([4, 5])
rank 3 after cat: torch.Size([4, 5])
rank 2 after cat: torch.Size([4, 5])
Rank 1 gathered tensor:
Rank 0 gathered tensor:
Rank 3 gathered tensor:
Rank 2 gathered tensor:
 tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]], device='cuda:1')
 tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]], device='cuda:0')
 tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]], device='cuda:2')
 tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]], device='cuda:3')
")



## 3.2 Column allgather

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 每个进程生成一个[5, 2]的张量，值为rank*10 + 列索引（便于观察）
    cols_per_rank = 2
    tensor = torch.zeros(5, cols_per_rank, device=rank, dtype=torch.float)
    for j in range(cols_per_rank):
        tensor[:, j] = rank * 10 + j
    print(f"Rank {rank} local tensor shape {tensor.shape}:\n{tensor.cpu().numpy()}")

    # 准备接收列表：每个元素形状与本地张量相同
    gather_list = [torch.empty_like(tensor) for _ in range(world_size)]

    # 执行all_gather（收集到列表）
    dist.all_gather(gather_list, tensor)

    # 沿列维度（dim=1）拼接所有张量
    gathered_tensor = torch.cat(gather_list, dim=1)  # 形状[5, world_size * cols_per_rank] = [5, 8]
    print(f"Rank {rank} after column-wise all_gather, shape {gathered_tensor.shape}:\n{gathered_tensor.cpu().numpy()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 1 local tensor shape torch.Size([5, 2]):
[[10. 11.]
 [10. 11.]
 [10. 11.]
 [10. 11.]
 [10. 11.]]
Rank 3 local tensor shape torch.Size([5, 2]):
[[30. 31.]
 [30. 31.]
 [30. 31.]
 [30. 31.]
 [30. 31.]]
Rank 2 local tensor shape torch.Size([5, 2]):
[[20. 21.]
 [20. 21.]
 [20. 21.]
 [20. 21.]
 [20. 21.]]
Rank 0 local tensor shape torch.Size([5, 2]):
[[0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]]
Rank 3 after column-wise all_gather, shape torch.Size([5, 8]):
[[ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]]
Rank 0 after column-wise all_gather, shape torch.Size([5, 8]):
[[ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21. 30. 31.]]
Rank 1 after column-wise all_gather, shape torch.Size([5, 8]):
[[ 0.  1. 10. 11. 20. 21. 30. 31.]
 [ 0.  1. 10. 11. 20. 21.

# 4 规约（Reduce）

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 本地张量：形状[1, 5]，数值为rank
    tensor = torch.ones(1, 5, device=rank) * rank

    print(f"Rank {rank} before reduce: {tensor.cpu().tolist()}")

    # 执行reduce操作，将所有进程的tensor求和，结果存储到rank 0的tensor中
    # 注意：reduce 后，非目标进程的tensor内容可能不再有效
    dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)

    # 在 rank 0 上打印规约结果
    if rank == 0:
        print(f"Rank {rank} after reduce (sum): {tensor.cpu().tolist()}")
    else:
        # 非目标进程的tensor内容未定义，但为了演示，打印其当前值
        print(f"Rank {rank} after reduce (tensor content undefined): {tensor.cpu().tolist()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 2 before reduce: [[2.0, 2.0, 2.0, 2.0, 2.0]]
Rank 3 before reduce: [[3.0, 3.0, 3.0, 3.0, 3.0]]
Rank 1 before reduce: [[1.0, 1.0, 1.0, 1.0, 1.0]]
Rank 0 before reduce: [[0.0, 0.0, 0.0, 0.0, 0.0]]
Rank 1 after reduce (tensor content undefined): [[1.0, 1.0, 1.0, 1.0, 1.0]]
Rank 2 after reduce (tensor content undefined): [[2.0, 2.0, 2.0, 2.0, 2.0]]
Rank 0 after reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]
Rank 3 after reduce (tensor content undefined): [[3.0, 3.0, 3.0, 3.0, 3.0]]



# 5 全规约（all reduce）

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 本地张量：形状[1, 5]，数值为 rank
    tensor = torch.ones(1, 5, device=rank) * rank

    print(f"Rank {rank} before all_reduce: {tensor.cpu().tolist()}")

    # 执行all_reduce操作，求和并广播到所有进程
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

    print(f"Rank {rank} after all_reduce (sum): {tensor.cpu().tolist()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 2 before all_reduce: [[2.0, 2.0, 2.0, 2.0, 2.0]]
Rank 3 before all_reduce: [[3.0, 3.0, 3.0, 3.0, 3.0]]
Rank 1 before all_reduce: [[1.0, 1.0, 1.0, 1.0, 1.0]]
Rank 0 before all_reduce: [[0.0, 0.0, 0.0, 0.0, 0.0]]
Rank 1 after all_reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]
Rank 2 after all_reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]
Rank 3 after all_reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]
Rank 0 after all_reduce (sum): [[6.0, 6.0, 6.0, 6.0, 6.0]]



In [None]:
# 增加维度观测输出结果：
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 本地张量：形状[2, 5]，所有元素值为 rank
    tensor = torch.full((2, 5), rank, dtype=torch.float, device=rank)

    print(f"Rank {rank} before all_reduce:\n{tensor.cpu().numpy()}")

    # 执行all_reduce求和
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

    print(f"Rank {rank} after all_reduce (sum):\n{tensor.cpu().numpy()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 2 before all_reduce:
[[2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2.]]
Rank 1 before all_reduce:
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
Rank 3 before all_reduce:
[[3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3.]]
Rank 0 before all_reduce:
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
Rank 0 after all_reduce (sum):
[[6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6.]]
Rank 2 after all_reduce (sum):
[[6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6.]]
Rank 1 after all_reduce (sum):
[[6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6.]]
Rank 3 after all_reduce (sum):
[[6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6.]]



# 6 分发（scatter）

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 接收缓冲区：每个进程得到一行，形状[5]
    recv_tensor = torch.empty(5, device=rank, dtype=torch.float)

    if rank == 0:
        # 源进程：创建一个形状[world_size, 5]的大张量，每一行填充行号
        data_to_scatter = torch.arange(world_size, device=0).float().unsqueeze(1).repeat(1, 5)
        print(f"Rank {rank}: Original data to scatter:\n{data_to_scatter.cpu().numpy()}")

        # 将大张量拆分为列表，每个元素是形状[5]的张量（对应每一行）
        scatter_list = [data_to_scatter[i] for i in range(world_size)]  # 每个元素形状[5]
        print(f"Rank {rank}: Scatter list shapes: {[t.shape for t in scatter_list]}")
    else:
        scatter_list = None

    # 执行scatter操作
    dist.scatter(recv_tensor, scatter_list=scatter_list, src=0)

    print(f"Rank {rank} received tensor of shape {recv_tensor.shape}:\n{recv_tensor.cpu().numpy()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 0: Original data to scatter:
[[0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1.]
 [2. 2. 2. 2. 2.]
 [3. 3. 3. 3. 3.]]
Rank 0: Scatter list shapes: [torch.Size([5]), torch.Size([5]), torch.Size([5]), torch.Size([5])]
Rank 0 received tensor of shape torch.Size([5]):
[0. 0. 0. 0. 0.]
Rank 1 received tensor of shape torch.Size([5]):
[1. 1. 1. 1. 1.]
Rank 3 received tensor of shape torch.Size([5]):
[3. 3. 3. 3. 3.]
Rank 2 received tensor of shape torch.Size([5]):
[2. 2. 2. 2. 2.]



# 7 规约分发（Reduce Scatter）

输入[world_size, 5]，沿着第0维切分。每个rank拿到[1, 5] 数据。


## 7.1  Row reduce scatter(使用dist.reduce_scatter接口)

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 原始数据：形状[world_size, 5]，第i行的值为rank*10 + i
    data = torch.zeros(world_size, 5, device=rank, dtype=torch.float)
    for i in range(world_size):
        data[i] = rank * 10 + i

    print(f"Rank {rank} original data (shape {data.shape}):\n{data.cpu().numpy()}")

    # 将data按第一维拆分成列表，每个元素保持二维形状[1, 5]（即每个分片是一行并保留维度）
    input_list = [data[i].unsqueeze(0) for i in range(world_size)]  # 列表长度=world_size，每个元素形状[1, 5]

    # 输出张量也设为二维[1, 5]，用于接收规约后属于当前 rank 的分片
    output = torch.empty(1, 5, device=rank, dtype=torch.float)

    # 执行 reduce_scatter：所有进程的input_list中对应本进程的分片会被规约（求和）后存入output
    dist.reduce_scatter(output, input_list, op=dist.ReduceOp.SUM)

    print(f"Rank {rank} after reduce_scatter (sum), output shape {output.shape}:\n{output.cpu().numpy()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 2 original data (shape torch.Size([4, 5])):
[[20. 20. 20. 20. 20.]
 [21. 21. 21. 21. 21.]
 [22. 22. 22. 22. 22.]
 [23. 23. 23. 23. 23.]]
Rank 3 original data (shape torch.Size([4, 5])):
[[30. 30. 30. 30. 30.]
 [31. 31. 31. 31. 31.]
 [32. 32. 32. 32. 32.]
 [33. 33. 33. 33. 33.]]
Rank 1 original data (shape torch.Size([4, 5])):
[[10. 10. 10. 10. 10.]
 [11. 11. 11. 11. 11.]
 [12. 12. 12. 12. 12.]
 [13. 13. 13. 13. 13.]]
Rank 0 original data (shape torch.Size([4, 5])):
[[0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1.]
 [2. 2. 2. 2. 2.]
 [3. 3. 3. 3. 3.]]
Rank 0 after reduce_scatter (sum), output shape torch.Size([1, 5]):
[[60. 60. 60. 60. 60.]]
Rank 1 after reduce_scatter (sum), output shape torch.Size([1, 5]):
[[64. 64. 64. 64. 64.]]
Rank 2 after reduce_scatter (sum), output shape torch.Size([1, 5]):
[[68. 68. 68. 68. 68.]]
Rank 3 after reduce_scatter (sum), output shape torch.Size([1, 5]):
[[72. 72. 72. 72. 72.]]



## 7.2  Row reduce scatter(使用reduce_scatter_tensor接口)

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # --- 构造输入大张量 ---
    # 每个进程有一个形状为[world_size, 5]的大张量
    # 这个张量相当于之前input_list中所有元素的拼接
    input_tensor = torch.zeros(world_size, 5, device=rank, dtype=torch.float)
    for i in range(world_size):
        input_tensor[i] = rank * 10 + i
    print(f"Rank {rank} input_tensor (shape {input_tensor.shape}):\n{input_tensor.cpu().numpy()}")

    # --- 准备输出张量 ---
    # 输出张量的形状是[1, 5]，即每个进程最终得到的结果块
    # 注意：因为沿第0维切分，输入是 [world_size, 5]，每个进程得到的块大小就是 [1, 5]
    output = torch.empty(1, 5, device=rank, dtype=torch.float)

    # --- 执行 reduce_scatter_tensor ---
    # PyTorch的reduce_scatter_tensor API会自动将input_tensor沿dim=0切分成world_size块
    # 并将对应rank的块进行规约（求和）后存入output
    dist.reduce_scatter_tensor(output, input_tensor, op=dist.ReduceOp.SUM)

    print(f"Rank {rank} after reduce_scatter_tensor (sum), output shape {output.shape}:\n{output.cpu().numpy()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 1 input_tensor (shape torch.Size([4, 5])):
[[10. 10. 10. 10. 10.]
 [11. 11. 11. 11. 11.]
 [12. 12. 12. 12. 12.]
 [13. 13. 13. 13. 13.]]
Rank 2 input_tensor (shape torch.Size([4, 5])):
[[20. 20. 20. 20. 20.]
 [21. 21. 21. 21. 21.]
 [22. 22. 22. 22. 22.]
 [23. 23. 23. 23. 23.]]
Rank 0 input_tensor (shape torch.Size([4, 5])):
[[0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1.]
 [2. 2. 2. 2. 2.]
 [3. 3. 3. 3. 3.]]
Rank 3 input_tensor (shape torch.Size([4, 5])):
[[30. 30. 30. 30. 30.]
 [31. 31. 31. 31. 31.]
 [32. 32. 32. 32. 32.]
 [33. 33. 33. 33. 33.]]
Rank 3 after reduce_scatter_tensor (sum), output shape torch.Size([1, 5]):
[[72. 72. 72. 72. 72.]]
Rank 0 after reduce_scatter_tensor (sum), output shape torch.Size([1, 5]):
[[60. 60. 60. 60. 60.]]
Rank 1 after reduce_scatter_tensor (sum), output shape torch.Size([1, 5]):
[[64. 64. 64. 64. 64.]]
Rank 2 after reduce_scatter_tensor (sum), output shape torch.Size([1, 5]):
[[68. 68. 68. 68. 68.]]




## 7.3 Column reduce scatter

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 原始数据：形状[5, world_size]，每一列的值设为rank*10 + j
    data = torch.zeros(5, world_size, device=rank, dtype=torch.float)
    for j in range(world_size):
        data[:, j] = rank * 10 + j
    print(f"Rank {rank} original data (shape {data.shape}):\n{data.cpu().numpy()}")

    # 沿维度1（列）切分，每个分片形状[5, 1]
    # 使用torch.split 沿dim=1切分成world_size个[5,1]的块
    input_list = list(torch.split(data, 1, dim=1))  # 列表长度=world_size

    # 输出张量：接收规约后属于本进程的列，形状也是[5, 1]
    output = torch.empty(5, 1, device=rank, dtype=torch.float)

    # 执行reduce_scatter：所有进程的input_list中对应本进程的分片（即第 rank 列）被求和后存入output
    dist.reduce_scatter(output, input_list, op=dist.ReduceOp.SUM)

    print(f"Rank {rank} after reduce_scatter (sum), output shape {output.shape}:\n{output.cpu().numpy()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 3 original data (shape torch.Size([5, 4])):
[[30. 31. 32. 33.]
 [30. 31. 32. 33.]
 [30. 31. 32. 33.]
 [30. 31. 32. 33.]
 [30. 31. 32. 33.]]
Rank 1 original data (shape torch.Size([5, 4])):
[[10. 11. 12. 13.]
 [10. 11. 12. 13.]
 [10. 11. 12. 13.]
 [10. 11. 12. 13.]
 [10. 11. 12. 13.]]
Rank 2 original data (shape torch.Size([5, 4])):
[[20. 21. 22. 23.]
 [20. 21. 22. 23.]
 [20. 21. 22. 23.]
 [20. 21. 22. 23.]
 [20. 21. 22. 23.]]
Rank 0 original data (shape torch.Size([5, 4])):
[[0. 1. 2. 3.]
 [0. 1. 2. 3.]
 [0. 1. 2. 3.]
 [0. 1. 2. 3.]
 [0. 1. 2. 3.]]
Rank 2 after reduce_scatter (sum), output shape torch.Size([5, 1]):
[[68.]
 [68.]
 [68.]
 [68.]
 [68.]]
Rank 1 after reduce_scatter (sum), output shape torch.Size([5, 1]):
[[64.]
 [64.]
 [64.]
 [64.]
 [64.]]Rank 0 after reduce_scatter (sum), output shape torch.Size([5, 1]):
[[60.]
 [60.]
 [60.]
 [60.]
 [60.]]

Rank 3 after reduce_scatter (sum), output shape torch.Size([5, 1]):
[[72.]
 [72.]
 [72.]
 [72.]
 [72.]]



# 8 多对多（all to all）

## 8.1 Row alltoall

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 每个元素的大小（这里每个数据块是[2]的张量，所以每个元素实际是标量，但这里用块大小表示）
    element_size = 2  # 每个块包含的元素个数

    # --- 构造发送数据 ---
    # 发送缓冲区形状：[world_size, element_size]
    send_tensor = torch.zeros(world_size, element_size, device=rank, dtype=torch.float)
    for dst in range(world_size):
        # 将要发送给进程dst的块填充为rank*10 + dst
        send_tensor[dst] = rank * 10 + dst
    print(f"Rank {rank} send_tensor (before alltoall):\n{send_tensor.cpu().numpy()}")

    # --- 准备接收缓冲区 ---
    # 接收缓冲区形状同样为[world_size, element_size]
    recv_tensor = torch.empty(world_size, element_size, device=rank, dtype=torch.float)

    # --- 执行 All-to-All ---
    # 使用all_to_all_single，它会自动根据张量的第0维切分（因为传入的是单个张量，不是列表）
    # 参数说明：
    #   output: 接收缓冲区
    #   input: 发送缓冲区
    #   其他参数（如 output_split_sizes, input_split_sizes）可省略，默认均匀切分
    dist.all_to_all_single(recv_tensor, send_tensor)

    print(f"\nRank {rank} recv_tensor (after alltoall):\n{recv_tensor.cpu().numpy()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 3 send_tensor (before alltoall):
[[30. 30.]
 [31. 31.]
 [32. 32.]
 [33. 33.]]
Rank 1 send_tensor (before alltoall):
[[10. 10.]
 [11. 11.]
 [12. 12.]
 [13. 13.]]
Rank 0 send_tensor (before alltoall):
[[0. 0.]
 [1. 1.]
 [2. 2.]
 [3. 3.]]
Rank 2 send_tensor (before alltoall):
[[20. 20.]
 [21. 21.]
 [22. 22.]
 [23. 23.]]

Rank 0 recv_tensor (after alltoall):
[[ 0.  0.]
 [10. 10.]
 [20. 20.]
 [30. 30.]]
Rank 1 recv_tensor (after alltoall):
[[ 1.  1.]
 [11. 11.]
 [21. 21.]
 [31. 31.]]


Rank 2 recv_tensor (after alltoall):
[[ 2.  2.]
 [12. 12.]
 [22. 22.]
 [32. 32.]]

Rank 3 recv_tensor (after alltoall):
[[ 3.  3.]
 [13. 13.]
 [23. 23.]
 [33. 33.]]



## 8.2 Column alltoall

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 本地张量：形状[5, world_size]，每一列的值设为 rank*10 + 本地列索引
    local_tensor = torch.zeros(5, world_size, device=rank, dtype=torch.float)
    for col in range(world_size):
        local_tensor[:, col] = rank * 10 + col
    print(f"Rank {rank} local tensor (shape {local_tensor.shape}):\n{local_tensor.cpu().numpy()}")

    # 沿列切分成world_size个块，每个块形状[5, 1]，并确保连续
    # 方法1：使用切片 + contiguous
    input_list = [local_tensor[:, col:col+1].contiguous() for col in range(world_size)]
    # 方法2：使用 torch.chunk + contiguous
    # chunks = torch.chunk(local_tensor, world_size, dim=1)
    # input_list = [chunk.contiguous() for chunk in chunks]

    # 准备输出列表，每个元素形状[5, 1]（新建张量默认连续）
    output_list = [torch.empty(5, 1, device=rank, dtype=torch.float) for _ in range(world_size)]

    # 执行 all_to_all（列表版本）
    dist.all_to_all(output_list, input_list)

    # 沿列维度拼接输出列表，得到最终的[5, world_size] 张量
    result = torch.cat(output_list, dim=1)

    print(f"Rank {rank} after column-wise all_to_all (shape {result.shape}):\n{result.cpu().numpy()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 3 local tensor (shape torch.Size([5, 4])):
[[30. 31. 32. 33.]
 [30. 31. 32. 33.]
 [30. 31. 32. 33.]
 [30. 31. 32. 33.]
 [30. 31. 32. 33.]]
Rank 2 local tensor (shape torch.Size([5, 4])):
[[20. 21. 22. 23.]
 [20. 21. 22. 23.]
 [20. 21. 22. 23.]
 [20. 21. 22. 23.]
 [20. 21. 22. 23.]]
Rank 0 local tensor (shape torch.Size([5, 4])):
[[0. 1. 2. 3.]
 [0. 1. 2. 3.]
 [0. 1. 2. 3.]
 [0. 1. 2. 3.]
 [0. 1. 2. 3.]]
Rank 1 local tensor (shape torch.Size([5, 4])):
[[10. 11. 12. 13.]
 [10. 11. 12. 13.]
 [10. 11. 12. 13.]
 [10. 11. 12. 13.]
 [10. 11. 12. 13.]]
Rank 0 after column-wise all_to_all (shape torch.Size([5, 4])):
[[ 0. 10. 20. 30.]
 [ 0. 10. 20. 30.]
 [ 0. 10. 20. 30.]
 [ 0. 10. 20. 30.]
 [ 0. 10. 20. 30.]]
Rank 3 after column-wise all_to_all (shape torch.Size([5, 4])):
[[ 3. 13. 23. 33.]
 [ 3. 13. 23. 33.]
 [ 3. 13. 23. 33.]
 [ 3. 13. 23. 33.]
 [ 3. 13. 23. 33.]]
Rank 1 after column-wise all_to_all (shape torch.Size([5, 4])):
[[ 1. 11. 21. 31.]
 [ 1. 11. 21. 31.]
 [ 1. 11. 21. 31.]
 [

# 9 广播（broadcast）

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '62115'

def example(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # 定义张量形状
    shape = (5, 2)

    if rank == 0:
        # 源进程：创建一个全为1的张量
        tensor = torch.ones(shape, device=rank, dtype=torch.float)
        print(f"Rank {rank} before broadcast:\n{tensor.cpu().numpy()}")
    else:
        # 非源进程：创建一个空的张量（或零张量），形状与源进程一致
        tensor = torch.zeros(shape, device=rank, dtype=torch.float)
        print(f"Rank {rank} before broadcast (initial zeros):\n{tensor.cpu().numpy()}")

    # 执行广播：所有进程从rank 0接收数据
    dist.broadcast(tensor, src=0)

    # 打印广播后的结果
    print(f"Rank {rank} after broadcast:\n{tensor.cpu().numpy()}")

def main():
    world_size = 4
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()


Rank 1 before broadcast (initial zeros):
[[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
Rank 2 before broadcast (initial zeros):
[[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
Rank 0 before broadcast:
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]
Rank 3 before broadcast (initial zeros):
[[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
Rank 0 after broadcast:
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]
Rank 1 after broadcast:
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]
Rank 3 after broadcast:
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]
Rank 2 after broadcast:
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]

