Batched matrix multiplication copying the input data(CUDA) #52111
Labels
module: cuda
Related to torch.cuda, and CUDA support in general
module: linear algebra
Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul
module: memory usage
PyTorch is using more memory than it should, or it is leaking memory
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Bug
Batched matrix multiplication is making a copy of the input data under certain conditions.
Bug discovered when trying to compute inner products in strided matrices.
To Reproduce
Steps to reproduce the behavior:
The following example works on CPU without copying the input. Notice that one of the matrix is a view that does not fit in the memory.
This is is essentially the same code but runnign on the GPU, and it try to allocate a huge amount of memory that agree with the size of lhs.
Expected behavior
I see that cuBlAS 8.0 supports https://developer.nvidia.com/blog/cublas-strided-batched-matrix-multiply/ that can be used to do the trick. To confirm this claim I imlplemented the reference algorithm
... and a function that use the provided computation to compute
bmm
output.Tested against the current implementation with
Environment
You can get the script and run it with:
The text was updated successfully, but these errors were encountered: