diff --git a/src/shims/x86/avx2.rs b/src/shims/x86/avx2.rs index 01e1ac6de5..142258c697 100644 --- a/src/shims/x86/avx2.rs +++ b/src/shims/x86/avx2.rs @@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi; use super::{ ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw, - packuswb, pmulhrsw, psign, shift_simd_by_scalar, shift_simd_by_simd, + packuswb, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd, }; use crate::*; @@ -241,41 +241,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { } } // Used to implement the _mm256_sad_epu8 function. - // Compute the absolute differences of packed unsigned 8-bit integers - // in `left` and `right`, then horizontally sum each consecutive 8 - // differences to produce four unsigned 16-bit integers, and pack - // these unsigned 16-bit integers in the low 16 bits of 64-bit elements - // in `dest`. - // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8 "psad.bw" => { let [left, right] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(left_len, right_len); - assert_eq!(left_len, dest_len.strict_mul(8)); - - for i in 0..dest_len { - let dest = this.project_index(&dest, i)?; - - let mut acc: u16 = 0; - for j in 0..8 { - let src_index = i.strict_mul(8).strict_add(j); - - let left = this.project_index(&left, src_index)?; - let left = this.read_scalar(&left)?.to_u8()?; - - let right = this.project_index(&right, src_index)?; - let right = this.read_scalar(&right)?.to_u8()?; - - acc = acc.strict_add(left.abs_diff(right).into()); - } - - this.write_scalar(Scalar::from_u64(acc.into()), &dest)?; - } + psadbw(this, left, right, dest)? } // Used to implement the _mm256_shuffle_epi8 intrinsic. // Shuffles bytes from `left` using `right` as pattern. diff --git a/src/shims/x86/avx512.rs b/src/shims/x86/avx512.rs index e15b99beba..4957b3b88c 100644 --- a/src/shims/x86/avx512.rs +++ b/src/shims/x86/avx512.rs @@ -3,6 +3,7 @@ use rustc_middle::ty::Ty; use rustc_span::Symbol; use rustc_target::callconv::FnAbi; +use super::psadbw; use crate::*; impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} @@ -78,6 +79,15 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { this.write_scalar(Scalar::from_u32(r), &d_lane)?; } } + // Used to implement the _mm512_sad_epu8 function. + "psad.bw.512" => { + this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?; + + let [left, right] = + this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; + + psadbw(this, left, right, dest)? + } _ => return interp_ok(EmulateItemResult::NotSupported), } interp_ok(EmulateItemResult::NeedsReturn) diff --git a/src/shims/x86/mod.rs b/src/shims/x86/mod.rs index c730609a4a..258ad9f8de 100644 --- a/src/shims/x86/mod.rs +++ b/src/shims/x86/mod.rs @@ -1038,6 +1038,54 @@ fn mpsadbw<'tcx>( interp_ok(()) } +/// Compute the absolute differences of packed unsigned 8-bit integers +/// in `left` and `right`, then horizontally sum each consecutive 8 +/// differences to produce unsigned 16-bit integers, and pack +/// these unsigned 16-bit integers in the low 16 bits of 64-bit elements +/// in `dest`. +/// +/// +/// +/// +fn psadbw<'tcx>( + ecx: &mut crate::MiriInterpCx<'tcx>, + left: &OpTy<'tcx>, + right: &OpTy<'tcx>, + dest: &MPlaceTy<'tcx>, +) -> InterpResult<'tcx, ()> { + let (left, left_len) = ecx.project_to_simd(left)?; + let (right, right_len) = ecx.project_to_simd(right)?; + let (dest, dest_len) = ecx.project_to_simd(dest)?; + + // fn psadbw(a: u8x16, b: u8x16) -> u64x2; + // fn psadbw(a: u8x32, b: u8x32) -> u64x4; + // fn vpsadbw(a: u8x64, b: u8x64) -> u64x8; + assert_eq!(left_len, right_len); + assert_eq!(left_len, left.layout.layout.size().bytes()); + assert_eq!(dest_len, left_len.strict_div(8)); + + for i in 0..dest_len { + let dest = ecx.project_index(&dest, i)?; + + let mut acc: u16 = 0; + for j in 0..8 { + let src_index = i.strict_mul(8).strict_add(j); + + let left = ecx.project_index(&left, src_index)?; + let left = ecx.read_scalar(&left)?.to_u8()?; + + let right = ecx.project_index(&right, src_index)?; + let right = ecx.read_scalar(&right)?.to_u8()?; + + acc = acc.strict_add(left.abs_diff(right).into()); + } + + ecx.write_scalar(Scalar::from_u64(acc.into()), &dest)?; + } + + interp_ok(()) +} + /// Multiplies packed 16-bit signed integer values, truncates the 32-bit /// product to the 18 most significant bits by right-shifting, and then /// divides the 18-bit value by 2 (rounding to nearest) by first adding diff --git a/src/shims/x86/sse2.rs b/src/shims/x86/sse2.rs index 9af7f05d3b..8389813903 100644 --- a/src/shims/x86/sse2.rs +++ b/src/shims/x86/sse2.rs @@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi; use super::{ FloatBinOp, ShiftOp, bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int, - packssdw, packsswb, packuswb, shift_simd_by_scalar, + packssdw, packsswb, packuswb, psadbw, shift_simd_by_scalar, }; use crate::*; @@ -37,41 +37,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { // vectors. match unprefixed_name { // Used to implement the _mm_sad_epu8 function. - // Computes the absolute differences of packed unsigned 8-bit integers in `a` - // and `b`, then horizontally sum each consecutive 8 differences to produce - // two unsigned 16-bit integers, and pack these unsigned 16-bit integers in - // the low 16 bits of 64-bit elements returned. - // - // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8 "psad.bw" => { let [left, right] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - // left and right are u8x16, dest is u64x2 - assert_eq!(left_len, right_len); - assert_eq!(left_len, 16); - assert_eq!(dest_len, 2); - - for i in 0..dest_len { - let dest = this.project_index(&dest, i)?; - - let mut res: u16 = 0; - let n = left_len.strict_div(dest_len); - for j in 0..n { - let op_i = j.strict_add(i.strict_mul(n)); - let left = this.read_scalar(&this.project_index(&left, op_i)?)?.to_u8()?; - let right = - this.read_scalar(&this.project_index(&right, op_i)?)?.to_u8()?; - - res = res.strict_add(left.abs_diff(right).into()); - } - - this.write_scalar(Scalar::from_u64(res.into()), &dest)?; - } + psadbw(this, left, right, dest)? } // Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions // (except _mm_sra_epi64, which is not available in SSE2). diff --git a/tests/pass/shims/x86/intrinsics-x86-avx512.rs b/tests/pass/shims/x86/intrinsics-x86-avx512.rs index c22227a8c6..0a9bb2d315 100644 --- a/tests/pass/shims/x86/intrinsics-x86-avx512.rs +++ b/tests/pass/shims/x86/intrinsics-x86-avx512.rs @@ -15,12 +15,48 @@ fn main() { assert!(is_x86_feature_detected!("avx512vpopcntdq")); unsafe { + test_avx512(); test_avx512bitalg(); test_avx512vpopcntdq(); test_avx512ternarylogic(); } } +#[target_feature(enable = "avx512bw")] +unsafe fn test_avx512() { + #[target_feature(enable = "avx512bw")] + unsafe fn test_mm512_sad_epu8() { + let a = _mm512_set_epi8( + 71, 70, 69, 68, 67, 66, 65, 64, // + 55, 54, 53, 52, 51, 50, 49, 48, // + 47, 46, 45, 44, 43, 42, 41, 40, // + 39, 38, 37, 36, 35, 34, 33, 32, // + 31, 30, 29, 28, 27, 26, 25, 24, // + 23, 22, 21, 20, 19, 18, 17, 16, // + 15, 14, 13, 12, 11, 10, 9, 8, // + 7, 6, 5, 4, 3, 2, 1, 0, // + ); + + // `d` is the absolute difference with the corresponding row in `a`. + let b = _mm512_set_epi8( + 63, 62, 61, 60, 59, 58, 57, 56, // lane 7 (d = 8) + 62, 61, 60, 59, 58, 57, 56, 55, // lane 6 (d = 7) + 53, 52, 51, 50, 49, 48, 47, 46, // lane 5 (d = 6) + 44, 43, 42, 41, 40, 39, 38, 37, // lane 4 (d = 5) + 35, 34, 33, 32, 31, 30, 29, 28, // lane 3 (d = 4) + 26, 25, 24, 23, 22, 21, 20, 19, // lane 2 (d = 3) + 17, 16, 15, 14, 13, 12, 11, 10, // lane 1 (d = 2) + 8, 7, 6, 5, 4, 3, 2, 1, // lane 0 (d = 1) + ); + + let r = _mm512_sad_epu8(a, b); + let e = _mm512_set_epi64(64, 56, 48, 40, 32, 24, 16, 8); + + assert_eq_m512i(r, e); + } + test_mm512_sad_epu8(); +} + // Some of the constants in the tests below are just bit patterns. They should not // be interpreted as integers; signedness does not make sense for them, but // __mXXXi happens to be defined in terms of signed integers.