diff --git a/src/array.rs b/src/array.rs index a134c19..4ffb599 100644 --- a/src/array.rs +++ b/src/array.rs @@ -59,21 +59,54 @@ fn build_array_from_shape( ArrayView::::from_shape(shape, data).map_err(ActiveStorageError::ShapeInvalid) } -/// Returns an optional [ndarray] SliceInfo object corresponding to the selection. +/// Returns an array index in numpy semantics to an index with ndarray semantics. +/// +/// The resulting value will be clamped such that it is safe for indexing in ndarray. +/// This allows us to accept selections with NumPy's less restrictive semantics. +/// When the stride is negative (`reverse` is `true`), the result is offset by one to allow for +/// Numpy's non-inclusive start and inclusive end in this scenario. +/// +/// # Arguments +/// +/// * `index`: Selection index +/// * `length`: Length of corresponding axis +/// * `reverse`: Whether the stride is negative +fn to_ndarray_index(index: isize, length: usize, reverse: bool) -> isize { + let length_isize = length.try_into().expect("Length too large!"); + let result = if reverse { index + 1 } else { index }; + if index < 0 { + std::cmp::max(result + length_isize, 0) + } else { + std::cmp::min(result, length_isize) + } +} + +/// Convert a [crate::models::Slice] object with indices in numpy semantics to an +/// [ndarray::SliceInfoElem::Slice] with ndarray semantics. +/// +/// See [ndarray docs](https://docs.rs/ndarray/0.15.6/ndarray/macro.s.html#negative-step) for +/// information about ndarray's handling of negative strides. +fn to_ndarray_slice(slice: &models::Slice, length: usize) -> ndarray::SliceInfoElem { + let reverse = slice.stride < 0; + let start = to_ndarray_index(slice.start, length, reverse); + let end = to_ndarray_index(slice.end, length, reverse); + let (start, end) = if reverse { (end, start) } else { (start, end) }; + ndarray::SliceInfoElem::Slice { + start, + end: Some(end), + step: slice.stride, + } +} + +/// Returns an [ndarray] SliceInfo object corresponding to the selection. pub fn build_slice_info( selection: &Option>, shape: &[usize], ) -> ndarray::SliceInfo, ndarray::IxDyn, ndarray::IxDyn> { match selection { Some(selection) => { - let si: Vec = selection - .iter() - .map(|slice| ndarray::SliceInfoElem::Slice { - // FIXME: usize should be isize? - start: slice.start as isize, - end: Some(slice.end as isize), - step: slice.stride as isize, - }) + let si: Vec = std::iter::zip(selection, shape) + .map(|(slice, length)| to_ndarray_slice(slice, *length)) .collect(); ndarray::SliceInfo::try_from(si).expect("SliceInfo should not fail for IxDyn") } @@ -81,7 +114,6 @@ pub fn build_slice_info( let si: Vec = shape .iter() .map(|_| ndarray::SliceInfoElem::Slice { - // FIXME: usize should be isize? start: 0, end: None, step: 1, @@ -309,7 +341,7 @@ mod tests { #[test] fn build_slice_info_1d_selection() { let selection = Some(vec![models::Slice::new(0, 1, 1)]); - let shape = []; + let shape = [1]; let slice_info = build_slice_info::(&selection, &shape); assert_eq!( [ndarray::SliceInfoElem::Slice { @@ -321,6 +353,51 @@ mod tests { ); } + #[test] + fn build_slice_info_1d_selection_negative_stride() { + let selection = Some(vec![models::Slice::new(1, 0, -1)]); + let shape = [1]; + let slice_info = build_slice_info::(&selection, &shape); + assert_eq!( + [ndarray::SliceInfoElem::Slice { + start: 1, + end: Some(1), + step: -1 + }], + slice_info.as_ref() + ); + } + + #[test] + fn build_slice_info_1d_selection_negative_start() { + let selection = Some(vec![models::Slice::new(-1, 1, 1)]); + let shape = [1]; + let slice_info = build_slice_info::(&selection, &shape); + assert_eq!( + [ndarray::SliceInfoElem::Slice { + start: 0, + end: Some(1), + step: 1 + }], + slice_info.as_ref() + ); + } + + #[test] + fn build_slice_info_1d_selection_negative_end() { + let selection = Some(vec![models::Slice::new(0, -1, 1)]); + let shape = [1]; + let slice_info = build_slice_info::(&selection, &shape); + assert_eq!( + [ndarray::SliceInfoElem::Slice { + start: 0, + end: Some(0), + step: 1 + }], + slice_info.as_ref() + ); + } + #[test] fn build_slice_info_2d_no_selection() { let selection = None; @@ -349,7 +426,7 @@ mod tests { models::Slice::new(0, 1, 1), models::Slice::new(0, 1, 1), ]); - let shape = []; + let shape = [1, 1]; let slice_info = build_slice_info::(&selection, &shape); assert_eq!( [ @@ -405,4 +482,136 @@ mod tests { let array = build_array::(&request_data, &bytes).unwrap(); assert_eq!(array![[0x04030201_i64], [0x08070605_i64]].into_dyn(), array); } + + // Helper function for tests that slice an array using a selection. + fn test_selection(slice: models::Slice, expected: Array1) { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let request_data = models::RequestData { + source: Url::parse("http://example.com").unwrap(), + bucket: "bar".to_string(), + object: "baz".to_string(), + dtype: models::DType::Uint32, + offset: None, + size: None, + shape: None, + order: None, + selection: None, + }; + let bytes = Bytes::copy_from_slice(&data); + let array = build_array::(&request_data, &bytes).unwrap(); + let shape = vec![2]; + let slice_info = build_slice_info::(&Some(vec![slice]), &shape); + let sliced = array.slice(slice_info); + assert_eq!(sliced, expected.into_dyn().view()); + } + + #[test] + fn build_array_with_selection_all() { + test_selection( + models::Slice::new(0, 2, 1), + array![0x04030201_u32, 0x08070605_u32], + ) + } + + #[test] + fn build_array_with_selection_negative_start() { + test_selection( + models::Slice::new(-2, 2, 1), + array![0x04030201_u32, 0x08070605_u32], + ) + } + + #[test] + fn build_array_with_selection_start_lt_negative_length() { + test_selection( + models::Slice::new(-3, 2, 1), + array![0x04030201_u32, 0x08070605_u32], + ) + } + + #[test] + fn build_array_with_selection_start_eq_length() { + test_selection(models::Slice::new(2, 2, 1), array![]) + } + + #[test] + fn build_array_with_selection_start_gt_length() { + test_selection(models::Slice::new(3, 2, 1), array![]) + } + + #[test] + fn build_array_with_selection_negative_end() { + test_selection(models::Slice::new(0, -1, 1), array![0x04030201_u32]) + } + + #[test] + fn build_array_with_selection_end_lt_negative_length() { + test_selection(models::Slice::new(0, -3, 1), array![]) + } + + #[test] + fn build_array_with_selection_end_gt_length() { + test_selection( + models::Slice::new(0, 3, 1), + array![0x04030201_u32, 0x08070605_u32], + ) + } + + #[test] + fn build_array_with_selection_all_negative_stride() { + // Need to end at -3 to read first item. + // translates to [0, 2] + test_selection( + models::Slice::new(1, -3, -1), + array![0x08070605_u32, 0x04030201_u32], + ) + } + + #[test] + fn build_array_with_selection_negative_start_negative_stride() { + // translates to [0, 2] + test_selection( + models::Slice::new(-1, -3, -1), + array![0x08070605_u32, 0x04030201_u32], + ) + } + + #[test] + fn build_array_with_selection_start_lt_negative_length_negative_stride() { + // translates to [1, 0] + test_selection(models::Slice::new(-3, 0, -1), array![]) + } + + #[test] + fn build_array_with_selection_start_eq_length_negative_stride() { + // translates to [2, 2] + test_selection(models::Slice::new(2, 1, -1), array![]) + } + + #[test] + fn build_array_with_selection_start_gt_length_negative_stride() { + // translates to [2, 2] + test_selection(models::Slice::new(3, 1, -1), array![]) + } + + #[test] + fn build_array_with_selection_negative_end_negative_stride() { + // translates to [2, 2] + test_selection(models::Slice::new(2, -1, -1), array![]) + } + + #[test] + fn build_array_with_selection_end_lt_negative_length_negative_stride() { + // translates to [0, 2] + test_selection( + models::Slice::new(1, -3, -1), + array![0x08070605_u32, 0x04030201_u32], + ) + } + + #[test] + fn build_array_with_selection_end_gt_length_negative_stride() { + // translates to [1, 2] + test_selection(models::Slice::new(3, 0, -1), array![0x08070605_u32]) + } } diff --git a/src/models.rs b/src/models.rs index e447a3f..5dc2c86 100644 --- a/src/models.rs +++ b/src/models.rs @@ -50,6 +50,19 @@ pub enum Order { } /// A slice of a single dimension of an array +/// +/// The API uses NumPy slice semantics: +/// +/// When start or end is negative: +/// * positive_start = start + length +/// * positive_end = end + length +/// Start and end are clamped: +/// * positive_start = min(positive_start, 0) +/// * positive_end + max(positive_end, length) +/// When the stride is positive: +/// * positive_start <= i < positive_end +/// When the stride is negative: +/// * positive_end <= i < positive_start // NOTE: In serde, structs can be deserialised from sequences or maps. This allows us to support // the [, , ] API, with the convenience of named fields. #[derive(Clone, Copy, Debug, Deserialize, PartialEq, Serialize, Validate)] @@ -57,18 +70,17 @@ pub enum Order { #[validate(schema(function = "validate_slice"))] pub struct Slice { /// Start of the slice - pub start: usize, + pub start: isize, /// End of the slice - pub end: usize, + pub end: isize, /// Stride size - #[validate(range(min = 1, message = "stride must be greater than 0"))] - pub stride: usize, + pub stride: isize, } impl Slice { /// Return a new Slice object. #[allow(dead_code)] - pub fn new(start: usize, end: usize, stride: usize) -> Self { + pub fn new(start: isize, end: isize, stride: isize) -> Self { Slice { start, end, stride } } } @@ -118,10 +130,9 @@ fn validate_shape(shape: &[usize]) -> Result<(), ValidationError> { /// Validate an array slice fn validate_slice(slice: &Slice) -> Result<(), ValidationError> { - if slice.end <= slice.start { - let mut error = ValidationError::new("Selection end must be greater than start"); - error.add_param("start".into(), &slice.start); - error.add_param("end".into(), &slice.end); + if slice.stride == 0 { + let mut error = ValidationError::new("Selection stride must not be equal to zero"); + error.add_param("stride".into(), &slice.stride); return Err(error); } Ok(()) @@ -138,23 +149,12 @@ fn validate_shape_selection( error.add_param("selection".into(), &selection.len()); return Err(error); } - for (shape_i, selection_i) in std::iter::zip(shape, selection) { - if selection_i.end > *shape_i { - let mut error = ValidationError::new( - "Selection end must be less than or equal to corresponding shape index", - ); - error.add_param("shape".into(), &shape_i); - error.add_param("selection".into(), &selection_i); - return Err(error); - } - } Ok(()) } /// Validate request data fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationError> { // Validation of multiple fields in RequestData. - // TODO: More validation of shape & selection vs. size if let Some(size) = &request_data.size { let dtype_size = request_data.dtype.size_of(); if size % dtype_size != 0 { @@ -478,7 +478,7 @@ mod tests { } #[test] - #[should_panic(expected = "stride must be greater than 0")] + #[should_panic(expected = "Selection stride must not be equal to zero")] fn test_invalid_selection2() { let mut request_data = get_test_request_data(); request_data.selection = Some(vec![Slice::new(1, 2, 0)]); @@ -486,10 +486,19 @@ mod tests { } #[test] - #[should_panic(expected = "Selection end must be greater than start")] - fn test_invalid_selection3() { + fn test_selection_end_lt_start() { + // Numpy sementics: start >= end yields an empty array let mut request_data = get_test_request_data(); - request_data.selection = Some(vec![Slice::new(1, 1, 1)]); + request_data.shape = Some(vec![1]); + request_data.selection = Some(vec![Slice::new(1, 0, 1)]); + request_data.validate().unwrap() + } + + #[test] + fn test_selection_negative_stride() { + let mut request_data = get_test_request_data(); + request_data.shape = Some(vec![1]); + request_data.selection = Some(vec![Slice::new(1, 0, -1)]); request_data.validate().unwrap() } @@ -511,16 +520,41 @@ mod tests { } #[test] - #[should_panic( - expected = "Selection end must be less than or equal to corresponding shape index" - )] + fn test_selection_start_gt_shape() { + // Numpy sementics: start > length yields an empty array + let mut request_data = get_test_request_data(); + request_data.shape = Some(vec![4]); + request_data.selection = Some(vec![Slice::new(5, 5, 1)]); + request_data.validate().unwrap() + } + + #[test] + fn test_selection_start_lt_negative_shape() { + // Numpy sementics: start < -length gets clamped to zero + let mut request_data = get_test_request_data(); + request_data.shape = Some(vec![4]); + request_data.selection = Some(vec![Slice::new(-5, 5, 1)]); + request_data.validate().unwrap() + } + + #[test] fn test_selection_end_gt_shape() { + // Numpy semantics: end > length gets clamped to length let mut request_data = get_test_request_data(); request_data.shape = Some(vec![4]); request_data.selection = Some(vec![Slice::new(1, 5, 1)]); request_data.validate().unwrap() } + #[test] + fn test_selection_end_lt_negative_shape() { + // Numpy semantics: end < -length gets clamped to zero + let mut request_data = get_test_request_data(); + request_data.shape = Some(vec![4]); + request_data.selection = Some(vec![Slice::new(1, -5, 1)]); + request_data.validate().unwrap() + } + #[test] #[should_panic(expected = "Selection requires shape to be specified")] fn test_selection_without_shape() {