-
Notifications
You must be signed in to change notification settings - Fork 178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request] Scatter sort #48
Comments
Hi, this is not really a |
Hi @rusty1s ,
|
Yes and yes. A good alternative is the usage of the keops library, which comes with a memory efficient |
Hi @rusty1s, I was looking for something similar to what you refer as import torch
def sparse_sort(src: torch.Tensor, index: torch.Tensor, dim=0, descending=False, eps=1e-12):
f_src = src.float()
f_min, f_max = f_src.min(dim)[0], f_src.max(dim)[0]
norm = (f_src - f_min)/(f_max - f_min + eps) + index.float()*(-1)**int(descending)
perm = norm.argsort(dim=dim, descending=descending)
return src[perm], perm In practice:
The only drawback could be the numerical stability with wide ranges, but I think this could work in most cases. I believe we could also extend it to a idx = index[perm]
mask = torch.ones_like(idx, dtype=torch.uint8)
mask[k:] = idx[k:] != idx[:-k] Thank you so much for your library, anyway. I hope you and your relatives are all ok and safe. Francesco Edit: Here I paste the results of a toy example: In[5]: src = torch.randperm(100)
...: index = torch.arange(100) // 10
...: src, index
Out[5]:
(tensor([76, 33, 30, 1, 19, 39, 97, 36, 77, 9, 92, 45, 37, 18, 25, 75, 61, 86,
71, 7, 20, 38, 47, 54, 93, 17, 4, 29, 96, 69, 10, 65, 74, 57, 8, 89,
56, 21, 34, 44, 68, 66, 31, 84, 73, 81, 78, 11, 24, 67, 87, 64, 32, 79,
28, 91, 94, 62, 35, 51, 12, 23, 27, 70, 5, 52, 6, 95, 0, 63, 53, 85,
14, 15, 49, 13, 41, 16, 43, 99, 98, 48, 80, 40, 26, 50, 58, 82, 83, 72,
22, 2, 60, 88, 55, 42, 59, 90, 3, 46]),
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9,
9, 9, 9, 9]))
In[6]: sort, perm = sparse_sort(src, index)
...: sort, perm
Out[6]:
(tensor([ 1, 9, 19, 30, 33, 36, 39, 76, 77, 97, 7, 18, 25, 37, 45, 61, 71, 75,
86, 92, 4, 17, 20, 29, 38, 47, 54, 69, 93, 96, 8, 10, 21, 34, 44, 56,
57, 65, 74, 89, 11, 24, 31, 66, 67, 68, 73, 78, 81, 84, 28, 32, 35, 51,
62, 64, 79, 87, 91, 94, 0, 5, 6, 12, 23, 27, 52, 63, 70, 95, 13, 14,
15, 16, 41, 43, 49, 53, 85, 99, 26, 40, 48, 50, 58, 72, 80, 82, 83, 98,
2, 3, 22, 42, 46, 55, 59, 60, 88, 90]),
tensor([ 3, 9, 4, 2, 1, 7, 5, 0, 8, 6, 19, 13, 14, 12, 11, 16, 18, 15,
17, 10, 26, 25, 20, 27, 21, 22, 23, 29, 24, 28, 34, 30, 37, 38, 39, 36,
33, 31, 32, 35, 47, 48, 42, 41, 49, 40, 44, 46, 45, 43, 54, 52, 58, 59,
57, 51, 53, 50, 55, 56, 68, 64, 66, 60, 61, 62, 65, 69, 63, 67, 75, 72,
73, 77, 76, 78, 74, 70, 71, 79, 84, 83, 81, 85, 86, 89, 82, 87, 88, 80,
91, 98, 90, 95, 99, 94, 96, 92, 93, 97]))
In[7]: index.equal(perm // 10)
Out[7]: True |
That looks like a totally valid solution! |
A lexical sort (or two stable sorts) would be better, anyway! This is just an alternative until PyTorch will provide one of the two |
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. |
@flandolfi Thanks for the brilliant solution! Here's an improved one that works with multidimensional tensors: def sparse_sort(src: torch.Tensor, index: torch.Tensor, dim=0, descending=False, eps=1e-12):
f_src = src.float()
f_min, f_max = f_src.min(dim=dim, keepdim=True).values, f_src.max(dim=dim, keepdim=True).values
norm = (f_src - f_min) / (f_max - f_min + eps) + index.float() * (-1) ** int(descending)
perm = list(torch.meshgrid(*[torch.arange(i).to(src.device) for i in src.shape], indexing='ij'))
perm[dim] = norm.argsort(dim=dim, descending=descending)
return src[perm], perm[dim] |
Hi @simonrouse9461! Thanks for the update :) Honestly I think that this approach can be abandoned since newer versions of pytorch allow stable sorts, and hence it should be easier to compute sparse sorts. Here you can find an example where sparse sort is obtained by a sequence of two stable sorts: |
@flandolfi Looks like a much simpler and more elegant solution. Thanks! |
Hi!
Similar to scatter_max, can we have a scatter_sort that sorts tensors based on
index
? Here,index
represents the graph to which a node belongs andsrc
represents all the nodes in a batch consisting of multiple graphs.The text was updated successfully, but these errors were encountered: