Skip to content

Commit

Permalink
API: Add requirement of non-negative strides to raw ptr constructors
Browse files Browse the repository at this point in the history
For raw views and array views, require non-negative strides and check
this with a debug assertion.
  • Loading branch information
bluss committed Jan 13, 2021
1 parent 0d8b965 commit e6a4f10
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,19 @@ pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option
Some(offset)
}

/// Checks if strides are non-negative.
pub fn strides_non_negative<D>(strides: &D) -> Result<(), ShapeError>
where
D: Dimension,
{
for &stride in strides.slice() {
if (stride as isize) < 0 {
return Err(from_kind(ErrorKind::Unsupported));
}
}
Ok(())
}

/// Implementation-specific extensions to `Dimension`
pub trait DimensionExt {
// note: many extensions go in the main trait if they need to be special-
Expand Down
12 changes: 12 additions & 0 deletions src/impl_raw_views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ where
/// [`.offset()`] regardless of the starting point due to past offsets.
///
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
///
/// * Strides must be non-negative.
///
/// This function can use debug assertions to check some of these requirements,
/// but it's not a complete check.
///
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *const A) -> Self
Expand All @@ -73,6 +78,7 @@ where
if cfg!(debug_assertions) {
assert!(!ptr.is_null(), "The pointer must be non-null.");
if let Strides::Custom(strides) = &shape.strides {
dimension::strides_non_negative(strides).unwrap();
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
} else {
dimension::size_of_shape_checked(&dim).unwrap();
Expand Down Expand Up @@ -202,6 +208,11 @@ where
/// [`.offset()`] regardless of the starting point due to past offsets.
///
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
///
/// * Strides must be non-negative.
///
/// This function can use debug assertions to check some of these requirements,
/// but it's not a complete check.
///
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *mut A) -> Self
Expand All @@ -213,6 +224,7 @@ where
if cfg!(debug_assertions) {
assert!(!ptr.is_null(), "The pointer must be non-null.");
if let Strides::Custom(strides) = &shape.strides {
dimension::strides_non_negative(strides).unwrap();
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
} else {
dimension::size_of_shape_checked(&dim).unwrap();
Expand Down
10 changes: 10 additions & 0 deletions src/impl_views/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ where
///
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
///
/// * Strides must be non-negative.
///
/// This function can use debug assertions to check some of these requirements,
/// but it's not a complete check.
///
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *const A) -> Self
where
Expand Down Expand Up @@ -188,6 +193,11 @@ where
///
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
///
/// * Strides must be non-negative.
///
/// This function can use debug assertions to check some of these requirements,
/// but it's not a complete check.
///
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *mut A) -> Self
where
Expand Down
15 changes: 15 additions & 0 deletions tests/raw_views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,18 @@ fn raw_view_deref_into_view_misaligned() {
let data: [u16; 2] = [0x0011, 0x2233];
misaligned_deref(&data);
}

#[test]
#[cfg(debug_assertions)]
#[should_panic = "Unsupported"]
fn raw_view_negative_strides() {
fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> {
let ptr: *const u16 = data.as_ptr();
unsafe {
let raw_view = RawArrayView::from_shape_ptr(1.strides((-1isize) as usize), ptr);
raw_view.deref_into_view()
}
}
let data: [u16; 2] = [0x0011, 0x2233];
misaligned_deref(&data);
}

0 comments on commit e6a4f10

Please sign in to comment.