Skip to content

Discrepancy between CPU->GPU and GPU->CPU data transfer speeds #52718

@agnz

Description

@agnz

🐛 Bug

I've ran into an issue with data transfer rate from GPU to CPU being much slower than expected. After benchmarking, there seems to be a discrepancy in CPU->GPU and GPU->CPU data transfer speeds as shown below.

Two transfer methods are benchmarked. Method 1 - .to() method (using pinned and unpinned RAM are both evaluated). Method 2 - creating a new Tensor on the destination device and using .copy_() method from the source (again using pinned and unpinned tensors).

For my application I need to achieve a certain data throughput GPU->CPU and I'm trying to understand why all but copy_() method to a CPU tensor in paged RAM are so much slower than CPU->GPU transfers. This last method comes with an overhead of creating a destination tensor in paged RAM, which is not accounted for in the benchmark, which makes it worse in practice than it looks on paper.

I don't know if this behaviour is a bug or not, but was hoping the devs could shed some light on the differences.

To Reproduce

Steps to reproduce the behavior:

import torch
import time
import os
import numpy as np
from sklearn.linear_model import LinearRegression
import sys
import logging

# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"]="3"

log_fname = sys.argv[1] if len(sys.argv) > 1 else None
logging.basicConfig(filename=log_fname, filemode='w', level=logging.INFO)

def fit_regression(x,y):
    ''' Fits linear regression and returns slope'''
    x = np.array(x).reshape(-1,1)
    y = np.array(y).reshape(-1,1)
    model = LinearRegression().fit(x,y)
    return model.coef_[0][0]

def benchmark_cpu_to_gpu(method=1,pin_memory=False):
# Transfer Tensors from ~373MB to ~7.45GB
    tgpu = torch.rand(200,200, device=torch.device(torch.device("cuda")))
    x = []
    y = []
    for i in range(1,21):
        del tgpu
        torch.cuda.empty_cache()
        # source Tensor on CPU
        tcpu = torch.empty(10000,10000,i, device=torch.device(torch.device("cpu")),pin_memory=pin_memory)
        y.append(tcpu.nelement() * tcpu.element_size()/1024/1024/1024) # Tensor size in GB
        # destination tensor on GPU for Method 2
        if method == 2:
            tgpu = torch.rand(10000,10000,i, device=torch.device(torch.device("cuda")))
        torch.cuda.synchronize()

        start_time = time.time()
        if method == 1:
            tgpu = tcpu.to(torch.device("cuda"),non_blocking=False)
        elif method == 2:
            tgpu.copy_(tcpu,non_blocking=False)
        torch.cuda.synchronize()
        elapsed_time = time.time() - start_time
        x.append(elapsed_time)
    logging.info(f"CPU->GPU Method {method}, pin_memory={pin_memory}, Transfer rate: {fit_regression(x,y):.4f} GB/s.")

def benchmark_gpu_to_cpu(method=1,pin_memory=False):
# Transfer Tensors from ~373MB to ~7.45GB
    tgpu = torch.rand(200,200, device=torch.device(torch.device("cuda")))
    x = []
    y = []
    for i in range(1,21):
        del tgpu
        torch.cuda.empty_cache()
        # source Tensor on GPU
        tgpu = torch.rand(10000,10000,i, device=torch.device(torch.device("cuda")))
        y.append(tgpu.nelement() * tgpu.element_size()/1024/1024/1024) # Tensor size in GB
        # destination tensor on CPU for Method 2
        if method == 2:
            tcpu = torch.empty(10000,10000,i, device=torch.device(torch.device("cpu")),pin_memory=pin_memory)
        torch.cuda.synchronize()

        start_time = time.time()
        if method == 1:
            tcpu = tgpu.to(torch.device("cpu"),non_blocking=False)
        elif method == 2:
            tcpu.copy_(tgpu,non_blocking=False)
        torch.cuda.synchronize()
        elapsed_time = time.time() - start_time
        x.append(elapsed_time)
    logging.info(f"GPU->CPU Method {method}, pin_memory={pin_memory}, Transfer rate: {fit_regression(x,y):.4f} GB/s.")

benchmark_cpu_to_gpu(1,False)
benchmark_cpu_to_gpu(2,False)
benchmark_cpu_to_gpu(1,True)
benchmark_cpu_to_gpu(2,True)

benchmark_gpu_to_cpu(1,False)
benchmark_gpu_to_cpu(2,False)
benchmark_gpu_to_cpu(1,True)
benchmark_gpu_to_cpu(2,True)

Output on Intel(R) Core(TM) i9-9900X with 4x 2080ti GPUs

CPU->GPU Method 1, pin_memory=False, Transfer rate: 3.6929 GB/s.
CPU->GPU Method 2, pin_memory=False, Transfer rate: 4.1244 GB/s.
CPU->GPU Method 1, pin_memory=True, Transfer rate: 11.2812 GB/s.
CPU->GPU Method 2, pin_memory=True, Transfer rate: 12.0484 GB/s.
GPU->CPU Method 1, pin_memory=False, Transfer rate: 1.8152 GB/s.
GPU->CPU Method 2, pin_memory=False, Transfer rate: 1.9026 GB/s.
GPU->CPU Method 1, pin_memory=True, Transfer rate: 1.8079 GB/s.
GPU->CPU Method 2, pin_memory=True, Transfer rate: 12.2043 GB/s.

output on Intel E5-2695v4 with 2x P100 GPUs:

CPU->GPU Method 1, pin_memory=False, Transfer rate: 6.8572 GB/s.
CPU->GPU Method 2, pin_memory=False, Transfer rate: 8.3023 GB/s.
CPU->GPU Method 1, pin_memory=True, Transfer rate: 10.7468 GB/s.
CPU->GPU Method 2, pin_memory=True, Transfer rate: 8.3660 GB/s.
GPU->CPU Method 1, pin_memory=False, Transfer rate: 2.0553 GB/s.
GPU->CPU Method 2, pin_memory=False, Transfer rate: 1.9625 GB/s.
GPU->CPU Method 1, pin_memory=True, Transfer rate: 2.0373 GB/s.
GPU->CPU Method 2, pin_memory=True, Transfer rate: 12.2455 GB/s.

Expected behavior

Transfer rates using .to() method (Method 1) in both directions should be similar.

Environment

PyTorch version: 1.7.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.3 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.1.74
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
GPU 2: GeForce RTX 2080 Ti
GPU 3: GeForce RTX 2080 Ti

Nvidia driver version: 455.23.05
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.4
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.7.1
[pip3] torchvision==0.8.2
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.2.89            684.g752c550    https://public.dhe.ibm.com/ibmdl/export/pub/software/server/ibm-ai/conda
[conda] mkl                       2020.2                      256
[conda] mkl-service               2.3.0            py37he8ac12f_0
[conda] mkl_fft                   1.2.0            py37h23d657b_0
[conda] mkl_random                1.1.1            py37h0573a6f_0
[conda] numpy                     1.19.2           py37h54aff64_0
[conda] numpy-base                1.19.2           py37hfa32c7d_0
[conda] pytorch                   1.7.1           py3.7_cuda10.2.89_cudnn7.6.5_0    pytorch
[conda] torchvision               0.8.2                py37_cu102    pytorch

cc @ngimel @VitalyFedyunin

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions