Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Scatter sort #48

Closed
soumyasanyal opened this issue Jun 4, 2019 · 10 comments
Closed

[Feature request] Scatter sort #48

soumyasanyal opened this issue Jun 4, 2019 · 10 comments
Labels

Comments

@soumyasanyal
Copy link

Hi!
Similar to scatter_max, can we have a scatter_sort that sorts tensors based on index? Here, index represents the graph to which a node belongs and src represents all the nodes in a batch consisting of multiple graphs.

@rusty1s
Copy link
Owner

rusty1s commented Jun 4, 2019

Hi, this is not really a scatter op, but rather a sparse_sort, because the number of elements stays the same. Hence, it does not really fit into the scope of this package. For implementing topk pooling and sort pooling in PyG, we implemented a sparse_sort op by creating dense matrices first, sorting them, and masking invalid entries out. This is not ideal, but fast enough for most use-cases, because the sparse_sort op is rather difficult to implement.

@ghost
Copy link

ghost commented Sep 13, 2019

Hi @rusty1s ,
If I understand correctly:

  • there is no easy way to extend scatter_max to scatter_topk ?
  • a workaround is to use the topk function you implemented in torch_geometric.nn.pool.topk_pool ?

@rusty1s
Copy link
Owner

rusty1s commented Sep 13, 2019

Yes and yes. A good alternative is the usage of the keops library, which comes with a memory efficient argKmin implementation.

@flandolfi
Copy link

flandolfi commented Mar 17, 2020

Hi @rusty1s,

I was looking for something similar to what you refer as sparse_sort. Do you think this might be a valid approach?

import torch

def sparse_sort(src: torch.Tensor, index: torch.Tensor, dim=0, descending=False, eps=1e-12):
    f_src = src.float()
    f_min, f_max = f_src.min(dim)[0], f_src.max(dim)[0]
    norm = (f_src - f_min)/(f_max - f_min + eps) + index.float()*(-1)**int(descending)
    perm = norm.argsort(dim=dim, descending=descending)

    return src[perm], perm

In practice:

  1. we normalize src in the range [0, 1);
  2. we sum (or subtract, if we sort in descending order) the index value. In this way, the values of src for each index i will fall in the range [i, i + 1) (or, alternatively, [-i, -i + 1));
  3. we argsort the results and obtain the final permutation.

The only drawback could be the numerical stability with wide ranges, but I think this could work in most cases.

I believe we could also extend it to a sparse_topk, creating a mask with something like

idx = index[perm]
mask = torch.ones_like(idx, dtype=torch.uint8)
mask[k:] = idx[k:] != idx[:-k]

Thank you so much for your library, anyway. I hope you and your relatives are all ok and safe.

Francesco

Edit: Here I paste the results of a toy example:

In[5]: src = torch.randperm(100)
  ...: index = torch.arange(100) // 10
  ...: src, index
Out[5]: 
(tensor([76, 33, 30,  1, 19, 39, 97, 36, 77,  9, 92, 45, 37, 18, 25, 75, 61, 86,
         71,  7, 20, 38, 47, 54, 93, 17,  4, 29, 96, 69, 10, 65, 74, 57,  8, 89,
         56, 21, 34, 44, 68, 66, 31, 84, 73, 81, 78, 11, 24, 67, 87, 64, 32, 79,
         28, 91, 94, 62, 35, 51, 12, 23, 27, 70,  5, 52,  6, 95,  0, 63, 53, 85,
         14, 15, 49, 13, 41, 16, 43, 99, 98, 48, 80, 40, 26, 50, 58, 82, 83, 72,
         22,  2, 60, 88, 55, 42, 59, 90,  3, 46]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
         4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9,
         9, 9, 9, 9]))
In[6]: sort, perm = sparse_sort(src, index)
  ...: sort, perm
Out[6]: 
(tensor([ 1,  9, 19, 30, 33, 36, 39, 76, 77, 97,  7, 18, 25, 37, 45, 61, 71, 75,
         86, 92,  4, 17, 20, 29, 38, 47, 54, 69, 93, 96,  8, 10, 21, 34, 44, 56,
         57, 65, 74, 89, 11, 24, 31, 66, 67, 68, 73, 78, 81, 84, 28, 32, 35, 51,
         62, 64, 79, 87, 91, 94,  0,  5,  6, 12, 23, 27, 52, 63, 70, 95, 13, 14,
         15, 16, 41, 43, 49, 53, 85, 99, 26, 40, 48, 50, 58, 72, 80, 82, 83, 98,
          2,  3, 22, 42, 46, 55, 59, 60, 88, 90]),
 tensor([ 3,  9,  4,  2,  1,  7,  5,  0,  8,  6, 19, 13, 14, 12, 11, 16, 18, 15,
         17, 10, 26, 25, 20, 27, 21, 22, 23, 29, 24, 28, 34, 30, 37, 38, 39, 36,
         33, 31, 32, 35, 47, 48, 42, 41, 49, 40, 44, 46, 45, 43, 54, 52, 58, 59,
         57, 51, 53, 50, 55, 56, 68, 64, 66, 60, 61, 62, 65, 69, 63, 67, 75, 72,
         73, 77, 76, 78, 74, 70, 71, 79, 84, 83, 81, 85, 86, 89, 82, 87, 88, 80,
         91, 98, 90, 95, 99, 94, 96, 92, 93, 97]))
In[7]: index.equal(perm // 10)
Out[7]: True

@rusty1s
Copy link
Owner

rusty1s commented Mar 18, 2020

That looks like a totally valid solution!

@flandolfi
Copy link

A lexical sort (or two stable sorts) would be better, anyway! This is just an alternative until PyTorch will provide one of the two

@github-actions
Copy link

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity.

@simonrouse9461
Copy link

@flandolfi Thanks for the brilliant solution! Here's an improved one that works with multidimensional tensors:

def sparse_sort(src: torch.Tensor, index: torch.Tensor, dim=0, descending=False, eps=1e-12):
    f_src = src.float()
    f_min, f_max = f_src.min(dim=dim, keepdim=True).values, f_src.max(dim=dim, keepdim=True).values
    norm = (f_src - f_min) / (f_max - f_min + eps) + index.float() * (-1) ** int(descending)
    perm = list(torch.meshgrid(*[torch.arange(i).to(src.device) for i in src.shape], indexing='ij'))
    perm[dim] = norm.argsort(dim=dim, descending=descending)

    return src[perm], perm[dim]

@flandolfi
Copy link

flandolfi commented Jan 24, 2023

Hi @simonrouse9461! Thanks for the update :)

Honestly I think that this approach can be abandoned since newer versions of pytorch allow stable sorts, and hence it should be easier to compute sparse sorts.

Here you can find an example where sparse sort is obtained by a sequence of two stable sorts:

https://github.com/pyg-team/pytorch_geometric/blob/ecf40202a4f5aaeb264b7183cc5038fdf14a45ad/torch_geometric/nn/aggr/quantile.py#L88-L93

@simonrouse9461
Copy link

@flandolfi Looks like a much simpler and more elegant solution. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants