-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory bug for backward on torch.sparse.mm? #41128
Comments
I think it might be caused by the dense gradient? I found a related comment here #12430 (comment). But I think the gradient wrt a sparse tensor should be a sparse tensor as well right? |
So yes, the problem is that the gradient will be dense. |
Adding to @t-vi suggestion of scatter ops, the torch-sparse package (which is related to torch-scatter) doesn't seem to have this issue. The OP's problem fits Tesla T4 (16GB VRAM) just fine:
Outputs:
|
Thanks for your sharing. I used torch-sparse as well, which perfectly solves the problem. Sorry for not updating here. I'd like to know whether there is any plan to support sparse gradients naturally in Pytorch? I think it should be an important feature as many applications/research relies on sparse tensors. |
cc @pearu |
I'm a little curious about this, in the code, why we need about 45GiB of gpu memory in backward? Can someone provide a simple calculation? |
Out of curiosity I checked if the behaviour would be different by requiring the gradient on the values of the sparse COO matrix (rather than on the sparse matrix itself so as to clarify that we are not after a dense result as per #41128 (comment)) but unfortunately, the same out-of-memory issue arises. I can also confim that pytorch_sparse doesn't face the memory bottleneck. Also, reading though the documentation of
As such, I would expect that if Here (and on colab) is an expanded version of the first test case in this issue: import torch
print(torch.__version__)
torchdevice = torch.device('cpu')
if torch.cuda.is_available():
torchdevice = torch.device('cuda')
print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda')))
print('Running on ' + str(torchdevice))
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
import torch_sparse
# Covenience wrappers around torch_sparse matmult
def ts_spmm(A,B):
Ats = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(A)
AB = torch_sparse.matmul(Ats,B)
return AB
print(f'GPU memory usage: {torch.cuda.memory_allocated(device=torchdevice)/10**9}')
# Dimension of the square sparse matrix
n = 100000
# Number of non-zero elements
nnz = 20000
# Second dimension of the dense matrix
m = 200
rowidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
colidx = torch.randint(low=0, high=n, size=(nnz,), device=torchdevice)
itemidx = torch.vstack((rowidx,colidx))
xvalues = torch.randn(nnz, device=torchdevice)
Y_dense = torch.randn((n,m), device=torchdevice)
# Require gradients on the values of the sparse matrix (not the matrix itself)
xvalues_0 = xvalues.detach().clone().requires_grad_(True)
X_sparse_0 = torch.sparse_coo_tensor(itemidx, xvalues_0, size=(n,n))
Y_dense_0 = Y_dense.detach().clone().requires_grad_(True)
xvalues_1 = xvalues.detach().clone().requires_grad_(True)
X_sparse_1 = torch.sparse_coo_tensor(itemidx, xvalues_1, size=(n,n))
Y_dense_1 = Y_dense.detach().clone().requires_grad_(True)
print(f'GPU memory usage: {torch.cuda.memory_allocated(device=None)/10**9}')
# torch_sparse path
t0 = ts_spmm(X_sparse_0, Y_dense_0).sum()
t0.backward()
print(f'torch_sparse SPMM, x.grad: {xvalues_0.grad}, y.grad: {Y_dense_0.grad}')
print(f'GPU memory usage: {torch.cuda.memory_allocated(device=None)/10**9}')
# vanilla pytorch path
t1 = torch.sparse.mm(X_sparse_1, Y_dense_1).sum()
t1.backward()
print(f'Vanilla torch SPMM, x.grad: {xvalues_1.grad}, y.grad: {Y_dense_1.grad}')
print(f'GPU memory usage: {torch.cuda.memory_allocated(device=None)/10**9}') |
Following the discussion in #86963 and #87358 (in particular #87358 (comment)), I understand that the semantics of This expected sparse semantics may be available at some point with MaskedTensors as suggested by @lezcano in #87358 (comment) but for now MaskedTensor is labelled as a prototype library only. In the interim, the need remains for having a sparse matrix multiplication with sparsity-preserving gradient that can provide a memory-efficient gradient with respect to the vector of non-zero values in the sparse matrix. While rusty1s/pytorch_sparse offers a solution for COO matrices, it doesn't support CSR matrices and its interaction with PyTorch can be fiddly. As of now, the least problematic solution I found is to rely on writting a cutom sparse @ dense multiplication operation where I manually specify the backward pass. This would certainly benefit from being handled better in PyTorch proper though. The rough custom op copied here for convenience seems to require much less memory than a direct use of class MySparseMatMul(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
# Is the detach needed / helpful?
ad = a.detach()
bd = b.detach()
x = torch.sparse.mm(ad, bd)
if a.requires_grad or b.requires_grad:
# Not sure if the following is needed / helpful
x.requires_grad = True
# Save context for backward pass
ctx.save_for_backward(ad, bd)
return x
@staticmethod
def backward(ctx, prev_grad):
# Recover context
a, b = ctx.saved_tensors
# The gradient with respect to the matrix a seen as a dense matrix would
# lead to a backprop rule as follows
# grad_a = prev_grad @ b.T
# but we are only interested in the gradient with respect to
# the (non-zero) values of a. To save memory, instead of computing the full
# dense matrix prev_grad @ b and then subsampling at the nnz locations in a,
# we can directly only compute the required values:
# grad_a[i,j] = dotprod(prev_grad[i,:], b[j,:])
# We start by getting the i and j indices
if(a.layout == torch.sparse_coo):
grad_a_idx = a.indices()
grad_a_row_idx = grad_a_idx[0,:]
grad_a_col_idx = grad_a_idx[1,:]
elif(a.layout == torch.sparse_csr):
grad_a_col_idx = a.col_indices()
grad_a_crow_idx = a.crow_indices()
# uncompress row indices
grad_a_row_idx = torch.repeat_interleave(
torch.arange(a.size()[0], device=a.device),
grad_a_crow_idx[1:]-grad_a_crow_idx[:-1] )
else:
raise ValueError(f"Unsupported layout: {a.layout}")
# Get prev_grad[a_row_idx,:]
prev_grad_select = prev_grad.index_select(0,grad_a_row_idx)
# Get b[a_col_idx,:]
b_select = b.index_select(0,grad_a_col_idx)
# Element-wise multiplication
prev_grad_b_ewise = prev_grad_select * b_select
if b.dim() == 1:
# if b is a vector, the dot prod and elementwise multiplication are the same
grad_a_vals = prev_grad_b_ewise
else:
# if b is a matrix, the dot prod requires summation
grad_a_vals = torch.sum( prev_grad_b_ewise, dim=1 )
# Create a sparse matrix of the gradient with respect to the nnz of a
if(a.layout == torch.sparse_coo):
grad_a = torch.sparse_coo_tensor(grad_a_idx, grad_a_vals, a.shape)
elif(a.layout == torch.sparse_csr):
grad_a = torch.sparse_csr_tensor(grad_a_crow_idx, grad_a_col_idx,
grad_a_vals, a.shape)
# Now compute the (dense) gradient with respect to b
grad_b = torch.t(a) @ prev_grad
return grad_a, grad_b
my_sparse_mm = MySparseMatMul.apply |
Just confirming that this memory issue in the backward pass of The workaround suggested above can however now be found here for convenience (with additional support for batched operations): |
馃悰 Bug
To Reproduce
Steps to reproduce the behavior:
Expected behavior
I got an error about OOM only in the backward step.
RuntimeError: CUDA out of memory. Tried to allocate 45.08 GiB (GPU 0; 10.73 GiB total capacity; 202.00 MiB already allocated; 9.68 GiB free; 204.00 MiB reserved in total by PyTorch)
.I think the memory usage of backward should be similar to that of forward. Before backward, there is still ~9GB free. I expect it would not try to allocate 45GB. I don't know whether there is a bug in backward support for torch.sparse.mm? Feel free to correct me if I am wrong.
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
conda
,pip
, source): pipAdditional context
I am not sure if it is a bug. But if so and anyone can point out the problem, I am willing to help and dig further to see how to fix it.
cc @ezyang @ssnl @albanD @zou3519 @gqchen @vincentqb @aocsa
The text was updated successfully, but these errors were encountered: