Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched symeig and qr are very slow on GPU #22573

Open
calincru opened this issue Jul 7, 2019 · 16 comments
Open

Batched symeig and qr are very slow on GPU #22573

calincru opened this issue Jul 7, 2019 · 16 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@calincru
Copy link

calincru commented Jul 7, 2019

🐛 Bug

The recently added batched eigenvalue decomposition via torch.symeig is very slow on GPU (pr: #21858, issue: #7500).

To Reproduce

import torch
a = torch.rand(500, 2, 2)
a = 0.5 * (a + a.transpose(1, 2))
w, _ = torch.symeig(a)  # fast (~0.0006s)
a = a.cuda()
w, _ = torch.symeig(a)  # slow (~0.9s)

Expected behavior

The GPU variant should be at least as fast as the CPU one. This is an elementary matrix operation and GPUs should be fast at that.

Environment

PyTorch version: 1.2.0.dev20190707
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX TITAN X
GPU 1: GeForce GTX TITAN X
GPU 2: GeForce GTX TITAN X
GPU 3: GeForce GTX TITAN X

Nvidia driver version: 418.67
cuDNN version: Probably one of the following:
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.5
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7.2.1

Versions of relevant libraries:
[pip3] numpy==1.15.1
[conda] mkl 2019.4 243
[conda] pytorch-nightly 1.2.0.dev20190707 py3.7_cuda9.0.176_cudnn7.5.1_0 pytorch

Additional context

I assume this is not surprising given the following comment (CC @vishwakftw)

// We create temporary tensors on the CPU, because tensors on the GPU
// cause segfault when passed to magmaSymeig. The data is later
// moved to the appropriate device.
// In the case where self.numel() == 0, we just return an empty tensor of
// dimensions on the CUDA (to avoid the unnecessary "to(at::kCUDA)")
but it should nonetheless be fixed. It's not clear to me if that implies there's a bug in MAGMA and whether something is being done about it.

cc @ngimel @vincentqb @vishwakftw @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @VitalyFedyunin

@calincru
Copy link
Author

calincru commented Jul 7, 2019

It seems that the same thing happens for QR decompositions,

Tensor tau = at::empty({k}, Q.options().device(at::kCPU));

so I assume it's related to the fact that batching for these functions is implemented in the for-loop style (#20689 (comment) is relevant). But that makes these functions really unusable on GPU.

@calincru calincru changed the title Batched symeig is very slow on GPU Batched symeig and qr are very slow on GPU Jul 7, 2019
@vishwakftw
Copy link
Contributor

Hi @calincru. Thank you for opening the issue.

As you have pointed out in your comment, the QR decomposition is slow for smaller matrices. A similar statement can be made for the symmetric eigendecomposition for MAGMA. However, MAGMA is comparatively faster on larger matrices as elucidated in this comment.

The comments above are slightly outdated, and I will fix them as soon as possible to prevent confusion. MAGMA uses hybrid CPU-GPU routines for many of their functions, and some arguments passed to these functions require to be placed on the CPU. Not placing the arguments on the CPU will cause segmentation fault, which is what is specified in the comments in the code.

@vishwakftw vishwakftw added module: operators module: performance Issues related to performance, either of kernel code or framework glue labels Jul 7, 2019
@vishwakftw
Copy link
Contributor

Regarding your comment on a for-loop style, these calls are parallelized internally in libraries like MKL, and parallelizing the batch for-loop could possibly lead to errors or cause performance issues.

@calincru
Copy link
Author

calincru commented Jul 7, 2019

I understand but the discrepancy between the two is still surprisingly large. I compared out of curiosity to TensorFlow and while they have the same issue (CPU is faster than GPU for many small matrices), the time taken by the latter is much smaller. For instance, on 10k 2x2 matrices it takes ~2.5s while PyTorch takes ~17.2s. This difference only increases with the number of matrices.

Just a wild thought: wouldn't it make sense to create block diagonal matrices from several such small matrices (using this property) and hence leverage the fact that GPUs are much better at dealing with large matrices? I would expect that a library such as MAGMA could hide such a logic behind something like magma_ssyevd_batched (which obviously doesn't currently exist). But this question should perhaps be addressed to MAGMA developers.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 8, 2019
@vishwakftw
Copy link
Contributor

It is my understanding that MAGMA uses a divide and conquer algorithm for computing the eigenvalues and eigenvectors, whereas cuSolver uses a Jacobi method for the same. To quote the documentation on the cuSolver website,

The parallelism of Jacobi method gives better GPU performance on small and medium size matrices.

This could be a reason as to why you are seeing the difference.

@vishwakftw
Copy link
Contributor

vishwakftw commented Jul 9, 2019

I did some more digging, and found out that the if the matrix sizes are too small, MAGMA internally calls LAPACK. To be more precise, the MAGMA kernel magma_ssyevd_gpu calls lapack77f_ssyevd if the matrix size is smaller than or equal to 128 x 128. I suspect that a lot of time is being wasted on D2H and H2D transfers before and after the LAPACK API calls.

@calincru
Copy link
Author

calincru commented Jul 14, 2019

It is my understanding that MAGMA uses a divide and conquer algorithm for computing the eigenvalues and eigenvectors, whereas cuSolver uses a Jacobi method for the same.

I just stumbled upon this older issue that mentions the exact same thing: #10172.

I do not fully get the implications of what you said probably because I'm not familiar with the backend dependencies of each framework. Could you perhaps briefly explain it? For instance, the author of the above issue asks for hooks into cuSolver while from you I get that PyTorch uses MAGMA.

@vishwakftw
Copy link
Contributor

Briefly speaking, cuSolver is rather slow on larger problem sizes than MAGMA, and hence adding cuSolver hooks won’t be as useful in general.

Further more, cuSolver doesn’t support matrix sizes larger than 32, which makes it hard for direct use.

And yes, MAGMA is used internally in PyTorch for linear algebra routines. We don’t use cuSolver for common linear algebra routines.

The reason I mentioned cuSolver is because that is what is used internally in TensorFlow.

@calincru
Copy link
Author

So what does TF do for matrix sizes greater than 32?

I guess that ideally one would have the best of both: use cuSolver for small matrices and MAGMA for larger ones.

@vishwakftw
Copy link
Contributor

vishwakftw commented Jul 15, 2019

One alternative is to offload the computation to CPU in the case of smaller matrices (to match what MAGMA does internally too). I am a bit busy at the moment, and can’t get to this immediately.

So what does TF do for matrix sizes greater than 32?

I don’t know, I haven’t looked at their code in great detail.

facebook-github-bot pushed a commit that referenced this issue Jul 16, 2019
#22618)

Summary:
…te argument in macro

Changelog:
- Update note about tensors on CPU for the following MAGMA functions
  - magma_(d/s)getrf_gpu and magma_getrf_nopiv_gpu require tensors on CPU for pivots
  - magma_(d/s)geqrf2_gpu requires tensors on CPU for elementary reflectors
  - magma_(d/s)syevd_gpu requires tensors on CPU for eigenvalues
- Remove dummy tensor in ALLOCATE_ARRAY MACRO
Pull Request resolved: #22618

Test Plan:
- All existing tests should pass to verify that the patch is correct

This PR has been proposed to eliminate confusion due to the previous comments, as indicated in #22573

Differential Revision: D16227440

Pulled By: zou3519

fbshipit-source-id: 97d5537c5da98c0ed3edc4668a09294794fc426b
zdevito pushed a commit to zdevito/ATen that referenced this issue Jul 16, 2019
…… (#22618)

Summary:
…te argument in macro

Changelog:
- Update note about tensors on CPU for the following MAGMA functions
  - magma_(d/s)getrf_gpu and magma_getrf_nopiv_gpu require tensors on CPU for pivots
  - magma_(d/s)geqrf2_gpu requires tensors on CPU for elementary reflectors
  - magma_(d/s)syevd_gpu requires tensors on CPU for eigenvalues
- Remove dummy tensor in ALLOCATE_ARRAY MACRO
Pull Request resolved: pytorch/pytorch#22618

Test Plan:
- All existing tests should pass to verify that the patch is correct

This PR has been proposed to eliminate confusion due to the previous comments, as indicated in pytorch/pytorch#22573

Differential Revision: D16227440

Pulled By: zou3519

fbshipit-source-id: 97d5537c5da98c0ed3edc4668a09294794fc426b
facebook-github-bot pushed a commit that referenced this issue Jul 17, 2019
#22618)

Summary:
…te argument in macro

Changelog:
- Update note about tensors on CPU for the following MAGMA functions
  - magma_(d/s)getrf_gpu and magma_getrf_nopiv_gpu require tensors on CPU for pivots
  - magma_(d/s)geqrf2_gpu requires tensors on CPU for elementary reflectors
  - magma_(d/s)syevd_gpu requires tensors on CPU for eigenvalues
- Remove dummy tensor in ALLOCATE_ARRAY MACRO
Pull Request resolved: #22618

Test Plan:
- All existing tests should pass to verify that the patch is correct

This PR has been proposed to eliminate confusion due to the previous comments, as indicated in #22573

Differential Revision: D16286198

Pulled By: zou3519

fbshipit-source-id: a5a6ec829084bdb752ca6006b8795227cbaf63b1
zdevito pushed a commit to zdevito/ATen that referenced this issue Jul 17, 2019
…… (#22618)

Summary:
…te argument in macro

Changelog:
- Update note about tensors on CPU for the following MAGMA functions
  - magma_(d/s)getrf_gpu and magma_getrf_nopiv_gpu require tensors on CPU for pivots
  - magma_(d/s)geqrf2_gpu requires tensors on CPU for elementary reflectors
  - magma_(d/s)syevd_gpu requires tensors on CPU for eigenvalues
- Remove dummy tensor in ALLOCATE_ARRAY MACRO
Pull Request resolved: pytorch/pytorch#22618

Test Plan:
- All existing tests should pass to verify that the patch is correct

This PR has been proposed to eliminate confusion due to the previous comments, as indicated in pytorch/pytorch#22573

Differential Revision: D16286198

Pulled By: zou3519

fbshipit-source-id: a5a6ec829084bdb752ca6006b8795227cbaf63b1
@alanhdu
Copy link
Contributor

alanhdu commented Mar 11, 2020

FWIW, we've also run into this problem: in our use case, we want to factor many small matrices (e.g. factor 5000 16x16 matrices), which neither the GPU or the CPU implementations of PyTorch seem well-suited for: the CPU version does not seem to parallelize across the batch dimension well (no matter the batch size, I only saturate 1.5 cores on my laptop) while the GPU version is painfully slow (easily an order of magnitude slower than the CPU version).

We currently plan to work around this by usingg Numba's guvectorize functionality which makes parallelizing over the batch dimension easy.

@lucmos
Copy link

lucmos commented Mar 12, 2020

Hi @alanhdu,

I'm hitting the same problem. Are you defining a custom backward pass for the QR decomposition?

@alanhdu
Copy link
Contributor

alanhdu commented Mar 12, 2020

That's our tentative plan, but we haven't gotten around to implementing it yet.

@eusoubrasileiro
Copy link

eusoubrasileiro commented Apr 4, 2020

Same problem here with 10000 matrices of 300x20 on torch.qr

For my case it seams the minimal size to make it faster than cpu is something x 256.
Still GPU usage is in average on 7% while CPU at 80%. GTX1060

eusoubrasileiro added a commit to eusoubrasileiro/stocks that referenced this issue Apr 4, 2020
- altough it works and is more stable than cholesky! There is a bug on 
pytorch for small matrices not well analysed. That makes the calculation 
totally on CPU. pytorch/pytorch#22573
Balandat added a commit to cornellius-gp/gpytorch that referenced this issue Jul 23, 2020
@mruberry mruberry added module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul and removed module: operators (deprecated) labels Oct 10, 2020
@cwindolf
Copy link

cwindolf commented Apr 12, 2023

Hi, with apologies for reviving an old thread, I was hoping to check whether this issue is still relevant. Since the table in #47953 indicates that QR is done with cuSolver in all cases, is the above material about MAGMA still relevant? Is batched QR in the case of small matrices still a slow case? Based on a small test, this seems to be the case, but I wanted to see if this particular issue is still relevant. Thanks!

@ehsansaleh
Copy link

ehsansaleh commented May 23, 2024

This is a quick-and-dirty replacement using a basic Gram-Schmidt process.

Hope this helps people who are still looking for an alternative.

def torch_qr(a, mode='complete', out=None, gram='classical'):
    """
    Due to a bug in MAGMA, qr on cuda is super slow for small matrices. 
    Therefore, this step must be performed on the cpu.
    
    This function aims to provide a temporary relief for using 
    `torch.linalg.qr` on GPU by implementing a Gram-Schmidt process. 
    
    Note: This implementation does not support backward propagation, and 
          only supports the 'complete' mode.
    
    See the following regarding this Bug:
        https://github.com/pytorch/pytorch/issues/22573
        https://github.com/cornellius-gp/gpytorch/pull/1224
        
    The input arguments, other than 'gram', follow the PyTorch standard. 
    See the following for their definition:
        https://pytorch.org/docs/stable/generated/torch.linalg.qr.html
        
    Parameters
    ----------
    a: (torch.tensor) the input tensor. Must have a shape of 
        `(*mb_dims, dim, dim)`, where `mb_dims` shows the batch 
        dimensions.

    mode: (str) Either `'complete'` or `'reduced'`. This current 
        implementation only supports the former.
        
    out: (None or torch.tensor) The output tensor for the `Q` matrix. 
        If provided, must have the same shape as `a`.
        
    gram: (str) The Gram-Schmidt process variant. 
    
        * The `classical` variant makes `O(dim)` calls to CUDA 
          and can be more efficient. 
          
        * The `modified` variant can be slightly more accurate, 
          but makes CUDA `O(dim^2)` calls and thus is less efficient.
          
          See Section 14.2 of "Numerical Linear Algebra with Applications" 
          by William Ford on the numerical stability of Gram-Schmidt and 
          its modified variant:
          
          https://www.sciencedirect.com/science/article/abs/pii/B9780123944351000144
          
        * The `cpu` variant uses Pytorch's routine on CPU.
          
        This has to be one of `('classical', 'modified', 'cpu')`.
        
    Output
    ------
    q: (torch.tensor) The output orthonormal matrix. 
        This should have a shape of `(*mb_dims, dim, dim)`.
    
    r: (torch.tensor) The output upper triangle matrix. 
        This should have a shape of `(*mb_dims, dim, dim)`.
    """

    assert not a.requires_grad

    # First Solution: Performing the QR decomposition on CPU
    # Issues: 
    #    1. Pytorch may still only utilize one thread 
    #       practically even though `torch.get_num_threads()` 
    #       may be large.
    #    2. Reliance on CPU resources.
    if gram == 'cpu':
        q, r = torch.linalg.qr(a.detach().cpu(), mode=mode, out=out)
        return q.to(device=a.device), r.to(device=a.device)
    
    ###############################################################
    ################## Initializing & Identifying #################
    ###############################################################
    assert mode == 'complete', 'reduced is not implemented yet'
    # The bactch dimensions
    mb_dims = a.shape[:-2]
    # The input device
    tch_device = a.device
    
    # The Data Type for performing the mathematical caculations
    # Note: Gram-schmidt is numerically unstable. For this reason, even 
    # when the input may be float32, we will do everything in float64.
    tch_dtype = torch.float64
    
    # The QR process dimension
    dim = a.shape[-1]
    assert a.shape == (*mb_dims, dim, dim)

    if out is None:
        q = torch.empty(*mb_dims, dim, dim, device=tch_device, dtype=tch_dtype)
    else:
        q = out
    assert q.shape == (*mb_dims, dim, dim)
    
    # Casting the `a` input to `tch_dtype` and using it from now on
    a_f64 = a.to(dtype=tch_dtype)
    
    ###############################################################
    ################### Performing Gram-Schmidt ###################
    ###############################################################
    if gram == 'classical':
        # Performing the classical Gram-Schmidt Process.
        
        # Creating a copy of `a` to avoid messing up the original input
        acp = a_f64.detach().clone()
        assert acp.shape == (*mb_dims, dim, dim)
        
        for k in range(dim):
            qk_unnorm = acp[..., :, k:k+1]
            assert qk_unnorm.shape == (*mb_dims, dim, 1)

            qk = qk_unnorm / qk_unnorm.norm(dim=-2, keepdim=True)
            assert qk.shape == (*mb_dims, dim, 1)

            a_qkcomps = qk.reshape(*mb_dims, 1, dim).matmul(acp)
            assert a_qkcomps.shape == (*mb_dims, 1, dim)

            # Removing the `qk` components from `a`
            acp -= qk.matmul(a_qkcomps)
            assert acp.shape == (*mb_dims, dim, dim)

            q[..., :, k] = qk.reshape(*mb_dims, dim)
    elif gram == 'modified':
        # Performing the modified Gram-Schmidt Process.
        for i in range(dim):
            q[..., i] = a_f64[..., i]
            for j in range(i):
                err_ij = torch.einsum('...i,...i->...', q[..., j], q[..., i])
                assert err_ij.shape == (*mb_dims,)
                q[..., i] -=  err_ij.reshape(*mb_dims, 1) * q[..., j]
            q[..., i] /= q[..., i].norm(dim=-1, keepdim=True)
    else:
        raise ValueError(f'Unknown gram={gram}')

    r = q.transpose(-1, -2).matmul(a_f64)
    assert r.shape == (*mb_dims, dim, dim)

    ###############################################################
    ######################## Final Cleanup ########################
    ###############################################################
    # Making sure the lower triangle of `r` is absolutely zero!
    col = torch.arange(dim, device=tch_device, dtype=tch_dtype).reshape(1, dim)
    assert col.shape == (1, dim)

    row = col.reshape(dim, 1)
    assert row.shape == (dim, 1)
    
    mb_ones = [1] * len(mb_dims)
    r *= (row <= col).reshape(*mb_ones, dim, dim)
    
    # Casting the `q` and `r` outputs to the `a` input dtype for compatibility
    q_out, r_out = q.to(dtype=a.dtype), r.to(dtype=a.dtype)
    
    return q_out, r_out

Gram-Schmidt is numerically unstable (i.e., isn't really super-accurate), but it may get the job done. In a million random runs, I got a maximum absolute error of 1e-7 order. Here are the precision tests:

###############################################################
################### Unit-testing `torch_qr` ###################
###############################################################
n_bch = 1000000
dim = 10

torch.manual_seed(12345)
tch_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tch_dtype = torch.float32
a = torch.randn(n_bch, dim, dim, device=tch_device, dtype=tch_dtype)

q, r = torch_qr(a, mode='complete', gram='classical')

rtol = 1e-05
atol = 1e-06

# Test 1: Checking if `q` is orthonormal
eye = torch.eye(dim, device=tch_device, dtype=tch_dtype).reshape(1, dim, dim)
assert q.transpose(-1, -2).matmul(q).allclose(eye, rtol=rtol, atol=atol)

# Test 2: Checking if `a == q @ r` holds
assert a.allclose(q.matmul(r), rtol=rtol, atol=atol)

# Test 3: Checking if `r` is upper-triangle
col = torch.arange(dim, device=tch_device, dtype=tch_dtype
    ).reshape(1, 1, dim).expand(n_bch, 1, dim)
assert col.shape == (n_bch, 1, dim)
row = col.reshape(n_bch, dim, 1)
assert row.shape == (n_bch, dim, 1)
r_lowtriang = r[row > col]
assert r_lowtriang.allclose(torch.zeros_like(r_lowtriang), rtol=rtol, atol=atol)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants