Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions encodings/sparse/src/compute/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::compute::FilterKernel;
use vortex_array::compute::FilterKernelAdapter;
use vortex_array::register_kernel;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::SparseArray;
use crate::SparseVTable;

impl FilterKernel for SparseVTable {
fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
let new_length = mask.true_count();

let Some(new_patches) = array.patches().filter(mask)? else {
return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
};

Ok(
SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
.into_array(),
)
}
}

register_kernel!(FilterKernelAdapter(SparseVTable).lift());

#[cfg(test)]
mod tests {
use rstest::fixture;
use rstest::rstest;
use vortex_array::Array;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::assert_arrays_eq;
use vortex_array::compute::cast;
use vortex_array::compute::conformance::filter::test_filter_conformance;
use vortex_array::validity::Validity;
use vortex_buffer::buffer;
use vortex_dtype::DType;
use vortex_dtype::Nullability;
use vortex_dtype::PType;
use vortex_mask::Mask;
use vortex_scalar::Scalar;

use crate::SparseArray;

#[fixture]
fn array() -> ArrayRef {
SparseArray::try_new(
buffer![2u64, 9, 15].into_array(),
PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
20,
Scalar::null_typed::<i32>(),
)
.unwrap()
.into_array()
}

#[rstest]
fn test_filter(array: ArrayRef) {
let mut predicate = vec![false, false, true];
predicate.extend_from_slice(&[false; 17]);
let mask = Mask::from_iter(predicate);

let filtered_array = array.filter(mask).unwrap();

// Construct expected SparseArray: index 2 was kept, which had value 33.
// The new index is 0 (since it's the only element).
let expected = SparseArray::try_new(
buffer![0u64].into_array(),
PrimitiveArray::new(buffer![33_i32], Validity::AllValid).into_array(),
1,
Scalar::null_typed::<i32>(),
)
.unwrap();

assert_arrays_eq!(filtered_array, expected);
}

#[test]
fn true_fill_value() {
let mask = Mask::from_iter([false, true, false, true, false, true, true]);
let array = SparseArray::try_new(
buffer![0_u64, 3, 6].into_array(),
PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
7,
Scalar::null_typed::<i32>(),
)
.unwrap()
.into_array();

let filtered_array = array.filter(mask).unwrap();

// Original indices 0, 3, 6 with values 33, 44, 55.
// Mask keeps indices 1, 3, 5, 6 -> new indices 0, 1, 2, 3.
// Index 3 (value 44) maps to new index 1.
// Index 6 (value 55) maps to new index 3.
let expected = SparseArray::try_new(
buffer![1u64, 3].into_array(),
PrimitiveArray::new(buffer![44_i32, 55], Validity::AllValid).into_array(),
4,
Scalar::null_typed::<i32>(),
)
.unwrap();

assert_arrays_eq!(filtered_array, expected);
}

#[test]
fn test_filter_sparse_array() {
let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
test_filter_conformance(
SparseArray::try_new(
buffer![1u64, 2, 4].into_array(),
cast(
&buffer![100i32, 200, 300].into_array(),
null_fill_value.dtype(),
)
.unwrap(),
5,
null_fill_value,
)
.unwrap()
.as_ref(),
);

let ten_fill_value = Scalar::from(10i32);
test_filter_conformance(
SparseArray::try_new(
buffer![1u64, 2, 4].into_array(),
buffer![100i32, 200, 300].into_array(),
5,
ten_fill_value,
)
.unwrap()
.as_ref(),
)
}
}
115 changes: 1 addition & 114 deletions encodings/sparse/src/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1,27 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::compute::FilterKernel;
use vortex_array::compute::FilterKernelAdapter;
use vortex_array::register_kernel;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::SparseArray;
use crate::SparseVTable;

mod binary_numeric;
mod cast;
mod filter;
mod invert;
mod take;

impl FilterKernel for SparseVTable {
fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
let new_length = mask.true_count();

let Some(new_patches) = array.patches().filter(mask)? else {
return Ok(ConstantArray::new(array.fill_scalar().clone(), new_length).into_array());
};

Ok(
SparseArray::try_new_from_patches(new_patches, array.fill_scalar().clone())?
.into_array(),
)
}
}

register_kernel!(FilterKernelAdapter(SparseVTable).lift());

#[cfg(test)]
mod test {
use rstest::fixture;
use rstest::rstest;
use vortex_array::Array;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::assert_arrays_eq;
use vortex_array::compute::cast;
use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array;
use vortex_array::compute::conformance::filter::test_filter_conformance;
use vortex_array::compute::conformance::mask::test_mask_conformance;
use vortex_array::validity::Validity;
use vortex_buffer::buffer;
use vortex_dtype::DType;
use vortex_dtype::Nullability;
use vortex_dtype::PType;
use vortex_mask::Mask;
use vortex_scalar::Scalar;

use crate::SparseArray;
Expand All @@ -70,56 +38,6 @@ mod test {
.into_array()
}

#[rstest]
fn test_filter(array: ArrayRef) {
let mut predicate = vec![false, false, true];
predicate.extend_from_slice(&[false; 17]);
let mask = Mask::from_iter(predicate);

let filtered_array = array.filter(mask).unwrap();

// Construct expected SparseArray: index 2 was kept, which had value 33.
// The new index is 0 (since it's the only element).
let expected = SparseArray::try_new(
buffer![0u64].into_array(),
PrimitiveArray::new(buffer![33_i32], Validity::AllValid).into_array(),
1,
Scalar::null_typed::<i32>(),
)
.unwrap();

assert_arrays_eq!(filtered_array, expected);
}

#[test]
fn true_fill_value() {
let mask = Mask::from_iter([false, true, false, true, false, true, true]);
let array = SparseArray::try_new(
buffer![0_u64, 3, 6].into_array(),
PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
7,
Scalar::null_typed::<i32>(),
)
.unwrap()
.into_array();

let filtered_array = array.filter(mask).unwrap();

// Original indices 0, 3, 6 with values 33, 44, 55.
// Mask keeps indices 1, 3, 5, 6 -> new indices 0, 1, 2, 3.
// Index 3 (value 44) maps to new index 1.
// Index 6 (value 55) maps to new index 3.
let expected = SparseArray::try_new(
buffer![1u64, 3].into_array(),
PrimitiveArray::new(buffer![44_i32, 55], Validity::AllValid).into_array(),
4,
Scalar::null_typed::<i32>(),
)
.unwrap();

assert_arrays_eq!(filtered_array, expected);
}

#[rstest]
fn test_sparse_binary_numeric(array: ArrayRef) {
test_binary_numeric_array(array)
Expand Down Expand Up @@ -155,37 +73,6 @@ mod test {
.as_ref(),
)
}

#[test]
fn test_filter_sparse_array() {
let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
test_filter_conformance(
SparseArray::try_new(
buffer![1u64, 2, 4].into_array(),
cast(
&buffer![100i32, 200, 300].into_array(),
null_fill_value.dtype(),
)
.unwrap(),
5,
null_fill_value,
)
.unwrap()
.as_ref(),
);

let ten_fill_value = Scalar::from(10i32);
test_filter_conformance(
SparseArray::try_new(
buffer![1u64, 2, 4].into_array(),
buffer![100i32, 200, 300].into_array(),
5,
ten_fill_value,
)
.unwrap()
.as_ref(),
)
}
}

#[cfg(test)]
Expand Down
Loading
Loading