Skip to content

Commit

Permalink
Auto merge of rust-lang#75600 - nagisa:improve_align_offset, r=KodrAus
Browse files Browse the repository at this point in the history
Improve codegen for `align_offset`

In this PR the `align_offset` implementation is changed/improved to produce better code in certain scenarios such as when pointer type is has a stride of 1 or when building for low optimisation levels.

While these changes do not achieve the "ideal" codegen referenced in rust-lang#75579, it gets significantly closer to it. I’m not actually sure if the codegen can actually be much better with this function returning the offset, rather than the aligned pointer.

See the descriptions for separate commits for further information.
  • Loading branch information
bors committed Aug 19, 2020
2 parents 5b04bbf + 5d22b18 commit 11a44ad
Showing 1 changed file with 47 additions and 28 deletions.
75 changes: 47 additions & 28 deletions library/core/src/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1166,16 +1166,20 @@ pub unsafe fn write_volatile<T>(dst: *mut T, src: T) {
/// Any questions go to @nagisa.
#[lang = "align_offset"]
pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
// FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
// 1, where the method versions of these operations are not inlined.
use intrinsics::{unchecked_shl, unchecked_shr, unchecked_sub, wrapping_mul, wrapping_sub};

/// Calculate multiplicative modular inverse of `x` modulo `m`.
///
/// This implementation is tailored for align_offset and has following preconditions:
/// This implementation is tailored for `align_offset` and has following preconditions:
///
/// * `m` is a power-of-two;
/// * `x < m`; (if `x ≥ m`, pass in `x % m` instead)
///
/// Implementation of this function shall not panic. Ever.
#[inline]
fn mod_inv(x: usize, m: usize) -> usize {
unsafe fn mod_inv(x: usize, m: usize) -> usize {
/// Multiplicative modular inverse table modulo 2⁴ = 16.
///
/// Note, that this table does not contain values where inverse does not exist (i.e., for
Expand All @@ -1187,8 +1191,10 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
const INV_TABLE_MOD_SQUARED: usize = INV_TABLE_MOD * INV_TABLE_MOD;

let table_inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
// SAFETY: `m` is required to be a power-of-two, hence non-zero.
let m_minus_one = unsafe { unchecked_sub(m, 1) };
if m <= INV_TABLE_MOD {
table_inverse & (m - 1)
table_inverse & m_minus_one
} else {
// We iterate "up" using the following formula:
//
Expand All @@ -1204,49 +1210,50 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
// uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
// usize::MAX` instead, because we take the result `mod n` at the end
// anyway.
inverse = inverse.wrapping_mul(2usize.wrapping_sub(x.wrapping_mul(inverse)));
inverse = wrapping_mul(inverse, wrapping_sub(2usize, wrapping_mul(x, inverse)));
if going_mod >= m {
return inverse & (m - 1);
return inverse & m_minus_one;
}
going_mod = going_mod.wrapping_mul(going_mod);
going_mod = wrapping_mul(going_mod, going_mod);
}
}
}

let stride = mem::size_of::<T>();
let a_minus_one = a.wrapping_sub(1);
let pmoda = p as usize & a_minus_one;
// SAFETY: `a` is a power-of-two, therefore non-zero.
let a_minus_one = unsafe { unchecked_sub(a, 1) };
if stride == 1 {
// `stride == 1` case can be computed more efficiently through `-p (mod a)`.
return wrapping_sub(0, p as usize) & a_minus_one;
}

let pmoda = p as usize & a_minus_one;
if pmoda == 0 {
// Already aligned. Yay!
return 0;
}

if stride <= 1 {
return if stride == 0 {
// If the pointer is not aligned, and the element is zero-sized, then no amount of
// elements will ever align the pointer.
!0
} else {
a.wrapping_sub(pmoda)
};
} else if stride == 0 {
// If the pointer is not aligned, and the element is zero-sized, then no amount of
// elements will ever align the pointer.
return usize::MAX;
}

let smoda = stride & a_minus_one;
// SAFETY: a is power-of-two so cannot be 0. stride = 0 is handled above.
// SAFETY: a is power-of-two hence non-zero. stride == 0 case is handled above.
let gcdpow = unsafe { intrinsics::cttz_nonzero(stride).min(intrinsics::cttz_nonzero(a)) };
let gcd = 1usize << gcdpow;
// SAFETY: gcdpow has an upper-bound that’s at most the number of bits in an usize.
let gcd = unsafe { unchecked_shl(1usize, gcdpow) };

if p as usize & (gcd.wrapping_sub(1)) == 0 {
// SAFETY: gcd is always greater or equal to 1.
if p as usize & unsafe { unchecked_sub(gcd, 1) } == 0 {
// This branch solves for the following linear congruence equation:
//
// ` p + so = 0 mod a `
//
// `p` here is the pointer value, `s` - stride of `T`, `o` offset in `T`s, and `a` - the
// requested alignment.
//
// With `g = gcd(a, s)`, and the above asserting that `p` is also divisible by `g`, we can
// denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
// With `g = gcd(a, s)`, and the above condition asserting that `p` is also divisible by
// `g`, we can denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
//
// ` p' + s'o = 0 mod a' `
// ` o = (a' - (p' mod a')) * (s'^-1 mod a') `
Expand All @@ -1259,11 +1266,23 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
//
// Furthermore, the result produced by this solution is not "minimal", so it is necessary
// to take the result `o mod lcm(s, a)`. We can replace `lcm(s, a)` with just a `a'`.
let a2 = a >> gcdpow;
let a2minus1 = a2.wrapping_sub(1);
let s2 = smoda >> gcdpow;
let minusp2 = a2.wrapping_sub(pmoda >> gcdpow);
return (minusp2.wrapping_mul(mod_inv(s2, a2))) & a2minus1;

// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
// `a`.
let a2 = unsafe { unchecked_shr(a, gcdpow) };
// SAFETY: `a2` is non-zero. Shifting `a` by `gcdpow` cannot shift out any of the set bits
// in `a` (of which it has exactly one).
let a2minus1 = unsafe { unchecked_sub(a2, 1) };
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
// `a`.
let s2 = unsafe { unchecked_shr(smoda, gcdpow) };
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
// `a`. Furthermore, the subtraction cannot overflow, because `a2 = a >> gcdpow` will
// always be strictly greater than `(p % a) >> gcdpow`.
let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(pmoda, gcdpow)) };
// SAFETY: `a2` is a power-of-two, as proven above. `s2` is strictly less than `a2`
// because `(s % a) >> gcdpow` is strictly less than `a >> gcdpow`.
return wrapping_mul(minusp2, unsafe { mod_inv(s2, a2) }) & a2minus1;
}

// Cannot be aligned at all.
Expand Down

0 comments on commit 11a44ad

Please sign in to comment.