Skip to content

Commit 2c752af

Browse files
committed
Optimized implementation for uN::{gather,scatter}_bits
1 parent 29e035e commit 2c752af

File tree

3 files changed

+178
-40
lines changed

3 files changed

+178
-40
lines changed

library/core/src/num/int_bits.rs

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
//! Implementations for `uN::gather_bits` and `uN::scatter_bits`
2+
//!
3+
//! For the purposes of this implementation, the operations can be thought
4+
//! of as operating on the input bits as a list, starting from the least
5+
//! significant bit. Gathering is like `Vec::retain` that deletes bits
6+
//! where the mask has a zero. Scattering is like doing the inverse by
7+
//! inserting the zeros that gathering would delete.
8+
//!
9+
//! Key observation: Each bit that is gathered/scattered needs to be
10+
//! shifted by the count of zeroes up to the corresponding mask bit.
11+
//!
12+
//! With that in mind, the general idea is to decompose the operation into
13+
//! a sequence of stages `k in 0..log2(BITS)`, where each stage shifts
14+
//! some of the bits by `n = 1 << k`. The masks for each stage are computed
15+
//! via prefix counts of zeroes in the mask.
16+
//!
17+
//! # Gathering
18+
//!
19+
//! Consider the input as a sequence of runs of data (bitstrings A,B,C,...),
20+
//! split by fixed-width groups of zeros ('.'), initially at width `n = 1`.
21+
//! Counting the groups of zeros, each stage shifts the odd-indexed runs of
22+
//! data right by `n`, effectively swapping them with the preceding zeros.
23+
//! For the next stage, `n` is doubled as all the zeros are now paired.
24+
//! ```
25+
//! .A.B.C.D.E.F.G.H
26+
//! ..AB..CD..EF..GH
27+
//! ....ABCD....EFGH
28+
//! ........ABCDEFGH
29+
//! ```
30+
//! What makes this nontrivial is that the lengths of the bitstrings are not
31+
//! the same and, using lowercase for individual bits, the above might look
32+
//! more like
33+
//! ```
34+
//! .a.bbb.ccccc.dd.e..g.hh
35+
//! ..abbb..cccccdd..e..ghh
36+
//! ....abbbcccccdd....eghh
37+
//! ........abbbcccccddeghh
38+
//! ```
39+
//!
40+
//! # Scattering
41+
//!
42+
//! For `scatter_bits`, the stages are reversed. We start with a single run of
43+
//! data in the low bits. Each stage then splits each run of data in two by
44+
//! shifting part of it left by `n`, which is halved each stage.
45+
//! ```
46+
//! ........ABCDEFGH
47+
//! ....ABCD....EFGH
48+
//! ..AB..CD..EF..GH
49+
//! .A.B.C.D.E.F.G.H
50+
//! ```
51+
//!
52+
//! # Stage masks
53+
//!
54+
//! To facilitate the shifts at each stage, we compute a mask that covers both
55+
//! the bitstrings to shift, and the zeros they shift into.
56+
//! ```
57+
//! .A.B.C.D.E.F.G.H
58+
//! ## ## ## ##
59+
//! ..AB..CD..EF..GH
60+
//! #### ####
61+
//! ....ABCD....EFGH
62+
//! ########
63+
//! ........ABCDEFGH
64+
//! ```
65+
66+
macro_rules! uint_impl {
67+
($U:ident) => {
68+
pub(super) mod $U {
69+
const STAGES: usize = $U::BITS.ilog2() as usize;
70+
#[inline]
71+
const fn prepare(m: $U) -> [$U; STAGES] {
72+
// We'll start with `zeros` as a mask of the bits to be removed,
73+
// and compute into `masks` the parts that shift at each stage.
74+
let mut zeros = !m;
75+
let mut masks = [0; STAGES];
76+
let mut n = 1;
77+
let mut k = 0;
78+
while n < $U::BITS {
79+
// Suppose `zeros` has bits set at ranges `{ a..a+n, b..b+n, ... }`.
80+
// Then `parity` will be computed as `{ a.. } XOR { b.. } XOR ...`,
81+
// which will be the ranges `{ a..b, c..d, e.. }`
82+
let mut parity = zeros;
83+
let mut j = n;
84+
while j < $U::BITS {
85+
parity ^= parity << j;
86+
j <<= 1;
87+
}
88+
masks[k] = parity;
89+
90+
// Toggle off the bits that are shifted into:
91+
// { a..a+n, b..b+n, ... } & !{ a..b, c..d, e.. }
92+
// == { b..b+n, d..d+n, ... }
93+
zeros &= !parity;
94+
// Expand the remaining ranges down to the bits that were
95+
// shifted from: { b-n..b+n, d-n..d+n, ... }
96+
zeros ^= zeros >> n;
97+
98+
n <<= 1;
99+
k += 1;
100+
}
101+
masks
102+
}
103+
104+
#[inline(always)]
105+
pub(in super::super) const fn gather_impl(mut x: $U, sparse: $U) -> $U {
106+
let masks = prepare(sparse);
107+
x &= sparse;
108+
let mut k = 0;
109+
while k < STAGES {
110+
let n = 1 << k;
111+
// Consider each two runs of data with their leading
112+
// groups of `n` 0-bits. Suppose that the run that is
113+
// shifted right has length `a`, and the other one has
114+
// length `b`. Assume that only zeros are shifted in.
115+
// ```
116+
// [0; n], [X; a], [0; n], [Y; b] // x
117+
// [0; n], [X; a], [0; n], [0; b] // q
118+
// [0; n], [0; a + n], [Y; b] // x ^= q
119+
// [0; n + n], [X; a], [0; b] // q >> n
120+
// [0; n], [0; n], [X; a], [Y; b] // x ^= q << n
121+
// ```
122+
// Only zeros are shifted out, satisfying the assumption
123+
// for the next group.
124+
125+
// In effect, the upper run of data is swapped with the
126+
// group of `n` zeros below it.
127+
let q = x & masks[k];
128+
x ^= q;
129+
x ^= q >> n;
130+
131+
k += 1;
132+
}
133+
x
134+
}
135+
#[inline(always)]
136+
pub(in super::super) const fn scatter_impl(mut x: $U, sparse: $U) -> $U {
137+
let masks = prepare(sparse);
138+
let mut k = STAGES;
139+
while k > 0 {
140+
k -= 1;
141+
let n = 1 << k;
142+
// Consider each run of data with the `2 * n` arbitrary bits
143+
// above it. Suppose that the run has length `a + b`, with
144+
// `a` being the length of the part that needs to be
145+
// shifted. Assume that only zeros are shifted in.
146+
// ```
147+
// [_; n], [_; n], [X; a], [Y; b] // x
148+
// [0; n], [_; n], [X; a], [0; b] // q
149+
// [_; n], [0; n + a], [Y; b] // x ^= q
150+
// [_; n], [X; a], [0; b + n] // q << n
151+
// [_; n], [X; a], [0; n], [Y; b] // x ^= q << n
152+
// ```
153+
// Only zeros are shifted out, satisfying the assumption
154+
// for the next group.
155+
156+
// In effect, `n` 0-bits are inserted somewhere in each run
157+
// of data to spread it, and the two groups of `n` bits
158+
// above are XOR'd together.
159+
let q = x & masks[k];
160+
x ^= q;
161+
x ^= q << n;
162+
}
163+
x & sparse
164+
}
165+
}
166+
};
167+
}
168+
169+
uint_impl!(u8);
170+
uint_impl!(u16);
171+
uint_impl!(u32);
172+
uint_impl!(u64);
173+
uint_impl!(u128);

library/core/src/num/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ mod int_macros; // import int_impl!
4444
mod uint_macros; // import uint_impl!
4545

4646
mod error;
47+
mod int_bits;
4748
mod int_log10;
4849
mod int_sqrt;
4950
pub(crate) mod libm;

library/core/src/num/uint_macros.rs

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -492,27 +492,8 @@ macro_rules! uint_impl {
492492
#[must_use = "this returns the result of the operation, \
493493
without modifying the original"]
494494
#[inline]
495-
pub const fn gather_bits(self, mut mask: Self) -> Self {
496-
let mut bit_position = 1;
497-
let mut result = 0;
498-
499-
// Iterate through the mask bits, unsetting the lowest bit after
500-
// each iteration. We fill the bits in the result starting from the
501-
// least significant bit.
502-
while mask != 0 {
503-
// Find the next lowest set bit in the mask
504-
let next_mask_bit = mask.isolate_lowest_one();
505-
506-
// Retrieve the masked bit and if present, set it in the result
507-
let src_bit = (self & next_mask_bit) != 0;
508-
result |= if src_bit { bit_position } else { 0 };
509-
510-
// Unset lowest set bit in the mask, prepare next position to set
511-
mask ^= next_mask_bit;
512-
bit_position <<= 1;
513-
}
514-
515-
result
495+
pub const fn gather_bits(self, mask: Self) -> Self {
496+
crate::num::int_bits::$ActualT::gather_impl(self as $ActualT, mask as $ActualT) as $SelfT
516497
}
517498

518499
/// Returns an integer with the least significant bits of `self`
@@ -528,25 +509,8 @@ macro_rules! uint_impl {
528509
#[must_use = "this returns the result of the operation, \
529510
without modifying the original"]
530511
#[inline]
531-
pub const fn scatter_bits(mut self, mut mask: Self) -> Self {
532-
let mut result = 0;
533-
534-
// Iterate through the mask bits, unsetting the lowest bit after
535-
// each iteration and right-shifting `self` by one to get the next
536-
// bit into the least significant bit position.
537-
while mask != 0 {
538-
// Find the next bit position to potentially set
539-
let next_mask_bit = mask.isolate_lowest_one();
540-
541-
// If bit is set, deposit it at the masked bit position
542-
result |= if (self & 1) != 0 { next_mask_bit } else { 0 };
543-
544-
// Unset lowest set bit in the mask, shift in next `self` bit
545-
mask ^= next_mask_bit;
546-
self >>= 1;
547-
}
548-
549-
result
512+
pub const fn scatter_bits(self, mask: Self) -> Self {
513+
crate::num::int_bits::$ActualT::scatter_impl(self as $ActualT, mask as $ActualT) as $SelfT
550514
}
551515

552516
/// Reverses the order of bits in the integer. The least significant bit becomes the most significant bit,

0 commit comments

Comments
 (0)