-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
I've been testing the new native c10d functional collectives integration with inductor (#112439) as I was told by @eellison it would fix some buffer reusability bugs I was hitting (which it does!) and its ability to codegen to different backends, but it seems like there's still some issue when combining them with the CUDAGraph integration in torch.compile(). Specifically, I'm getting different output logits when running the script in https://github.com/foundation-model-stack/foundation-model-stack/blob/cudagraph-repros/scripts/cudagraph_repro1.py (cudagraph-repos
branch). For eager and regular compile, the outputs basically match, but on reduce-overhead
they are completely different.
I've been experimenting to see if I can better isolate the issue, and I've got it down to tensors being used before the allReduce and allGather operations are done with them. The issue doesn't always happen, which points to a race condition of some kind, and it only happens when the compute kernels and their launches are faster than nccl, as for example running the script with 2 processes might work 50% of the time, but making nccl slower by running 4 or 8 processes will cause the issue 100% of the time on an AWS p4de node.
To run the repro script:
git clone https://github.com/foundation-model-stack/foundation-model-stack.git
cd foundation-model-stack
git checkout cudagraph-repros
python setup.py develop
torchrun --nproc_per_node=2 scripts/cudagraph_repro1.py
torchrun --nproc_per_node=4 scripts/cudagraph_repro1.py
Other relevant pieces of code, including our Tensor Parallel implementation using the new native functional collectives are in https://github.com/foundation-model-stack/foundation-model-stack/blob/cudagraph-repros/fms/distributed/tensorparallel.py#L50-L65, and the Llama model implementation is in https://github.com/foundation-model-stack/foundation-model-stack/blob/cudagraph-repros/fms/models/llama.py.
I'll try to get a smaller repro, although the current one already runs in under 20s in my testing machine.
Versions
Collecting environment information...
PyTorch version: 2.2.0.dev20231116
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.11.6 | packaged by conda-forge | (main, Oct 3 2023, 10:40:35) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1041-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB
Nvidia driver version: 535.54.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping: 7
CPU MHz: 2999.998
BogoMIPS: 5999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 71.5 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Versions of relevant libraries:
[pip3] numpy==1.26.0
[pip3] torch==2.2.0.dev20231116
[pip3] torchaudio==2.2.0.dev20231116
[pip3] torchvision==0.17.0.dev20231116
[pip3] triton==2.1.0
[conda] blas 2.116 mkl conda-forge
[conda] blas-devel 3.9.0 16_linux64_mkl conda-forge
[conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly
[conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly
[conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly
[conda] filelock 3.9.0 py311_0 pytorch-nightly
[conda] libblas 3.9.0 16_linux64_mkl conda-forge
[conda] libcblas 3.9.0 16_linux64_mkl conda-forge
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch-nightly
[conda] liblapack 3.9.0 16_linux64_mkl conda-forge
[conda] liblapacke 3.9.0 16_linux64_mkl conda-forge
[conda] mkl 2022.1.0 h84fe81f_915 conda-forge
[conda] mkl-devel 2022.1.0 ha770c72_916 conda-forge
[conda] mkl-include 2022.1.0 h84fe81f_915 conda-forge
[conda] mpmath 1.2.1 py311_0 pytorch-nightly
[conda] numpy 1.26.0 py311h64a7726_0 conda-forge
[conda] pillow 9.3.0 py311h3fd9d12_2 pytorch-nightly
[conda] pysocks 1.7.1 py311_0 pytorch-nightly
[conda] pytorch 2.2.0.dev20231116 py3.11_cuda12.1_cudnn8.9.2_0 pytorch-nightly
[conda] pytorch-cuda 12.1 ha16c6d3_5 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] requests 2.28.1 py311_0 pytorch-nightly
[conda] torchaudio 2.2.0.dev20231116 py311_cu121 pytorch-nightly
[conda] torchtriton 2.1.0+6e4932cda8 py311 pytorch-nightly
[conda] torchvision 0.17.0.dev20231116 py311_cu121 pytorch-nightly
[conda] urllib3 1.26.14 py311_0 pytorch-nightly
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler