Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Construction of MultivariateNormal much slower on GPU than CPU #23780

Open
danielcrane opened this issue Aug 5, 2019 · 4 comments
Open

Construction of MultivariateNormal much slower on GPU than CPU #23780

danielcrane opened this issue Aug 5, 2019 · 4 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: distributions Related to torch.distributions module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@danielcrane
Copy link

馃悰 Bug

Constructing a MultivariateNormal distribution is much slower when inputting GPU-based FloatTensors than CPU-based ones.

On my machine the GPU version is ~33x slower than CPU.

To Reproduce

Steps to reproduce the behavior:

import time
import torch
from torch.distributions.multivariate_normal import MultivariateNormal

mu = torch.FloatTensor([2, 4])
sigma = torch.FloatTensor([[5, 0], [0, 2]])

mu_gpu = mu.cuda()
sigma_gpu = sigma.cuda()

num_runs = 1000
t_cpu, t_gpu = 0, 0
for _ in range(num_runs):
    st = time.perf_counter()
    m1 = MultivariateNormal(mu, sigma)
    t_cpu += time.perf_counter() - st

    torch.cuda.synchronize()
    st = time.perf_counter()
    m2 = MultivariateNormal(mu_gpu, sigma_gpu)
    torch.cuda.synchronize()
    t_gpu += time.perf_counter() - st

print(f'[CPU] Time Taken: {t_cpu}s')
print(f'[GPU] Time Taken: {t_gpu}s')

Output on my machine:

[CPU] Time Taken: 0.08132426194060827s
[GPU] Time Taken: 2.7058167830073216s

Expected behavior

I'd expect the GPU to be faster, or at least of a comparable speed to CPU.

Environment

PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Pop!_OS 18.10
GCC version: (Ubuntu 8.3.0-6ubuntu1~18.10) 8.3.0
CMake version: version 3.12.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: GeForce GTX 1070 With Max-Q Design
Nvidia driver version: 410.78
cuDNN version: /usr/lib/cuda-10.0/lib64/libcudnn.so.7.4.1

Versions of relevant libraries:
[pip3] numpy==1.17.0
[pip3] torch==1.1.0
[pip3] torchvision==0.3.0
[conda] blas                      1.0                         mkl  
[conda] mkl                       2019.4                      243  
[conda] mkl_fft                   1.0.12           py36ha843d7b_0  
[conda] mkl_random                1.0.2            py36hd81dba3_0
@vishwakftw
Copy link
Contributor

vishwakftw commented Aug 5, 2019

Hi @danielcrane, thank you for raising the issue.

This is related to #20700.

To paraphrase what I said there, MAGMA (the GPU backend for most linear algebra operations in PyTorch) is extremely efficient and fast for large problem sizes, but unfortunately pretty slow for smaller sizes (for example: the 2 x 2 covariance matrix that you have provided). This is a recurring issue, and I am thinking of certain ways to circumvent it.

@vishwakftw vishwakftw added module: operators module: performance Issues related to performance, either of kernel code or framework glue labels Aug 5, 2019
@pbelevich pbelevich added module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: distributions Related to torch.distributions labels Aug 5, 2019
@vishwakftw vishwakftw removed the module: cpu CPU specific problem (e.g., perf, algorithm) label Aug 5, 2019
@danielcrane
Copy link
Author

danielcrane commented Aug 6, 2019

Hi @vishwakftw, thanks for your prompt response.

I'm glad to hear that this issue is already known to you.

I guess for now the best temporary workaround would be to calculate the cholesky factorisation on the CPU, and provide it as input scale_tril after moving it to the GPU.

@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 7, 2019
@Balandat
Copy link
Contributor

@vishwakftw, MAGMA 2.5 brings support for magma_Xpotrf_native, a GPU-only implementation of cholesky. Not sure if the GPU/GPU overhead is in play here, but is it worth looking into making use of this?

@GoingMyWay
Copy link

GoingMyWay commented Jun 9, 2021

Same issue here on RTX 3090 and PyTorch 1.8.1.

To improve the speed, my workaround is first to conduct Cholesky and then use the decomposed results as scale_tril, which can attain 10x speedup.

scale_tril = torch.cholesky(logits[1])
dist = MultivariateNormal(loc=logits[0], scale_tril=scale_tril)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: distributions Related to torch.distributions module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants