From a11fb5ff4b0cef2df0372ee4293fbfb8679dbaac Mon Sep 17 00:00:00 2001 From: Folkert de Vries Date: Mon, 17 Nov 2025 11:53:56 +0100 Subject: [PATCH 1/3] add shim for `_mm512_maddubs_epi16` --- src/shims/x86/avx2.rs | 32 +------------ src/shims/x86/avx512.rs | 9 +++- src/shims/x86/mod.rs | 46 +++++++++++++++++++ src/shims/x86/ssse3.rs | 33 +------------ tests/pass/shims/x86/intrinsics-x86-avx512.rs | 42 +++++++++++++++++ 5 files changed, 100 insertions(+), 62 deletions(-) diff --git a/src/shims/x86/avx2.rs b/src/shims/x86/avx2.rs index 142258c697..a3cfe5ffb8 100644 --- a/src/shims/x86/avx2.rs +++ b/src/shims/x86/avx2.rs @@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi; use super::{ ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw, - packuswb, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd, + packuswb, pmaddbw, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd, }; use crate::*; @@ -102,39 +102,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { } } // Used to implement the _mm256_maddubs_epi16 function. - // Multiplies packed 8-bit unsigned integers from `left` and packed - // signed 8-bit integers from `right` into 16-bit signed integers. Then, - // the saturating sum of the products with indices `2*i` and `2*i+1` - // produces the output at index `i`. "pmadd.ub.sw" => { let [left, right] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(left_len, right_len); - assert_eq!(dest_len.strict_mul(2), left_len); - - for i in 0..dest_len { - let j1 = i.strict_mul(2); - let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_u8()?; - let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i8()?; - - let j2 = j1.strict_add(1); - let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_u8()?; - let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i8()?; - - let dest = this.project_index(&dest, i)?; - - // Multiplication of a u8 and an i8 into an i16 cannot overflow. - let mul1 = i16::from(left1).strict_mul(right1.into()); - let mul2 = i16::from(left2).strict_mul(right2.into()); - let res = mul1.saturating_add(mul2); - - this.write_scalar(Scalar::from_i16(res), &dest)?; - } + pmaddbw(this, left, right, dest)?; } // Used to implement the _mm_maskload_epi32, _mm_maskload_epi64, // _mm256_maskload_epi32 and _mm256_maskload_epi64 functions. diff --git a/src/shims/x86/avx512.rs b/src/shims/x86/avx512.rs index 4957b3b88c..e1f875b613 100644 --- a/src/shims/x86/avx512.rs +++ b/src/shims/x86/avx512.rs @@ -3,7 +3,7 @@ use rustc_middle::ty::Ty; use rustc_span::Symbol; use rustc_target::callconv::FnAbi; -use super::psadbw; +use super::{pmaddbw, psadbw}; use crate::*; impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} @@ -88,6 +88,13 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { psadbw(this, left, right, dest)? } + // Used to implement the _mm512_maddubs_epi16 function. + "pmaddubs.w.512" => { + let [left, right] = + this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; + + pmaddbw(this, left, right, dest)?; + } _ => return interp_ok(EmulateItemResult::NotSupported), } interp_ok(EmulateItemResult::NeedsReturn) diff --git a/src/shims/x86/mod.rs b/src/shims/x86/mod.rs index 258ad9f8de..3a0d3f7cfb 100644 --- a/src/shims/x86/mod.rs +++ b/src/shims/x86/mod.rs @@ -1086,6 +1086,52 @@ fn psadbw<'tcx>( interp_ok(()) } +/// Multiplies packed 8-bit unsigned integers from `left` and packed +/// signed 8-bit integers from `right` into 16-bit signed integers. Then, +/// the saturating sum of the products with indices `2*i` and `2*i+1` +/// produces the output at index `i`. +/// +/// +/// +/// +fn pmaddbw<'tcx>( + ecx: &mut crate::MiriInterpCx<'tcx>, + left: &OpTy<'tcx>, + right: &OpTy<'tcx>, + dest: &MPlaceTy<'tcx>, +) -> InterpResult<'tcx, ()> { + let (left, left_len) = ecx.project_to_simd(left)?; + let (right, right_len) = ecx.project_to_simd(right)?; + let (dest, dest_len) = ecx.project_to_simd(dest)?; + + // fn pmaddubsw128(a: u8x16, b: i8x16) -> i16x8; + // fn pmaddubsw( a: u8x32, b: i8x32) -> i16x16; + // fn vpmaddubsw( a: u8x64, b: i8x64) -> i16x32; + assert_eq!(left_len, right_len); + assert_eq!(dest_len.strict_mul(2), left_len); + + for i in 0..dest_len { + let j1 = i.strict_mul(2); + let left1 = ecx.read_scalar(&ecx.project_index(&left, j1)?)?.to_u8()?; + let right1 = ecx.read_scalar(&ecx.project_index(&right, j1)?)?.to_i8()?; + + let j2 = j1.strict_add(1); + let left2 = ecx.read_scalar(&ecx.project_index(&left, j2)?)?.to_u8()?; + let right2 = ecx.read_scalar(&ecx.project_index(&right, j2)?)?.to_i8()?; + + let dest = ecx.project_index(&dest, i)?; + + // Multiplication of a u8 and an i8 into an i16 cannot overflow. + let mul1 = i16::from(left1).strict_mul(right1.into()); + let mul2 = i16::from(left2).strict_mul(right2.into()); + let res = mul1.saturating_add(mul2); + + ecx.write_scalar(Scalar::from_i16(res), &dest)?; + } + + interp_ok(()) +} + /// Multiplies packed 16-bit signed integer values, truncates the 32-bit /// product to the 18 most significant bits by right-shifting, and then /// divides the 18-bit value by 2 (rounding to nearest) by first adding diff --git a/src/shims/x86/ssse3.rs b/src/shims/x86/ssse3.rs index 398f538e1b..56fc63ce14 100644 --- a/src/shims/x86/ssse3.rs +++ b/src/shims/x86/ssse3.rs @@ -4,7 +4,7 @@ use rustc_middle::ty::Ty; use rustc_span::Symbol; use rustc_target::callconv::FnAbi; -use super::{horizontal_bin_op, pmulhrsw, psign}; +use super::{horizontal_bin_op, pmaddbw, pmulhrsw, psign}; use crate::*; impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} @@ -67,40 +67,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { horizontal_bin_op(this, which, /*saturating*/ true, left, right, dest)?; } // Used to implement the _mm_maddubs_epi16 function. - // Multiplies packed 8-bit unsigned integers from `left` and packed - // signed 8-bit integers from `right` into 16-bit signed integers. Then, - // the saturating sum of the products with indices `2*i` and `2*i+1` - // produces the output at index `i`. - // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_maddubs_epi16 "pmadd.ub.sw.128" => { let [left, right] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(left_len, right_len); - assert_eq!(dest_len.strict_mul(2), left_len); - - for i in 0..dest_len { - let j1 = i.strict_mul(2); - let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_u8()?; - let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i8()?; - - let j2 = j1.strict_add(1); - let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_u8()?; - let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i8()?; - - let dest = this.project_index(&dest, i)?; - - // Multiplication of a u8 and an i8 into an i16 cannot overflow. - let mul1 = i16::from(left1).strict_mul(right1.into()); - let mul2 = i16::from(left2).strict_mul(right2.into()); - let res = mul1.saturating_add(mul2); - - this.write_scalar(Scalar::from_i16(res), &dest)?; - } + pmaddbw(this, left, right, dest)?; } // Used to implement the _mm_mulhrs_epi16 function. // Multiplies packed 16-bit signed integer values, truncates the 32-bit diff --git a/tests/pass/shims/x86/intrinsics-x86-avx512.rs b/tests/pass/shims/x86/intrinsics-x86-avx512.rs index 0a9bb2d315..10a11ac809 100644 --- a/tests/pass/shims/x86/intrinsics-x86-avx512.rs +++ b/tests/pass/shims/x86/intrinsics-x86-avx512.rs @@ -55,6 +55,48 @@ unsafe fn test_avx512() { assert_eq_m512i(r, e); } test_mm512_sad_epu8(); + + #[target_feature(enable = "avx512bw")] + unsafe fn test_mm512_maddubs_epi16() { + // `a` is interpreted as `u8x16`, but `_mm512_set_epi8` expects `i8`, so we have to cast. + #[rustfmt::skip] + let a = _mm512_set_epi8( + 255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8, + 255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10, + + 255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8, + 255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10, + + 255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8, + 255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10, + + 255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8, + 255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10, + ); + + let b = _mm512_set_epi8( + 64, 64, -2, 1, 100, 100, -128, -128, // + 127, 127, -1, 1, 2, 2, 1, 1, // + 64, 64, -2, 1, 100, 100, -128, -128, // + 127, 127, -1, 1, 2, 2, 1, 1, // + 64, 64, -2, 1, 100, 100, -128, -128, // + 127, 127, -1, 1, 2, 2, 1, 1, // + 64, 64, -2, 1, 100, 100, -128, -128, // + 127, 127, -1, 1, 2, 2, 1, 1, // + ); + + let r = _mm512_maddubs_epi16(a, b); + + let e = _mm512_set_epi16( + 32640, -70, 20000, -32768, 32767, -100, 220, 30, // + 32640, -70, 20000, -32768, 32767, -100, 220, 30, // + 32640, -70, 20000, -32768, 32767, -100, 220, 30, // + 32640, -70, 20000, -32768, 32767, -100, 220, 30, // + ); + + assert_eq_m512i(r, e); + } + test_mm512_maddubs_epi16(); } // Some of the constants in the tests below are just bit patterns. They should not From 760047c3c9b4815764a13baa43218d1e88719c46 Mon Sep 17 00:00:00 2001 From: Folkert de Vries Date: Mon, 17 Nov 2025 12:19:50 +0100 Subject: [PATCH 2/3] add shim for `_mm512_permutexvar_epi32` --- src/shims/x86/avx512.rs | 20 ++++++++ tests/pass/shims/x86/intrinsics-x86-avx512.rs | 46 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/src/shims/x86/avx512.rs b/src/shims/x86/avx512.rs index e1f875b613..8840af0c0d 100644 --- a/src/shims/x86/avx512.rs +++ b/src/shims/x86/avx512.rs @@ -95,6 +95,26 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { pmaddbw(this, left, right, dest)?; } + // Used to implement the _mm512_permutexvar_epi32 function. + "permvar.si.512" => { + let [left, right] = + this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; + + let (left, left_len) = this.project_to_simd(left)?; + let (right, right_len) = this.project_to_simd(right)?; + let (dest, dest_len) = this.project_to_simd(dest)?; + + assert_eq!(dest_len, left_len); + assert_eq!(dest_len, right_len); + + for i in 0..dest_len { + let dest = this.project_index(&dest, i)?; + let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?; + let left = this.project_index(&left, (right & 0b1111).into())?; + + this.copy_op(&left, &dest)?; + } + } _ => return interp_ok(EmulateItemResult::NotSupported), } interp_ok(EmulateItemResult::NeedsReturn) diff --git a/tests/pass/shims/x86/intrinsics-x86-avx512.rs b/tests/pass/shims/x86/intrinsics-x86-avx512.rs index 10a11ac809..e778567b48 100644 --- a/tests/pass/shims/x86/intrinsics-x86-avx512.rs +++ b/tests/pass/shims/x86/intrinsics-x86-avx512.rs @@ -97,6 +97,52 @@ unsafe fn test_avx512() { assert_eq_m512i(r, e); } test_mm512_maddubs_epi16(); + + #[target_feature(enable = "avx512f")] + unsafe fn test_mm512_permutexvar_epi32() { + let a = _mm512_set_epi32( + 15, 14, 13, 12, // + 11, 10, 9, 8, // + 7, 6, 5, 4, // + 3, 2, 1, 0, // + ); + + let idx_identity = _mm512_set_epi32( + 15, 14, 13, 12, // + 11, 10, 9, 8, // + 7, 6, 5, 4, // + 3, 2, 1, 0, // + ); + let r_id = _mm512_permutexvar_epi32(idx_identity, a); + assert_eq_m512i(r_id, a); + + // Test some out-of-bounds indices. + let edge_cases = _mm512_set_epi32( + 0, + -1, + -128, + i32::MIN, + 15, + 16, + 128, + i32::MAX, + 0, + -1, + -128, + i32::MIN, + 15, + 16, + 128, + i32::MAX, + ); + + let r = _mm512_permutexvar_epi32(edge_cases, a); + + let e = _mm512_set_epi32(0, 15, 0, 0, 15, 0, 0, 15, 0, 15, 0, 0, 15, 0, 0, 15); + + assert_eq_m512i(r, e); + } + test_mm512_permutexvar_epi32(); } // Some of the constants in the tests below are just bit patterns. They should not From 51f7643106a73f138d45e89f39b9df8ce3d1ffe8 Mon Sep 17 00:00:00 2001 From: Folkert de Vries Date: Mon, 17 Nov 2025 12:43:37 +0100 Subject: [PATCH 3/3] share `permute` shim --- src/shims/x86/avx2.rs | 22 +++------------------- src/shims/x86/avx512.rs | 17 ++--------------- src/shims/x86/mod.rs | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 34 deletions(-) diff --git a/src/shims/x86/avx2.rs b/src/shims/x86/avx2.rs index a3cfe5ffb8..cf96a61ff0 100644 --- a/src/shims/x86/avx2.rs +++ b/src/shims/x86/avx2.rs @@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi; use super::{ ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw, - packuswb, pmaddbw, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd, + packuswb, permute, pmaddbw, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd, }; use crate::*; @@ -189,28 +189,12 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { packusdw(this, left, right, dest)?; } - // Used to implement the _mm256_permutevar8x32_epi32 and - // _mm256_permutevar8x32_ps function. - // Shuffles `left` using the three low bits of each element of `right` - // as indices. + // Used to implement _mm256_permutevar8x32_epi32 and _mm256_permutevar8x32_ps. "permd" | "permps" => { let [left, right] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(dest_len, left_len); - assert_eq!(dest_len, right_len); - - for i in 0..dest_len { - let dest = this.project_index(&dest, i)?; - let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?; - let left = this.project_index(&left, (right & 0b111).into())?; - - this.copy_op(&left, &dest)?; - } + permute(this, left, right, dest)?; } // Used to implement the _mm256_sad_epu8 function. "psad.bw" => { diff --git a/src/shims/x86/avx512.rs b/src/shims/x86/avx512.rs index 8840af0c0d..9b43aad96e 100644 --- a/src/shims/x86/avx512.rs +++ b/src/shims/x86/avx512.rs @@ -3,7 +3,7 @@ use rustc_middle::ty::Ty; use rustc_span::Symbol; use rustc_target::callconv::FnAbi; -use super::{pmaddbw, psadbw}; +use super::{permute, pmaddbw, psadbw}; use crate::*; impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} @@ -100,20 +100,7 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let [left, right] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(dest_len, left_len); - assert_eq!(dest_len, right_len); - - for i in 0..dest_len { - let dest = this.project_index(&dest, i)?; - let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?; - let left = this.project_index(&left, (right & 0b1111).into())?; - - this.copy_op(&left, &dest)?; - } + permute(this, left, right, dest)?; } _ => return interp_ok(EmulateItemResult::NotSupported), } diff --git a/src/shims/x86/mod.rs b/src/shims/x86/mod.rs index 3a0d3f7cfb..febfc5afa2 100644 --- a/src/shims/x86/mod.rs +++ b/src/shims/x86/mod.rs @@ -1132,6 +1132,44 @@ fn pmaddbw<'tcx>( interp_ok(()) } +/// Shuffle 32-bit integers in `values` across lanes using the corresponding +/// index in `indices`, and store the results in dst. +/// +/// +/// +/// +fn permute<'tcx>( + ecx: &mut crate::MiriInterpCx<'tcx>, + values: &OpTy<'tcx>, + indices: &OpTy<'tcx>, + dest: &MPlaceTy<'tcx>, +) -> InterpResult<'tcx, ()> { + let (values, values_len) = ecx.project_to_simd(values)?; + let (indices, indices_len) = ecx.project_to_simd(indices)?; + let (dest, dest_len) = ecx.project_to_simd(dest)?; + + // fn permd(a: u32x8, b: u32x8) -> u32x8; + // fn permps(a: __m256, b: i32x8) -> __m256; + // fn vpermd(a: i32x16, idx: i32x16) -> i32x16; + assert_eq!(dest_len, values_len); + assert_eq!(dest_len, indices_len); + + // Only use the lower 3 bits to index into a vector with 8 lanes, + // or the lower 4 bits when indexing into a 16-lane vector. + assert!(dest_len.is_power_of_two()); + let mask = u32::try_from(dest_len).unwrap().strict_sub(1); + + for i in 0..dest_len { + let dest = ecx.project_index(&dest, i)?; + let index = ecx.read_scalar(&ecx.project_index(&indices, i)?)?.to_u32()?; + let element = ecx.project_index(&values, (index & mask).into())?; + + ecx.copy_op(&element, &dest)?; + } + + interp_ok(()) +} + /// Multiplies packed 16-bit signed integer values, truncates the 32-bit /// product to the 18 most significant bits by right-shifting, and then /// divides the 18-bit value by 2 (rounding to nearest) by first adding