Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched matrix multiplication copying the input data(CUDA) #52111

Closed
o-alexandre-felipe opened this issue Feb 10, 2021 · 1 comment
Closed

Batched matrix multiplication copying the input data(CUDA) #52111

o-alexandre-felipe opened this issue Feb 10, 2021 · 1 comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: memory usage PyTorch is using more memory than it should, or it is leaking memory triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@o-alexandre-felipe
Copy link

o-alexandre-felipe commented Feb 10, 2021

馃悰 Bug

Batched matrix multiplication is making a copy of the input data under certain conditions.
Bug discovered when trying to compute inner products in strided matrices.

To Reproduce

Steps to reproduce the behavior:

The following example works on CPU without copying the input. Notice that one of the matrix is a view that does not fit in the memory.

B, N, L = 1000, 100, 100000
# this would require 40GiB (and I don't have)
lhs = torch.arange(0, B+N+L, device='cpu', dtype=torch.float32).as_strided([B, N, L], [1,1,1]);
rhs = torch.arange(0, B+L+1, device='cpu', dtype=torch.float32).as_strided([B, L, 1], [1,1,1]);
x = torch.bmm(lhs, rhs)

This is is essentially the same code but runnign on the GPU, and it try to allocate a huge amount of memory that agree with the size of lhs.

B, N, L = 1000, 100, 100000
lhs = torch.arange(0, B+N+L, device='cuda', dtype=torch.float32).as_strided([B, N, L], [1,1,1]);
rhs = torch.arange(0, B+L+1, device='cuda', dtype=torch.float32).as_strided([B, L, 1], [1,1,1]);
x = torch.bmm(lhs, rhs) # here it tries to alloocate 37.25GiB

Expected behavior

I see that cuBlAS 8.0 supports https://developer.nvidia.com/blog/cublas-strided-batched-matrix-multiply/ that can be used to do the trick. To confirm this claim I imlplemented the reference algorithm

def gemmStridedBatched(handle, 
                      transA, transB,
                      M, N, K, 
                      alpha,
                      A, ldA, strideA, 
                      B, ldB, strideB, 
                      beta,
                      C, ldC, strideC,
                      batchCount):
    for p in range(batchCount):
        for m in range(M):
            for n in range(N):
                c_mnp = sum(A[m + k*ldA + p*strideA] * B[k + n*ldB + p*strideB] 
                           for k in range(K))
                C[m + n*ldC + p*strideC] = alpha*c_mnp + beta*C[m + n*ldC + p*strideC];

... and a function that use the provided computation to compute bmm output.

def bmm(A, B):
    strideA, ldA, a1 = A.stride()
    strideB, ldB, b1 = B.stride()
    assert(a1 == 1)
    assert(b1 == 1)
    
    bA, mA, nA = A.shape
    bB, mB, nB = B.shape
    assert(bA == bB or bA == 1 or bB == 1) # bash broadcastable
    assert(nA == mB) # matrix multiplication constraint
    
    # fixes the index for 1-element batches
    if bA == 1:
        strideA = 0
    if bB == 1:
        strideB = 0;
    
    batchCount = max(bA, bB)
    
    
    C = torch.empty((batchCount, mA, nB), dtype=torch.float32);
    strideC, ldC, c1 = C.stride()
    assert(c1 == 1)
    
    # view the underlying data as an array
    A_1d = A.as_strided((strideA * bA + ldA * mA + nA,), (1,))
    B_1d = B.as_strided((strideB * bB + ldB * mB + nB,), (1,))
    
    gemmStridedBatched(handle=None, transA=False, transB=False, # not used here
                        M=mA, N=nB, K=mB, alpha=1.0,
                        # input matrices
                        A=A_1d, ldA=ldA, strideA=strideA,
                        B=B_1d, ldB=ldB, strideB=strideB,
                        beta=0.0,
                        # output
                        C=C.view(-1), ldC=ldC, strideC=strideC,  
                        batchCount=batchCount
                     )
    return C

Tested against the current implementation with

B, N, L = 10, 15, 100
lhs = torch.arange(0, B+N+L, device='cpu', dtype=torch.float32).as_strided([B, N, L], [1,1,1]);
rhs = torch.arange(0, B+L+1, device='cpu', dtype=torch.float32).as_strided([B, L, 1], [1,1,1]);
x = torch.bmm(lhs, rhs)
xnew = bmm(lhs, rhs)
assert(torch.allclose(x, xnew))

Environment

Collecting environment information...
PyTorch version: 1.7.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Pro
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 1660 Ti with Max-Q Design
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.7.1
[pip3] torchaudio==0.7.2
[pip3] torchvision==0.8.2
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.2.89              h74a9793_1  
[conda] mkl                       2020.2                      256  
[conda] mkl-service               2.3.0            py38h196d8e1_0  
[conda] mkl_fft                   1.2.0            py38h45dec08_0  
[conda] mkl_random                1.1.1            py38h47e9c7a_0  
[conda] numpy                     1.19.2           py38hadc3359_0  
[conda] numpy-base                1.19.2           py38ha3acd2a_0  
[conda] pytorch                   1.7.1           py3.8_cuda102_cudnn7_0    pytorch
[conda] torch                     1.7.1                    pypi_0    pypi
[conda] torchaudio                0.7.2                      py38    pytorch
[conda] torchvision               0.8.2                py38_cu102    pytorch

You can get the script and run it with:


## Additional context

I wanted to compute a batched inner product (a short slice of a convolution).

cc @ngimel @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @Lezcano
@albanD albanD added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: cuda Related to torch.cuda, and CUDA support in general module: memory usage PyTorch is using more memory than it should, or it is leaking memory triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 11, 2021
@ngimel
Copy link
Collaborator

ngimel commented Feb 13, 2021

While in some cases bmm does make memory copies that can be avoided, this is not one of them. Blas libraries that are used to compute matrix multiplications, cannot have lda/ldb parameters that are smaller than corresponding sizes.
Also, same copy is happening on CPU.

@ngimel ngimel removed the enhancement Not as big of a feature, but technically not a bug. Should be easy to fix label Feb 13, 2021
@ngimel ngimel added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label Oct 1, 2021
@ngimel ngimel closed this as completed Oct 1, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: memory usage PyTorch is using more memory than it should, or it is leaking memory triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants