Skip to content

Commit

Permalink
Auto merge of #103779 - the8472:simd-str-contains, r=thomcc
Browse files Browse the repository at this point in the history
x86_64 SSE2 fast-path for str.contains(&str) and short needles

Based on Wojciech Muła's [SIMD-friendly algorithms for substring searching](http://0x80.pl/articles/simd-strfind.html#sse-avx2)

The two-way algorithm is Big-O efficient but it needs to preprocess the needle
to find a "critical factorization" of it. This additional work is significant
for short needles. Additionally it mostly advances needle.len() bytes at a time.

The SIMD-based approach used here on the other hand can advance based on its
vector width, which can exceed the needle length. Except for pathological cases,
but due to being limited to small needles the worst case blowup is also small.

benchmarks taken on a Zen2, compiled with `-Ccodegen-units=1`:

```
OLD:
test str::bench_contains_16b_in_long                     ... bench:         504 ns/iter (+/- 14) = 5061 MB/s
test str::bench_contains_2b_repeated_long                ... bench:         948 ns/iter (+/- 175) = 2690 MB/s
test str::bench_contains_32b_in_long                     ... bench:         445 ns/iter (+/- 6) = 5732 MB/s
test str::bench_contains_bad_naive                       ... bench:         130 ns/iter (+/- 1) = 569 MB/s
test str::bench_contains_bad_simd                        ... bench:          84 ns/iter (+/- 8) = 880 MB/s
test str::bench_contains_equal                           ... bench:         142 ns/iter (+/- 7) = 394 MB/s
test str::bench_contains_short_long                      ... bench:         677 ns/iter (+/- 25) = 3768 MB/s
test str::bench_contains_short_short                     ... bench:          27 ns/iter (+/- 2) = 2074 MB/s

NEW:
test str::bench_contains_16b_in_long                     ... bench:          82 ns/iter (+/- 0) = 31109 MB/s
test str::bench_contains_2b_repeated_long                ... bench:          73 ns/iter (+/- 0) = 34945 MB/s
test str::bench_contains_32b_in_long                     ... bench:          71 ns/iter (+/- 1) = 35929 MB/s
test str::bench_contains_bad_naive                       ... bench:           7 ns/iter (+/- 0) = 10571 MB/s
test str::bench_contains_bad_simd                        ... bench:          97 ns/iter (+/- 41) = 762 MB/s
test str::bench_contains_equal                           ... bench:           4 ns/iter (+/- 0) = 14000 MB/s
test str::bench_contains_short_long                      ... bench:          73 ns/iter (+/- 0) = 34945 MB/s
test str::bench_contains_short_short                     ... bench:          12 ns/iter (+/- 0) = 4666 MB/s
```
  • Loading branch information
bors committed Nov 17, 2022
2 parents 251831e + a2b2010 commit 9340e5c
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 12 deletions.
65 changes: 58 additions & 7 deletions library/alloc/benches/str.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::iter::Iterator;
use test::{black_box, Bencher};

#[bench]
Expand Down Expand Up @@ -122,14 +123,13 @@ fn bench_contains_short_short(b: &mut Bencher) {
let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
let needle = "sit";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(haystack.contains(needle));
assert!(black_box(haystack).contains(black_box(needle)));
})
}

#[bench]
fn bench_contains_short_long(b: &mut Bencher) {
let haystack = "\
static LONG_HAYSTACK: &str = "\
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Suspendisse quis lorem sit amet dolor \
ultricies condimentum. Praesent iaculis purus elit, ac malesuada quam malesuada in. Duis sed orci \
eros. Suspendisse sit amet magna mollis, mollis nunc luctus, imperdiet mi. Integer fringilla non \
Expand Down Expand Up @@ -164,10 +164,48 @@ feugiat. Etiam quis mauris vel risus luctus mattis a a nunc. Nullam orci quam, i
vehicula in, porttitor ut nibh. Duis sagittis adipiscing nisl vitae congue. Donec mollis risus eu \
leo suscipit, varius porttitor nulla porta. Pellentesque ut sem nec nisi euismod vehicula. Nulla \
malesuada sollicitudin quam eu fermentum.";

#[bench]
fn bench_contains_2b_repeated_long(b: &mut Bencher) {
let haystack = LONG_HAYSTACK;
let needle = "::";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

#[bench]
fn bench_contains_short_long(b: &mut Bencher) {
let haystack = LONG_HAYSTACK;
let needle = "english";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

#[bench]
fn bench_contains_16b_in_long(b: &mut Bencher) {
let haystack = LONG_HAYSTACK;
let needle = "english language";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

#[bench]
fn bench_contains_32b_in_long(b: &mut Bencher) {
let haystack = LONG_HAYSTACK;
let needle = "the english language sample text";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!haystack.contains(needle));
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

Expand All @@ -176,8 +214,20 @@ fn bench_contains_bad_naive(b: &mut Bencher) {
let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
let needle = "aaaaaaaab";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

#[bench]
fn bench_contains_bad_simd(b: &mut Bencher) {
let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
let needle = "aaabaaaa";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(!haystack.contains(needle));
assert!(!black_box(haystack).contains(black_box(needle)));
})
}

Expand All @@ -186,8 +236,9 @@ fn bench_contains_equal(b: &mut Bencher) {
let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
let needle = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";

b.bytes = haystack.len() as u64;
b.iter(|| {
assert!(haystack.contains(needle));
assert!(black_box(haystack).contains(black_box(needle)));
})
}

Expand Down
26 changes: 21 additions & 5 deletions library/alloc/tests/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1590,11 +1590,27 @@ fn test_bool_from_str() {
assert_eq!("not even a boolean".parse::<bool>().ok(), None);
}

fn check_contains_all_substrings(s: &str) {
assert!(s.contains(""));
for i in 0..s.len() {
for j in i + 1..=s.len() {
assert!(s.contains(&s[i..j]));
fn check_contains_all_substrings(haystack: &str) {
let mut modified_needle = String::new();

for i in 0..haystack.len() {
// check different haystack lengths since we special-case short haystacks.
let haystack = &haystack[0..i];
assert!(haystack.contains(""));
for j in 0..haystack.len() {
for k in j + 1..=haystack.len() {
let needle = &haystack[j..k];
assert!(haystack.contains(needle));
modified_needle.clear();
modified_needle.push_str(needle);
modified_needle.replace_range(0..1, "\0");
assert!(!haystack.contains(&modified_needle));

modified_needle.clear();
modified_needle.push_str(needle);
modified_needle.replace_range(needle.len() - 1..needle.len(), "\0");
assert!(!haystack.contains(&modified_needle));
}
}
}
}
Expand Down
232 changes: 232 additions & 0 deletions library/core/src/str/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)]

use crate::cmp;
use crate::cmp::Ordering;
use crate::fmt;
use crate::slice::memchr;

Expand Down Expand Up @@ -946,6 +947,32 @@ impl<'a, 'b> Pattern<'a> for &'b str {
haystack.as_bytes().starts_with(self.as_bytes())
}

/// Checks whether the pattern matches anywhere in the haystack
#[inline]
fn is_contained_in(self, haystack: &'a str) -> bool {
if self.len() == 0 {
return true;
}

match self.len().cmp(&haystack.len()) {
Ordering::Less => {
if self.len() == 1 {
return haystack.as_bytes().contains(&self.as_bytes()[0]);
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
if self.len() <= 32 {
if let Some(result) = simd_contains(self, haystack) {
return result;
}
}

self.into_searcher(haystack).next_match().is_some()
}
_ => self == haystack,
}
}

/// Removes the pattern from the front of haystack, if it matches.
#[inline]
fn strip_prefix_of(self, haystack: &'a str) -> Option<&'a str> {
Expand Down Expand Up @@ -1684,3 +1711,208 @@ impl TwoWayStrategy for RejectAndMatch {
SearchStep::Match(a, b)
}
}

/// SIMD search for short needles based on
/// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
///
/// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
/// does) by probing the first and last byte of the needle for the whole vector width
/// and only doing full needle comparisons when the vectorized probe indicated potential matches.
///
/// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
/// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
/// should be evaluated.
///
/// For haystacks smaller than vector-size + needle length it falls back to
/// a naive O(n*m) search so this implementation should not be called on larger needles.
///
/// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
#[inline]
fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
let needle = needle.as_bytes();
let haystack = haystack.as_bytes();

debug_assert!(needle.len() > 1);

use crate::ops::BitAnd;
use crate::simd::mask8x16 as Mask;
use crate::simd::u8x16 as Block;
use crate::simd::{SimdPartialEq, ToBitMask};

let first_probe = needle[0];

// the offset used for the 2nd vector
let second_probe_offset = if needle.len() == 2 {
// never bail out on len=2 needles because the probes will fully cover them and have
// no degenerate cases.
1
} else {
// try a few bytes in case first and last byte of the needle are the same
let Some(second_probe_offset) = (needle.len().saturating_sub(4)..needle.len()).rfind(|&idx| needle[idx] != first_probe) else {
// fall back to other search methods if we can't find any different bytes
// since we could otherwise hit some degenerate cases
return None;
};
second_probe_offset
};

// do a naive search if the haystack is too small to fit
if haystack.len() < Block::LANES + second_probe_offset {
return Some(haystack.windows(needle.len()).any(|c| c == needle));
}

let first_probe: Block = Block::splat(first_probe);
let second_probe: Block = Block::splat(needle[second_probe_offset]);
// first byte are already checked by the outer loop. to verify a match only the
// remainder has to be compared.
let trimmed_needle = &needle[1..];

// this #[cold] is load-bearing, benchmark before removing it...
let check_mask = #[cold]
|idx, mask: u16, skip: bool| -> bool {
if skip {
return false;
}

// and so is this. optimizations are weird.
let mut mask = mask;

while mask != 0 {
let trailing = mask.trailing_zeros();
let offset = idx + trailing as usize + 1;
// SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
// and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
unsafe {
let sub = haystack.get_unchecked(offset..).get_unchecked(..trimmed_needle.len());
if small_slice_eq(sub, trimmed_needle) {
return true;
}
}
mask &= !(1 << trailing);
}
return false;
};

let test_chunk = |idx| -> u16 {
// SAFETY: this requires at least LANES bytes being readable at idx
// that is ensured by the loop ranges (see comments below)
let a: Block = unsafe { haystack.as_ptr().add(idx).cast::<Block>().read_unaligned() };
// SAFETY: this requires LANES + block_offset bytes being readable at idx
let b: Block = unsafe {
haystack.as_ptr().add(idx).add(second_probe_offset).cast::<Block>().read_unaligned()
};
let eq_first: Mask = a.simd_eq(first_probe);
let eq_last: Mask = b.simd_eq(second_probe);
let both = eq_first.bitand(eq_last);
let mask = both.to_bitmask();

return mask;
};

let mut i = 0;
let mut result = false;
// The loop condition must ensure that there's enough headroom to read LANE bytes,
// and not only at the current index but also at the index shifted by block_offset
const UNROLL: usize = 4;
while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result {
let mut masks = [0u16; UNROLL];
for j in 0..UNROLL {
masks[j] = test_chunk(i + j * Block::LANES);
}
for j in 0..UNROLL {
let mask = masks[j];
if mask != 0 {
result |= check_mask(i + j * Block::LANES, mask, result);
}
}
i += UNROLL * Block::LANES;
}
while i + second_probe_offset + Block::LANES < haystack.len() && !result {
let mask = test_chunk(i);
if mask != 0 {
result |= check_mask(i, mask, result);
}
i += Block::LANES;
}

// Process the tail that didn't fit into LANES-sized steps.
// This simply repeats the same procedure but as right-aligned chunk instead
// of a left-aligned one. The last byte must be exactly flush with the string end so
// we don't miss a single byte or read out of bounds.
let i = haystack.len() - second_probe_offset - Block::LANES;
let mask = test_chunk(i);
if mask != 0 {
result |= check_mask(i, mask, result);
}

Some(result)
}

/// Compares short slices for equality.
///
/// It avoids a call to libc's memcmp which is faster on long slices
/// due to SIMD optimizations but it incurs a function call overhead.
///
/// # Safety
///
/// Both slices must have the same length.
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
#[inline]
unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
// This function is adapted from
// https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32

// If we don't have enough bytes to do 4-byte at a time loads, then
// fall back to the naive slow version.
//
// Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
// of a loop. Benchmark it.
if x.len() < 4 {
for (&b1, &b2) in x.iter().zip(y) {
if b1 != b2 {
return false;
}
}
return true;
}
// When we have 4 or more bytes to compare, then proceed in chunks of 4 at
// a time using unaligned loads.
//
// Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
// that this particular version of memcmp is likely to be called with tiny
// needles. That means that if we do 8 byte loads, then a higher proportion
// of memcmp calls will use the slower variant above. With that said, this
// is a hypothesis and is only loosely supported by benchmarks. There's
// likely some improvement that could be made here. The main thing here
// though is to optimize for latency, not throughput.

// SAFETY: Via the conditional above, we know that both `px` and `py`
// have the same length, so `px < pxend` implies that `py < pyend`.
// Thus, derefencing both `px` and `py` in the loop below is safe.
//
// Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
// end of of `px` and `py`. Thus, the final dereference outside of the
// loop is guaranteed to be valid. (The final comparison will overlap with
// the last comparison done in the loop for lengths that aren't multiples
// of four.)
//
// Finally, we needn't worry about alignment here, since we do unaligned
// loads.
unsafe {
let (mut px, mut py) = (x.as_ptr(), y.as_ptr());
let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4));
while px < pxend {
let vx = (px as *const u32).read_unaligned();
let vy = (py as *const u32).read_unaligned();
if vx != vy {
return false;
}
px = px.add(4);
py = py.add(4);
}
let vx = (pxend as *const u32).read_unaligned();
let vy = (pyend as *const u32).read_unaligned();
vx == vy
}
}

0 comments on commit 9340e5c

Please sign in to comment.