Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions crates/core_simd/src/simd/num/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,52 @@ pub trait SimdInt: Copy + Sealed {

/// Returns the number of trailing ones in the binary representation of each element.
fn trailing_ones(self) -> Self::Unsigned;

/// Unchecked shift left.
///
/// Computes `self << rhs`, assuming that `rhs` is less than the number of bits in `T`.
///
/// # Safety
///
/// This results in undefined behavior if any element of `rhs` is greater than or equal
/// to the number of bits in `T`.
///
/// # Examples
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{prelude::*, num::SimdInt};
/// let a = i32x4::from_array([1, 2, 3, 4]);
/// let b = i32x4::from_array([1, 2, 3, 4]);
/// let c = unsafe { SimdInt::unchecked_shl(a, b) };
/// assert_eq!(c, i32x4::from_array([2, 8, 24, 64]));
/// ```
unsafe fn unchecked_shl(self, rhs: Self) -> Self;

/// Unchecked shift right.
///
/// Computes `self >> rhs`, assuming that `rhs` is less than the number of bits in `T`.
///
/// # Safety
///
/// This results in undefined behavior if any element of `rhs` is greater than or equal
/// to the number of bits in `T`.
///
/// # Examples
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{prelude::*, num::SimdInt};
/// let a = i32x4::from_array([16, 32, 64, 128]);
/// let b = i32x4::from_array([1, 2, 3, 4]);
/// let c = unsafe { SimdInt::unchecked_shr(a, b) };
/// assert_eq!(c, i32x4::from_array([8, 8, 8, 8]));
/// ```
unsafe fn unchecked_shr(self, rhs: Self) -> Self;
}

macro_rules! impl_trait {
Expand Down Expand Up @@ -394,6 +440,18 @@ macro_rules! impl_trait {
fn trailing_ones(self) -> Self::Unsigned {
self.cast::<$unsigned>().trailing_ones()
}

#[inline]
unsafe fn unchecked_shl(self, rhs: Self) -> Self {
// Safety: the caller must ensure rhs < T::BITS for all lanes
unsafe { core::intrinsics::simd::simd_shl(self, rhs) }
}

#[inline]
unsafe fn unchecked_shr(self, rhs: Self) -> Self {
// Safety: the caller must ensure rhs < T::BITS for all lanes
unsafe { core::intrinsics::simd::simd_shr(self, rhs) }
}
}
)*
}
Expand Down
58 changes: 58 additions & 0 deletions crates/core_simd/src/simd/num/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,52 @@ pub trait SimdUint: Copy + Sealed {

/// Returns the number of trailing ones in the binary representation of each element.
fn trailing_ones(self) -> Self;

/// Unchecked shift left.
///
/// Computes `self << rhs`, assuming that `rhs` is less than the number of bits in `T`.
///
/// # Safety
///
/// This results in undefined behavior if any element of `rhs` is greater than or equal
/// to the number of bits in `T`.
///
/// # Examples
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{prelude::*, num::SimdUint};
/// let a = u32x4::from_array([1, 2, 3, 4]);
/// let b = u32x4::from_array([1, 2, 3, 4]);
/// let c = unsafe { SimdUint::unchecked_shl(a, b) };
/// assert_eq!(c, u32x4::from_array([2, 8, 24, 64]));
/// ```
unsafe fn unchecked_shl(self, rhs: Self) -> Self;

/// Unchecked shift right.
///
/// Computes `self >> rhs`, assuming that `rhs` is less than the number of bits in `T`.
///
/// # Safety
///
/// This results in undefined behavior if any element of `rhs` is greater than or equal
/// to the number of bits in `T`.
///
/// # Examples
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{prelude::*, num::SimdUint};
/// let a = u32x4::from_array([16, 32, 64, 128]);
/// let b = u32x4::from_array([1, 2, 3, 4]);
/// let c = unsafe { SimdUint::unchecked_shr(a, b) };
/// assert_eq!(c, u32x4::from_array([8, 8, 8, 8]));
/// ```
unsafe fn unchecked_shr(self, rhs: Self) -> Self;
}

macro_rules! impl_trait {
Expand Down Expand Up @@ -247,6 +293,18 @@ macro_rules! impl_trait {
fn trailing_ones(self) -> Self {
(!self).trailing_zeros()
}

#[inline]
unsafe fn unchecked_shl(self, rhs: Self) -> Self {
// Safety: the caller must ensure rhs < T::BITS for all lanes
unsafe { core::intrinsics::simd::simd_shl(self, rhs) }
}

#[inline]
unsafe fn unchecked_shr(self, rhs: Self) -> Self {
// Safety: the caller must ensure rhs < T::BITS for all lanes
unsafe { core::intrinsics::simd::simd_shr(self, rhs) }
}
}
)*
}
Expand Down
40 changes: 40 additions & 0 deletions crates/core_simd/tests/ops_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,46 @@ macro_rules! impl_common_integer_tests {
&|_| true,
)
}

fn unchecked_shl<const LANES: usize>() {
// Test with valid shift amounts
let a = $vector::<LANES>::splat(1);
let b = $vector::<LANES>::splat(2);
let result = unsafe { a.unchecked_shl(b) };
assert_eq!(result, $vector::<LANES>::splat(4));

// Test with zero shift
let a = $vector::<LANES>::splat(42);
let b = $vector::<LANES>::splat(0);
let result = unsafe { a.unchecked_shl(b) };
assert_eq!(result, $vector::<LANES>::splat(42));

// Test with shift by 1
let a = $vector::<LANES>::splat(8);
let b = $vector::<LANES>::splat(1);
let result = unsafe { a.unchecked_shl(b) };
assert_eq!(result, $vector::<LANES>::splat(16));
}

fn unchecked_shr<const LANES: usize>() {
// Test with valid shift amounts
let a = $vector::<LANES>::splat(16);
let b = $vector::<LANES>::splat(2);
let result = unsafe { a.unchecked_shr(b) };
assert_eq!(result, $vector::<LANES>::splat(4));

// Test with zero shift
let a = $vector::<LANES>::splat(42);
let b = $vector::<LANES>::splat(0);
let result = unsafe { a.unchecked_shr(b) };
assert_eq!(result, $vector::<LANES>::splat(42));

// Test with shift by 1
let a = $vector::<LANES>::splat(8);
let b = $vector::<LANES>::splat(1);
let result = unsafe { a.unchecked_shr(b) };
assert_eq!(result, $vector::<LANES>::splat(4));
}
}
}
}
Expand Down