Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different numerical results during training with CUDA graphs #119873

Open
tsengalb99 opened this issue Feb 14, 2024 · 8 comments
Open

Different numerical results during training with CUDA graphs #119873

tsengalb99 opened this issue Feb 14, 2024 · 8 comments
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: numerical-reproducibility triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@tsengalb99
Copy link

tsengalb99 commented Feb 14, 2024

🐛 Describe the bug

I am trying to wrap a sequence of Llama decoder layers (as implemented by huggingface) with a cuda graph to speed up training. Without cuda graphs, the loss decreases normally. With cuda graphs, the loss decreases at a slower rate and does not converge to the same solution. I have some custom kernels in my linear layer that decompress quantized weights but I verified these are not the source of the issue by pre-manifesting the decompressed weights and just calling x@W.T in the linear layer.

Is there a specific way that cuda graphs need to be set up to support training?

Details:

I0214 05:23:27.217544 4143374 finetune.py:164] initial loss 2.3495054244995117
I0214 05:24:07.684923 4143374 finetune.py:190] epoch 0 new loss 2.3075146675109863 old loss 2.3495054244995117 BETTER
I0214 05:24:47.764547 4143374 finetune.py:190] epoch 1 new loss 2.2501564025878906 old loss 2.3075146675109863 BETTER
  • whereas with cuda graphs (comment out L69 and uncomment L68) you should get
I0214 04:24:51.857225 4141303 finetune.py:164] initial loss 2.349527359008789
I0214 04:25:05.204223 4141303 finetune.py:204] epoch 0 new loss 2.3479909896850586 old loss 2.349527359008789 BETTER
I0214 04:25:18.243131 4141303 finetune.py:204] epoch 1 new loss 2.347287654876709 old loss 2.3479909896850586 BETTER
  • You can test manifesting the whole weight matrix with --ft_train_mode which should give the same results. This happens when using only one gpu as well so it should not be due to using more than one gpu.

Versions

Collecting environment information...
PyTorch version: 2.1.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-139-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe
GPU 4: NVIDIA A100 80GB PCIe
GPU 5: NVIDIA A100 80GB PCIe
GPU 6: NVIDIA A100 80GB PCIe
GPU 7: NVIDIA A100 80GB PCIe

Nvidia driver version: 525.105.17
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 57 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 8
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 106
Model name: Intel(R) Xeon(R) Gold 6338 CPU @ 2.00GHz
Stepping: 6
CPU MHz: 1999.982
BogoMIPS: 3999.96
Virtualization: VT-x
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 192 MiB
L3 cache: 128 MiB
NUMA node0 CPU(s): 0-47
NUMA node1 CPU(s): 48-95
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid md_clear arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchvision==0.16.1
[pip3] triton==2.1.0
[conda] numpy 1.26.3 pypi_0 pypi
[conda] torch 2.1.1 pypi_0 pypi
[conda] torchaudio 2.1.1 pypi_0 pypi
[conda] torchvision 0.16.1 pypi_0 pypi
[conda] triton 2.1.0 pypi_0 pypi

cc @mcarilli @ezyang

@colesbury colesbury added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: numerical-reproducibility module: cuda graphs Ability to capture and then replay streams of CUDA kernels labels Feb 14, 2024
@ezyang
Copy link
Contributor

ezyang commented Feb 15, 2024

There are some safety conditions which must be fulfilled for CUDA graphs to give valid results. Hypothetically, torch.compile on PT2 with cudagraphs should automatically test these safety conditions. If you're willing to do a detour in making your code torch.compile'able, it might tell you about what might be the problem.

@tsengalb99
Copy link
Author

I did try torch.compile in both regular and 'reduce-overhead' (cuda graphs) mode on the Shard class. In both cases, I got a rather lengthy error. I can post those later today. Is there anything obviously wrong with what I'm doing? Essentially the structure is this:

  • A shard wrapper model that has n shards, each of which is on a specific gpu
  • Each shard has its own cuda graph
  • Only capture the forward pass in the graph
  • During training, do the usual training loop but call graph.replay() on the shards sequentially to do a forward pass and do the backward pass normally.

@ezyang
Copy link
Contributor

ezyang commented Feb 19, 2024

Well, it's not clear to me how you can "do the backward pass normally", because ordinarily when you run a forwards pass, on CPU we setup an autograd graph that says how to do backwards. If you cudagraph replay, though, this CPU compute is skipped (the point of cudagraphs) and now we can't backward. To cudagraph, you need to cudagraph both forward and backward (which is what PT2 would do for you.) Not so sure about the multi-gpu interaction though.

@tsengalb99
Copy link
Author

tsengalb99 commented Feb 19, 2024 via email

@ezyang
Copy link
Contributor

ezyang commented Feb 20, 2024

I know you're not going to like this answer... but if you use DDP instead of DP you will not have this problem :P

@tsengalb99
Copy link
Author

I'm not using DataParallel though? The shard wrapper thing is just a quick wrapper class I wrote to manually do sharding. I did consider FSDP - do you know if that works with cuda graphs correctly?

@ezyang
Copy link
Contributor

ezyang commented Feb 21, 2024

cc @awgu, it feels like it could in principle but I don't know if we've done it in practice

@Abhishekghosh1998
Copy link

Hypothetically, torch.compile on PT2 with cudagraphs should automatically test these safety conditions. If you're willing to do a detour in making your code torch.compile'able, it might tell you about what might be the problem.

@ezyang Could you please point to some documentation/blog that discusses what are the capabilities of torch.compile in the context of CUDA Graph? Like, what are the "safety conditions" which get automatically tested?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: numerical-reproducibility triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants