-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement hybrid sparse to/from dense conversions. #90177
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90177
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0498f42: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 40aa7e2c6e47c6d67fffc954dd6c31dd17a529cc Pull Request resolved: #90177
Just the first version of hybrid CSR/CSC to dense conversions... Mimics BSR/BSC to dense conversions, so it remains to:
Testing script (create a CSR matrix, then create CSC with transposed rows/columns, convert both to dense, and check are these transposes, along the sparse dimensions, of each other)import torch
BATCH_DIMS_MIN = 1
BATCH_DIMS_MAX = 3
BATCH_SIZE_MIN = 1
BATCH_SIZE_MAX = 5
SPARSE_SIZE_MIN = 1
SPARSE_SIZE_MAX = 10
HYBRID_DIMS_MIN = 1
HYBRID_DIMS_MAX = 5
HYBRID_SIZE_MIN = 1
HYBRID_SIZE_MAX = 10
NTESTS = 10000
for i in range(NTESTS):
batch_dims = torch.randint(BATCH_DIMS_MIN, BATCH_DIMS_MAX + 1, (1,)).item()
batch_size = tuple(torch.randint(BATCH_SIZE_MIN, BATCH_SIZE_MAX + 1, (batch_dims,)).tolist())
(nrows, ncols) = sparse_size = tuple(torch.randint(SPARSE_SIZE_MIN, SPARSE_SIZE_MAX + 1, (2,)).tolist())
hybrid_dims = torch.randint(HYBRID_DIMS_MIN, HYBRID_DIMS_MAX + 1, (1,)).item()
hybrid_size = tuple(torch.randint(HYBRID_SIZE_MIN, HYBRID_SIZE_MAX + 1, (hybrid_dims,)).tolist())
dense = torch.randn(sparse_size, dtype=torch.float32)
dense = dense * dense.relu().bool()
dense = dense.repeat(batch_size + (1, 1))
csr = dense.to_sparse_csr()
hybrid_values = torch.randn(csr.values().shape + hybrid_size, dtype=torch.float32)
hybrid_csr = torch.sparse_compressed_tensor(
csr.crow_indices(),
csr.col_indices(),
hybrid_values,
batch_size + (nrows, ncols) + hybrid_size,
layout=torch.sparse_csr,
dtype=torch.float32)
#print("hybrid_csr:")
#print(hybrid_csr)
dense_hybrid_csr = hybrid_csr.to_dense()
#print(dense_hybrid_csr)
hybrid_csc = torch.sparse_compressed_tensor(
csr.crow_indices(),import torch
csr.col_indices(),
hybrid_values,
batch_size + (ncols, nrows) + hybrid_size,
layout=torch.sparse_csc,
dtype=torch.float32)
#print("hybrid_csc:")
#print(hybrid_csc)
dense_hybrid_csc = hybrid_csc.to_dense()
#print(dense_hybrid_csc)
assert(torch.all(torch.transpose(dense_hybrid_csr, batch_dims, batch_dims + 1) == dense_hybrid_csc)) |
[ghstack-poisoned]
ghstack-source-id: b40f9c43ef1b231507ad3c0122231211182fa0d3 Pull Request resolved: #90177
compressed_indices.unsqueeze_(0); | ||
plain_indices.unsqueeze_(0); | ||
values = values.unsqueeze_(0); | ||
dense = dense.unsqueeze_(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: mixed-style in-place unsqueeze. It is probably better to do values.unsqueeze_(...); dense.unsqueeze_(...)
for more clarity.
for (auto batch : c10::irange(n_batch)) { | ||
Tensor batch_indices = at::_convert_indices_from_csr_to_coo( | ||
compressed_indices[batch], | ||
plain_indices[batch], | ||
false, | ||
self.layout() == kSparseCsc); | ||
auto batch_row_indices = batch_indices.select(0, 0); | ||
auto batch_col_indices = batch_indices.select(0, 1); | ||
|
||
auto offsets = batch_col_indices + batch_row_indices * self.size(batch_ndim + 1); | ||
dense[batch].index_add_(0, offsets, values[batch]); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is possible to remove the for-loop and use compressed_indices_to_plain_indices
from https://github.com/pytorch/pytorch/pull/88078/files#r1013344803 reimplemented in CPP as a utility function so that a single call to _convert_indices_from_csr_to_coo
is sufficient. Similar pattern but reversed was used in #82122.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer - as mentioned in my comment, that would be good optimization to have.
ghstack-source-id: b40f9c43ef1b231507ad3c0122231211182fa0d3 Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
ghstack-source-id: a92221a72a0a9d8cfeb54062f9b256acdf028022 Pull Request resolved: #90177
Added support for hybrid BSR/BSC to dense conversions. The unification with CSR/CSC to dense conversions, as well as optimization on eliminating looping for batches handling, are next to come. |
ghstack-source-id: a92221a72a0a9d8cfeb54062f9b256acdf028022 Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
ghstack-source-id: 10794a6aa30bd21f1144a33f80fa0ac4ef4b319f Pull Request resolved: #90177
Unified CSR/CSC and BSR/BSC conversion to dense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start, I found one error, and I have a cleaner-code note. See inline comments.
} | ||
if (self.dim() > 3) { | ||
if (batch_ndim > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition should either be batch_ndim > 1
or there should be a notee that when there is already a single batch dim (batch_ndim = 1
) the flatten ops below result in flatten(0, 0)
and return the input tensors unmodified hence we are not doing anything incorrect or unnecessary.
dense = dense.flatten(1, 2); | ||
} | ||
else { | ||
blocksize = {values.size(batch_ndim + 2), values.size(batch_ndim + 3)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
values
was reshaped above so it contains exactly one batch dimension, but batch_ndim
tracks the number of batch dims in the original shape.
After the flatten
/unsqueeze
of is applied to values
the generalized shape would be (n_batch, nnz) + blocksize + dense_dims
where n_batch
, and nnz
are scalars, blocksize
is a length 2 tuple and dense_dims
is a tuple with as many members as there are dense dimensions in the tensor, possibly zero.
So at this point the blocksizes can always be computed as values.size(2)
and values.size(3)
for any input which enters this region.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed - thanks! Also updated, in the next commit below, according to other comments of yours.
dense = dense.reshape({n_batch, -1, values.size(-2), values.size(-1)}); | ||
|
||
int64_t nrows, ncols; | ||
std::array<int64_t, 2> blocksize = {1, 1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like blocksize is only used to compute the dense shape in the else
block of the conditional below, so it can be moved into that scope and need not exist at this level.
ghstack-source-id: 10794a6aa30bd21f1144a33f80fa0ac4ef4b319f Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
ghstack-source-id: 0a2917c48e28ac8f2f43ef183ae5d29d0f5dd5d6 Pull Request resolved: #90177
Removed looping over batches. Please do not review yet, until comments added, and tests eventually enabled. |
cc nikitaved pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
ghstack-source-id: 61a7db0b67df3e1552e9766ee302a45a701ee000 Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer [ghstack-poisoned]
ghstack-source-id: 33f48b6672f74513d974234e7f59b83f5fbf9da0 Pull Request resolved: #90177
ghstack-source-id: 3b141893eab378f8b743fa14d8561f2427d390ba Pull Request resolved: #90177
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @alexsamardzic, this looks good!
I have a suggestion to add support for
to_sparse(layout=torch.tensor_coo, dense_dim=...)
as a low handing fruit.
Also, I have a question about the non-sparse-non-dense dimensions in to_sparse_csr/csc/bsr/bsc
.
…ns." cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen [ghstack-poisoned]
ghstack-source-id: 819bec92c32e05b3e87a6876d39ade9c7192c759 Pull Request resolved: #90177
values=tensor([[[[1.]], | ||
|
||
[[1.]]], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this additional whitespace really how we print BSR Tensors?
@alexsamardzic - you can see the preview of the Python documentation as part of the Helpful Links. You'll just need to search for the function name that you changed to see how it renders on the web.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch printing BSR tensors are defined as how we print strided tensors. For example:
>>> torch.ones(4, 4).to_sparse_bsr((2, 2)).values()
tensor([[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
Personally, I also find the whitespace annoying but I guess this is supposed to help to view high-dimensional tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, I've copied output from the interpreter for these examples.
torch/_tensor_docs.py
Outdated
@@ -5467,15 +5517,43 @@ def callable(a, b) -> number | |||
>>> sparse._nnz() | |||
25 | |||
|
|||
>>> dense = torch.zeros(3, 3, 1, 1) | |||
>>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be = 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should, strange that output is correct i.e. as that = 1
was there. In any case, an update pushed.
cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen [ghstack-poisoned]
ghstack-source-id: ed051929490979d1a948c1ded30419f956f2b203 Pull Request resolved: #90177
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a couple of nits and questions. Thanks, @alexsamardzic!
ghstack-source-id: ed051929490979d1a948c1ded30419f956f2b203 Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen [ghstack-poisoned]
ghstack-source-id: bf398893331f43a4e4594a6b26f1aa075db1a762 Pull Request resolved: #90177
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a couple of more clean-up nits but otherwise looks good to me!
cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen [ghstack-poisoned]
ghstack-source-id: 20e9a46245463eb384b4b6bfaccc3d3f99d751d2 Pull Request resolved: #90177
ghstack-source-id: 20e9a46245463eb384b4b6bfaccc3d3f99d751d2 Pull Request resolved: #90177
cc nikitaved pearu cpuhrsch amjames bhosmer gujinghui PenghuiCheng XiaobingSuper jianyuh jgong5 mingfeima sanchitintel ashokei jingxu10 min-jean-cho yanbing-j Guobing-Chen Xia-Weiwen [ghstack-poisoned]
ghstack-source-id: ed8d8ba2b373c4126794d21c20cf7179fcc4313b Pull Request resolved: #90177
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks, @alexsamardzic!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
cc @nikitaved @pearu @cpuhrsch @amjames @bhosmer @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen