Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 4 additions & 48 deletions src/shims/x86/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;

use super::{
ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw,
packuswb, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
packuswb, permute, pmaddbw, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
};
use crate::*;

Expand Down Expand Up @@ -102,39 +102,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
}
// Used to implement the _mm256_maddubs_epi16 function.
// Multiplies packed 8-bit unsigned integers from `left` and packed
// signed 8-bit integers from `right` into 16-bit signed integers. Then,
// the saturating sum of the products with indices `2*i` and `2*i+1`
// produces the output at index `i`.
"pmadd.ub.sw" => {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;

let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert_eq!(dest_len.strict_mul(2), left_len);

for i in 0..dest_len {
let j1 = i.strict_mul(2);
let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_u8()?;
let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i8()?;

let j2 = j1.strict_add(1);
let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_u8()?;
let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i8()?;

let dest = this.project_index(&dest, i)?;

// Multiplication of a u8 and an i8 into an i16 cannot overflow.
let mul1 = i16::from(left1).strict_mul(right1.into());
let mul2 = i16::from(left2).strict_mul(right2.into());
let res = mul1.saturating_add(mul2);

this.write_scalar(Scalar::from_i16(res), &dest)?;
}
pmaddbw(this, left, right, dest)?;
}
// Used to implement the _mm_maskload_epi32, _mm_maskload_epi64,
// _mm256_maskload_epi32 and _mm256_maskload_epi64 functions.
Expand Down Expand Up @@ -217,28 +189,12 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {

packusdw(this, left, right, dest)?;
}
// Used to implement the _mm256_permutevar8x32_epi32 and
// _mm256_permutevar8x32_ps function.
// Shuffles `left` using the three low bits of each element of `right`
// as indices.
// Used to implement _mm256_permutevar8x32_epi32 and _mm256_permutevar8x32_ps.
"permd" | "permps" => {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;

let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

for i in 0..dest_len {
let dest = this.project_index(&dest, i)?;
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?;
let left = this.project_index(&left, (right & 0b111).into())?;

this.copy_op(&left, &dest)?;
}
permute(this, left, right, dest)?;
}
// Used to implement the _mm256_sad_epu8 function.
"psad.bw" => {
Expand Down
16 changes: 15 additions & 1 deletion src/shims/x86/avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::callconv::FnAbi;

use super::psadbw;
use super::{permute, pmaddbw, psadbw};
use crate::*;

impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
Expand Down Expand Up @@ -88,6 +88,20 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {

psadbw(this, left, right, dest)?
}
// Used to implement the _mm512_maddubs_epi16 function.
"pmaddubs.w.512" => {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;

pmaddbw(this, left, right, dest)?;
}
// Used to implement the _mm512_permutexvar_epi32 function.
"permvar.si.512" => {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;

permute(this, left, right, dest)?;
}
_ => return interp_ok(EmulateItemResult::NotSupported),
}
interp_ok(EmulateItemResult::NeedsReturn)
Expand Down
84 changes: 84 additions & 0 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,90 @@ fn psadbw<'tcx>(
interp_ok(())
}

/// Multiplies packed 8-bit unsigned integers from `left` and packed
/// signed 8-bit integers from `right` into 16-bit signed integers. Then,
/// the saturating sum of the products with indices `2*i` and `2*i+1`
/// produces the output at index `i`.
///
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_maddubs_epi16>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_maddubs_epi16>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_maddubs_epi16>
fn pmaddbw<'tcx>(
ecx: &mut crate::MiriInterpCx<'tcx>,
left: &OpTy<'tcx>,
right: &OpTy<'tcx>,
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = ecx.project_to_simd(left)?;
let (right, right_len) = ecx.project_to_simd(right)?;
let (dest, dest_len) = ecx.project_to_simd(dest)?;

// fn pmaddubsw128(a: u8x16, b: i8x16) -> i16x8;
// fn pmaddubsw( a: u8x32, b: i8x32) -> i16x16;
// fn vpmaddubsw( a: u8x64, b: i8x64) -> i16x32;
assert_eq!(left_len, right_len);
assert_eq!(dest_len.strict_mul(2), left_len);

for i in 0..dest_len {
let j1 = i.strict_mul(2);
let left1 = ecx.read_scalar(&ecx.project_index(&left, j1)?)?.to_u8()?;
let right1 = ecx.read_scalar(&ecx.project_index(&right, j1)?)?.to_i8()?;

let j2 = j1.strict_add(1);
let left2 = ecx.read_scalar(&ecx.project_index(&left, j2)?)?.to_u8()?;
let right2 = ecx.read_scalar(&ecx.project_index(&right, j2)?)?.to_i8()?;

let dest = ecx.project_index(&dest, i)?;

// Multiplication of a u8 and an i8 into an i16 cannot overflow.
let mul1 = i16::from(left1).strict_mul(right1.into());
let mul2 = i16::from(left2).strict_mul(right2.into());
let res = mul1.saturating_add(mul2);

ecx.write_scalar(Scalar::from_i16(res), &dest)?;
}

interp_ok(())
}

/// Shuffle 32-bit integers in `values` across lanes using the corresponding
/// index in `indices`, and store the results in dst.
///
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_permutevar8x32_epi32>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_permutevar8x32_ps>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_permutexvar_epi32>
fn permute<'tcx>(
ecx: &mut crate::MiriInterpCx<'tcx>,
values: &OpTy<'tcx>,
indices: &OpTy<'tcx>,
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx, ()> {
let (values, values_len) = ecx.project_to_simd(values)?;
let (indices, indices_len) = ecx.project_to_simd(indices)?;
let (dest, dest_len) = ecx.project_to_simd(dest)?;

// fn permd(a: u32x8, b: u32x8) -> u32x8;
// fn permps(a: __m256, b: i32x8) -> __m256;
// fn vpermd(a: i32x16, idx: i32x16) -> i32x16;
assert_eq!(dest_len, values_len);
assert_eq!(dest_len, indices_len);

// Only use the lower 3 bits to index into a vector with 8 lanes,
// or the lower 4 bits when indexing into a 16-lane vector.
assert!(dest_len.is_power_of_two());
let mask = u32::try_from(dest_len).unwrap().strict_sub(1);

for i in 0..dest_len {
let dest = ecx.project_index(&dest, i)?;
let index = ecx.read_scalar(&ecx.project_index(&indices, i)?)?.to_u32()?;
let element = ecx.project_index(&values, (index & mask).into())?;

ecx.copy_op(&element, &dest)?;
}

interp_ok(())
}

/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
/// product to the 18 most significant bits by right-shifting, and then
/// divides the 18-bit value by 2 (rounding to nearest) by first adding
Expand Down
33 changes: 2 additions & 31 deletions src/shims/x86/ssse3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::callconv::FnAbi;

use super::{horizontal_bin_op, pmulhrsw, psign};
use super::{horizontal_bin_op, pmaddbw, pmulhrsw, psign};
use crate::*;

impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
Expand Down Expand Up @@ -67,40 +67,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
horizontal_bin_op(this, which, /*saturating*/ true, left, right, dest)?;
}
// Used to implement the _mm_maddubs_epi16 function.
// Multiplies packed 8-bit unsigned integers from `left` and packed
// signed 8-bit integers from `right` into 16-bit signed integers. Then,
// the saturating sum of the products with indices `2*i` and `2*i+1`
// produces the output at index `i`.
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_maddubs_epi16
"pmadd.ub.sw.128" => {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;

let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert_eq!(dest_len.strict_mul(2), left_len);

for i in 0..dest_len {
let j1 = i.strict_mul(2);
let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_u8()?;
let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i8()?;

let j2 = j1.strict_add(1);
let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_u8()?;
let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i8()?;

let dest = this.project_index(&dest, i)?;

// Multiplication of a u8 and an i8 into an i16 cannot overflow.
let mul1 = i16::from(left1).strict_mul(right1.into());
let mul2 = i16::from(left2).strict_mul(right2.into());
let res = mul1.saturating_add(mul2);

this.write_scalar(Scalar::from_i16(res), &dest)?;
}
pmaddbw(this, left, right, dest)?;
}
// Used to implement the _mm_mulhrs_epi16 function.
// Multiplies packed 16-bit signed integer values, truncates the 32-bit
Expand Down
88 changes: 88 additions & 0 deletions tests/pass/shims/x86/intrinsics-x86-avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,94 @@ unsafe fn test_avx512() {
assert_eq_m512i(r, e);
}
test_mm512_sad_epu8();

#[target_feature(enable = "avx512bw")]
unsafe fn test_mm512_maddubs_epi16() {
// `a` is interpreted as `u8x16`, but `_mm512_set_epi8` expects `i8`, so we have to cast.
#[rustfmt::skip]
let a = _mm512_set_epi8(
255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8,
255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10,

255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8,
255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10,

255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8,
255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10,

255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8,
255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10,
);

let b = _mm512_set_epi8(
64, 64, -2, 1, 100, 100, -128, -128, //
127, 127, -1, 1, 2, 2, 1, 1, //
64, 64, -2, 1, 100, 100, -128, -128, //
127, 127, -1, 1, 2, 2, 1, 1, //
64, 64, -2, 1, 100, 100, -128, -128, //
127, 127, -1, 1, 2, 2, 1, 1, //
64, 64, -2, 1, 100, 100, -128, -128, //
127, 127, -1, 1, 2, 2, 1, 1, //
);

let r = _mm512_maddubs_epi16(a, b);

let e = _mm512_set_epi16(
32640, -70, 20000, -32768, 32767, -100, 220, 30, //
32640, -70, 20000, -32768, 32767, -100, 220, 30, //
32640, -70, 20000, -32768, 32767, -100, 220, 30, //
32640, -70, 20000, -32768, 32767, -100, 220, 30, //
);

assert_eq_m512i(r, e);
}
test_mm512_maddubs_epi16();

#[target_feature(enable = "avx512f")]
unsafe fn test_mm512_permutexvar_epi32() {
let a = _mm512_set_epi32(
15, 14, 13, 12, //
11, 10, 9, 8, //
7, 6, 5, 4, //
3, 2, 1, 0, //
);

let idx_identity = _mm512_set_epi32(
15, 14, 13, 12, //
11, 10, 9, 8, //
7, 6, 5, 4, //
3, 2, 1, 0, //
);
let r_id = _mm512_permutexvar_epi32(idx_identity, a);
assert_eq_m512i(r_id, a);

// Test some out-of-bounds indices.
let edge_cases = _mm512_set_epi32(
0,
-1,
-128,
i32::MIN,
15,
16,
128,
i32::MAX,
0,
-1,
-128,
i32::MIN,
15,
16,
128,
i32::MAX,
);

let r = _mm512_permutexvar_epi32(edge_cases, a);

let e = _mm512_set_epi32(0, 15, 0, 0, 15, 0, 0, 15, 0, 15, 0, 0, 15, 0, 0, 15);

assert_eq_m512i(r, e);
}
test_mm512_permutexvar_epi32();
}

// Some of the constants in the tests below are just bit patterns. They should not
Expand Down