Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize is_ascii for str and [u8]. #74066

Merged
merged 6 commits into from
Jul 12, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/libcore/benches/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,31 @@ macro_rules! benches {
}
)+
}
}
};

// For some tests the vec allocation tends to dominate, so it can be avoided.
(@readonly $( fn $name: ident($arg: ident: &[u8]) $body: block )+) => {
benches!(@ro mod short_readonly SHORT $($name $arg $body)+);
benches!(@ro mod medium_readonly MEDIUM $($name $arg $body)+);
benches!(@ro mod long_readonly LONG $($name $arg $body)+);
};
(@ro mod $mod_name: ident $input: ident $($name: ident $arg: ident $body: block)+) => {
mod $mod_name {
use super::*;

$(
#[bench]
fn $name(bencher: &mut Bencher) {
bencher.bytes = $input.len() as u64;
let vec = $input.as_bytes().to_vec();
bencher.iter(|| {
let $arg = black_box(&vec[..]);
black_box($body)
})
}
)+
}
};
}

use test::black_box;
Expand Down Expand Up @@ -245,6 +269,17 @@ benches! {
is_ascii_control,
}

benches! {
@readonly
fn is_ascii_slice_libcore(bytes: &[u8]) {
bytes.is_ascii()
}

fn is_ascii_slice_iter_all(bytes: &[u8]) {
bytes.iter().all(|b| b.is_ascii())
}
}

macro_rules! repeat {
($s: expr) => {
concat!($s, $s, $s, $s, $s, $s, $s, $s, $s, $s)
Expand Down
102 changes: 101 additions & 1 deletion src/libcore/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2795,7 +2795,7 @@ impl [u8] {
#[stable(feature = "ascii_methods_on_intrinsics", since = "1.23.0")]
#[inline]
pub fn is_ascii(&self) -> bool {
self.iter().all(|b| b.is_ascii())
is_ascii(self)
}

/// Checks that two slices are an ASCII case-insensitive match.
Expand Down Expand Up @@ -2843,6 +2843,106 @@ impl [u8] {
}
}

/// Returns `true` if any byte in the word `v` is nonascii (>= 128). Snarfed
/// from `../str/mod.rs`, which does something similar for utf8 validation.
#[inline]
fn contains_nonascii(v: usize) -> bool {
const NONASCII_MASK: usize = 0x80808080_80808080u64 as usize;
(NONASCII_MASK & v) != 0
}

/// Optimized ASCII test that will use usize-at-a-time operations instead of
/// byte-at-a-time operations (when possible).
///
/// The algorithm we use here is pretty simple. If `s` is too short, we just
/// check each byte and be done with it. Otherwise:
///
/// - Read the first word with an unaligned load.
/// - Align the pointer, read subsequent words until end with aligned loads.
/// - If there's a tail, the last `usize` from `s` with an unaligned load.
///
/// If any of these loads produces something for which `contains_nonascii`
/// (above) returns true, then we know the answer is false.
#[inline]
fn is_ascii(s: &[u8]) -> bool {
const USIZE_SIZE: usize = mem::size_of::<usize>();

let len = s.len();
let align_offset = s.as_ptr().align_offset(USIZE_SIZE);

// If we wouldn't gain anything from the word-at-a-time implementation, fall
// back to a scalar loop.
//
// We also do this for architectures where `size_of::<usize>()` isn't
// sufficient alignment for `usize`, because it's a weird edge case.
if len < USIZE_SIZE || len < align_offset || USIZE_SIZE < mem::align_of::<usize>() {
return s.iter().all(|b| b.is_ascii());
}

// We always read the first word unaligned, which means `align_offset` is
// 0, we'd read the same value again for the aligned read.
let offset_to_aligned = if align_offset == 0 { USIZE_SIZE } else { align_offset };

let start = s.as_ptr();
// SAFETY: We verify `len < USIZE_SIZE` above.
let first_word = unsafe { (start as *const usize).read_unaligned() };

if contains_nonascii(first_word) {
return false;
}
// We checked this above, somewhat implicitly. Note that `offset_to_aligned`
// is either `align_offset` or `USIZE_SIZE`, both of are explicitly checked
// above.
debug_assert!(offset_to_aligned <= len);

// word_ptr is the (properly aligned) usize ptr we use to read the middle chunk of the slice.
let mut word_ptr = unsafe { start.add(offset_to_aligned) as *const usize };

// `byte_pos` is the byte index of `word_ptr`, used for loop end checks.
let mut byte_pos = offset_to_aligned;

// Paranoia check about alignment, since we're about to do a bunch of
// unaligned loads. In practice this should be impossible barring a bug in
// `align_offset` though.
debug_assert_eq!((word_ptr as usize) % mem::align_of::<usize>(), 0);

while byte_pos <= len - USIZE_SIZE {
debug_assert!(
// Sanity check that the read is in bounds
(word_ptr as usize + USIZE_SIZE) <= (start.wrapping_add(len) as usize) &&
// And that our assumptions about `byte_pos` hold.
(word_ptr as usize) - (start as usize) == byte_pos
);

// Safety: We know `word_ptr` is properly aligned (because of
// `align_offset`), and we know that we have enough bytes between `word_ptr` and the end
let word = unsafe { word_ptr.read() };
if contains_nonascii(word) {
return false;
}

byte_pos += USIZE_SIZE;
// SAFETY: We know that `byte_pos <= len - USIZE_SIZE`, which means that
// after this `add`, `word_ptr` will be at most one-past-the-end.
word_ptr = unsafe { word_ptr.add(1) };
}

// If we have anything left over, it should be at-most 1 usize worth of bytes,
// which we check with a read_unaligned.
if byte_pos == len {
return true;
}
Comment on lines +2932 to +2934
Copy link
Contributor

@pickfire pickfire Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? What happens if we always check last_word? As in making it branchless.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an optimization to avoid an extra (redundant) load. It's not necessary for correctness.

Copy link
Contributor

@pickfire pickfire Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an optimization to avoid an extra (redundant) load. It's not necessary for correctness.

Avoid extra load? But what happens if you remove it? Along with the debug assert.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we perform an extra load of the last word, which we've already checked. What's your point?

Copy link
Contributor

@pickfire pickfire Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's what I mean, perform the extra load of the last word to remove the branch. Or we could rework the logic a bit and remove this branch, I wonder would it be faster.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that doesn't seem worth it to me at all. If you can show it's faster in benchmarks (which there are many), I suggest you submit a PR.

Keep in mind on some platforms a read_unaligned is comparatively expensive, so it's worth avoiding if possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a minimal testing crate to test this easily without having to build rust?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, sorry.


// Sanity check to ensure there really is only one `usize` left. This should
// be guaranteed by our loop condition.
debug_assert!(byte_pos < len && len - byte_pos < USIZE_SIZE);

// SAFETY: This relies on `len >= USIZE_SIZE`, which we check at the start.
let last_word = unsafe { (start.add(len - USIZE_SIZE) as *const usize).read_unaligned() };

!contains_nonascii(last_word)
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T, I> ops::Index<I> for [T]
where
Expand Down
2 changes: 1 addition & 1 deletion src/libcore/str/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4348,7 +4348,7 @@ impl str {
// We can treat each byte as character here: all multibyte characters
// start with a byte that is not in the ascii range, so we will stop
// there already.
self.bytes().all(|b| b.is_ascii())
self.as_bytes().is_ascii()
}

/// Checks that two strings are an ASCII case-insensitive match.
Expand Down
56 changes: 56 additions & 0 deletions src/libcore/tests/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,59 @@ fn test_is_ascii_control() {
" ",
);
}

// `is_ascii` does a good amount of pointer manipulation and has
// alignment-dependent computation. This is all sanity-checked via
// `debug_assert!`s, so we test various sizes/alignments thoroughly versus an
// "obviously correct" baseline function.
#[test]
fn test_is_ascii_align_size_thoroughly() {
// The "obviously-correct" baseline mentioned above.
fn is_ascii_baseline(s: &[u8]) -> bool {
s.iter().all(|b| b.is_ascii())
}

// Helper to repeat `l` copies of `b0` followed by `l` copies of `b1`.
fn repeat_concat(b0: u8, b1: u8, l: usize) -> Vec<u8> {
use core::iter::repeat;
repeat(b0).take(l).chain(repeat(b1).take(l)).collect()
}

// Miri is too slow for much of this, and in miri `align_offset` always
// returns `usize::max_value()` anyway (at the moment), so we just test
// lightly.
let iter = if cfg!(miri) { 0..5 } else { 0..100 };

for i in iter {
#[cfg(not(miri))]
let cases = &[
b"a".repeat(i),
b"\0".repeat(i),
b"\x7f".repeat(i),
b"\x80".repeat(i),
b"\xff".repeat(i),
repeat_concat(b'a', 0x80u8, i),
repeat_concat(0x80u8, b'a', i),
];

#[cfg(miri)]
let cases = &[repeat_concat(b'a', 0x80u8, i)];

for case in cases {
for pos in 0..=case.len() {
// Potentially misaligned head
let prefix = &case[pos..];
assert_eq!(is_ascii_baseline(prefix), prefix.is_ascii(),);

// Potentially misaligned tail
let suffix = &case[..case.len() - pos];

assert_eq!(is_ascii_baseline(suffix), suffix.is_ascii(),);

// Both head and tail are potentially misaligned
let mid = &case[(pos / 2)..(case.len() - (pos / 2))];
assert_eq!(is_ascii_baseline(mid), mid.is_ascii(),);
}
}
}
}