New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support sparse inputs for torch.block_diag #31942
Labels
function request
A request for a new function or the addition of new arguments/modes to an existing function.
module: sparse
Related to torch.sparse
module: tensor creation
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Comments
zou3519
added
feature
A request for a proper, new feature.
module: operators
module: sparse
Related to torch.sparse
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
labels
Jan 9, 2020
I'd like to work on this and make a PR. |
Adapting #31932 (comment), I'm currently using:
|
mruberry
added
function request
A request for a new function or the addition of new arguments/modes to an existing function.
module: tensor creation
and removed
feature
A request for a proper, new feature.
module: operators (deprecated)
labels
Oct 10, 2020
@taralloc May I suggest the following changes for increased performance without dependence on iterrools and numpy? def block_diagonal(*arrs):
bad_args = [k for k in range(len(arrs)) if not (isinstance(arrs[k], torch.Tensor) and arrs[k].ndim == 2)]
if bad_args:
raise ValueError("arguments in the following positions must be 2-dimension tensor: %s" % bad_args)
shapes = torch.tensor([a.shape for a in arrs])
i = []
v = []
r, c = 0, 0
for k, (rr, cc) in enumerate(shapes):
first_index = torch.arange(r, r + rr, device=arrs[0].device)
second_index = torch.arange(c, c + cc, device=arrs[0].device)
index = torch.stack((first_index.tile((cc,1)).transpose(0,1).flatten(), second_index.repeat(rr)), dim=0)
i += [index]
v += [arrs[k].flatten()]
r += rr
c += cc
out_shape = torch.sum(shapes, dim=0).tolist()
if arrs[0].device == "cpu":
out = torch.sparse.DoubleTensor(torch.cat(i, dim=1), torch.cat(v), out_shape)
else:
out = torch.cuda.sparse.DoubleTensor(torch.cat(i, dim=1).to(arrs[0].device), torch.cat(v), out_shape)
return out |
IvanYashchuk
changed the title
sparse torch.blkdiag method
Support sparse inputs for torch.block_diag
Jan 6, 2022
def batched_block_diagonal(sparse_coo):
"""
@param sparse_coo: [bs, h, w]
@return: sparse coo [bs*h, bs*w]
"""
nnz = sparse_coo._nnz()
shape = sparse_coo.size()
indices = sparse_coo.indices()
# [b, h, w] -> [b*h, b*w]
new_shape = (shape[0] * shape[1], shape[0] * shape[1])
new_indices = torch.empty(2, nnz, device=indices.device, dtype=indices.dtype)
# indices: [b,h,w] -> [h+b*H, w+b*W]
new_indices[0, :] = indices[1, :] + indices[0, :] * shape[1]
new_indices[1, :] = indices[2, :] + indices[0, :] * shape[2]
val = torch.ones(nnz, device=indices.device, dtype=torch.bool)
return torch.sparse_coo_tensor(indices=new_indices, values=val, size=new_shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
function request
A request for a new function or the addition of new arguments/modes to an existing function.
module: sparse
Related to torch.sparse
module: tensor creation
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃殌 Feature
The blkdiag method is defined clearly in #31932
#31932 suggests blkdiag should create a dense Tensor, which may also be helpful in some case.
However, considering graph neural networks, we always want a sparse block tensor rather than a dense one, since a dense block tensor will be even slower than multiply submatrix one by one and will easily cause OOM.
A use case can be found in https://stackoverflow.com/a/59641321/7699035
It's consistent with the most popular pytorch_geometric module, where node features x1, x2, x3, ..., xn of different graphs are concatenated to a large tensor and a batch index is given. I've also asked the author of pytorch_geometric on this problem here rusty1s/pytorch_scatter#95 .
Pitch
This issue is for something like torch.spase.blkdiag rather than torch.blkdiag.
Alternatives
The operation in https://stackoverflow.com/a/59641321/7699035 is clearly parallelizable, I want an efficient solution in pytorch, however a torch.sparse.blkdiag method seems the best solution.
cc @vincentqb @aocsa @nikitaved @pearu @mruberry
The text was updated successfully, but these errors were encountered: