diff --git a/crates/core_simd/src/simd/num/int.rs b/crates/core_simd/src/simd/num/int.rs index 6ebf0ba397c..31d39982635 100644 --- a/crates/core_simd/src/simd/num/int.rs +++ b/crates/core_simd/src/simd/num/int.rs @@ -235,6 +235,52 @@ pub trait SimdInt: Copy + Sealed { /// Returns the number of trailing ones in the binary representation of each element. fn trailing_ones(self) -> Self::Unsigned; + + /// Unchecked shift left. + /// + /// Computes `self << rhs`, assuming that `rhs` is less than the number of bits in `T`. + /// + /// # Safety + /// + /// This results in undefined behavior if any element of `rhs` is greater than or equal + /// to the number of bits in `T`. + /// + /// # Examples + /// + /// ``` + /// # #![feature(portable_simd)] + /// # #[cfg(feature = "as_crate")] use core_simd::simd; + /// # #[cfg(not(feature = "as_crate"))] use core::simd; + /// # use simd::{prelude::*, num::SimdInt}; + /// let a = i32x4::from_array([1, 2, 3, 4]); + /// let b = i32x4::from_array([1, 2, 3, 4]); + /// let c = unsafe { SimdInt::unchecked_shl(a, b) }; + /// assert_eq!(c, i32x4::from_array([2, 8, 24, 64])); + /// ``` + unsafe fn unchecked_shl(self, rhs: Self) -> Self; + + /// Unchecked shift right. + /// + /// Computes `self >> rhs`, assuming that `rhs` is less than the number of bits in `T`. + /// + /// # Safety + /// + /// This results in undefined behavior if any element of `rhs` is greater than or equal + /// to the number of bits in `T`. + /// + /// # Examples + /// + /// ``` + /// # #![feature(portable_simd)] + /// # #[cfg(feature = "as_crate")] use core_simd::simd; + /// # #[cfg(not(feature = "as_crate"))] use core::simd; + /// # use simd::{prelude::*, num::SimdInt}; + /// let a = i32x4::from_array([16, 32, 64, 128]); + /// let b = i32x4::from_array([1, 2, 3, 4]); + /// let c = unsafe { SimdInt::unchecked_shr(a, b) }; + /// assert_eq!(c, i32x4::from_array([8, 8, 8, 8])); + /// ``` + unsafe fn unchecked_shr(self, rhs: Self) -> Self; } macro_rules! impl_trait { @@ -394,6 +440,18 @@ macro_rules! impl_trait { fn trailing_ones(self) -> Self::Unsigned { self.cast::<$unsigned>().trailing_ones() } + + #[inline] + unsafe fn unchecked_shl(self, rhs: Self) -> Self { + // Safety: the caller must ensure rhs < T::BITS for all lanes + unsafe { core::intrinsics::simd::simd_shl(self, rhs) } + } + + #[inline] + unsafe fn unchecked_shr(self, rhs: Self) -> Self { + // Safety: the caller must ensure rhs < T::BITS for all lanes + unsafe { core::intrinsics::simd::simd_shr(self, rhs) } + } } )* } diff --git a/crates/core_simd/src/simd/num/uint.rs b/crates/core_simd/src/simd/num/uint.rs index f8a40f8ec56..4d7a37ce886 100644 --- a/crates/core_simd/src/simd/num/uint.rs +++ b/crates/core_simd/src/simd/num/uint.rs @@ -118,6 +118,52 @@ pub trait SimdUint: Copy + Sealed { /// Returns the number of trailing ones in the binary representation of each element. fn trailing_ones(self) -> Self; + + /// Unchecked shift left. + /// + /// Computes `self << rhs`, assuming that `rhs` is less than the number of bits in `T`. + /// + /// # Safety + /// + /// This results in undefined behavior if any element of `rhs` is greater than or equal + /// to the number of bits in `T`. + /// + /// # Examples + /// + /// ``` + /// # #![feature(portable_simd)] + /// # #[cfg(feature = "as_crate")] use core_simd::simd; + /// # #[cfg(not(feature = "as_crate"))] use core::simd; + /// # use simd::{prelude::*, num::SimdUint}; + /// let a = u32x4::from_array([1, 2, 3, 4]); + /// let b = u32x4::from_array([1, 2, 3, 4]); + /// let c = unsafe { SimdUint::unchecked_shl(a, b) }; + /// assert_eq!(c, u32x4::from_array([2, 8, 24, 64])); + /// ``` + unsafe fn unchecked_shl(self, rhs: Self) -> Self; + + /// Unchecked shift right. + /// + /// Computes `self >> rhs`, assuming that `rhs` is less than the number of bits in `T`. + /// + /// # Safety + /// + /// This results in undefined behavior if any element of `rhs` is greater than or equal + /// to the number of bits in `T`. + /// + /// # Examples + /// + /// ``` + /// # #![feature(portable_simd)] + /// # #[cfg(feature = "as_crate")] use core_simd::simd; + /// # #[cfg(not(feature = "as_crate"))] use core::simd; + /// # use simd::{prelude::*, num::SimdUint}; + /// let a = u32x4::from_array([16, 32, 64, 128]); + /// let b = u32x4::from_array([1, 2, 3, 4]); + /// let c = unsafe { SimdUint::unchecked_shr(a, b) }; + /// assert_eq!(c, u32x4::from_array([8, 8, 8, 8])); + /// ``` + unsafe fn unchecked_shr(self, rhs: Self) -> Self; } macro_rules! impl_trait { @@ -247,6 +293,18 @@ macro_rules! impl_trait { fn trailing_ones(self) -> Self { (!self).trailing_zeros() } + + #[inline] + unsafe fn unchecked_shl(self, rhs: Self) -> Self { + // Safety: the caller must ensure rhs < T::BITS for all lanes + unsafe { core::intrinsics::simd::simd_shl(self, rhs) } + } + + #[inline] + unsafe fn unchecked_shr(self, rhs: Self) -> Self { + // Safety: the caller must ensure rhs < T::BITS for all lanes + unsafe { core::intrinsics::simd::simd_shr(self, rhs) } + } } )* } diff --git a/crates/core_simd/tests/ops_macros.rs b/crates/core_simd/tests/ops_macros.rs index 6de78f51e59..66ff130553f 100644 --- a/crates/core_simd/tests/ops_macros.rs +++ b/crates/core_simd/tests/ops_macros.rs @@ -263,6 +263,46 @@ macro_rules! impl_common_integer_tests { &|_| true, ) } + + fn unchecked_shl() { + // Test with valid shift amounts + let a = $vector::::splat(1); + let b = $vector::::splat(2); + let result = unsafe { a.unchecked_shl(b) }; + assert_eq!(result, $vector::::splat(4)); + + // Test with zero shift + let a = $vector::::splat(42); + let b = $vector::::splat(0); + let result = unsafe { a.unchecked_shl(b) }; + assert_eq!(result, $vector::::splat(42)); + + // Test with shift by 1 + let a = $vector::::splat(8); + let b = $vector::::splat(1); + let result = unsafe { a.unchecked_shl(b) }; + assert_eq!(result, $vector::::splat(16)); + } + + fn unchecked_shr() { + // Test with valid shift amounts + let a = $vector::::splat(16); + let b = $vector::::splat(2); + let result = unsafe { a.unchecked_shr(b) }; + assert_eq!(result, $vector::::splat(4)); + + // Test with zero shift + let a = $vector::::splat(42); + let b = $vector::::splat(0); + let result = unsafe { a.unchecked_shr(b) }; + assert_eq!(result, $vector::::splat(42)); + + // Test with shift by 1 + let a = $vector::::splat(8); + let b = $vector::::splat(1); + let result = unsafe { a.unchecked_shr(b) }; + assert_eq!(result, $vector::::splat(4)); + } } } }