Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion encodings/runend/src/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ mod test {
use rstest::rstest;
use vortex_array::Array;
use vortex_array::ArrayRef;
use vortex_array::Canonical;
use vortex_array::IntoArray;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::assert_arrays_eq;
use vortex_array::compute::conformance::take::test_take_conformance;
Expand Down Expand Up @@ -126,7 +129,10 @@ mod test {
#[test]
#[should_panic]
fn ree_take_out_of_bounds() {
take(ree_array().as_ref(), buffer![12].into_array().as_ref()).unwrap();
let _array = take(ree_array().as_ref(), buffer![12].into_array().as_ref())
.unwrap()
.execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
.unwrap();
}

#[test]
Expand Down
8 changes: 7 additions & 1 deletion encodings/sequence/src/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ impl TakeExecute for SequenceVTable {
#[cfg(test)]
mod test {
use rstest::rstest;
use vortex_array::Canonical;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::compute::take;
use vortex_dtype::Nullability;

Expand Down Expand Up @@ -168,6 +171,9 @@ mod test {
fn test_bounds_check() {
let array = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]);
let _array = take(array.as_ref(), indices.as_ref()).unwrap();
let _array = take(array.as_ref(), indices.as_ref())
.unwrap()
.execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
.unwrap();
}
}
19 changes: 10 additions & 9 deletions vortex-array/src/arrays/chunked/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use crate::arrays::PrimitiveArray;
use crate::arrays::TakeExecute;
use crate::arrays::chunked::ChunkedArray;
use crate::compute::cast;
use crate::compute::take;
use crate::executor::ExecutionCtx;
use crate::validity::Validity;

Expand Down Expand Up @@ -47,10 +46,11 @@ fn take_chunked(array: &ChunkedArray, indices: &dyn Array) -> VortexResult<Array
indices_in_chunk.clone().freeze(),
Validity::from_mask(indices_mask.slice(start..stop), nullability),
);
chunks.push(take(
array.chunk(prev_chunk_idx),
indices_in_chunk_array.as_ref(),
)?);
chunks.push(
array
.chunk(prev_chunk_idx)
.take(indices_in_chunk_array.into_array())?,
);
indices_in_chunk.clear();
start = stop;
}
Expand All @@ -65,10 +65,11 @@ fn take_chunked(array: &ChunkedArray, indices: &dyn Array) -> VortexResult<Array
indices_in_chunk.freeze(),
Validity::from_mask(indices_mask.slice(start..stop), nullability),
);
chunks.push(take(
array.chunk(prev_chunk_idx),
indices_in_chunk_array.as_ref(),
)?);
chunks.push(
array
.chunk(prev_chunk_idx)
.take(indices_in_chunk_array.into_array())?,
);
}

// SAFETY: take on chunks that all have same DType retains same DType
Expand Down
40 changes: 10 additions & 30 deletions vortex-array/src/arrays/decimal/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,65 +53,45 @@ mod tests {
use vortex_buffer::Buffer;
use vortex_buffer::buffer;
use vortex_dtype::DecimalDType;
use vortex_dtype::Nullability;
use vortex_scalar::DecimalValue;
use vortex_scalar::Scalar;

use crate::IntoArray;
use crate::arrays::DecimalArray;
use crate::arrays::DecimalVTable;
use crate::arrays::PrimitiveArray;
use crate::assert_arrays_eq;
use crate::compute::conformance::take::test_take_conformance;
use crate::compute::take;
use crate::validity::Validity;

#[test]
fn test_take() {
let ddtype = DecimalDType::new(19, 1);
let array = DecimalArray::new(
buffer![10i128, 11i128, 12i128, 13i128],
DecimalDType::new(19, 1),
ddtype,
Validity::NonNullable,
);

let indices = buffer![0, 2, 3].into_array();
let taken = take(array.as_ref(), indices.as_ref()).unwrap();
let taken_decimals = taken.as_::<DecimalVTable>();
assert_eq!(
taken_decimals.buffer::<i128>(),
buffer![10i128, 12i128, 13i128]
);
assert_eq!(taken_decimals.decimal_dtype(), DecimalDType::new(19, 1));

let expected = DecimalArray::from_iter([10i128, 12, 13], ddtype);
assert_arrays_eq!(expected, taken);
}

#[test]
fn test_take_null_indices() {
let ddtype = DecimalDType::new(19, 1);
let array = DecimalArray::new(
buffer![i128::MAX, 11i128, 12i128, 13i128],
DecimalDType::new(19, 1),
ddtype,
Validity::NonNullable,
);

let indices = PrimitiveArray::from_option_iter([None, Some(2), Some(3)]).into_array();
let taken = take(array.as_ref(), indices.as_ref()).unwrap();

assert!(taken.scalar_at(0).unwrap().is_null());
assert_eq!(
taken.scalar_at(1).unwrap(),
Scalar::decimal(
DecimalValue::I128(12i128),
array.decimal_dtype(),
Nullability::Nullable
)
);

assert_eq!(
taken.scalar_at(2).unwrap(),
Scalar::decimal(
DecimalValue::I128(13i128),
array.decimal_dtype(),
Nullability::Nullable
)
);
let expected = DecimalArray::from_option_iter([None, Some(12i128), Some(13)], ddtype);
assert_arrays_eq!(expected, taken);
}

#[rstest]
Expand Down
101 changes: 25 additions & 76 deletions vortex-array/src/arrays/fixed_size_list/tests/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use super::common::create_single_element_fsl;
use crate::Array;
use crate::IntoArray;
use crate::arrays::FixedSizeListArray;
use crate::arrays::FixedSizeListVTable;
use crate::arrays::PrimitiveArray;
use crate::assert_arrays_eq;
use crate::builders::ArrayBuilder;
use crate::builders::FixedSizeListBuilder;
use crate::compute::conformance::take::test_take_conformance;
Expand All @@ -39,39 +39,20 @@ fn test_take_fsl_conformance(#[case] fsl: FixedSizeListArray) {

#[test]
fn test_take_basic_smoke_test() {
// Basic smoke test to ensure take works for FSL and preserves structure.
let elements = buffer![1i32, 2, 3, 4, 5, 6].into_array();
let fsl = FixedSizeListArray::new(elements.into_array(), 2, Validity::NonNullable, 3);

let indices = buffer![2u32, 0, 1].into_array();
let result = take(fsl.as_ref(), indices.as_ref()).unwrap();
let result_fsl = result.as_::<FixedSizeListVTable>();

assert_eq!(result_fsl.len(), 3, "Wrong number of lists after take");
assert_eq!(result_fsl.list_size(), 2, "list_size should be preserved");

// First list should be the original third list [5, 6].
let first = result_fsl.fixed_size_list_elements_at(0).unwrap();
assert_eq!(
first.scalar_at(0).unwrap(),
5i32.into(),
"Wrong value at [2][0] after take"
);
assert_eq!(
first.scalar_at(1).unwrap(),
6i32.into(),
"Wrong value at [2][1] after take"
// Expected: [[5,6], [1,2], [3,4]]
let expected = FixedSizeListArray::new(
buffer![5i32, 6, 1, 2, 3, 4].into_array(),
2,
Validity::NonNullable,
3,
);

// Second list should be the original first list [1, 2].
let second = result_fsl.fixed_size_list_elements_at(1).unwrap();
assert_eq!(second.scalar_at(0).unwrap(), 1i32.into());
assert_eq!(second.scalar_at(1).unwrap(), 2i32.into());

// Third list should be the original second list [3, 4].
let third = result_fsl.fixed_size_list_elements_at(2).unwrap();
assert_eq!(third.scalar_at(0).unwrap(), 3i32.into());
assert_eq!(third.scalar_at(1).unwrap(), 4i32.into());
assert_arrays_eq!(expected, result);
}

// Parameterized test for FSL-specific degenerate (list_size=0) cases.
Expand Down Expand Up @@ -110,73 +91,44 @@ fn test_take_degenerate_lists(
// Also test the specific behavior.
let indices_array = PrimitiveArray::from_option_iter(indices);
let result = take(fsl.as_ref(), indices_array.as_ref()).unwrap();
let result_fsl = result.as_::<FixedSizeListVTable>();

assert_eq!(result_fsl.len(), expected_len);
assert_eq!(result_fsl.list_size(), 0);
assert_eq!(result_fsl.elements().len(), 0);

// Check nullability of results.
assert_eq!(result.len(), expected_len);
for (i, expected_null) in expected_nulls.iter().enumerate() {
assert_eq!(result_fsl.scalar_at(i).unwrap().is_null(), *expected_null);
assert_eq!(result.scalar_at(i).unwrap().is_null(), *expected_null);
}
}

#[test]
fn test_take_large_list_size() {
// Test FSL-specific behavior with large list sizes.
// This tests the performance characteristics specific to FSL's element expansion.
let elements = buffer![0i32..300].into_array();
let fsl = FixedSizeListArray::new(elements, 100, Validity::NonNullable, 3);

let indices = buffer![2u16, 0].into_array();
let result = take(fsl.as_ref(), indices.as_ref()).unwrap();
let result_fsl = result.as_::<FixedSizeListVTable>();

assert_eq!(result_fsl.len(), 2);
assert_eq!(result_fsl.list_size(), 100);

// First list should be [200..300].
let first = result_fsl.fixed_size_list_elements_at(0).unwrap();
for i in 0..100i32 {
assert_eq!(first.scalar_at(i as usize).unwrap(), (200 + i).into());
}

// Second list should be [0..100].
let second = result_fsl.fixed_size_list_elements_at(1).unwrap();
for i in 0..100i32 {
assert_eq!(second.scalar_at(i as usize).unwrap(), i.into());
}
// Expected: [[200..300], [0..100]]
let expected_elems = PrimitiveArray::from_iter((200i32..300).chain(0..100)).into_array();
let expected = FixedSizeListArray::new(expected_elems, 100, Validity::NonNullable, 2);
assert_arrays_eq!(expected, result);
}

#[test]
fn test_take_fsl_with_null_indices_preserves_elements() {
// FSL-specific test: verify that null indices don't affect element array indexing.
let elements = buffer![1i32, 2, 3, 4, 5, 6].into_array();
let fsl = FixedSizeListArray::new(elements.into_array(), 2, Validity::NonNullable, 3);

// Create indices with nulls: [1, null, 0].
// Indices with nulls: [1, null, 0].
let indices = PrimitiveArray::from_option_iter([Some(1u32), None, Some(0)]);
let result = take(fsl.as_ref(), indices.as_ref()).unwrap();
let result_fsl = result.as_::<FixedSizeListVTable>();

assert_eq!(result_fsl.len(), 3);
assert_eq!(result_fsl.list_size(), 2);

// First list should be [3, 4].
assert!(!result_fsl.scalar_at(0).unwrap().is_null());
let first = result_fsl.fixed_size_list_elements_at(0).unwrap();
assert_eq!(first.scalar_at(0).unwrap(), 3i32.into());
assert_eq!(first.scalar_at(1).unwrap(), 4i32.into());

// Second list should be null.
assert!(result_fsl.scalar_at(1).unwrap().is_null());

// Third list should be [1, 2].
assert!(!result_fsl.scalar_at(2).unwrap().is_null());
let third = result_fsl.fixed_size_list_elements_at(2).unwrap();
assert_eq!(third.scalar_at(0).unwrap(), 1i32.into());
assert_eq!(third.scalar_at(1).unwrap(), 2i32.into());
// Expected: [[3,4], null, [1,2]]
let expected = FixedSizeListArray::new(
buffer![3i32, 4, 0, 0, 1, 2].into_array(),
2,
Validity::from_iter([true, false, true]),
3,
);
assert_arrays_eq!(expected, result);
}

// Parameterized test for nullable array scenarios that are specific to FSL's implementation.
Expand Down Expand Up @@ -234,12 +186,9 @@ fn test_take_nullable_arrays_fsl_specific(
// Create indices (with possible nulls).
let indices_array = PrimitiveArray::from_option_iter(indices.clone());
let result = take(fsl.as_ref(), indices_array.as_ref()).unwrap();
let result_fsl = result.as_::<FixedSizeListVTable>();

assert_eq!(result_fsl.len(), indices.len());

// Check nullability of results.
assert_eq!(result.len(), indices.len());
for (i, expected_null) in expected_nulls.iter().enumerate() {
assert_eq!(result_fsl.scalar_at(i).unwrap().is_null(), *expected_null);
assert_eq!(result.scalar_at(i).unwrap().is_null(), *expected_null);
}
}
13 changes: 2 additions & 11 deletions vortex-array/src/arrays/list/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use vortex_error::VortexResult;

use crate::Array;
use crate::ArrayRef;
use crate::IntoArray;
use crate::ToCanonical;
use crate::arrays::ListArray;
use crate::arrays::ListVTable;
Expand Down Expand Up @@ -101,11 +100,7 @@ fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
let elements_to_take = elements_to_take.finish();
let new_offsets = new_offsets.finish();

let new_elements = array
.elements()
.take(elements_to_take.to_array())?
.to_canonical()?
.into_array();
let new_elements = array.elements().take(elements_to_take.to_array())?;

Ok(ListArray::try_new(
new_elements,
Expand Down Expand Up @@ -173,11 +168,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPTy

let elements_to_take = elements_to_take.finish();
let new_offsets = new_offsets.finish();
let new_elements = array
.elements()
.take(elements_to_take.to_array())?
.to_canonical()?
.into_array();
let new_elements = array.elements().take(elements_to_take.to_array())?;

Ok(ListArray::try_new(
new_elements,
Expand Down
15 changes: 8 additions & 7 deletions vortex-array/src/arrays/varbin/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ mod tests {
use crate::IntoArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::VarBinArray;
use crate::arrays::VarBinViewVTable;
use crate::arrays::VarBinViewArray;
use crate::assert_arrays_eq;
use crate::compute::conformance::take::test_take_conformance;
use crate::compute::take;
use crate::validity::Validity;
Expand Down Expand Up @@ -313,13 +314,13 @@ mod tests {
Validity::NonNullable,
);

let indices = buffer![0u32, 0u32, 0u32].into_array();
let indices = buffer![0u32; 3].into_array();
let taken = take(array.as_ref(), indices.as_ref()).unwrap();

let taken_view = taken.as_::<VarBinViewVTable>();
assert_eq!(taken_view.len(), 3);
assert_eq!(taken_view.bytes_at(0).as_slice(), scream.as_bytes());
assert_eq!(taken_view.bytes_at(1).as_slice(), scream.as_bytes());
assert_eq!(taken_view.bytes_at(2).as_slice(), scream.as_bytes());
let expected = VarBinViewArray::from_iter(
[Some(scream.clone()), Some(scream.clone()), Some(scream)],
DType::Utf8(Nullability::NonNullable),
);
assert_arrays_eq!(expected, taken);
}
}
Loading
Loading