Skip to content

Commit

Permalink
Add add/sub methods that only panic with debug assertions to rustc
Browse files Browse the repository at this point in the history
This mitigates the perf impact of enabling overflow checks on rustc.
The change to use overflow checks will be done in a later PR.
  • Loading branch information
Nilstrieb committed Apr 13, 2024
1 parent c3b05c6 commit 5039160
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 27 deletions.
38 changes: 21 additions & 17 deletions compiler/rustc_data_structures/src/sip128.rs
@@ -1,5 +1,8 @@
//! This is a copy of `core::hash::sip` adapted to providing 128 bit hashes.

// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use rustc_serialize::int_overflow::{DebugStrictAdd, DebugStrictSub};
use std::hash::Hasher;
use std::mem::{self, MaybeUninit};
use std::ptr;
Expand Down Expand Up @@ -103,19 +106,19 @@ unsafe fn copy_nonoverlapping_small(src: *const u8, dst: *mut u8, count: usize)
}

let mut i = 0;
if i + 3 < count {
if i.debug_strict_add(3) < count {
ptr::copy_nonoverlapping(src.add(i), dst.add(i), 4);
i += 4;
i = i.debug_strict_add(4);
}

if i + 1 < count {
if i.debug_strict_add(1) < count {
ptr::copy_nonoverlapping(src.add(i), dst.add(i), 2);
i += 2
i = i.debug_strict_add(2)
}

if i < count {
*dst.add(i) = *src.add(i);
i += 1;
i = i.debug_strict_add(1);
}

debug_assert_eq!(i, count);
Expand Down Expand Up @@ -211,14 +214,14 @@ impl SipHasher128 {
debug_assert!(nbuf < BUFFER_SIZE);
debug_assert!(nbuf + LEN < BUFFER_WITH_SPILL_SIZE);

if nbuf + LEN < BUFFER_SIZE {
if nbuf.debug_strict_add(LEN) < BUFFER_SIZE {
unsafe {
// The memcpy call is optimized away because the size is known.
let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf);
ptr::copy_nonoverlapping(bytes.as_ptr(), dst, LEN);
}

self.nbuf = nbuf + LEN;
self.nbuf = nbuf.debug_strict_add(LEN);

return;
}
Expand Down Expand Up @@ -265,8 +268,9 @@ impl SipHasher128 {
// This function should only be called when the write fills the buffer.
// Therefore, when LEN == 1, the new `self.nbuf` must be zero.
// LEN is statically known, so the branch is optimized away.
self.nbuf = if LEN == 1 { 0 } else { nbuf + LEN - BUFFER_SIZE };
self.processed += BUFFER_SIZE;
self.nbuf =
if LEN == 1 { 0 } else { nbuf.debug_strict_add(LEN).debug_strict_sub(BUFFER_SIZE) };
self.processed = self.processed.debug_strict_add(BUFFER_SIZE);
}
}

Expand All @@ -277,7 +281,7 @@ impl SipHasher128 {
let nbuf = self.nbuf;
debug_assert!(nbuf < BUFFER_SIZE);

if nbuf + length < BUFFER_SIZE {
if nbuf.debug_strict_add(length) < BUFFER_SIZE {
unsafe {
let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf);

Expand All @@ -289,7 +293,7 @@ impl SipHasher128 {
}
}

self.nbuf = nbuf + length;
self.nbuf = nbuf.debug_strict_add(length);

return;
}
Expand All @@ -315,7 +319,7 @@ impl SipHasher128 {
// This function should only be called when the write fills the buffer,
// so we know that there is enough input to fill the current element.
let valid_in_elem = nbuf % ELEM_SIZE;
let needed_in_elem = ELEM_SIZE - valid_in_elem;
let needed_in_elem = ELEM_SIZE.debug_strict_sub(valid_in_elem);

let src = msg.as_ptr();
let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf);
Expand All @@ -327,7 +331,7 @@ impl SipHasher128 {
// ELEM_SIZE` to show the compiler that this loop's upper bound is > 0.
// We know that is true, because last step ensured we have a full
// element in the buffer.
let last = nbuf / ELEM_SIZE + 1;
let last = (nbuf / ELEM_SIZE).debug_strict_add(1);

for i in 0..last {
let elem = self.buf.get_unchecked(i).assume_init().to_le();
Expand All @@ -338,7 +342,7 @@ impl SipHasher128 {

// Process the remaining element-sized chunks of input.
let mut processed = needed_in_elem;
let input_left = length - processed;
let input_left = length.debug_strict_sub(processed);
let elems_left = input_left / ELEM_SIZE;
let extra_bytes_left = input_left % ELEM_SIZE;

Expand All @@ -347,7 +351,7 @@ impl SipHasher128 {
self.state.v3 ^= elem;
Sip13Rounds::c_rounds(&mut self.state);
self.state.v0 ^= elem;
processed += ELEM_SIZE;
processed = processed.debug_strict_add(ELEM_SIZE);
}

// Copy remaining input into start of buffer.
Expand All @@ -356,7 +360,7 @@ impl SipHasher128 {
copy_nonoverlapping_small(src, dst, extra_bytes_left);

self.nbuf = extra_bytes_left;
self.processed += nbuf + processed;
self.processed = self.processed.debug_strict_add(nbuf.debug_strict_add(processed));
}
}

Expand Down Expand Up @@ -394,7 +398,7 @@ impl SipHasher128 {
};

// Finalize the hash.
let length = self.processed + self.nbuf;
let length = self.processed.debug_strict_add(self.nbuf);
let b: u64 = ((length as u64 & 0xff) << 56) | elem;

state.v3 ^= b;
Expand Down
65 changes: 65 additions & 0 deletions compiler/rustc_serialize/src/int_overflow.rs
@@ -0,0 +1,65 @@
// This would belong to `rustc_data_structures`, but `rustc_serialize` needs it too.

/// Addition, but only overflow checked when `cfg(debug_assertions)` is set
/// instead of respecting `-Coverflow-checks`.
///
/// This exists for performance reasons, as we ship rustc with overflow checks.
/// While overflow checks are perf neutral in almost all of the compiler, there
/// are a few particularly hot areas where we don't want overflow checks in our
/// dist builds. Overflow is still a bug there, so we want overflow check for
/// builds with debug assertions.
///
/// That's a long way to say that this should be used in areas where overflow
/// is a bug but overflow checking is too slow.
pub trait DebugStrictAdd {
/// See [`DebugStrictAdd`].
fn debug_strict_add(self, other: Self) -> Self;
}

macro_rules! impl_debug_strict_add {
($( $ty:ty )*) => {
$(
impl DebugStrictAdd for $ty {
fn debug_strict_add(self, other: Self) -> Self {
if cfg!(debug_assertions) {
self + other
} else {
self.wrapping_add(other)
}
}
}
)*
};
}

/// See [`DebugStrictAdd`].
pub trait DebugStrictSub {
/// See [`DebugStrictAdd`].
fn debug_strict_sub(self, other: Self) -> Self;
}

macro_rules! impl_debug_strict_sub {
($( $ty:ty )*) => {
$(
impl DebugStrictSub for $ty {
fn debug_strict_sub(self, other: Self) -> Self {
if cfg!(debug_assertions) {
self - other
} else {
self.wrapping_sub(other)
}
}
}
)*
};
}

impl_debug_strict_add! {
u8 u16 u32 u64 u128 usize
i8 i16 i32 i64 i128 isize
}

impl_debug_strict_sub! {
u8 u16 u32 u64 u128 usize
i8 i16 i32 i64 i128 isize
}
14 changes: 9 additions & 5 deletions compiler/rustc_serialize/src/leb128.rs
@@ -1,6 +1,10 @@
use crate::opaque::MemDecoder;
use crate::serialize::Decoder;

// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use crate::int_overflow::DebugStrictAdd;

/// Returns the length of the longest LEB128 encoding for `T`, assuming `T` is an integer type
pub const fn max_leb128_len<T>() -> usize {
// The longest LEB128 encoding for an integer uses 7 bits per byte.
Expand All @@ -24,15 +28,15 @@ macro_rules! impl_write_unsigned_leb128 {
*out.get_unchecked_mut(i) = value as u8;
}

i += 1;
i = i.debug_strict_add(1);
break;
} else {
unsafe {
*out.get_unchecked_mut(i) = ((value & 0x7f) | 0x80) as u8;
}

value >>= 7;
i += 1;
i = i.debug_strict_add(1);
}
}

Expand Down Expand Up @@ -69,7 +73,7 @@ macro_rules! impl_read_unsigned_leb128 {
} else {
result |= ((byte & 0x7F) as $int_ty) << shift;
}
shift += 7;
shift = shift.debug_strict_add(7);
}
}
};
Expand Down Expand Up @@ -101,7 +105,7 @@ macro_rules! impl_write_signed_leb128 {
*out.get_unchecked_mut(i) = byte;
}

i += 1;
i = i.debug_strict_add(1);

if !more {
break;
Expand Down Expand Up @@ -130,7 +134,7 @@ macro_rules! impl_read_signed_leb128 {
loop {
byte = decoder.read_u8();
result |= <$int_ty>::from(byte & 0x7F) << shift;
shift += 7;
shift = shift.debug_strict_add(7);

if (byte & 0x80) == 0 {
break;
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_serialize/src/lib.rs
Expand Up @@ -23,5 +23,6 @@ pub use self::serialize::{Decodable, Decoder, Encodable, Encoder};

mod serialize;

pub mod int_overflow;
pub mod leb128;
pub mod opaque;
10 changes: 7 additions & 3 deletions compiler/rustc_serialize/src/opaque.rs
Expand Up @@ -7,6 +7,10 @@ use std::ops::Range;
use std::path::Path;
use std::path::PathBuf;

// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use crate::int_overflow::DebugStrictAdd;

// -----------------------------------------------------------------------------
// Encoder
// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -65,7 +69,7 @@ impl FileEncoder {
// Tracking position this way instead of having a `self.position` field
// means that we only need to update `self.buffered` on a write call,
// as opposed to updating `self.position` and `self.buffered`.
self.flushed + self.buffered
self.flushed.debug_strict_add(self.buffered)
}

#[cold]
Expand Down Expand Up @@ -119,7 +123,7 @@ impl FileEncoder {
}
if let Some(dest) = self.buffer_empty().get_mut(..buf.len()) {
dest.copy_from_slice(buf);
self.buffered += buf.len();
self.buffered = self.buffered.debug_strict_add(buf.len());
} else {
self.write_all_cold_path(buf);
}
Expand Down Expand Up @@ -158,7 +162,7 @@ impl FileEncoder {
if written > N {
Self::panic_invalid_write::<N>(written);
}
self.buffered += written;
self.buffered = self.buffered.debug_strict_add(written);
}

#[cold]
Expand Down
8 changes: 6 additions & 2 deletions compiler/rustc_span/src/span_encoding.rs
Expand Up @@ -5,6 +5,10 @@ use crate::{BytePos, SpanData};

use rustc_data_structures::fx::FxIndexSet;

// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use rustc_serialize::int_overflow::DebugStrictAdd;

/// A compressed span.
///
/// [`SpanData`] is 16 bytes, which is too big to stick everywhere. `Span` only
Expand Down Expand Up @@ -166,7 +170,7 @@ impl Span {
debug_assert!(len <= MAX_LEN);
SpanData {
lo: BytePos(self.lo_or_index),
hi: BytePos(self.lo_or_index + len),
hi: BytePos(self.lo_or_index.debug_strict_add(len)),
ctxt: SyntaxContext::from_u32(self.ctxt_or_parent_or_marker as u32),
parent: None,
}
Expand All @@ -179,7 +183,7 @@ impl Span {
};
SpanData {
lo: BytePos(self.lo_or_index),
hi: BytePos(self.lo_or_index + len),
hi: BytePos(self.lo_or_index.debug_strict_add(len)),
ctxt: SyntaxContext::root(),
parent: Some(parent),
}
Expand Down

0 comments on commit 5039160

Please sign in to comment.