Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sparse inputs for torch.block_diag #31942

Open
ThyrixYang opened this issue Jan 8, 2020 · 4 comments
Open

Support sparse inputs for torch.block_diag #31942

ThyrixYang opened this issue Jan 8, 2020 · 4 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: sparse Related to torch.sparse module: tensor creation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ThyrixYang
Copy link

ThyrixYang commented Jan 8, 2020

馃殌 Feature

The blkdiag method is defined clearly in #31932

#31932 suggests blkdiag should create a dense Tensor, which may also be helpful in some case.

However, considering graph neural networks, we always want a sparse block tensor rather than a dense one, since a dense block tensor will be even slower than multiply submatrix one by one and will easily cause OOM.

A use case can be found in https://stackoverflow.com/a/59641321/7699035

It's consistent with the most popular pytorch_geometric module, where node features x1, x2, x3, ..., xn of different graphs are concatenated to a large tensor and a batch index is given. I've also asked the author of pytorch_geometric on this problem here rusty1s/pytorch_scatter#95 .

Pitch

This issue is for something like torch.spase.blkdiag rather than torch.blkdiag.

Alternatives

The operation in https://stackoverflow.com/a/59641321/7699035 is clearly parallelizable, I want an efficient solution in pytorch, however a torch.sparse.blkdiag method seems the best solution.

cc @vincentqb @aocsa @nikitaved @pearu @mruberry

@zou3519 zou3519 added feature A request for a proper, new feature. module: operators module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 9, 2020
@ThyrixYang
Copy link
Author

I'd like to work on this and make a PR.

@taralloc
Copy link

Adapting #31932 (comment), I'm currently using:

def block_diag_sparse(*arrs):
        bad_args = [k for k in range(len(arrs)) if not (isinstance(arrs[k], torch.Tensor) and arrs[k].ndim == 2)]
        if bad_args:
            raise ValueError("arguments in the following positions must be 2-dimension tensor: %s" % bad_args)

        shapes = torch.tensor([a.shape for a in arrs])

        i = []
        v = []
        r, c = 0, 0
        for k, (rr, cc) in enumerate(shapes):
            i += [torch.LongTensor(list(itertools.product(np.arange(c, c+cc), np.arange(r, r+rr)))).t()]
            v += [arrs[k].flatten()]
            r += rr
            c += cc
        if arrs[0].device == "cpu":
            out = torch.sparse.DoubleTensor(torch.cat(i, dim=1), torch.cat(v), torch.sum(shapes, dim=0).tolist())
        else:
            out = torch.cuda.sparse.DoubleTensor(torch.cat(i, dim=1).to(device), torch.cat(v), torch.sum(shapes, dim=0).tolist())
        return out

@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. module: tensor creation and removed feature A request for a proper, new feature. module: operators (deprecated) labels Oct 10, 2020
@Whadup
Copy link

Whadup commented Aug 3, 2021

@taralloc May I suggest the following changes for increased performance without dependence on iterrools and numpy?

def block_diagonal(*arrs):
        bad_args = [k for k in range(len(arrs)) if not (isinstance(arrs[k], torch.Tensor) and arrs[k].ndim == 2)]
        if bad_args:
            raise ValueError("arguments in the following positions must be 2-dimension tensor: %s" % bad_args)
        shapes = torch.tensor([a.shape for a in arrs])
        i = []
        v = []
        r, c = 0, 0
        for k, (rr, cc) in enumerate(shapes):
            first_index = torch.arange(r, r + rr, device=arrs[0].device)
            second_index = torch.arange(c, c + cc, device=arrs[0].device)
            index = torch.stack((first_index.tile((cc,1)).transpose(0,1).flatten(), second_index.repeat(rr)), dim=0)
            i += [index]
            v += [arrs[k].flatten()]
            r += rr
            c += cc
        out_shape = torch.sum(shapes, dim=0).tolist()

        if arrs[0].device == "cpu":
            out = torch.sparse.DoubleTensor(torch.cat(i, dim=1), torch.cat(v), out_shape)
        else:
            out = torch.cuda.sparse.DoubleTensor(torch.cat(i, dim=1).to(arrs[0].device), torch.cat(v), out_shape)
        return out

@pearu pearu added this to To do in Sparse tensors Aug 10, 2021
@IvanYashchuk IvanYashchuk changed the title sparse torch.blkdiag method Support sparse inputs for torch.block_diag Jan 6, 2022
@krshrimali krshrimali removed their assignment Feb 10, 2022
@shenshanf
Copy link

def batched_block_diagonal(sparse_coo):
    """

    @param sparse_coo: [bs, h, w]
    @return: sparse coo [bs*h, bs*w] 
    """
    nnz = sparse_coo._nnz()
    shape = sparse_coo.size()
    indices = sparse_coo.indices()
    # [b, h, w] -> [b*h, b*w]
    new_shape = (shape[0] * shape[1], shape[0] * shape[1])

    new_indices = torch.empty(2, nnz, device=indices.device, dtype=indices.dtype)
    # indices: [b,h,w] -> [h+b*H, w+b*W]
    new_indices[0, :] = indices[1, :] + indices[0, :] * shape[1]
    new_indices[1, :] = indices[2, :] + indices[0, :] * shape[2]

    val = torch.ones(nnz, device=indices.device, dtype=torch.bool)
    return torch.sparse_coo_tensor(indices=new_indices, values=val, size=new_shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: sparse Related to torch.sparse module: tensor creation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Development

No branches or pull requests

7 participants