Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions vortex-array/src/arrays/filter/execute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,23 @@ fn take_impl(

match indices_validity.bit_buffer() {
AllOr::All => {
let result_dtype = array
.dtype()
.union_nullability(indices.dtype().nullability());

if let Some((start, end)) =
contiguous_sequential_take_range_indices(array.filter_mask(), indices)?
{
return array.child().slice(start..end);
return array.child().slice(start..end)?.cast(result_dtype);
}

if let Some(take_len) = sequential_take_len(indices, array.len())? {
if take_len == 0 {
return Ok(Canonical::empty(array.child().dtype()).into_array());
return Ok(Canonical::empty(&result_dtype).into_array());
}
let rank_mask = Mask::from_slices(array.len(), vec![(0, take_len)]);
let mask = array.filter_mask().intersect_by_rank(&rank_mask);
return array.child().filter(mask);
return array.child().filter(mask)?.cast(result_dtype);
}

let translated = translate_indices(array.filter_mask(), indices, None)?;
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/arrays/filter/execute/take/fixed_width.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ where
};

let output_validity = if child_validity.no_nulls() {
ranks_validity
ranks_validity.union_nullability(child_validity.nullability())
} else {
let translated_indices =
PrimitiveArray::new(translate_ranks(filter, ranks, None)?, ranks_validity)
Expand All @@ -121,7 +121,7 @@ where
let taken = take_filtered_values(child.as_slice(), filter, ranks, Some(buf))?;

let output_validity = if child_validity.no_nulls() {
ranks_validity
ranks_validity.union_nullability(child_validity.nullability())
} else {
let translated_indices =
PrimitiveArray::new(translate_ranks(filter, ranks, Some(buf))?, ranks_validity)
Expand Down
48 changes: 48 additions & 0 deletions vortex-array/src/arrays/filter/execute/take/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,28 @@ fn test_take_execute_kernel_handles_nullable_primitive_filter_child() -> VortexR
Ok(())
}

#[test]
fn test_take_execute_kernel_preserves_nullable_all_valid_fixed_width_child() -> VortexResult<()> {
let filter = FilterArray::new(
PrimitiveArray::new(buffer![10i32, 20, 30], Validity::AllValid).into_array(),
Mask::new_true(3),
)
.into_array();
let parent = DictArray::try_new(buffer![0u64, 1].into_array(), filter.clone())?.into_array();
let mut ctx = ExecutionCtx::new(VortexSession::empty());

let result = filter
.execute_parent(&parent, 1, &mut ctx)?
.expect("filter child should execute its take parent");

assert_eq!(result.dtype(), parent.dtype());
assert_arrays_eq!(
result.execute::<RecursiveCanonical>(&mut ctx)?.0,
PrimitiveArray::new(buffer![10i32, 20], Validity::AllValid).into_array()
);
Ok(())
}

#[test]
fn test_take_execute_kernel_handles_nullable_decimal_filter_child() -> VortexResult<()> {
let decimal_dtype = DecimalDType::new(19, 2);
Expand Down Expand Up @@ -431,6 +453,32 @@ fn test_take_execute_kernel_handles_string_filter_child() -> VortexResult<()> {
)
}

#[test]
fn test_take_execute_kernel_preserves_nullable_indices_dtype_fast_path() -> VortexResult<()> {
let filter = FilterArray::new(
VarBinViewArray::from_iter_str(["a", "b", "c"]).into_array(),
Mask::new_true(3),
)
.into_array();
let parent = DictArray::try_new(
PrimitiveArray::new(buffer![0u64, 1], Validity::AllValid).into_array(),
filter.clone(),
)?
.into_array();
let mut ctx = ExecutionCtx::new(VortexSession::empty());

let result = filter
.execute_parent(&parent, 1, &mut ctx)?
.expect("filter child should execute its nullable take parent");

assert_eq!(result.dtype(), parent.dtype());
assert_arrays_eq!(
result.execute::<RecursiveCanonical>(&mut ctx)?.0,
VarBinViewArray::from_iter_nullable_str([Some("a"), Some("b")]).into_array()
);
Ok(())
}

#[test]
fn test_take_execute_kernel_handles_struct_filter_child() -> VortexResult<()> {
assert_take_execute_maps_child_dtype(
Expand Down
Loading