Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct P2P GPU <-> GPU communication with torch.to does not seem to work. #119638

Closed
morgangiraud opened this issue Feb 10, 2024 · 15 comments
Closed
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@morgangiraud
Copy link

morgangiraud commented Feb 10, 2024

🐛 Describe the bug

Hi,

I've been looking at direct GPU <-> GPU communication using the tensor.to pytorch function and I've found that it doesn't seem to be able to copy the tensor from one CUDA device to the other directly.
I'm sorry if I've missed something obvious but I didn't see anywhere that this shouldn't work as expected

import torch
import importlib.metadata
import time


print(f"torch version: {importlib.metadata.version('torch')}\n")

print(f"cuda is_available: {torch.cuda.is_available()}")
d_count = torch.cuda.device_count()
print(f"device_count(): {d_count}")
for idx in range(d_count):
    print(f"get_device_name({idx}): {torch.cuda.get_device_name(idx)}")
    print(f"get_device_properties({idx}): {torch.cuda.get_device_properties(idx)}")
    print(f"get_device_capability({idx}): {torch.cuda.get_device_capability(idx)}")
print(f"current device: {torch.cuda.current_device()}\n")

if d_count > 1:
    access_mat = torch.zeros((d_count, d_count), dtype=torch.bool)
    for i in range(d_count):
        for j in range(d_count):
            access_mat[i, j] = (
                torch.cuda.can_device_access_peer(i, j) if i != j else True
            )

    print("Devices access matrix:\n", access_mat.data, "\n")


def get_tensor_info(name, t):
    return f"{name:10s} -> device:{t.device}, dtype:{t.dtype}, shape:{t.shape}, mean:{t.mean()}"


class DistributedMatMul(torch.nn.Module):
    def __init__(self, D):
        super().__init__()

        self.device0 = torch.device("cuda", 0)
        self.device1 = torch.device("cuda", 1)

        self.w0 = torch.ones((D, 2 * D), dtype=torch.float32, device=self.device0)
        self.w1 = torch.ones((2 * D, D), dtype=torch.float32, device=self.device1)

    def forward(self, x):
        x_gpu_0 = x.to(self.device0)

        y0 = x_gpu_0 @ self.w0
        print(f"{'y0':10s} -> {y0}")

        # y0_gpu_1 = y0.to("cpu").to(self.device1) # This work

        y0_gpu_1 = y0.to(self.device1)  # This does not work
        print(
            f"{'y0_gpu_1':10s} -> {y0_gpu_1}"
        )  # should return [[2., 2., 2., 2.]] but returns [[1., 0., 0., 0.]]

        y1 = y0_gpu_1 @ self.w1

        y_cpu = y1.cpu().mean()

        return y_cpu

    def __str__(self):
        w0 = get_tensor_info("w0", self.w0)
        w1 = get_tensor_info("w1", self.w1)
        return f"{w0}\n{w1}"


torch.manual_seed(0)

N = 1
D = 2

model = DistributedMatMul(D)
print(model)
x_cpu = torch.ones((N, D), dtype=torch.float32, device="cpu")
y_cpu = model(x_cpu)

# Returns 0.0 when it does not work, should return 8.0
print(get_tensor_info("y_cpu", y_cpu))

######
# Output for the above code
######
# torch version: 2.1.2.post301

# cuda is_available: True
# device_count(): 2
# get_device_name(0): NVIDIA Graphics Device
# get_device_properties(0): _CudaDeviceProperties(name='NVIDIA Graphics Device', major=8, minor=9, total_memory=15868MB, multi_processor_count=66)
# get_device_capability(0): (8, 9)
# get_device_name(1): NVIDIA Graphics Device
# get_device_properties(1): _CudaDeviceProperties(name='NVIDIA Graphics Device', major=8, minor=9, total_memory=15868MB, multi_processor_count=66)
# get_device_capability(1): (8, 9)
# current device: 0

# Devices access matrix:
#  tensor([[True, True],
#         [True, True]])

# w0         -> device:cuda:0, dtype:torch.float32, shape:torch.Size([2, 4]), mean:1.0
# w1         -> device:cuda:1, dtype:torch.float32, shape:torch.Size([4, 2]), mean:1.0
# y0         -> tensor([[2., 2., 2., 2.]], device='cuda:0')
# y0_gpu_1   -> tensor([[1., 0., 0., 0.]], device='cuda:1') <------------ !!!! The copy is completely wrong
# y_cpu      -> device:cpu, dtype:torch.float32, shape:torch.Size([]), mean:1.0


######
# I also ran the Nvidia cuda samples p2pBandwidthLatencyTest
######
# Device: 0, NVIDIA Graphics Device, pciBusID: 1, pciDeviceID: 0, pciDomainID:0
# Device: 1, NVIDIA Graphics Device, pciBusID: 3, pciDeviceID: 0, pciDomainID:0
# Device=0 CAN Access Peer Device=1
# Device=1 CAN Access Peer Device=0

# ***NOTE: In case a device doesn't have P2P access to other one, it falls back to normal memcopy procedure.
# So you can see lesser Bandwidth (GB/s) and unstable Latency (us) in those cases.

# P2P Connectivity Matrix
#      D\D     0     1
#      0	     1     1
#      1	     1     1
# Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
#    D\D     0      1
#      0 608.69  12.13
#      1  12.11 613.95
# Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
#    D\D     0      1
#      0 609.64  13.55
#      1  13.55 614.61
# Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
#    D\D     0      1
#      0 611.41  17.34
#      1  17.30 613.95
# Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
#    D\D     0      1
#      0 611.31  27.10
#      1  27.11 613.83
# P2P=Disabled Latency Matrix (us)
#    GPU     0      1
#      0   1.16  10.69
#      1  10.36   1.20

#    CPU     0      1
#      0   1.30   4.03
#      1   4.09   1.29
# P2P=Enabled Latency (P2P Writes) Matrix (us)
#    GPU     0      1
#      0   1.16   0.90
#      1   0.90   1.20

#    CPU     0      1
#      0   1.32   1.08
#      1   1.07   1.30

Versions

PyTorch version: 2.1.2.post301
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.12.1 | packaged by conda-forge | (main, Dec 23 2023, 08:03:24) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.5.0-17-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.3.107
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA Graphics Device
GPU 1: NVIDIA Graphics Device

Nvidia driver version: 545.23.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             24
On-line CPU(s) list:                0-23
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen 9 7900X 12-Core Processor
CPU family:                         25
Model:                              97
Thread(s) per core:                 2
Core(s) per socket:                 12
Socket(s):                          1
Stepping:                           2
CPU max MHz:                        5733,0000
CPU min MHz:                        400,0000
BogoMIPS:                           9399.81
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization:                     AMD-V
L1d cache:                          384 KiB (12 instances)
L1i cache:                          384 KiB (12 instances)
L2 cache:                           12 MiB (12 instances)
L3 cache:                           64 MiB (2 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-23
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.1.2.post301
[conda] libmagma                  2.7.2                h173bb3b_2    conda-forge
[conda] libmagma_sparse           2.7.2                h173bb3b_2    conda-forge
[conda] libtorch                  2.1.2           cuda120_h2aa5df7_301    conda-forge
[conda] magma                     2.7.2                h51420fd_2    conda-forge
[conda] mkl                       2023.2.0         h84fe81f_50496    conda-forge
[conda] numpy                     1.26.4          py312heda63a1_0    conda-forge
[conda] pytorch                   2.1.2           cuda120_py312h6e1fc47_301    conda-forge```

cc @ptrblck @albanD
@jbschlosser jbschlosser added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: python frontend For issues relating to PyTorch's Python frontend labels Feb 13, 2024
@morgangiraud
Copy link
Author

Let me add a missing piece of information, he is the result of nvidia-smi topo -m:

	GPU0	GPU1	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	PHB	0-23	0		N/A
GPU1	PHB	 X 	0-23	0		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

@hanzhi713
Copy link
Contributor

hanzhi713 commented Feb 18, 2024

@morgangiraud I've seen reports of some versions of nvidia driver 545 having broken P2P support. You may want to try upgrading your driver version to 545.29+ according to this vllm-project/vllm#1801 or downgrading to 535.x.

@morgangiraud
Copy link
Author

Thanks for helping me here. I've improved a bit the problem:

So if I downgrade to driver 535.* , the driver does not see P2P capabilities anymore and si the code is working but I have a non negligeable troughput penalty

Then, I started to look at NCCL to see if I could pinpoint the problem more precisely using the following threads:

Especially I ended up on this exact problem: NVIDIA/nccl#606 (comment)

Where If I run NCCL tests with NCCL_P2P_LEVEL=PHB it hangs and with NCCL_P2P_LEVEL=PXB it works.

Finally, I duplicated my code to tensorflow and the problem is the same, so this issue is not related directly to pytorch.

Anyway, do you know if there is an equivalent to NCCL_P2P_LEVEL=PXB for pytorch?

@hanzhi713
Copy link
Contributor

@morgangiraud NCCL_P2P_LEVEL=PXB will effectively disable P2P for your platform, since you only have PHB as the available link which is a level higher than PXB.

@morgangiraud
Copy link
Author

Ho all right, so that's why!
Thanks.

Do you see any link with the .to function in pytorch?
Or maybe there is a way to force it to go through the CPU for now like a shortcut for .cpu().to('gpu:1')?

@hanzhi713
Copy link
Contributor

I'm not quite sure how .to is implemented, but I suppose it's just a cudaMemcpy and we can't control the data route (whether going through the CPU or not). You might be able to improve the performance a little bit by using .to("cpu", non_blocking=True).to('cuda:1', non_blocking=True).

@morgangiraud
Copy link
Author

I see, thanks.

I will give a try to the beta drivers 550.* before abdicating.

@morgangiraud
Copy link
Author

Well:

  • drivers 535.* and 550.* tells me I do not have P2P capabilities and then of course, everything is working fine.
  • drivers 545.* tells me I have P2P capabilities but then NCCL hangs and common frameworks don't work.

So I'm left wondering if I do have P2P capabilities int the end. What is strange is the fact that the Nvidia script p2pBandwidthLatencyTest do show better p2p numbers with drivers 545.*. I'm puzzled.

@hanzhi713
Copy link
Contributor

Driver reporting it supports p2p doesn't always mean that it supports it correctly. That is what I mean by broken p2p support. In your case, it might be that p2p is not supported on your platform but somehow 545 driver thinks it does.

Also, p2pBandwidthLatencyTest only measures p2p time. It doesn't verify if the p2p result is correct. It's likely that it's not correct in your case.

@morgangiraud
Copy link
Author

morgangiraud commented Feb 20, 2024

Yes, you were right.

The end result is that the 40 series does not support P2P and driver 545.* is broken in that regard. (Nvidia could be clearer on that feature though).
I'm writing a small blog post to go through the post mortem of this. If you don't mind I will link it there for others ot find.

Thanks a lot for your help!

@hanzhi713
Copy link
Contributor

No problem. You can link this issue.

Also, there's a blog post about 4090 lacking P2P support: https://www.tomshardware.com/news/nvidia-confirms-geforce-cards-lack-p2p-support

@morgangiraud
Copy link
Author

@Klohto
Copy link

Klohto commented Apr 12, 2024

For anyone coming across this and the great writeup by @morgangiraud – tinygrad just published experimental P2P driver for 4090 https://github.com/tinygrad/open-gpu-kernel-modules, it should be compatible with torch AFAIK

@morgangiraud
Copy link
Author

Yes!
I actually tested it on my build and it worked (dual 4070 TI SUPER 16GB)!
TF, Pytorch, NCCL all good.

I do see some speed improvements but they are not crazy though.

@morgangiraud
Copy link
Author

NCCL all reduce test for the reference: https://pastebin.com/ne4ipn6K

58% speed up in the end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants