Skip to content

Commit

Permalink
Initial conversion to const generics
Browse files Browse the repository at this point in the history
  • Loading branch information
Amanieu committed Feb 26, 2021
1 parent fc71718 commit 0c3a550
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 72 deletions.
17 changes: 16 additions & 1 deletion crates/assert-instr-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub fn assert_instr(
);
let mut inputs = Vec::new();
let mut input_vals = Vec::new();
let mut const_vals = Vec::new();
let ret = &func.sig.output;
for arg in func.sig.inputs.iter() {
let capture = match *arg {
Expand All @@ -82,6 +83,20 @@ pub fn assert_instr(
input_vals.push(quote! { #ident });
}
}
for arg in func.sig.generics.params.iter() {
let c = match *arg {
syn::GenericParam::Const(ref c) => c,
ref v => panic!(
"only const generics are allowed: `{:?}`",
v.clone().into_token_stream()
),
};
if let Some(&(_, ref tokens)) = invoc.args.iter().find(|a| c.ident == a.0) {
const_vals.push(quote! { #tokens });
} else {
panic!("const generics must have a value for tests");
}
}

let attrs = func
.attrs
Expand Down Expand Up @@ -133,7 +148,7 @@ pub fn assert_instr(
std::mem::transmute(#shim_name_str.as_bytes().as_ptr()),
std::sync::atomic::Ordering::Relaxed,
);
#name(#(#input_vals),*)
#name::<#(#const_vals),*>(#(#input_vals),*)
}
};

Expand Down
3 changes: 2 additions & 1 deletion crates/core_arch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
clippy::shadow_reuse,
clippy::cognitive_complexity,
clippy::similar_names,
clippy::many_single_char_names
clippy::many_single_char_names,
non_upper_case_globals
)]
#![cfg_attr(test, allow(unused_imports))]
#![no_std]
Expand Down
4 changes: 2 additions & 2 deletions crates/core_arch/src/x86/avx512f.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22694,7 +22694,7 @@ pub unsafe fn _mm_mask_shuffle_ps(
) -> __m128 {
macro_rules! call {
($imm8:expr) => {
_mm_shuffle_ps(a, b, $imm8)
_mm_shuffle_ps::<$imm8>(a, b)
};
}
let r = constify_imm8_sae!(imm8, call);
Expand All @@ -22711,7 +22711,7 @@ pub unsafe fn _mm_mask_shuffle_ps(
pub unsafe fn _mm_maskz_shuffle_ps(k: __mmask8, a: __m128, b: __m128, imm8: i32) -> __m128 {
macro_rules! call {
($imm8:expr) => {
_mm_shuffle_ps(a, b, $imm8)
_mm_shuffle_ps::<$imm8>(a, b)
};
}
let r = constify_imm8_sae!(imm8, call);
Expand Down
98 changes: 33 additions & 65 deletions crates/core_arch/src/x86/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,52 +1007,20 @@ pub const fn _MM_SHUFFLE(z: u32, y: u32, x: u32, w: u32) -> i32 {
#[inline]
#[target_feature(enable = "sse")]
#[cfg_attr(test, assert_instr(shufps, mask = 3))]
#[rustc_args_required_const(2)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_shuffle_ps(a: __m128, b: __m128, mask: i32) -> __m128 {
let mask = (mask & 0xFF) as u8;

macro_rules! shuffle_done {
($x01:expr, $x23:expr, $x45:expr, $x67:expr) => {
simd_shuffle4(a, b, [$x01, $x23, $x45, $x67])
};
}
macro_rules! shuffle_x67 {
($x01:expr, $x23:expr, $x45:expr) => {
match (mask >> 6) & 0b11 {
0b00 => shuffle_done!($x01, $x23, $x45, 4),
0b01 => shuffle_done!($x01, $x23, $x45, 5),
0b10 => shuffle_done!($x01, $x23, $x45, 6),
_ => shuffle_done!($x01, $x23, $x45, 7),
}
};
}
macro_rules! shuffle_x45 {
($x01:expr, $x23:expr) => {
match (mask >> 4) & 0b11 {
0b00 => shuffle_x67!($x01, $x23, 4),
0b01 => shuffle_x67!($x01, $x23, 5),
0b10 => shuffle_x67!($x01, $x23, 6),
_ => shuffle_x67!($x01, $x23, 7),
}
};
}
macro_rules! shuffle_x23 {
($x01:expr) => {
match (mask >> 2) & 0b11 {
0b00 => shuffle_x45!($x01, 0),
0b01 => shuffle_x45!($x01, 1),
0b10 => shuffle_x45!($x01, 2),
_ => shuffle_x45!($x01, 3),
}
};
}
match mask & 0b11 {
0b00 => shuffle_x23!(0),
0b01 => shuffle_x23!(1),
0b10 => shuffle_x23!(2),
_ => shuffle_x23!(3),
}
#[rustc_legacy_const_generics(2)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_shuffle_ps<const mask: i32>(a: __m128, b: __m128) -> __m128 {
assert!(mask >= 0 && mask <= 255);
simd_shuffle4(
a,
b,
[
mask as u32 & 0b11,
(mask as u32 >> 2) & 0b11,
((mask as u32 >> 4) & 0b11) + 4,
((mask as u32 >> 6) & 0b11) + 4,
],
)
}

/// Unpacks and interleave single-precision (32-bit) floating-point elements
Expand Down Expand Up @@ -1725,6 +1693,14 @@ pub const _MM_HINT_T2: i32 = 1;
#[stable(feature = "simd_x86", since = "1.27.0")]
pub const _MM_HINT_NTA: i32 = 0;

/// See [`_mm_prefetch`](fn._mm_prefetch.html).
#[stable(feature = "simd_x86", since = "1.27.0")]
pub const _MM_HINT_ET0: i32 = 7;

/// See [`_mm_prefetch`](fn._mm_prefetch.html).
#[stable(feature = "simd_x86", since = "1.27.0")]
pub const _MM_HINT_ET1: i32 = 6;

/// Fetch the cache line that contains address `p` using the given `strategy`.
///
/// The `strategy` must be one of:
Expand All @@ -1742,6 +1718,10 @@ pub const _MM_HINT_NTA: i32 = 0;
/// but outside of the cache hierarchy. This is used to reduce access latency
/// without polluting the cache.
///
/// * [`_MM_HINT_ET0`](constant._MM_HINT_ET0.html) and
/// [`_MM_HINT_ET1`](constant._MM_HINT_ET1.html) are similar to `_MM_HINT_T0`
/// and `_MM_HINT_T1` but indicate an anticipation to write to the address.
///
/// The actual implementation depends on the particular CPU. This instruction
/// is considered a hint, so the CPU is also free to simply ignore the request.
///
Expand Down Expand Up @@ -1769,24 +1749,12 @@ pub const _MM_HINT_NTA: i32 = 0;
#[cfg_attr(test, assert_instr(prefetcht1, strategy = _MM_HINT_T1))]
#[cfg_attr(test, assert_instr(prefetcht2, strategy = _MM_HINT_T2))]
#[cfg_attr(test, assert_instr(prefetchnta, strategy = _MM_HINT_NTA))]
#[rustc_args_required_const(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_prefetch(p: *const i8, strategy: i32) {
// The `strategy` must be a compile-time constant, so we use a short form
// of `constify_imm8!` for now.
// We use the `llvm.prefetch` instrinsic with `rw` = 0 (read), and
// `cache type` = 1 (data cache). `locality` is based on our `strategy`.
macro_rules! pref {
($imm8:expr) => {
match $imm8 {
0 => prefetch(p, 0, 0, 1),
1 => prefetch(p, 0, 1, 1),
2 => prefetch(p, 0, 2, 1),
_ => prefetch(p, 0, 3, 1),
}
};
}
pref!(strategy)
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_prefetch<const strategy: i32>(p: *const i8) {
// We use the `llvm.prefetch` instrinsic with `cache type` = 1 (data cache).
// `locality` and `rw` are based on our `strategy`.
prefetch(p, (strategy >> 2) & 1, strategy & 3, 1);
}

/// Returns vector of type __m128 with undefined elements.
Expand Down Expand Up @@ -2976,7 +2944,7 @@ mod tests {
unsafe fn test_mm_shuffle_ps() {
let a = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
let b = _mm_setr_ps(5.0, 6.0, 7.0, 8.0);
let r = _mm_shuffle_ps(a, b, 0b00_01_01_11);
let r = _mm_shuffle_ps::<0b00_01_01_11>(a, b);
assert_eq_m128(r, _mm_setr_ps(4.0, 2.0, 6.0, 5.0));
}

Expand Down
30 changes: 27 additions & 3 deletions crates/stdarch-verify/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,21 @@ fn functions(input: TokenStream, dirs: &[&str]) -> TokenStream {
let name = &f.sig.ident;
// println!("{}", name);
let mut arguments = Vec::new();
let mut const_arguments = Vec::new();
for input in f.sig.inputs.iter() {
let ty = match *input {
syn::FnArg::Typed(ref c) => &c.ty,
_ => panic!("invalid argument on {}", name),
};
arguments.push(to_type(ty));
}
for generic in f.sig.generics.params.iter() {
let ty = match *generic {
syn::GenericParam::Const(ref c) => &c.ty,
_ => panic!("invalid generic argument on {}", name),
};
const_arguments.push(to_type(ty));
}
let ret = match f.sig.output {
syn::ReturnType::Default => quote! { None },
syn::ReturnType::Type(_, ref t) => {
Expand All @@ -101,7 +109,23 @@ fn functions(input: TokenStream, dirs: &[&str]) -> TokenStream {
} else {
quote! { None }
};
let required_const = find_required_const(&f.attrs);

let required_const = find_required_const("rustc_args_required_const", &f.attrs);
let mut legacy_const_generics =
find_required_const("rustc_legacy_const_generics", &f.attrs);
if !required_const.is_empty() && !legacy_const_generics.is_empty() {
panic!(
"Can't have both #[rustc_args_required_const] and \
#[rustc_legacy_const_generics]"
);
}
legacy_const_generics.sort();
for (idx, ty) in legacy_const_generics
.into_iter()
.zip(const_arguments.into_iter())
{
arguments.insert(idx, ty);
}

// strip leading underscore from fn name when building a test
// _mm_foo -> mm_foo such that the test name is test_mm_foo.
Expand Down Expand Up @@ -390,11 +414,11 @@ fn find_target_feature(attrs: &[syn::Attribute]) -> Option<syn::Lit> {
})
}

fn find_required_const(attrs: &[syn::Attribute]) -> Vec<usize> {
fn find_required_const(name: &str, attrs: &[syn::Attribute]) -> Vec<usize> {
attrs
.iter()
.flat_map(|a| {
if a.path.segments[0].ident == "rustc_args_required_const" {
if a.path.segments[0].ident == name {
syn::parse::<RustcArgsRequiredConst>(a.tokens.clone().into())
.unwrap()
.args
Expand Down

0 comments on commit 0c3a550

Please sign in to comment.