Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sparse] torch.sparse.sum() #12430

Closed
wants to merge 7 commits into from
Closed

Conversation

weiyangfb
Copy link
Contributor

@weiyangfb weiyangfb commented Oct 7, 2018

  • to fix torch.sum() for sparse tensor #12241
  • add _sparse_sum() to ATen, and expose as torch.sparse.sum(), not support SparseTensor.sum() currently
  • this PR depends on [sparse] Autograd get_indices/values and sparse_coo ctor #11253, and will need to be updated upon it lands
  • implement forward
  • implement backward
  • performance benchmark script:
    • sum all dims is fastest for sparse tensor
    • when input is sparse enough nnz = 0.1%, sum of sparse tensor is faster than dense in CPU, but not necessary in CUDA
    • CUDA backward is comparable (<2x) between sum several dims vs sum all dims in sparse
    • CPU backward uses binary search is still slow in sparse, takes 5x time in sum [0, 2, 3] dims vs sum all dims
      • optimize CUDA backward for now
        • using thrust for sort and binary search, but runtime not improved
    • both of CPU and CUDA forward are slow in sparse (sum several dims vs sum all dims), at most 20x slower in CPU, and 10x in CUDA
      • improve CPU and CUDA forward kernels
(nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) CPU (sparse vs dense) CUDA(sparse vs dense)
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll) 8.77 µs vs 72.9 µs 42.5 µs vs 108 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD) 112 µs vs 4.47 ms 484 µs vs 407 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) 141 µs vs 148 µs 647 µs vs 231 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) 235 µs vs 1.23 ms 781 µs vs 213 µs
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD) 48.5 µs vs 360 µs 160 µs vs 2.03 ms
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) 258 µs vs 1.22 ms 798 µs vs 224 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) 204 µs vs 882 µs 443 µs vs 133 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) 709 µs vs 1.15 ms 893 µs vs 202 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll) 39.8 µs vs 81 µs 42.4 µs vs 113 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD) 747 µs vs 4.7 ms 2.4 ms vs 414 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) 1.04 ms vs 126 µs 5.03 ms vs 231 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) 1.12 ms vs 1.24 ms 5.99 ms vs 213 µs
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD) 133 µs vs 366 µs 463 µs vs 2.03 ms
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) 1.56 ms vs 1.22 ms 6.11 ms vs 229 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) 1.53 ms vs 799 µs 824 µs vs 134 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) 5.15 ms vs 1.09 ms 7.02 ms vs 205 µs
  • after improving CPU and CUDA forward kernels
    • in (1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) forward, CPU takes 171 µs, in which 130 µs is spent on coalesce(), for CUDA, total time is 331 µs, in which 141 µs is spent on coalesce(), we need to reduce time at other places outside coalesce().
    • after a few simple tweaks, now in the forward, it is at most 10x slower in CPU, and 7x in CUDA. And time takes in sum dense dims only [2, 3] is ~2x of sum all dims. Speed of sum all sparse dims [0, 1] is on bar with sum all dims
(nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) CPU (sparse vs dense) CUDA(sparse vs dense)
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll) 7 µs vs 69.5 µs 31.5 µs vs 61.6 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD) 11.3 µs vs 4.72 ms 35.2 µs vs 285 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) 197 µs vs 124 µs 857 µs vs 134 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) 124 µs vs 833 µs 796 µs vs 106 µs
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD) 20.5 µs vs 213 µs 39.4 µs vs 1.24 ms
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) 131 µs vs 830 µs 881 µs vs 132 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) 95.8 µs vs 409 µs 246 µs vs 87.2 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) 624 µs vs 820 µs 953 µs vs 124 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll) 45.3 µs vs 72.9 µs 33.9 µs vs 57.2 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD) 81.4 µs vs 4.49 ms 39.7 µs vs 280 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) 984 µs vs 111 µs 6.41 ms vs 121 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) 1.45 ms vs 828 µs 6.77 ms vs 113 µs
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD) 74.9 µs vs 209 µs 37.7 µs vs 1.23 ms
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) 1.48 ms vs 845 µs 6.96 ms vs 132 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) 1.14 ms vs 411 µs 252 µs vs 87.8 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) 4.53 ms vs 851 µs 7.12 ms vs 128 µs
  • time takes in CUDA backward of sparse is super long with large variance (in case of nnz=10000, it normally takes 6-7ms). To improve backward of sparse ops, we will need to debug at places other than CUDA kernels. here is a benchmark of torch.copy_():
>>> d = [1000, 1000, 2, 2]
>>> nnz = 10000
>>> I = torch.cat([torch.randint(0, d[0], size=(nnz,)), 
               torch.randint(0, d[1], size=(nnz,))], 0).reshape(2, nnz)
>>> V = torch.randn(nnz, d[2], d[3])
>>> size = torch.Size(d)
>>> S = torch.sparse_coo_tensor(I, V, size).coalesce().cuda()
>>> S2 = torch.sparse_coo_tensor(I, V, size).coalesce().cuda().requires_grad_()
>>> data = S2.clone()
>>> S.copy_(S2)
>>> y = S * 2
>>> torch.cuda.synchronize()
>>> %timeit y.backward(data, retain_graph=True); torch.cuda.synchronize()
7.07 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Tensor sparse_sum(const SparseTensor& t, IntList dims, bool keepdim) {

const int64_t total_dims = t.dim();
check_dims_errors(dims, total_dims);

This comment was marked as off-topic.

// check if number of axis in dim is valid
AT_CHECK(flip_dims_size > 0 && flip_dims_size <= total_dims,
"flip dims size out of range, got flip dims size=", flip_dims_size);
static inline void check_dims_errors(IntList dims, int64_t total_dims) {

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -1790,6 +1790,23 @@
SparseCPU: norm_sparse
SparseCUDA: norm_sparse

# TODO: reduce signatures down to one when optinoal args is available

This comment was marked as off-topic.

apaszke
apaszke previously requested changes Oct 9, 2018
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we adding a new function for this instead of simply using torch.sum?

- func: sparse_sum(Tensor self, *, ScalarType dtype) -> Tensor
variant: method, function

- func: sparse_sum(Tensor self, IntList[1] dims, bool keepdim) -> Tensor

This comment was marked as off-topic.

}

Tensor sparse_sum(const SparseTensor& t, ScalarType dtype) {
return t._values().sum().to(dtype);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@weiyangfb
Copy link
Contributor Author

@apaszke It is for the sake of autograd support. Ops prefixed with sparse_ will have autograd support with gradients zeroed out at zero input locations during backward. But yes, I can also provide sum(SparseTensor) without backward.

@apaszke
Copy link
Contributor

apaszke commented Oct 9, 2018

Why is zeroing those gradients a good thing? That's not the real gradient of this operation.

@weiyangfb
Copy link
Contributor Author

@apaszke I agree, zeroing out gradients may not make sense in my application, but it might be useful for others use cases such Graph Networks @AntoinePrv (#10043)

@apaszke
Copy link
Contributor

apaszke commented Oct 10, 2018

I see the use case, but I don't feel like adding a sparse_X for every single operation we have is a good solution. The right way to do it would be to simply have some kind of a specialized masked tensor type, which also holds on to a mask specifying which entries are valid (where mask can be represented in a sparse way ofc). That's what @AntoinePrv originally proposed, and is very reasonable.

@weiyangfb
Copy link
Contributor Author

@apaszke ok, I think here is more about the badness of sparse_X ops, to address it, how about to add a sparse=bool args for each op supports SparseTensor? Similar to embedding, sparse=True indicates sparse gradients during the backward. About masked tensor representation, it is more of an option of possible way to implement backward since it might not be universally optimal for all ops.

@ezyang
Copy link
Contributor

ezyang commented Oct 10, 2018

@apaszke Are you recommending that we have a new backend, SparseMaskedTensor? Even if we do it that way, you still shouldn't call the operation sum, because it's not a sum. I suppose that if SparseMaskedTensor is just a Python-level only wrapper on top of tensor and not actually a tensor, you're allowed to do that, but if you want it to be an honest to goodness Tensor you have to respect our semantics which say, at the very least, that the derivative of all operations called "sum" should be the same.

@apaszke
Copy link
Contributor

apaszke commented Oct 11, 2018

@weiyangfb adding a sparse flag to every single operation we have is having a tail wag the dog.

@ezyang why is that not a sum? I don't understand. If the sparsity pattern encodes the "valid entries" on which you want to do the compute, then it's exactly the sum. In this case you shouldn't think of it as a tensor with zeros defined everywhere, but as some sentinel values that always act as neutral elements of every operation. I don't know if I'm proposing a backend, because backend is heavily overloaded in our vocabulary. What I want is just a tensor-like object that overloads operations which we already have.

Of course the new type would be some kind of a wrapper around other tensor operations, and in case it can't be expressed using regular tensor math, they can always desugar to more specialized implementations to the backend (think sth like sparse_sum, except that would be an internal function). However none of this should be of concern to the user. Otherwise, we'll end up with a 100000 variations of every function for every weird quirk that people want, and both maintenance and discoverability will be a problem.

@weiyangfb
Copy link
Contributor Author

@apaszke Thanks for the clarification! I think your proposal sounds reasonable to me now. One important thing is to define zero locations at SparseTensor as masked locations, i.e., not involved in computations during either forward or backward. This allows SparseTensor ops to share the same names as in dense tensor with autograd support. A new type sounds fine to me. Is it meant to replace the current dispatch mechanism for SparseTensor?

@weiyangfb
Copy link
Contributor Author

Hi @apaszke, would you like to comment this further so that I can unblock this work? Thanks!

@apaszke
Copy link
Contributor

apaszke commented Oct 14, 2018

I'm not sure what's the current dispatch for sparse tensors. My point is that we should avoid increasing our API surface by duplicating every single function with a sparse_ prefix. If this is blocking for a few models then we can merge this PR, but I'd like us to clean this mess up in the future (by deprecating them and replacing with a new tensor type).

@soumith
Copy link
Member

soumith commented Oct 15, 2018

We could put them in a torch.sparse.* namespace, for example torch.sparse.sum, to clearly indicate that the gradients are sparsified approximations, and not true gradients.

We can clearly declare this at the top level doc page for torch.sparse.*.
I think that's the best way forward. @apaszke thoughts?

@apaszke
Copy link
Contributor

apaszke commented Oct 15, 2018

@soumith I don't really care if they're called torch.sparse_sum or torch.sparse.sum (although it does seem slightly nicer to put it in a submodule). My point is that it's not the right API for this kind of thing, and should get deprecated sooner or later.

@weiyangfb
Copy link
Contributor Author

@apaszke The current dispatch for SparseTensor relies on files like pytorch/aten/src/ATen/native/LegacyBridge.cpp that dispatches to different backends depends on type of input tensor. Without introducing a new type for SparseTensor, we can still keep the API surface the same by relying on the current dispatch along with the masked positions rule (masked positions are not involved in computations in either forward or backward). This way we can support autograd for sparse with the same API names. On the other hand, a new type is welcomed if it makes more sense than the existing dispatching mechanism.

@apaszke
Copy link
Contributor

apaszke commented Oct 17, 2018

I thought you meant that we will want to have both sum and sparse_sum which returns the incorrect gradient. In that case we need an entirely new sparse tensor type (which says that "zero" entries are not really zero, but they are "invalid").

@weiyangfb
Copy link
Contributor Author

@apaszke actually I meant to keep sum() only. But yes, it will have backward and zero entries are invalid. We will need to define "correctness" of the gradient for sparse tensor in this way.

@apaszke
Copy link
Contributor

apaszke commented Oct 18, 2018

Hmmm I'm not sure if we want this to be the default meaning of a sparse tensor. In my view they should really be a drop in replacement for dense semantics, except that we would be optimizing most entries out for space reasons. On the other hand, allocating large dense tensors as their gradients might OOM too...

@weiyangfb
Copy link
Contributor Author

Hi @ezyang, I think the PR is good for review again :) Thanks!

@weiyangfb
Copy link
Contributor Author

@soumith can I get a review?

Copy link
Member

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm just getting back familiarity with sparse, so parts of my review are probably pretty ill-informed.
I left some comments in-line.

tbd on review is: cuda implementation, backward

@@ -1 +1,92 @@
# The Tensor classes are added to this module by python_tensor.cpp
import torch

This comment was marked as off-topic.

This comment was marked as off-topic.

// Ex2:
// dims_to_flatten = [1]
// new_indices = [ 3, 1, 3 ] # uncoalesced
inline LongTensor flatten_indices_by_dims(const LongTensor& indices, const IntList& sizes, const IntList& dims_to_flatten){

This comment was marked as off-topic.

This comment was marked as off-topic.

// dims_to_flatten = [1]
// new_indices = [ 3, 1, 3 ] # uncoalesced
inline LongTensor flatten_indices_by_dims(const LongTensor& indices, const IntList& sizes, const IntList& dims_to_flatten){
LongTensor new_indices = at::zeros({indices.size(1)}, indices.options());

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

inline LongTensor flatten_indices_by_dims(const LongTensor& indices, const IntList& sizes, const IntList& dims_to_flatten){
LongTensor new_indices = at::zeros({indices.size(1)}, indices.options());
for (auto d : dims_to_flatten) {
new_indices.mul_(sizes[d]);

This comment was marked as off-topic.

// for ops like sum, max, and min.
// --------------------------------------------------------------------
Tensor _sparse_sum(const SparseTensor& input) {
return input.coalesce().values().sum();

This comment was marked as off-topic.

This comment was marked as off-topic.

else {
if (keepdim) {
new_indices = at::zeros_like(indices);
if (!sum_all_sparse_dim) {

This comment was marked as off-topic.

}
}
else {
if (sum_all_sparse_dim) {

This comment was marked as off-topic.

new_indices = at::zeros_like(indices);
if (!sum_all_sparse_dim) {
for (int64_t d = 0; d < sparse_dim; d++) {
if (!dims_to_sum_b[d]) new_indices[d].copy_(indices[d]);

This comment was marked as off-topic.

This comment was marked as off-topic.

@skipIfRocm
def test_sparse_sum(self):

def run_tests(S, td=None, k=False):

This comment was marked as off-topic.

S_grad_dense = S_grad.to_dense() if S_grad.is_sparse else S_grad
self.assertEqual(S_grad_dense, D_grad)

nnz = 10

This comment was marked as off-topic.

@weiyangfb weiyangfb force-pushed the sparse_sum branch 2 times, most recently from 6ebca4b to b41fba6 Compare November 21, 2018 06:12
@weiyangfb
Copy link
Contributor Author

I removed keepdim args, and addressed comments, the CI failures do not look related. This PR is ready for review again. cc @soumith

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

- move sparse_sum to sparse.sum
- requires input sparse tensor to be coalesced in sparse_sum to ease autograd support
- optimize CPU kernel at backward with binary search
- optimize runtime of forward and backward, use cheap sparse tensor ctor _sparse_coo_tensor_with_dims_and_tensors()
- optimize backward CUDA kernel with binary search, runtime doesn't seem to improve at all, not sure why
- improve speed of forward when summing all sparse dims; address comments
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Nov 28, 2018
Summary:
- to fix #12241
- add `_sparse_sum()` to ATen, and expose as `torch.sparse.sum()`, not support `SparseTensor.sum()` currently
- this PR depends on #11253, and will need to be updated upon it lands
- [x] implement forward
- [x] implement backward
- performance [benchmark script](https://gist.github.com/weiyangfb/f4c55c88b6092ef8f7e348f6b9ad8946#file-sparse_sum_benchmark-py):
  - sum all dims is fastest for sparse tensor
  - when input is sparse enough nnz = 0.1%, sum of sparse tensor is faster than dense in CPU, but not necessary in CUDA
  - CUDA backward is comparable (<2x) between `sum several dims` vs `sum all dims` in sparse
  - CPU backward uses binary search is still slow in sparse, takes `5x` time in `sum [0, 2, 3] dims` vs `sum all dims`
    - optimize CUDA backward for now
      - using thrust for sort and binary search, but runtime not improved
  - both of CPU and CUDA forward are slow in sparse (`sum several dims` vs `sum all dims`), at most `20x` slower in CPU, and `10x` in CUDA
    - improve CPU and CUDA forward kernels

(nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 8.77 µs vs 72.9 µs | 42.5 µs vs 108 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 112 µs vs 4.47 ms | 484 µs vs 407 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 141 µs vs 148 µs | 647 µs vs 231 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 235 µs vs 1.23 ms | 781 µs vs 213 µs
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 48.5 µs vs 360 µs | 160 µs vs 2.03 ms
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 258 µs vs 1.22 ms | 798 µs vs 224 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 204 µs vs 882 µs | 443 µs vs 133 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 709 µs vs 1.15 ms | 893 µs vs 202 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 39.8 µs vs 81 µs | 42.4 µs vs 113 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 747 µs vs 4.7 ms | 2.4 ms vs 414 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 1.04 ms vs 126 µs | 5.03 ms vs 231 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.12 ms vs 1.24 ms | 5.99 ms vs 213 µs
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 133 µs vs 366 µs | 463 µs vs 2.03 ms
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.56 ms vs 1.22 ms | 6.11 ms vs 229 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.53 ms vs 799 µs | 824 µs vs 134 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 5.15 ms vs 1.09 ms | 7.02 ms vs 205 µs

- after improving CPU and CUDA forward kernels
  - in `(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD)` forward, CPU takes ~~`171 µs`~~, in which `130 µs` is spent on `coalesce()`, for CUDA, total time is ~~`331 µs`~~, in which `141 µs` is spent on `coalesce()`, we need to reduce time at other places outside `coalesce()`.
  - after a few simple tweaks, now in the forward, it is at most `10x` slower in CPU, and `7x` in CUDA. And time takes in `sum dense dims only [2, 3]` is `~2x` of `sum all dims`. Speed of `sum all sparse dims [0, 1]` is on bar with `sum all dims`

(nnz,   sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 7 µs vs 69.5 µs | 31.5 µs vs 61.6 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 11.3 µs vs 4.72 ms | 35.2 µs vs 285 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 197 µs vs 124 µs | 857 µs vs 134 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 124 µs vs 833 µs | 796 µs vs 106 µs
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 20.5 µs vs 213 µs | 39.4 µs vs 1.24 ms
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 131 µs vs 830 µs | 881 µs vs 132 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 95.8 µs vs 409 µs | 246 µs vs 87.2 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 624 µs vs 820 µs | 953 µs vs 124 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 45.3 µs vs 72.9 µs | 33.9 µs vs 57.2 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 81.4 µs vs 4.49 ms | 39.7 µs vs 280 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 984 µs vs 111 µs | 6.41 ms vs 121 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.45 ms vs 828 µs | 6.77 ms vs 113 µs
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 74.9 µs vs 209 µs | 37.7 µs vs 1.23 ms
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.48 ms vs 845 µs | 6.96 ms vs 132 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.14 ms vs 411 µs | 252 µs vs 87.8 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 4.53 ms vs 851 µs | 7.12 ms vs 128 µs

- time takes in CUDA backward of sparse is super long with large variance (in case of nnz=10000, it normally takes 6-7ms). To improve backward of sparse ops, we will need to debug at places other than CUDA kernels. here is a benchmark of `torch.copy_()`:
```
>>> d = [1000, 1000, 2, 2]
>>> nnz = 10000
>>> I = torch.cat([torch.randint(0, d[0], size=(nnz,)),
               torch.randint(0, d[1], size=(nnz,))], 0).reshape(2, nnz)
>>> V = torch.randn(nnz, d[2], d[3])
>>> size = torch.Size(d)
>>> S = torch.sparse_coo_tensor(I, V, size).coalesce().cuda()
>>> S2 = torch.sparse_coo_tensor(I, V, size).coalesce().cuda().requires_grad_()
>>> data = S2.clone()
>>> S.copy_(S2)
>>> y = S * 2
>>> torch.cuda.synchronize()
>>> %timeit y.backward(data, retain_graph=True); torch.cuda.synchronize()
7.07 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
Pull Request resolved: pytorch/pytorch#12430

Differential Revision: D12878313

Pulled By: weiyangfb

fbshipit-source-id: e16dc7681ba41fdabf4838cf05e491ca9108c6fe
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.sum() for sparse tensor
6 participants