Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement hybrid sparse to/from dense conversions. #90177

Closed
wants to merge 22 commits into from

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 5, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90177

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0498f42:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

alexsamardzic added a commit that referenced this pull request Dec 5, 2022
ghstack-source-id: 40aa7e2c6e47c6d67fffc954dd6c31dd17a529cc
Pull Request resolved: #90177
@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Dec 5, 2022

Just the first version of hybrid CSR/CSC to dense conversions... Mimics BSR/BSC to dense conversions, so it remains to:

  1. unify those
  2. try to avoid looping over batches
Testing script (create a CSR matrix, then create CSC with transposed rows/columns, convert both to dense, and check are these transposes, along the sparse dimensions, of each other)
import torch

BATCH_DIMS_MIN = 1
BATCH_DIMS_MAX = 3
BATCH_SIZE_MIN = 1
BATCH_SIZE_MAX = 5
SPARSE_SIZE_MIN = 1
SPARSE_SIZE_MAX = 10
HYBRID_DIMS_MIN = 1
HYBRID_DIMS_MAX = 5
HYBRID_SIZE_MIN = 1
HYBRID_SIZE_MAX = 10
NTESTS = 10000

for i in range(NTESTS):
    batch_dims = torch.randint(BATCH_DIMS_MIN, BATCH_DIMS_MAX + 1, (1,)).item()
    batch_size = tuple(torch.randint(BATCH_SIZE_MIN, BATCH_SIZE_MAX + 1, (batch_dims,)).tolist())
    (nrows, ncols) = sparse_size = tuple(torch.randint(SPARSE_SIZE_MIN, SPARSE_SIZE_MAX + 1, (2,)).tolist())
    hybrid_dims = torch.randint(HYBRID_DIMS_MIN, HYBRID_DIMS_MAX + 1, (1,)).item()
    hybrid_size = tuple(torch.randint(HYBRID_SIZE_MIN, HYBRID_SIZE_MAX + 1, (hybrid_dims,)).tolist())

    dense = torch.randn(sparse_size, dtype=torch.float32)
    dense = dense * dense.relu().bool()
    dense = dense.repeat(batch_size + (1, 1))
    csr = dense.to_sparse_csr()
    hybrid_values = torch.randn(csr.values().shape + hybrid_size, dtype=torch.float32)
    hybrid_csr = torch.sparse_compressed_tensor(
        csr.crow_indices(),
        csr.col_indices(),
        hybrid_values,
        batch_size + (nrows, ncols) + hybrid_size,
        layout=torch.sparse_csr,
        dtype=torch.float32)
    #print("hybrid_csr:")
    #print(hybrid_csr)
    dense_hybrid_csr = hybrid_csr.to_dense()
    #print(dense_hybrid_csr)

    hybrid_csc = torch.sparse_compressed_tensor(
        csr.crow_indices(),import torch
        csr.col_indices(),
        hybrid_values,
        batch_size + (ncols, nrows) + hybrid_size,
        layout=torch.sparse_csc,
        dtype=torch.float32)
    #print("hybrid_csc:")
    #print(hybrid_csc)
    dense_hybrid_csc = hybrid_csc.to_dense()
    #print(dense_hybrid_csc)

    assert(torch.all(torch.transpose(dense_hybrid_csr, batch_dims, batch_dims + 1) == dense_hybrid_csc))

alexsamardzic added a commit that referenced this pull request Dec 5, 2022
ghstack-source-id: b40f9c43ef1b231507ad3c0122231211182fa0d3
Pull Request resolved: #90177
@alexsamardzic alexsamardzic added the module: sparse Related to torch.sparse label Dec 5, 2022
Comment on lines 580 to 583
compressed_indices.unsqueeze_(0);
plain_indices.unsqueeze_(0);
values = values.unsqueeze_(0);
dense = dense.unsqueeze_(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mixed-style in-place unsqueeze. It is probably better to do values.unsqueeze_(...); dense.unsqueeze_(...) for more clarity.

Comment on lines 603 to 614
for (auto batch : c10::irange(n_batch)) {
Tensor batch_indices = at::_convert_indices_from_csr_to_coo(
compressed_indices[batch],
plain_indices[batch],
false,
self.layout() == kSparseCsc);
auto batch_row_indices = batch_indices.select(0, 0);
auto batch_col_indices = batch_indices.select(0, 1);

auto offsets = batch_col_indices + batch_row_indices * self.size(batch_ndim + 1);
dense[batch].index_add_(0, offsets, values[batch]);
}
Copy link
Collaborator

@nikitaved nikitaved Dec 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible to remove the for-loop and use compressed_indices_to_plain_indices from https://github.com/pytorch/pytorch/pull/88078/files#r1013344803 reimplemented in CPP as a utility function so that a single call to _convert_indices_from_csr_to_coo is sufficient. Similar pattern but reversed was used in #82122.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer - as mentioned in my comment, that would be good optimization to have.

alexsamardzic added a commit that referenced this pull request Dec 6, 2022
ghstack-source-id: b40f9c43ef1b231507ad3c0122231211182fa0d3
Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Dec 6, 2022
ghstack-source-id: a92221a72a0a9d8cfeb54062f9b256acdf028022
Pull Request resolved: #90177
@alexsamardzic alexsamardzic changed the title Implement hybrid CSR/CSC to dense conversions. Implement hybrid compressed sparse to dense conversions. Dec 6, 2022
@alexsamardzic
Copy link
Collaborator Author

Added support for hybrid BSR/BSC to dense conversions. The unification with CSR/CSC to dense conversions, as well as optimization on eliminating looping for batches handling, are next to come.

alexsamardzic added a commit that referenced this pull request Dec 7, 2022
ghstack-source-id: a92221a72a0a9d8cfeb54062f9b256acdf028022
Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Dec 7, 2022
ghstack-source-id: 10794a6aa30bd21f1144a33f80fa0ac4ef4b319f
Pull Request resolved: #90177
@alexsamardzic
Copy link
Collaborator Author

Unified CSR/CSC and BSR/BSC conversion to dense.

Copy link
Collaborator

@amjames amjames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start, I found one error, and I have a cleaner-code note. See inline comments.

}
if (self.dim() > 3) {
if (batch_ndim > 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition should either be batch_ndim > 1 or there should be a notee that when there is already a single batch dim (batch_ndim = 1) the flatten ops below result in flatten(0, 0) and return the input tensors unmodified hence we are not doing anything incorrect or unnecessary.

dense = dense.flatten(1, 2);
}
else {
blocksize = {values.size(batch_ndim + 2), values.size(batch_ndim + 3)};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

values was reshaped above so it contains exactly one batch dimension, but batch_ndim tracks the number of batch dims in the original shape.

After the flatten/unsqueeze of is applied to values the generalized shape would be (n_batch, nnz) + blocksize + dense_dims where n_batch, and nnz are scalars, blocksize is a length 2 tuple and dense_dims is a tuple with as many members as there are dense dimensions in the tensor, possibly zero.

So at this point the blocksizes can always be computed as values.size(2) and values.size(3) for any input which enters this region.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - thanks! Also updated, in the next commit below, according to other comments of yours.

dense = dense.reshape({n_batch, -1, values.size(-2), values.size(-1)});

int64_t nrows, ncols;
std::array<int64_t, 2> blocksize = {1, 1};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like blocksize is only used to compute the dense shape in the else block of the conditional below, so it can be moved into that scope and need not exist at this level.

alexsamardzic added a commit that referenced this pull request Dec 8, 2022
ghstack-source-id: 10794a6aa30bd21f1144a33f80fa0ac4ef4b319f
Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Dec 8, 2022
ghstack-source-id: 0a2917c48e28ac8f2f43ef183ae5d29d0f5dd5d6
Pull Request resolved: #90177
@alexsamardzic
Copy link
Collaborator Author

Removed looping over batches. Please do not review yet, until comments added, and tests eventually enabled.

cc nikitaved pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Dec 8, 2022
ghstack-source-id: 61a7db0b67df3e1552e9766ee302a45a701ee000
Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Dec 8, 2022
ghstack-source-id: 33f48b6672f74513d974234e7f59b83f5fbf9da0
Pull Request resolved: #90177
alexsamardzic added a commit that referenced this pull request Jan 3, 2023
ghstack-source-id: 3b141893eab378f8b743fa14d8561f2427d390ba
Pull Request resolved: #90177
@alexsamardzic alexsamardzic added the release notes: sparse release notes category label Jan 3, 2023
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @alexsamardzic, this looks good!

I have a suggestion to add support for

to_sparse(layout=torch.tensor_coo, dense_dim=...)

as a low handing fruit.

Also, I have a question about the non-sparse-non-dense dimensions in to_sparse_csr/csc/bsr/bsc.

aten/src/ATen/native/TensorConversions.cpp Outdated Show resolved Hide resolved
torch/_tensor_docs.py Outdated Show resolved Hide resolved
torch/_tensor_docs.py Outdated Show resolved Hide resolved
aten/src/ATen/native/sparse/SparseTensor.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/sparse/SparseTensor.cpp Outdated Show resolved Hide resolved
…ns."

cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Jan 5, 2023
ghstack-source-id: 819bec92c32e05b3e87a6876d39ade9c7192c759
Pull Request resolved: #90177
@alexsamardzic alexsamardzic changed the title Implement hybrid compressed sparse to/from dense conversions. Implement hybrid sparse to/from dense conversions. Jan 5, 2023
values=tensor([[[[1.]],

[[1.]]],

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this additional whitespace really how we print BSR Tensors?

@alexsamardzic - you can see the preview of the Python documentation as part of the Helpful Links. You'll just need to search for the function name that you changed to see how it renders on the web.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch printing BSR tensors are defined as how we print strided tensors. For example:

>>> torch.ones(4, 4).to_sparse_bsr((2, 2)).values()
tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])

Personally, I also find the whitespace annoying but I guess this is supposed to help to view high-dimensional tensors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I've copied output from the interpreter for these examples.

@@ -5467,15 +5517,43 @@ def callable(a, b) -> number
>>> sparse._nnz()
25

>>> dense = torch.zeros(3, 3, 1, 1)
>>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be = 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should, strange that output is correct i.e. as that = 1 was there. In any case, an update pushed.

cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Jan 6, 2023
ghstack-source-id: ed051929490979d1a948c1ded30419f956f2b203
Pull Request resolved: #90177
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a couple of nits and questions. Thanks, @alexsamardzic!

aten/src/ATen/native/TensorConversions.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/TensorConversions.cpp Show resolved Hide resolved
aten/src/ATen/native/TensorConversions.cpp Outdated Show resolved Hide resolved
test/test_sparse_csr.py Outdated Show resolved Hide resolved
torch/_tensor_docs.py Outdated Show resolved Hide resolved
alexsamardzic added a commit that referenced this pull request Jan 6, 2023
ghstack-source-id: ed051929490979d1a948c1ded30419f956f2b203
Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Jan 6, 2023
ghstack-source-id: bf398893331f43a4e4594a6b26f1aa075db1a762
Pull Request resolved: #90177
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a couple of more clean-up nits but otherwise looks good to me!

aten/src/ATen/native/TensorConversions.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/TensorConversions.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/TensorConversions.cpp Outdated Show resolved Hide resolved
cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Jan 6, 2023
ghstack-source-id: 20e9a46245463eb384b4b6bfaccc3d3f99d751d2
Pull Request resolved: #90177
alexsamardzic added a commit that referenced this pull request Jan 11, 2023
ghstack-source-id: 20e9a46245463eb384b4b6bfaccc3d3f99d751d2
Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Jan 11, 2023
ghstack-source-id: ed8d8ba2b373c4126794d21c20cf7179fcc4313b
Pull Request resolved: #90177
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks, @alexsamardzic!

@cpuhrsch
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/alexsamardzic/4/head branch June 8, 2023 15:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration module: sparse Related to torch.sparse open source release notes: sparse release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants