New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of torch.sparse.sampled_baddmm #105319
Comments
@candyflower2005 , unfortunately, we do not have that operation, but if you have CUDA and Triton, you can convert your CSR inputs to BSR with blocksizes which are powers of 2 and greater than 16, you could use |
Would also benefit very much from the batch mode of this op as it enables custom attention patterns (and I assume for GNNs too, cc @fzhao3): https://discuss.pytorch.org/t/incomplete-sparse-gemm-for-some-sort-of-local-attention-with-only-given-indices/187428/2 |
@vadimkantorov will take care of it. So the inputs for this operator is like below ?
And output has the same shape and sparse pattern with input ? |
@mingfeima , yes, but be aware of the linked issues above, i.e. there is no |
@mingfeima for my usecase - yes! The alternative formulation is at https://discuss.pytorch.org/t/incomplete-sparse-gemm-for-some-sort-of-local-attention-with-only-given-indices/187428/2 :
In this original formulation the number of non-zero entries per row is fixed (so every "vertex" has the same number of "outbound edges", i.e. output_mask[b, m] number of non-zero is constant). But I guess the sparse formulation of Ideally, I'd also need this method for CUDA. |
Also, in my usecase, the "number of outbound edges" is always fixed (typical for k-nn graphs which are regular), so the "output mask" and the "output" can be simply dense tensors (and the "output mask" can simply contain the non-zero column indices). Is this equivalent to CSR format (with my I don't know if it's worth a separate function for my original format, but this sort of regular neighborhood sizes can often be accommodated when experimenting with attention patterns and might be simpler than dealing with sparse tensors from the UX standpoint (and if some "vertex" really-really needs to use less than fixed number of "outbound edges", the remaining edges can just be some sort of padding to some sort of sink node, or e.g. some special value like |
A few findings upon this topic: First of all, Following is an example: # input is csr of {B, M, N}
# mat1 is dense of {B, M, K}
# mat2 is dense of {B, K, N}
input = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]]).to_sparse_csr()
mat1 = torch.randn(2, 2, 5)
mat2 = torch.randn(2, 5, 2)
out = torch.sparse.sampled_addmm(input, mat1, mat2)
out2 = torch.sparse.sampled_addmm(input.cuda(), mat1.cuda(), mat2.cuda())
print(torch.allclose(out.to_dense(), out2.to_dense().cpu())) As for per-sample mask, there is a constraint that the nnz from each batch should be identical. This constraint is from @vadimkantorov your request of applying sparse mask in attention can be done like: in brief, the scaled dot product attention can be done:
B = 2 # batch_size
T = 10 # seq_length
H = 6 # num_heads
E = 20 # emb_size
qkv = torch.randn(B, T, 3 * H * E)
query, key, value = qkv.chunk(3, dim=2)
# q, k and v are physically contiguous in order
# of B-T-H-E
query = query.view(B, T, H, E).transpose_(1, 2)
key = key.view(B, T, H, E).transpose_(1, 2)
value = value.view(B, T, H, E).transpose_(1, 2)
# top 5 queries will have 5 keys
# bottom 5 queries will have 2 keys
# each {b, h} slice will have 5 * 5 + 5 * 2 = 35 samples (nnz = 35)
mask = torch.ones(B, H, T, T)
mask[:,:,:5,5:].fill_(0)
mask[:,:,5:,2:].fill_(0)
mask = mask.to_sparse_csr()
# attn = query @ key.t()
# BLAS term: SDDMM
# currently, this kernel will have additional memory copy triggered by .contiguous(), but it can be fixed
attn = torch.sparse.sampled_addmm(mask, query, key.transpose(-1, -2))
# SparseCSR does not have softmax now
# attn.softmax(-1)
# out = attn @ value
# directly use torch.mm or torch.sparse.mm will have following error
# RuntimeError: mat1 must be a matrix, got 4-D tensor
# torch.matmul also don't support this
# BLAS term: SPMM
out = []
for b in range(B):
for h in range(H):
out.append(torch.sparse.mm(attn[b][h], value[b][h]))
output = torch.cat(out, dim=0)
output = output.view(B, H, T, E)
print(output.size()) A few issues to be addressed:
|
Hmm, didn't fully understand if different per-batch-element masks are supported or not. For my original usecase Am I understanding right that my usecase or same number of "outgoing edges" fits sampled_admm? It would be nice to have a snippet for implementing: Also, does "the same number of nnz" allow for some actually semantically zero values? (e.g. a memory-occupying cell in the sparse tensor, but having values[i] == 0) to specify some padding. For my usecase, I don't care much about this, but it still might be useful sometimes |
@vadimkantorov I might not be following the "outgoing edges" translation to mask nnz correctly, but yes the sampled addmm should support different masks per batch. However every batch should have the same number of specified elements in order to use the layout, and the conversion function will enforce this. For a 3d mask Does that answer your question? |
I'm modeling a regular graph, so yeah, every vertex has equal number of vertex. So it seems sampled_addmm should work well for my usecase :) I meant that there is always ambiguity between values[i] == 0 element but still stored in the sparse tensor and an implicit zero value. How one should convert |
@vadimkantorov The whole process in transformer attention is very much alike to the message propagation in GNN ( we can treat transformer attention as a fully connected graph). In your use case,
### say we have T = 4, then the adjacency matrix should be 4 x 4
### where q_0 is connected to k_1 and k_2
### q_1 is connected to k_0 and k_3
### q_2 is connected to k_2 and k_3
### q_3 is connected to k_1 and k_3
### physically it needs only 4 * 2 storage but logically it is a 4*4 tensor
### the adjacency pattern is arbitrary, you can also have different numbers of keys for each query
### but need to make sure that the nnz from each batch slice is the same
>>> ind = torch.tensor([[0, 1, 1, 0],[1, 0, 0, 1],[0, 0, 1, 1],[0, 1, 0, 1]])
>>> ind
tensor([[0, 1, 1, 0],
[1, 0, 0, 1],
[0, 0, 1, 1],
[0, 0, 0, 1]])
>>> ind_csr = ind.to_sparse_csr()
>>> ind_csr
tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
col_indices=tensor([1, 2, 0, 3, 2, 3, 1, 3]),
values=tensor([1, 1, 1, 1, 1, 1, 1, 1]), size=(4, 4), nnz=8,
layout=torch.sparse_csr) Note that CSR would be much faster than dense when doing a gemm if the sparsity is high (also much faster than COO)
|
Thanks! Yeah, I would say, that my |
Exactly! Using sparse gemm here would be much faster than dense logic (which requires |
@mingfeima I think some related discussions in #71465 (comment) mentioning fused index_select + matmul in MoE context. I wonder if my semi-structured per-vertex indices list is related to that... |
🚀 The feature, motivation and pitch
Hi,
I would like to perform a batch matrix-matrix product with a per-sample mask.
It's similar to torch.sparse.sampled_addmm, the only difference is that
input
would be a (b, m, n) sparse tensor in the CSR format, unless we could provide masks as a list consisting of b (m, n) tensors.It might be blocked by #104193 though.
Alternatives
No response
Additional context
No response
cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer
The text was updated successfully, but these errors were encountered: