-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched symeig and qr are very slow on GPU #22573
Comments
It seems that the same thing happens for QR decompositions,
so I assume it's related to the fact that batching for these functions is implemented in the for-loop style (#20689 (comment) is relevant). But that makes these functions really unusable on GPU. |
Hi @calincru. Thank you for opening the issue. As you have pointed out in your comment, the QR decomposition is slow for smaller matrices. A similar statement can be made for the symmetric eigendecomposition for MAGMA. However, MAGMA is comparatively faster on larger matrices as elucidated in this comment. The comments above are slightly outdated, and I will fix them as soon as possible to prevent confusion. MAGMA uses hybrid CPU-GPU routines for many of their functions, and some arguments passed to these functions require to be placed on the CPU. Not placing the arguments on the CPU will cause segmentation fault, which is what is specified in the comments in the code. |
Regarding your comment on a for-loop style, these calls are parallelized internally in libraries like MKL, and parallelizing the batch for-loop could possibly lead to errors or cause performance issues. |
I understand but the discrepancy between the two is still surprisingly large. I compared out of curiosity to TensorFlow and while they have the same issue (CPU is faster than GPU for many small matrices), the time taken by the latter is much smaller. For instance, on 10k 2x2 matrices it takes ~2.5s while PyTorch takes ~17.2s. This difference only increases with the number of matrices. Just a wild thought: wouldn't it make sense to create block diagonal matrices from several such small matrices (using this property) and hence leverage the fact that GPUs are much better at dealing with large matrices? I would expect that a library such as MAGMA could hide such a logic behind something like |
It is my understanding that MAGMA uses a divide and conquer algorithm for computing the eigenvalues and eigenvectors, whereas cuSolver uses a Jacobi method for the same. To quote the documentation on the cuSolver website,
This could be a reason as to why you are seeing the difference. |
I did some more digging, and found out that the if the matrix sizes are too small, MAGMA internally calls LAPACK. To be more precise, the MAGMA kernel magma_ssyevd_gpu calls lapack77f_ssyevd if the matrix size is smaller than or equal to 128 x 128. I suspect that a lot of time is being wasted on D2H and H2D transfers before and after the LAPACK API calls. |
I just stumbled upon this older issue that mentions the exact same thing: #10172. I do not fully get the implications of what you said probably because I'm not familiar with the backend dependencies of each framework. Could you perhaps briefly explain it? For instance, the author of the above issue asks for hooks into cuSolver while from you I get that PyTorch uses MAGMA. |
Briefly speaking, cuSolver is rather slow on larger problem sizes than MAGMA, and hence adding cuSolver hooks won’t be as useful in general. Further more, cuSolver doesn’t support matrix sizes larger than 32, which makes it hard for direct use. And yes, MAGMA is used internally in PyTorch for linear algebra routines. We don’t use cuSolver for common linear algebra routines. The reason I mentioned cuSolver is because that is what is used internally in TensorFlow. |
So what does TF do for matrix sizes greater than 32? I guess that ideally one would have the best of both: use cuSolver for small matrices and MAGMA for larger ones. |
One alternative is to offload the computation to CPU in the case of smaller matrices (to match what MAGMA does internally too). I am a bit busy at the moment, and can’t get to this immediately.
I don’t know, I haven’t looked at their code in great detail. |
#22618) Summary: …te argument in macro Changelog: - Update note about tensors on CPU for the following MAGMA functions - magma_(d/s)getrf_gpu and magma_getrf_nopiv_gpu require tensors on CPU for pivots - magma_(d/s)geqrf2_gpu requires tensors on CPU for elementary reflectors - magma_(d/s)syevd_gpu requires tensors on CPU for eigenvalues - Remove dummy tensor in ALLOCATE_ARRAY MACRO Pull Request resolved: #22618 Test Plan: - All existing tests should pass to verify that the patch is correct This PR has been proposed to eliminate confusion due to the previous comments, as indicated in #22573 Differential Revision: D16227440 Pulled By: zou3519 fbshipit-source-id: 97d5537c5da98c0ed3edc4668a09294794fc426b
…… (#22618) Summary: …te argument in macro Changelog: - Update note about tensors on CPU for the following MAGMA functions - magma_(d/s)getrf_gpu and magma_getrf_nopiv_gpu require tensors on CPU for pivots - magma_(d/s)geqrf2_gpu requires tensors on CPU for elementary reflectors - magma_(d/s)syevd_gpu requires tensors on CPU for eigenvalues - Remove dummy tensor in ALLOCATE_ARRAY MACRO Pull Request resolved: pytorch/pytorch#22618 Test Plan: - All existing tests should pass to verify that the patch is correct This PR has been proposed to eliminate confusion due to the previous comments, as indicated in pytorch/pytorch#22573 Differential Revision: D16227440 Pulled By: zou3519 fbshipit-source-id: 97d5537c5da98c0ed3edc4668a09294794fc426b
#22618) Summary: …te argument in macro Changelog: - Update note about tensors on CPU for the following MAGMA functions - magma_(d/s)getrf_gpu and magma_getrf_nopiv_gpu require tensors on CPU for pivots - magma_(d/s)geqrf2_gpu requires tensors on CPU for elementary reflectors - magma_(d/s)syevd_gpu requires tensors on CPU for eigenvalues - Remove dummy tensor in ALLOCATE_ARRAY MACRO Pull Request resolved: #22618 Test Plan: - All existing tests should pass to verify that the patch is correct This PR has been proposed to eliminate confusion due to the previous comments, as indicated in #22573 Differential Revision: D16286198 Pulled By: zou3519 fbshipit-source-id: a5a6ec829084bdb752ca6006b8795227cbaf63b1
…… (#22618) Summary: …te argument in macro Changelog: - Update note about tensors on CPU for the following MAGMA functions - magma_(d/s)getrf_gpu and magma_getrf_nopiv_gpu require tensors on CPU for pivots - magma_(d/s)geqrf2_gpu requires tensors on CPU for elementary reflectors - magma_(d/s)syevd_gpu requires tensors on CPU for eigenvalues - Remove dummy tensor in ALLOCATE_ARRAY MACRO Pull Request resolved: pytorch/pytorch#22618 Test Plan: - All existing tests should pass to verify that the patch is correct This PR has been proposed to eliminate confusion due to the previous comments, as indicated in pytorch/pytorch#22573 Differential Revision: D16286198 Pulled By: zou3519 fbshipit-source-id: a5a6ec829084bdb752ca6006b8795227cbaf63b1
FWIW, we've also run into this problem: in our use case, we want to factor many small matrices (e.g. factor 5000 16x16 matrices), which neither the GPU or the CPU implementations of PyTorch seem well-suited for: the CPU version does not seem to parallelize across the batch dimension well (no matter the batch size, I only saturate 1.5 cores on my laptop) while the GPU version is painfully slow (easily an order of magnitude slower than the CPU version). We currently plan to work around this by usingg Numba's guvectorize functionality which makes parallelizing over the batch dimension easy. |
Hi @alanhdu, I'm hitting the same problem. Are you defining a custom |
That's our tentative plan, but we haven't gotten around to implementing it yet. |
Same problem here with 10000 matrices of 300x20 on For my case it seams the minimal size to make it faster than cpu is something x 256. |
- altough it works and is more stable than cholesky! There is a bug on pytorch for small matrices not well analysed. That makes the calculation totally on CPU. pytorch/pytorch#22573
Fixes (temporarily) #1157, until pytorch/pytorch#22573 is addressed
Hi, with apologies for reviving an old thread, I was hoping to check whether this issue is still relevant. Since the table in #47953 indicates that QR is done with cuSolver in all cases, is the above material about MAGMA still relevant? Is batched QR in the case of small matrices still a slow case? Based on a small test, this seems to be the case, but I wanted to see if this particular issue is still relevant. Thanks! |
This is a quick-and-dirty replacement using a basic Gram-Schmidt process. Hope this helps people who are still looking for an alternative. def torch_qr(a, mode='complete', out=None, gram='classical'):
"""
Due to a bug in MAGMA, qr on cuda is super slow for small matrices.
Therefore, this step must be performed on the cpu.
This function aims to provide a temporary relief for using
`torch.linalg.qr` on GPU by implementing a Gram-Schmidt process.
Note: This implementation does not support backward propagation, and
only supports the 'complete' mode.
See the following regarding this Bug:
https://github.com/pytorch/pytorch/issues/22573
https://github.com/cornellius-gp/gpytorch/pull/1224
The input arguments, other than 'gram', follow the PyTorch standard.
See the following for their definition:
https://pytorch.org/docs/stable/generated/torch.linalg.qr.html
Parameters
----------
a: (torch.tensor) the input tensor. Must have a shape of
`(*mb_dims, dim, dim)`, where `mb_dims` shows the batch
dimensions.
mode: (str) Either `'complete'` or `'reduced'`. This current
implementation only supports the former.
out: (None or torch.tensor) The output tensor for the `Q` matrix.
If provided, must have the same shape as `a`.
gram: (str) The Gram-Schmidt process variant.
* The `classical` variant makes `O(dim)` calls to CUDA
and can be more efficient.
* The `modified` variant can be slightly more accurate,
but makes CUDA `O(dim^2)` calls and thus is less efficient.
See Section 14.2 of "Numerical Linear Algebra with Applications"
by William Ford on the numerical stability of Gram-Schmidt and
its modified variant:
https://www.sciencedirect.com/science/article/abs/pii/B9780123944351000144
* The `cpu` variant uses Pytorch's routine on CPU.
This has to be one of `('classical', 'modified', 'cpu')`.
Output
------
q: (torch.tensor) The output orthonormal matrix.
This should have a shape of `(*mb_dims, dim, dim)`.
r: (torch.tensor) The output upper triangle matrix.
This should have a shape of `(*mb_dims, dim, dim)`.
"""
assert not a.requires_grad
# First Solution: Performing the QR decomposition on CPU
# Issues:
# 1. Pytorch may still only utilize one thread
# practically even though `torch.get_num_threads()`
# may be large.
# 2. Reliance on CPU resources.
if gram == 'cpu':
q, r = torch.linalg.qr(a.detach().cpu(), mode=mode, out=out)
return q.to(device=a.device), r.to(device=a.device)
###############################################################
################## Initializing & Identifying #################
###############################################################
assert mode == 'complete', 'reduced is not implemented yet'
# The bactch dimensions
mb_dims = a.shape[:-2]
# The input device
tch_device = a.device
# The Data Type for performing the mathematical caculations
# Note: Gram-schmidt is numerically unstable. For this reason, even
# when the input may be float32, we will do everything in float64.
tch_dtype = torch.float64
# The QR process dimension
dim = a.shape[-1]
assert a.shape == (*mb_dims, dim, dim)
if out is None:
q = torch.empty(*mb_dims, dim, dim, device=tch_device, dtype=tch_dtype)
else:
q = out
assert q.shape == (*mb_dims, dim, dim)
# Casting the `a` input to `tch_dtype` and using it from now on
a_f64 = a.to(dtype=tch_dtype)
###############################################################
################### Performing Gram-Schmidt ###################
###############################################################
if gram == 'classical':
# Performing the classical Gram-Schmidt Process.
# Creating a copy of `a` to avoid messing up the original input
acp = a_f64.detach().clone()
assert acp.shape == (*mb_dims, dim, dim)
for k in range(dim):
qk_unnorm = acp[..., :, k:k+1]
assert qk_unnorm.shape == (*mb_dims, dim, 1)
qk = qk_unnorm / qk_unnorm.norm(dim=-2, keepdim=True)
assert qk.shape == (*mb_dims, dim, 1)
a_qkcomps = qk.reshape(*mb_dims, 1, dim).matmul(acp)
assert a_qkcomps.shape == (*mb_dims, 1, dim)
# Removing the `qk` components from `a`
acp -= qk.matmul(a_qkcomps)
assert acp.shape == (*mb_dims, dim, dim)
q[..., :, k] = qk.reshape(*mb_dims, dim)
elif gram == 'modified':
# Performing the modified Gram-Schmidt Process.
for i in range(dim):
q[..., i] = a_f64[..., i]
for j in range(i):
err_ij = torch.einsum('...i,...i->...', q[..., j], q[..., i])
assert err_ij.shape == (*mb_dims,)
q[..., i] -= err_ij.reshape(*mb_dims, 1) * q[..., j]
q[..., i] /= q[..., i].norm(dim=-1, keepdim=True)
else:
raise ValueError(f'Unknown gram={gram}')
r = q.transpose(-1, -2).matmul(a_f64)
assert r.shape == (*mb_dims, dim, dim)
###############################################################
######################## Final Cleanup ########################
###############################################################
# Making sure the lower triangle of `r` is absolutely zero!
col = torch.arange(dim, device=tch_device, dtype=tch_dtype).reshape(1, dim)
assert col.shape == (1, dim)
row = col.reshape(dim, 1)
assert row.shape == (dim, 1)
mb_ones = [1] * len(mb_dims)
r *= (row <= col).reshape(*mb_ones, dim, dim)
# Casting the `q` and `r` outputs to the `a` input dtype for compatibility
q_out, r_out = q.to(dtype=a.dtype), r.to(dtype=a.dtype)
return q_out, r_out Gram-Schmidt is numerically unstable (i.e., isn't really super-accurate), but it may get the job done. In a million random runs, I got a maximum absolute error of ###############################################################
################### Unit-testing `torch_qr` ###################
###############################################################
n_bch = 1000000
dim = 10
torch.manual_seed(12345)
tch_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tch_dtype = torch.float32
a = torch.randn(n_bch, dim, dim, device=tch_device, dtype=tch_dtype)
q, r = torch_qr(a, mode='complete', gram='classical')
rtol = 1e-05
atol = 1e-06
# Test 1: Checking if `q` is orthonormal
eye = torch.eye(dim, device=tch_device, dtype=tch_dtype).reshape(1, dim, dim)
assert q.transpose(-1, -2).matmul(q).allclose(eye, rtol=rtol, atol=atol)
# Test 2: Checking if `a == q @ r` holds
assert a.allclose(q.matmul(r), rtol=rtol, atol=atol)
# Test 3: Checking if `r` is upper-triangle
col = torch.arange(dim, device=tch_device, dtype=tch_dtype
).reshape(1, 1, dim).expand(n_bch, 1, dim)
assert col.shape == (n_bch, 1, dim)
row = col.reshape(n_bch, dim, 1)
assert row.shape == (n_bch, dim, 1)
r_lowtriang = r[row > col]
assert r_lowtriang.allclose(torch.zeros_like(r_lowtriang), rtol=rtol, atol=atol) |
🐛 Bug
The recently added batched eigenvalue decomposition via
torch.symeig
is very slow on GPU (pr: #21858, issue: #7500).To Reproduce
Expected behavior
The GPU variant should be at least as fast as the CPU one. This is an elementary matrix operation and GPUs should be fast at that.
Environment
PyTorch version: 1.2.0.dev20190707
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX TITAN X
GPU 1: GeForce GTX TITAN X
GPU 2: GeForce GTX TITAN X
GPU 3: GeForce GTX TITAN X
Nvidia driver version: 418.67
cuDNN version: Probably one of the following:
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.5
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7.2.1
Versions of relevant libraries:
[pip3] numpy==1.15.1
[conda] mkl 2019.4 243
[conda] pytorch-nightly 1.2.0.dev20190707 py3.7_cuda9.0.176_cudnn7.5.1_0 pytorch
Additional context
I assume this is not surprising given the following comment (CC @vishwakftw)
pytorch/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Lines 1208 to 1212 in bcb5fd8
cc @ngimel @vincentqb @vishwakftw @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @VitalyFedyunin
The text was updated successfully, but these errors were encountered: