From 24250a3ee89bd9245c06d31b9142717722512574 Mon Sep 17 00:00:00 2001 From: Grigory Evko Date: Sat, 15 Nov 2025 08:00:45 +0300 Subject: [PATCH] Add unchecked_shl and unchecked_shr for SIMD integers Adds unchecked shift operations to SimdInt and SimdUint traits. These skip the masking that regular shifts perform, providing optimization opportunities when the shift amount is known to be valid. Regular shifts mask: (value << rhs) where rhs &= (BITS-1) Unchecked shifts: direct simd_shl/simd_shr without masking Safety: Caller must ensure shift amount < BITS for all lanes. --- crates/core_simd/src/simd/num/int.rs | 58 +++++++++++++++++++++++++++ crates/core_simd/src/simd/num/uint.rs | 58 +++++++++++++++++++++++++++ crates/core_simd/tests/ops_macros.rs | 40 ++++++++++++++++++ 3 files changed, 156 insertions(+) diff --git a/crates/core_simd/src/simd/num/int.rs b/crates/core_simd/src/simd/num/int.rs index 6ebf0ba397c..31d39982635 100644 --- a/crates/core_simd/src/simd/num/int.rs +++ b/crates/core_simd/src/simd/num/int.rs @@ -235,6 +235,52 @@ pub trait SimdInt: Copy + Sealed { /// Returns the number of trailing ones in the binary representation of each element. fn trailing_ones(self) -> Self::Unsigned; + + /// Unchecked shift left. + /// + /// Computes `self << rhs`, assuming that `rhs` is less than the number of bits in `T`. + /// + /// # Safety + /// + /// This results in undefined behavior if any element of `rhs` is greater than or equal + /// to the number of bits in `T`. + /// + /// # Examples + /// + /// ``` + /// # #![feature(portable_simd)] + /// # #[cfg(feature = "as_crate")] use core_simd::simd; + /// # #[cfg(not(feature = "as_crate"))] use core::simd; + /// # use simd::{prelude::*, num::SimdInt}; + /// let a = i32x4::from_array([1, 2, 3, 4]); + /// let b = i32x4::from_array([1, 2, 3, 4]); + /// let c = unsafe { SimdInt::unchecked_shl(a, b) }; + /// assert_eq!(c, i32x4::from_array([2, 8, 24, 64])); + /// ``` + unsafe fn unchecked_shl(self, rhs: Self) -> Self; + + /// Unchecked shift right. + /// + /// Computes `self >> rhs`, assuming that `rhs` is less than the number of bits in `T`. + /// + /// # Safety + /// + /// This results in undefined behavior if any element of `rhs` is greater than or equal + /// to the number of bits in `T`. + /// + /// # Examples + /// + /// ``` + /// # #![feature(portable_simd)] + /// # #[cfg(feature = "as_crate")] use core_simd::simd; + /// # #[cfg(not(feature = "as_crate"))] use core::simd; + /// # use simd::{prelude::*, num::SimdInt}; + /// let a = i32x4::from_array([16, 32, 64, 128]); + /// let b = i32x4::from_array([1, 2, 3, 4]); + /// let c = unsafe { SimdInt::unchecked_shr(a, b) }; + /// assert_eq!(c, i32x4::from_array([8, 8, 8, 8])); + /// ``` + unsafe fn unchecked_shr(self, rhs: Self) -> Self; } macro_rules! impl_trait { @@ -394,6 +440,18 @@ macro_rules! impl_trait { fn trailing_ones(self) -> Self::Unsigned { self.cast::<$unsigned>().trailing_ones() } + + #[inline] + unsafe fn unchecked_shl(self, rhs: Self) -> Self { + // Safety: the caller must ensure rhs < T::BITS for all lanes + unsafe { core::intrinsics::simd::simd_shl(self, rhs) } + } + + #[inline] + unsafe fn unchecked_shr(self, rhs: Self) -> Self { + // Safety: the caller must ensure rhs < T::BITS for all lanes + unsafe { core::intrinsics::simd::simd_shr(self, rhs) } + } } )* } diff --git a/crates/core_simd/src/simd/num/uint.rs b/crates/core_simd/src/simd/num/uint.rs index f8a40f8ec56..4d7a37ce886 100644 --- a/crates/core_simd/src/simd/num/uint.rs +++ b/crates/core_simd/src/simd/num/uint.rs @@ -118,6 +118,52 @@ pub trait SimdUint: Copy + Sealed { /// Returns the number of trailing ones in the binary representation of each element. fn trailing_ones(self) -> Self; + + /// Unchecked shift left. + /// + /// Computes `self << rhs`, assuming that `rhs` is less than the number of bits in `T`. + /// + /// # Safety + /// + /// This results in undefined behavior if any element of `rhs` is greater than or equal + /// to the number of bits in `T`. + /// + /// # Examples + /// + /// ``` + /// # #![feature(portable_simd)] + /// # #[cfg(feature = "as_crate")] use core_simd::simd; + /// # #[cfg(not(feature = "as_crate"))] use core::simd; + /// # use simd::{prelude::*, num::SimdUint}; + /// let a = u32x4::from_array([1, 2, 3, 4]); + /// let b = u32x4::from_array([1, 2, 3, 4]); + /// let c = unsafe { SimdUint::unchecked_shl(a, b) }; + /// assert_eq!(c, u32x4::from_array([2, 8, 24, 64])); + /// ``` + unsafe fn unchecked_shl(self, rhs: Self) -> Self; + + /// Unchecked shift right. + /// + /// Computes `self >> rhs`, assuming that `rhs` is less than the number of bits in `T`. + /// + /// # Safety + /// + /// This results in undefined behavior if any element of `rhs` is greater than or equal + /// to the number of bits in `T`. + /// + /// # Examples + /// + /// ``` + /// # #![feature(portable_simd)] + /// # #[cfg(feature = "as_crate")] use core_simd::simd; + /// # #[cfg(not(feature = "as_crate"))] use core::simd; + /// # use simd::{prelude::*, num::SimdUint}; + /// let a = u32x4::from_array([16, 32, 64, 128]); + /// let b = u32x4::from_array([1, 2, 3, 4]); + /// let c = unsafe { SimdUint::unchecked_shr(a, b) }; + /// assert_eq!(c, u32x4::from_array([8, 8, 8, 8])); + /// ``` + unsafe fn unchecked_shr(self, rhs: Self) -> Self; } macro_rules! impl_trait { @@ -247,6 +293,18 @@ macro_rules! impl_trait { fn trailing_ones(self) -> Self { (!self).trailing_zeros() } + + #[inline] + unsafe fn unchecked_shl(self, rhs: Self) -> Self { + // Safety: the caller must ensure rhs < T::BITS for all lanes + unsafe { core::intrinsics::simd::simd_shl(self, rhs) } + } + + #[inline] + unsafe fn unchecked_shr(self, rhs: Self) -> Self { + // Safety: the caller must ensure rhs < T::BITS for all lanes + unsafe { core::intrinsics::simd::simd_shr(self, rhs) } + } } )* } diff --git a/crates/core_simd/tests/ops_macros.rs b/crates/core_simd/tests/ops_macros.rs index 6de78f51e59..66ff130553f 100644 --- a/crates/core_simd/tests/ops_macros.rs +++ b/crates/core_simd/tests/ops_macros.rs @@ -263,6 +263,46 @@ macro_rules! impl_common_integer_tests { &|_| true, ) } + + fn unchecked_shl() { + // Test with valid shift amounts + let a = $vector::::splat(1); + let b = $vector::::splat(2); + let result = unsafe { a.unchecked_shl(b) }; + assert_eq!(result, $vector::::splat(4)); + + // Test with zero shift + let a = $vector::::splat(42); + let b = $vector::::splat(0); + let result = unsafe { a.unchecked_shl(b) }; + assert_eq!(result, $vector::::splat(42)); + + // Test with shift by 1 + let a = $vector::::splat(8); + let b = $vector::::splat(1); + let result = unsafe { a.unchecked_shl(b) }; + assert_eq!(result, $vector::::splat(16)); + } + + fn unchecked_shr() { + // Test with valid shift amounts + let a = $vector::::splat(16); + let b = $vector::::splat(2); + let result = unsafe { a.unchecked_shr(b) }; + assert_eq!(result, $vector::::splat(4)); + + // Test with zero shift + let a = $vector::::splat(42); + let b = $vector::::splat(0); + let result = unsafe { a.unchecked_shr(b) }; + assert_eq!(result, $vector::::splat(42)); + + // Test with shift by 1 + let a = $vector::::splat(8); + let b = $vector::::splat(1); + let result = unsafe { a.unchecked_shr(b) }; + assert_eq!(result, $vector::::splat(4)); + } } } }