Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of torch.sparse.sampled_baddmm #105319

Open
candyflower2005 opened this issue Jul 17, 2023 · 14 comments
Open

Implementation of torch.sparse.sampled_baddmm #105319

candyflower2005 opened this issue Jul 17, 2023 · 14 comments
Labels
module: sparse Related to torch.sparse topic: new features topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@candyflower2005
Copy link

candyflower2005 commented Jul 17, 2023

🚀 The feature, motivation and pitch

Hi,
I would like to perform a batch matrix-matrix product with a per-sample mask.
It's similar to torch.sparse.sampled_addmm, the only difference is that input would be a (b, m, n) sparse tensor in the CSR format, unless we could provide masks as a list consisting of b (m, n) tensors.
It might be blocked by #104193 though.

Alternatives

No response

Additional context

No response

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer

@candyflower2005 candyflower2005 changed the title Implement torch.sparse.sampled_baddmm Implementation of torch.sparse.sampled_baddmm Jul 17, 2023
@mikaylagawarecki mikaylagawarecki added module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 17, 2023
@nikitaved nikitaved added the topic: new features topic category label Jul 17, 2023
@nikitaved
Copy link
Collaborator

nikitaved commented Jul 17, 2023

@candyflower2005 , unfortunately, we do not have that operation, but if you have CUDA and Triton, you can convert your CSR inputs to BSR with blocksizes which are powers of 2 and greater than 16, you could use sampled_addmm from torch.sparse._triton_ops. That one does support batch dimensions and broadcasting.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Aug 30, 2023

Would also benefit very much from the batch mode of this op as it enables custom attention patterns (and I assume for GNNs too, cc @fzhao3): https://discuss.pytorch.org/t/incomplete-sparse-gemm-for-some-sort-of-local-attention-with-only-given-indices/187428/2

@mingfeima
Copy link
Collaborator

@vadimkantorov will take care of it.

So the inputs for this operator is like below ?

  • input – a sparse CSR matrix of shape (b, m, n)

  • mat1 – a dense matrix of shape (b, m, k)

  • mat2 – a dense matrix of shape (b, k, n)

And output has the same shape and sparse pattern with input ?

@nikitaved
Copy link
Collaborator

nikitaved commented Aug 31, 2023

@mingfeima , yes, but be aware of the linked issues above, i.e. there is no softmax for CSR inputs, but it is not very hard to implement for dim=-1, especially if the Triton BSR kernel is reused. In fact, the triton BSR kernel could be used for CSR inputs with the CSR -> BSR conversion with blocksize 1x1 (trivial value unsqueeze, or just the code modification inside the kernel). You might find the code in torch.sparse._triton_ops.sampled_addmm potentially useful, although that operation implementation does perform broadcasting.
The other problem of sampled_addmm for CSR inputs is that the full matrix-matrix product is likely unavoidable for a general purpose algorithm. So one has to be careful and probably implement two kernels depending on the sparsity level of the mask, and/or directly call into cuSPARSE which is designed for super sparse matrices.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Aug 31, 2023

@mingfeima for my usecase - yes!

The alternative formulation is at https://discuss.pytorch.org/t/incomplete-sparse-gemm-for-some-sort-of-local-attention-with-only-given-indices/187428/2 :

Given two tensors: emb[B, T, C] and ind[B, T, K] where ind[b, t, :] contains indices of the neighborhood of (b, t).
I would like to compute out[B, T, K], where out[b, t, k] = \sum_c emb[b, t, c] * emb[b, ind[b, t, k], c].

In this original formulation the number of non-zero entries per row is fixed (so every "vertex" has the same number of "outbound edges", i.e. output_mask[b, m] number of non-zero is constant). But I guess the sparse formulation of sampled_baddbmm is a more general variant.

Ideally, I'd also need this method for CUDA.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Aug 31, 2023

Also, in my usecase, the "number of outbound edges" is always fixed (typical for k-nn graphs which are regular), so the "output mask" and the "output" can be simply dense tensors (and the "output mask" can simply contain the non-zero column indices).

Is this equivalent to CSR format (with my ind being col_indices for CSR and the output being the values for CSR, and row_indices just being consecutive arange)?

I don't know if it's worth a separate function for my original format, but this sort of regular neighborhood sizes can often be accommodated when experimenting with attention patterns and might be simpler than dealing with sparse tensors from the UX standpoint (and if some "vertex" really-really needs to use less than fixed number of "outbound edges", the remaining edges can just be some sort of padding to some sort of sink node, or e.g. some special value like -1 which can be ignored by the underlying op).

@mingfeima
Copy link
Collaborator

A few findings upon this topic:

First of all, torch.sparse.sampled_addmm already supports a batch mode right now, but this is not cleared documented in https://pytorch.org/docs/stable/generated/torch.sparse.sampled_addmm.html. The doc said input, mat1 and mat2 should be 2d but actually the code supposed batch dimension(s), e.g. {batch0, batch1 ... M, N}.

Following is an example:

# input is csr of {B, M, N}
# mat1 is dense of {B, M, K}
# mat2 is dense of {B, K, N}
input = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]]).to_sparse_csr()
mat1 = torch.randn(2, 2, 5)
mat2 = torch.randn(2, 5, 2)

out = torch.sparse.sampled_addmm(input, mat1, mat2)
out2 = torch.sparse.sampled_addmm(input.cuda(), mat1.cuda(), mat2.cuda())
print(torch.allclose(out.to_dense(), out2.to_dense().cpu()))

As for per-sample mask, there is a constraint that the nnz from each batch should be identical. This constraint is from to_sparse_csr() not sampled_addmm, just as described in #104193

@vadimkantorov your request of applying sparse mask in attention can be done like:

in brief, the scaled dot product attention can be done:

  • q @ k.t: this is SDDMM, which is torch.sparse.sampled_addmm
  • attn @ v: this is SPMM, which is torch.mm or torch.sparse.mm
B = 2 # batch_size
T = 10 # seq_length
H = 6 # num_heads
E = 20 # emb_size

qkv = torch.randn(B, T, 3 * H * E)
query, key, value = qkv.chunk(3, dim=2)

# q, k and v are physically contiguous in order
# of B-T-H-E
query = query.view(B, T, H, E).transpose_(1, 2)
key = key.view(B, T, H, E).transpose_(1, 2)
value = value.view(B, T, H, E).transpose_(1, 2)

# top 5 queries will have 5 keys
# bottom 5 queries will have 2 keys
# each {b, h} slice will have 5 * 5 + 5 * 2 = 35 samples (nnz = 35)
mask = torch.ones(B, H, T, T)
mask[:,:,:5,5:].fill_(0)
mask[:,:,5:,2:].fill_(0)
mask = mask.to_sparse_csr()

# attn = query @ key.t()
# BLAS term: SDDMM
# currently, this kernel will have additional memory copy triggered by .contiguous(), but it can be fixed
attn = torch.sparse.sampled_addmm(mask, query, key.transpose(-1, -2))

# SparseCSR does not have softmax now
# attn.softmax(-1)

# out = attn @ value
# directly use torch.mm or torch.sparse.mm will have following error
#   RuntimeError: mat1 must be a matrix, got 4-D tensor

# torch.matmul also don't support this

# BLAS term: SPMM
out = []
for b in range(B):
    for h in range(H):
        out.append(torch.sparse.mm(attn[b][h], value[b][h]))

output = torch.cat(out, dim=0)
output = output.view(B, H, T, E)
print(output.size())

A few issues to be addressed:

  • SparseCSR doesn't have softmax @nikitaved has mentioned this earlier
  • torch.mm or torch.matmul doesn't work on csr with batch-dimension right now

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Sep 7, 2023

As for per-sample mask, there is a constraint that the nnz from each batch should be identical.

Hmm, didn't fully understand if different per-batch-element masks are supported or not.

For my original usecase emb[B, T, C] x ind[B, T, K] -> out[B, T, K] (alternatively, instead of a single emb, different key, query can be wanted), different per-batch-element masks are crucial for processing multiple images in one go.

Am I understanding right that my usecase or same number of "outgoing edges" fits sampled_admm? It would be nice to have a snippet for implementing: def graph_mm(emb: '[B, T, C]', adj_ind: '[B, T, K]') -> '[B, T, K]': ... directly of these dense inputs.

Also, does "the same number of nnz" allow for some actually semantically zero values? (e.g. a memory-occupying cell in the sparse tensor, but having values[i] == 0) to specify some padding. For my usecase, I don't care much about this, but it still might be useful sometimes

@amjames
Copy link
Collaborator

amjames commented Sep 7, 2023

@vadimkantorov I might not be following the "outgoing edges" translation to mask nnz correctly, but yes the sampled addmm should support different masks per batch. However every batch should have the same number of specified elements in order to use the layout, and the conversion function will enforce this.

For a 3d mask M, (M[i] == 0).sum() should be equal for any i however M[i] == 0 need not be the same for all i.

Does that answer your question?

@vadimkantorov
Copy link
Contributor

I'm modeling a regular graph, so yeah, every vertex has equal number of vertex. So it seems sampled_addmm should work well for my usecase :)

I meant that there is always ambiguity between values[i] == 0 element but still stored in the sparse tensor and an implicit zero value.

How one should convert adj_ind : '[B, T, K]' to an input mask? (and in reverse, convert the output of sampled_addmm to out : '[B, T, K]'?)

@mingfeima
Copy link
Collaborator

mingfeima commented Sep 8, 2023

@vadimkantorov The whole process in transformer attention is very much alike to the message propagation in GNN ( we can treat transformer attention as a fully connected graph).

In your use case, out[b, t, k] = \sum_c emb[b, t, c] * emb[b, ind[b, t, k], c], to let it fill into sampled_addmm or SDDMM (sparse-dense-dense-matrix-multiplication):

  • ind is the tensor which describes the adjacency matrix from queries to keys: logically it should be a tensor shape of {B, T, T}, physically it should be stored as a sparse csr tensor. If we take a look at the [b] slice, the {T, T} matrix is the connection between query and key. For conventional transformer, each query will have T connections (fully connected) and here we say each query has only K connections, so the nnz is T * K here. In sampled_addmm, ind is the input.
  • You can have different connections for different queries, for example:
### say we have T = 4, then the adjacency matrix should be 4 x 4
### where q_0 is connected to k_1 and k_2
### q_1 is connected to k_0 and k_3
### q_2 is connected to k_2 and k_3
### q_3 is connected to k_1 and k_3
### physically it needs only 4 * 2 storage but logically it is a 4*4 tensor
### the adjacency pattern is arbitrary, you can also have different numbers of keys for each query
### but need to make sure that the nnz from each batch slice is the same
>>> ind = torch.tensor([[0, 1, 1, 0],[1, 0, 0, 1],[0, 0, 1, 1],[0, 1, 0, 1]])
>>> ind
tensor([[0, 1, 1, 0],
        [1, 0, 0, 1],
        [0, 0, 1, 1],
        [0, 0, 0, 1]])

>>> ind_csr = ind.to_sparse_csr()
>>> ind_csr
tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
       col_indices=tensor([1, 2, 0, 3, 2, 3, 1, 3]),
       values=tensor([1, 1, 1, 1, 1, 1, 1, 1]), size=(4, 4), nnz=8,
       layout=torch.sparse_csr)

Note that CSR would be much faster than dense when doing a gemm if the sparsity is high (also much faster than COO)

  • emb is the mat1 and mat2 in sampled_addmm, shape of {B, T, H}, it is a dense tensor.
  • out is a sparse csr tensor shape of {B, T, T} which has the same sparse pattern as ind. This would be attention in the transformer, and we need to do attn @ value next, which will be {B, T, T}_sparse @ {B, T, C}_dense = {B, T, C}_dense. This is a SPMM (sparse-matrix-multipilcation). Now you get back to a dense tensor again.

@vadimkantorov
Copy link
Contributor

Thanks! Yeah, I would say, that my adj_ind : '[B, T, K]', this is col_indices, right? and the crow_indices is sth like torch.arange(T)[None, :, None].expand(B, -1, K)? and then I can get out : '[B, T, K]' as .values().reshape(B, T, K)?

@mingfeima
Copy link
Collaborator

Thanks! Yeah, I would say, that my adj_ind : '[B, T, K]', this is col_indices, right? and the crow_indices is sth like torch.arange(T)[None, :, None].expand(B, -1, K)? and then I can get out : '[B, T, K]' as .values().reshape(B, T, K)?

Exactly! Using sparse gemm here would be much faster than dense logic (which requires scatter_add and index_select) ...

@vadimkantorov
Copy link
Contributor

@mingfeima I think some related discussions in #71465 (comment) mentioning fused index_select + matmul in MoE context. I wonder if my semi-structured per-vertex indices list is related to that...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: sparse Related to torch.sparse topic: new features topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants