Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 2 additions & 32 deletions src/shims/x86/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;

use super::{
ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw,
packuswb, pmulhrsw, psign, shift_simd_by_scalar, shift_simd_by_simd,
packuswb, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
};
use crate::*;

Expand Down Expand Up @@ -241,41 +241,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
}
// Used to implement the _mm256_sad_epu8 function.
// Compute the absolute differences of packed unsigned 8-bit integers
// in `left` and `right`, then horizontally sum each consecutive 8
// differences to produce four unsigned 16-bit integers, and pack
// these unsigned 16-bit integers in the low 16 bits of 64-bit elements
// in `dest`.
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8
"psad.bw" => {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;

let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert_eq!(left_len, dest_len.strict_mul(8));

for i in 0..dest_len {
let dest = this.project_index(&dest, i)?;

let mut acc: u16 = 0;
for j in 0..8 {
let src_index = i.strict_mul(8).strict_add(j);

let left = this.project_index(&left, src_index)?;
let left = this.read_scalar(&left)?.to_u8()?;

let right = this.project_index(&right, src_index)?;
let right = this.read_scalar(&right)?.to_u8()?;

acc = acc.strict_add(left.abs_diff(right).into());
}

this.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
}
psadbw(this, left, right, dest)?
}
// Used to implement the _mm256_shuffle_epi8 intrinsic.
// Shuffles bytes from `left` using `right` as pattern.
Expand Down
10 changes: 10 additions & 0 deletions src/shims/x86/avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::callconv::FnAbi;

use super::psadbw;
use crate::*;

impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
Expand Down Expand Up @@ -78,6 +79,15 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
this.write_scalar(Scalar::from_u32(r), &d_lane)?;
}
}
// Used to implement the _mm512_sad_epu8 function.
"psad.bw.512" => {
this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;

let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;

psadbw(this, left, right, dest)?
}
_ => return interp_ok(EmulateItemResult::NotSupported),
}
interp_ok(EmulateItemResult::NeedsReturn)
Expand Down
48 changes: 48 additions & 0 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,54 @@ fn mpsadbw<'tcx>(
interp_ok(())
}

/// Compute the absolute differences of packed unsigned 8-bit integers
/// in `left` and `right`, then horizontally sum each consecutive 8
/// differences to produce unsigned 16-bit integers, and pack
/// these unsigned 16-bit integers in the low 16 bits of 64-bit elements
/// in `dest`.
///
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_sad_epu8>
fn psadbw<'tcx>(
ecx: &mut crate::MiriInterpCx<'tcx>,
left: &OpTy<'tcx>,
right: &OpTy<'tcx>,
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = ecx.project_to_simd(left)?;
let (right, right_len) = ecx.project_to_simd(right)?;
let (dest, dest_len) = ecx.project_to_simd(dest)?;

// fn psadbw(a: u8x16, b: u8x16) -> u64x2;
// fn psadbw(a: u8x32, b: u8x32) -> u64x4;
// fn vpsadbw(a: u8x64, b: u8x64) -> u64x8;
assert_eq!(left_len, right_len);
assert_eq!(left_len, left.layout.layout.size().bytes());
assert_eq!(dest_len, left_len.strict_div(8));

for i in 0..dest_len {
let dest = ecx.project_index(&dest, i)?;

let mut acc: u16 = 0;
for j in 0..8 {
let src_index = i.strict_mul(8).strict_add(j);

let left = ecx.project_index(&left, src_index)?;
let left = ecx.read_scalar(&left)?.to_u8()?;

let right = ecx.project_index(&right, src_index)?;
let right = ecx.read_scalar(&right)?.to_u8()?;

acc = acc.strict_add(left.abs_diff(right).into());
}

ecx.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
}

interp_ok(())
}

/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
/// product to the 18 most significant bits by right-shifting, and then
/// divides the 18-bit value by 2 (rounding to nearest) by first adding
Expand Down
34 changes: 2 additions & 32 deletions src/shims/x86/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;

use super::{
FloatBinOp, ShiftOp, bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int,
packssdw, packsswb, packuswb, shift_simd_by_scalar,
packssdw, packsswb, packuswb, psadbw, shift_simd_by_scalar,
};
use crate::*;

Expand Down Expand Up @@ -37,41 +37,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// vectors.
match unprefixed_name {
// Used to implement the _mm_sad_epu8 function.
// Computes the absolute differences of packed unsigned 8-bit integers in `a`
// and `b`, then horizontally sum each consecutive 8 differences to produce
// two unsigned 16-bit integers, and pack these unsigned 16-bit integers in
// the low 16 bits of 64-bit elements returned.
//
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8
"psad.bw" => {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;

let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

// left and right are u8x16, dest is u64x2
assert_eq!(left_len, right_len);
assert_eq!(left_len, 16);
assert_eq!(dest_len, 2);

for i in 0..dest_len {
let dest = this.project_index(&dest, i)?;

let mut res: u16 = 0;
let n = left_len.strict_div(dest_len);
for j in 0..n {
let op_i = j.strict_add(i.strict_mul(n));
let left = this.read_scalar(&this.project_index(&left, op_i)?)?.to_u8()?;
let right =
this.read_scalar(&this.project_index(&right, op_i)?)?.to_u8()?;

res = res.strict_add(left.abs_diff(right).into());
}

this.write_scalar(Scalar::from_u64(res.into()), &dest)?;
}
psadbw(this, left, right, dest)?
}
// Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions
// (except _mm_sra_epi64, which is not available in SSE2).
Expand Down
36 changes: 36 additions & 0 deletions tests/pass/shims/x86/intrinsics-x86-avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,48 @@ fn main() {
assert!(is_x86_feature_detected!("avx512vpopcntdq"));

unsafe {
test_avx512();
test_avx512bitalg();
test_avx512vpopcntdq();
test_avx512ternarylogic();
}
}

#[target_feature(enable = "avx512bw")]
unsafe fn test_avx512() {
#[target_feature(enable = "avx512bw")]
unsafe fn test_mm512_sad_epu8() {
let a = _mm512_set_epi8(
71, 70, 69, 68, 67, 66, 65, 64, //
55, 54, 53, 52, 51, 50, 49, 48, //
47, 46, 45, 44, 43, 42, 41, 40, //
39, 38, 37, 36, 35, 34, 33, 32, //
31, 30, 29, 28, 27, 26, 25, 24, //
23, 22, 21, 20, 19, 18, 17, 16, //
15, 14, 13, 12, 11, 10, 9, 8, //
7, 6, 5, 4, 3, 2, 1, 0, //
);

// `d` is the absolute difference with the corresponding row in `a`.
let b = _mm512_set_epi8(
63, 62, 61, 60, 59, 58, 57, 56, // lane 7 (d = 8)
62, 61, 60, 59, 58, 57, 56, 55, // lane 6 (d = 7)
53, 52, 51, 50, 49, 48, 47, 46, // lane 5 (d = 6)
44, 43, 42, 41, 40, 39, 38, 37, // lane 4 (d = 5)
35, 34, 33, 32, 31, 30, 29, 28, // lane 3 (d = 4)
26, 25, 24, 23, 22, 21, 20, 19, // lane 2 (d = 3)
17, 16, 15, 14, 13, 12, 11, 10, // lane 1 (d = 2)
8, 7, 6, 5, 4, 3, 2, 1, // lane 0 (d = 1)
);

let r = _mm512_sad_epu8(a, b);
let e = _mm512_set_epi64(64, 56, 48, 40, 32, 24, 16, 8);

assert_eq_m512i(r, e);
}
test_mm512_sad_epu8();
}

// Some of the constants in the tests below are just bit patterns. They should not
// be interpreted as integers; signedness does not make sense for them, but
// __mXXXi happens to be defined in terms of signed integers.
Expand Down