In [7]:
import os
import torch.multiprocessing as mp
import torch.distributed as dist
import torch

def distributed_init():

    dist.init_process_group(backend="nccl")
    local_rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(local_rank)

    return local_rank, world_size

In [8]:
def find_free_port():
    """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """
    import socket
    from contextlib import closing

    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return str(s.getsockname()[1])

In [12]:
world_size = 2
master_addr = '127.0.0.1'
master_port = find_free_port()

In [13]:
master_port

'57727'

In [None]:
from datetime import timedelta

rank = 0
backend = 'nccl'

print(f'setting up {rank=} {world_size=} {backend=}')

# set up the master's ip address so this child process can coordinate
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
print(f"{master_addr=} {master_port=}")

# Initializes the default distributed process group, and this will also initialize the distributed package.
dist.init_process_group(backend, rank=rank, world_size=world_size, timeout=timedelta(seconds=30), init_method="env://")
print(f"{rank=} init complete")
dist.destroy_process_group()
print(f"{rank=} destroy complete")

In [None]:
CUDA_VISIBLE_DEVICES=6,7 OMP_NUM_THREADS=48 torchrun --nproc_per_node=2 test/TP_baseline.py