forked from chaitjo/learning-paradigms-for-tsp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlexsort.py
55 lines (44 loc) · 2.33 KB
/
lexsort.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import numpy as np
def torch_lexsort(keys, dim=-1):
if keys[0].is_cuda:
return _torch_lexsort_cuda(keys, dim)
else:
# Use numpy lex sort
return torch.from_numpy(np.lexsort([k.numpy() for k in keys], axis=dim))
def _torch_lexsort_cuda(keys, dim=-1):
"""
Function calculates a lexicographical sort order on GPU, similar to np.lexsort
Relies heavily on undocumented behavior of torch.sort, namely that when sorting more than
2048 entries in the sorting dim, it performs a sort using Thrust and it uses a stable sort
https://github.com/pytorch/pytorch/blob/695fd981924bd805704ecb5ccd67de17c56d7308/aten/src/THC/generic/THCTensorSort.cu#L330
"""
MIN_NUMEL_STABLE_SORT = 2049 # Minimum number of elements for stable sort
# Swap axis such that sort dim is last and reshape all other dims to a single (batch) dimension
reordered_keys = tuple(key.transpose(dim, -1).contiguous() for key in keys)
flat_keys = tuple(key.view(-1) for key in keys)
d = keys[0].size(dim) # Sort dimension size
numel = flat_keys[0].numel()
batch_size = numel // d
batch_key = torch.arange(batch_size, dtype=torch.int64, device=keys[0].device)[:, None].repeat(1, d).view(-1)
flat_keys = flat_keys + (batch_key,)
# We rely on undocumented behavior that the sort is stable provided that
if numel < MIN_NUMEL_STABLE_SORT:
n_rep = (MIN_NUMEL_STABLE_SORT + numel - 1) // numel # Ceil
rep_key = torch.arange(n_rep, dtype=torch.int64, device=keys[0].device)[:, None].repeat(1, numel).view(-1)
flat_keys = tuple(k.repeat(n_rep) for k in flat_keys) + (rep_key,)
idx = None # Identity sorting initially
for k in flat_keys:
if idx is None:
_, idx = k.sort(-1)
else:
# Order data according to idx and then apply
# found ordering to current idx (so permutation of permutation)
# such that we can order the next key according to the current sorting order
_, idx_ = k[idx].sort(-1)
idx = idx[idx_]
# In the end gather only numel and strip of extra sort key
if numel < MIN_NUMEL_STABLE_SORT:
idx = idx[:numel]
# Get only numel (if we have replicated), swap axis back and shape results
return idx[:numel].view(*reordered_keys[0].size()).transpose(dim, -1) % d