Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

customized distributed store causes crash during p2p communication #115977

Open
xial-thu opened this issue Dec 16, 2023 · 4 comments
Open

customized distributed store causes crash during p2p communication #115977

xial-thu opened this issue Dec 16, 2023 · 4 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@xial-thu
Copy link

xial-thu commented Dec 16, 2023

🐛 Describe the bug

I'm wrapping torch distributed store with a python class, to figure out what happended and thus I can do tracing, monitoring and logging on process group initialization. here's the code:

class TcpStoreWrapper(Store):
    '''a wrapper of TCPStore, to trace the internal behavior'''

    def __init__(self, rank: int, world_size: int, port: Optional[int] = None, hostname: Optional[str] = None, timeout: Optional[timedelta] = None):
        super().__init__()
        self.rank = rank
        self.is_master = rank == 0
        self.real_store = TCPStore(
            host_name=hostname if hostname is not None else os.getenv("MASTER_ADDR"),
            port=port if port is not None else int(os.getenv("MASTER_PORT")) + 100,
            timeout=timeout,
            world_size=world_size,
            is_master=self.is_master,
            wait_for_workers=True,
        )

    def set(self, key: str, value: str):
        return self.real_store.set(key, value)

    def get(self, key: str) -> bytes:
        return self.real_store.get(key)

    def add(self, key: str, value: int) -> int:
        return self.real_store.add(key, value)

    def compare_set(self, key: str, expected_value: str, desired_value: str) -> bytes:
        return self.real_store.compare_set(key, expected_value, desired_value)

    def delete_key(self, key: str) -> bool:
        return self.real_store.delete_key(key)

    def num_keys(self) -> int:
        return self.real_store.num_keys()

    def set_timeout(self, timeout: timedelta):
        return self.real_store.set_timeout(timeout)

    @overload
    def wait(self, keys: List[str]):
        return self.real_store.wait(keys)

    @overload
    def wait(self, keys: List[str], timeout: timedelta):
        return = self.real_store.wait(keys, timeout)

But it goes weird when I'm running on megatron-lm. The logic is simply like

  1. init DP, TP, PP groups. For example, init by group = torch.distributed.new_group(ranks) and store it in a global var _TENSOR_MODEL_PARALLEL_GROUP
  2. in other files (I think maybe it's the key to the bug), fetch the group and execute. For example, dist.barrier(group=mpu._TENSOR_MODEL_PARALLEL_GROUP)

Then when running, the program crashes. Here's the log. It can be divided into two cases.

Case1: some rank cannot get value from store. It‘s expected because the root cause is not here.

File "/home/xialei/.

local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3330, in barrier
work = group.barrier(opts=opts)
│ │ └ <torch.distributed.distributed_c10d.BarrierOptions object at 0x7f8cf0563870>
│ └ <instancemethod barrier at 0x7f8d0719e140>
└ <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f8cf057c9b0>

RuntimeError: [1] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: fn INTERNAL ASSERT FAILED at "../torch/csrc/distributed/c10d/init.cpp":151, please report a bug to PyTorch.

Case2: root cause. When rank 0 try to set the key-value, it crashes

File "/home/xialei/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3330, in barrier
work = group.barrier(opts=opts)
│ │ └ <torch.distributed.distributed_c10d.BarrierOptions object at 0x7f05b0c702f0>
│ └ <instancemethod barrier at 0x7f05c1bde140>
└ <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f05b0513bf0>

RuntimeError: fn INTERNAL ASSERT FAILED at "../torch/csrc/distributed/c10d/init.cpp":137, please report a bug to PyTorch.

The file in the log shows

class PythonStore : public ::c10d::Store {
 public:
  using ::c10d::Store::Store;

  // Note: this function manually calls the Python-side overload
  // for this function instead of using the PYBIND11_OVERLOAD_XYZ
  // macros. This is done so that we can call the Python-side
  // function with a std::string instead of a std::vector<uint8_t>.
  void set(const std::string& key, const std::vector<uint8_t>& value) override {
    pybind11::gil_scoped_acquire gil;
    pybind11::function fn =
        pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set");
    TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
    // Call function with a py::bytes object for the value.
    fn(key,
       py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
  }

But! what really annoys me is when I write a simple test program, it works perfect! The init_process_group, group all reduce, global barrier... everything goes fine on a simple test program. here's the code

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import random
from datetime import timedelta
import threading
from typing import Optional
import time
from loguru import logger
from torch._C._distributed_c10d import Store, TCPStore
from typing import List

store = TcpStoreWrapper(
    hostname="localhost",
    port=4567,
    world_size=world_size,
    rank=rank,
    timeout=timedelta(seconds=60),
)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank, store=store)
assert dist.is_initialized()

local_world_size = int(os.environ['LOCAL_WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
n = 1 * 1000_000_000 // 4
stride = 2
x = torch.zeros(n).cuda(local_rank)

nnodes = dist.get_world_size() // local_world_size
groups = [None] * local_world_size
for i in range(0, local_world_size, stride):
    groups[i] = dist.new_group(ranks=[i + j * local_world_size for j in range(nnodes)])

for i in range(10):
    dist.barrier()
    t = time.time()
    if local_rank % stride == 0:
        dist.all_reduce(x, group=groups[local_rank])
    torch.cuda.synchronize()
    dist.barrier()
    if dist.get_rank() == 0:
        print("{:.1f} MB/s".format(n * 4 / (time.time() - t) / 1e6))

Execute the script by torchrun --nproc_per_node=8 --nnodes 1 --node_rank 0 --master_addr=localhost --master_port=23791 main.py. as you can see, I generate new group and put them in a list, then fetch it from list, but it won't crash.

Comparing the two different cases, I guess is group fetched by another file loses some properties(maybe related to python itself), and torch cannot distinguish the store inside the group?

Please help me, it really helps on open the black box of training, otherwise once I get stuck on TCP store, all I can do is restart, and it's painful.

Versions

Collecting environment information...
PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.35

Python version: 3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.10.25-nvidia-gpu-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 57 bits virtual
Byte Order:                      Little Endian
CPU(s):                          65
On-line CPU(s) list:             0-64
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Gold 6330 CPU @ 2.00GHz
CPU family:                      6
Model:                           106
Thread(s) per core:              1
Core(s) per socket:              1
Socket(s):                       65
Stepping:                        6
BogoMIPS:                        4000.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid md_clear arch_capabilities
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       2 MiB (65 instances)
L1i cache:                       2 MiB (65 instances)
L2 cache:                        260 MiB (65 instances)
L3 cache:                        1 GiB (65 instances)
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] torch==2.0.1
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.0.2
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] Could not collect

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225

@cpuhrsch cpuhrsch added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Dec 17, 2023
@pkuwangh
Copy link

@xial-thu , did you figure out how to overcome this?

@wconstab
Copy link
Contributor

wconstab commented Mar 1, 2024

Just wondering, are you trying to wrap the TCPStore due to performance issues with initialization? If so you may want to try enabling the new LIB_UV backend for TCPStore.

export USE_LIBUV=1 to test this out

@XilunWu maybe you can take a look at the wrapping issues blocking the customized store

@pkuwangh
Copy link

pkuwangh commented Mar 1, 2024

I actually figured out what's wrong on my end yesterday and I have a fix for my code.

I created my CustomStore object within a function and passed into init_process_group. Later that object is getting destroyed/garbage-collected. Therefore that pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set") returns a null.

The interface of init_process_group does make you feel like it will take the ownership of that Store object; my guess is it is passed into the C++ world and then the Python's reference counter does not track that.

@wconstab
Copy link
Contributor

wconstab commented Mar 1, 2024

Maybe we can follow up here by adding a documentation note on the pyi file for TCPStore explaining this for anyone else trying to extend the store. Cc @XilunWu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

4 participants