/
algorithm.hpp
138 lines (120 loc) · 4.98 KB
/
algorithm.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#ifndef ENTT_CORE_ALGORITHM_HPP
#define ENTT_CORE_ALGORITHM_HPP
#include <algorithm>
#include <functional>
#include <iterator>
#include <utility>
#include <vector>
#include "utility.hpp"
namespace entt {
/**
* @brief Function object to wrap `std::sort` in a class type.
*
* Unfortunately, `std::sort` cannot be passed as template argument to a class
* template or a function template.<br/>
* This class fills the gap by wrapping some flavors of `std::sort` in a
* function object.
*/
struct std_sort {
/**
* @brief Sorts the elements in a range.
*
* Sorts the elements in a range using the given binary comparison function.
*
* @tparam It Type of random access iterator.
* @tparam Compare Type of comparison function object.
* @tparam Args Types of arguments to forward to the sort function.
* @param first An iterator to the first element of the range to sort.
* @param last An iterator past the last element of the range to sort.
* @param compare A valid comparison function object.
* @param args Arguments to forward to the sort function, if any.
*/
template<typename It, typename Compare = std::less<>, typename... Args>
void operator()(It first, It last, Compare compare = Compare{}, Args &&...args) const {
std::sort(std::forward<Args>(args)..., std::move(first), std::move(last), std::move(compare));
}
};
/*! @brief Function object for performing insertion sort. */
struct insertion_sort {
/**
* @brief Sorts the elements in a range.
*
* Sorts the elements in a range using the given binary comparison function.
*
* @tparam It Type of random access iterator.
* @tparam Compare Type of comparison function object.
* @param first An iterator to the first element of the range to sort.
* @param last An iterator past the last element of the range to sort.
* @param compare A valid comparison function object.
*/
template<typename It, typename Compare = std::less<>>
void operator()(It first, It last, Compare compare = Compare{}) const {
if(first < last) {
for(auto it = first + 1; it < last; ++it) {
auto value = std::move(*it);
auto pre = it;
for(; pre > first && compare(value, *(pre - 1)); --pre) {
*pre = std::move(*(pre - 1));
}
*pre = std::move(value);
}
}
}
};
/**
* @brief Function object for performing LSD radix sort.
* @tparam Bit Number of bits processed per pass.
* @tparam N Maximum number of bits to sort.
*/
template<std::size_t Bit, std::size_t N>
struct radix_sort {
static_assert((N % Bit) == 0, "The maximum number of bits to sort must be a multiple of the number of bits processed per pass");
/**
* @brief Sorts the elements in a range.
*
* Sorts the elements in a range using the given _getter_ to access the
* actual data to be sorted.
*
* This implementation is inspired by the online book
* [Physically Based Rendering](http://www.pbr-book.org/3ed-2018/Primitives_and_Intersection_Acceleration/Bounding_Volume_Hierarchies.html#RadixSort).
*
* @tparam It Type of random access iterator.
* @tparam Getter Type of _getter_ function object.
* @param first An iterator to the first element of the range to sort.
* @param last An iterator past the last element of the range to sort.
* @param getter A valid _getter_ function object.
*/
template<typename It, typename Getter = identity>
void operator()(It first, It last, Getter getter = Getter{}) const {
if(first < last) {
constexpr auto passes = N / Bit;
using value_type = typename std::iterator_traits<It>::value_type;
std::vector<value_type> aux(std::distance(first, last));
auto part = [getter = std::move(getter)](auto from, auto to, auto out, auto start) {
constexpr auto mask = (1 << Bit) - 1;
constexpr auto buckets = 1 << Bit;
std::size_t index[buckets]{};
std::size_t count[buckets]{};
for(auto it = from; it != to; ++it) {
++count[(getter(*it) >> start) & mask];
}
for(std::size_t pos{}, end = buckets - 1u; pos < end; ++pos) {
index[pos + 1u] = index[pos] + count[pos];
}
for(auto it = from; it != to; ++it) {
out[index[(getter(*it) >> start) & mask]++] = std::move(*it);
}
};
for(std::size_t pass = 0; pass < (passes & ~1); pass += 2) {
part(first, last, aux.begin(), pass * Bit);
part(aux.begin(), aux.end(), first, (pass + 1) * Bit);
}
if constexpr(passes & 1) {
part(first, last, aux.begin(), (passes - 1) * Bit);
std::move(aux.begin(), aux.end(), first);
}
}
}
};
} // namespace entt
#endif