diff --git a/vortex-array/src/aggregate_fn/fns/min_max/mod.rs b/vortex-array/src/aggregate_fn/fns/min_max/mod.rs index f91ae351392..540b5608e28 100644 --- a/vortex-array/src/aggregate_fn/fns/min_max/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/min_max/mod.rs @@ -61,8 +61,8 @@ pub fn min_max(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult DType { ) } +fn minmax_supported_dtype(input_dtype: &DType) -> bool { + match input_dtype { + DType::Bool(_) + | DType::Primitive(..) + | DType::Decimal(..) + | DType::Utf8(..) + | DType::Binary(..) + | DType::Extension(..) => true, + DType::List(element_dtype, _) => minmax_supported_dtype(element_dtype), + DType::FixedSizeList(element_dtype, ..) => minmax_supported_dtype(element_dtype), + _ => false, + } +} + +/// Returns whether [`min_max`] can currently compute extrema for this logical dtype. +/// +/// This is intentionally narrower than [`minmax_supported_dtype`]. List and fixed-size-list +/// extrema have a defined output dtype for aggregate expression lowering, but the accumulator does +/// not yet implement lexicographic list comparison. +fn minmax_compute_supported_dtype(input_dtype: &DType) -> bool { + matches!( + input_dtype, + DType::Bool(_) + | DType::Primitive(..) + | DType::Decimal(..) + | DType::Utf8(..) + | DType::Binary(..) + | DType::Extension(..) + ) +} + impl AggregateFnVTable for MinMax { type Options = EmptyOptions; type Partial = MinMaxPartial; @@ -175,15 +206,7 @@ impl AggregateFnVTable for MinMax { } fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { - match input_dtype { - DType::Bool(_) - | DType::Primitive(..) - | DType::Decimal(..) - | DType::Utf8(..) - | DType::Binary(..) - | DType::Extension(..) => Some(make_minmax_dtype(input_dtype)), - _ => None, - } + minmax_supported_dtype(input_dtype).then(|| make_minmax_dtype(input_dtype)) } fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { @@ -278,6 +301,8 @@ impl AggregateFnVTable for MinMax { #[cfg(test)] mod tests { + use std::sync::Arc; + use vortex_buffer::BitBuffer; use vortex_buffer::buffer; use vortex_error::VortexExpect; @@ -298,6 +323,8 @@ mod tests { use crate::arrays::ChunkedArray; use crate::arrays::ConstantArray; use crate::arrays::DecimalArray; + use crate::arrays::FixedSizeListArray; + use crate::arrays::ListArray; use crate::arrays::NullArray; use crate::arrays::PrimitiveArray; use crate::arrays::VarBinArray; @@ -570,6 +597,47 @@ mod tests { Ok(()) } + #[test] + fn list_and_fixed_size_list_return_dtype() { + let element_dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let list_dtype = DType::List(Arc::new(element_dtype.clone()), Nullability::Nullable); + let fixed_size_list_dtype = + DType::FixedSizeList(Arc::new(element_dtype), 1, Nullability::Nullable); + + assert_eq!( + MinMax.return_dtype(&EmptyOptions, &list_dtype), + Some(make_minmax_dtype(&list_dtype)) + ); + assert_eq!( + MinMax.return_dtype(&EmptyOptions, &fixed_size_list_dtype), + Some(make_minmax_dtype(&fixed_size_list_dtype)) + ); + } + + #[test] + fn list_and_fixed_size_list_min_max_returns_none() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + let list_array = ListArray::try_new( + buffer![1i32, 2, 3].into_array(), + buffer![0u32, 2, 3].into_array(), + Validity::NonNullable, + )? + .into_array(); + assert_eq!(min_max(&list_array, &mut ctx)?, None); + + let fixed_size_list_array = FixedSizeListArray::try_new( + buffer![1i32, 2, 3, 4].into_array(), + 2, + Validity::NonNullable, + 2, + )? + .into_array(); + assert_eq!(min_max(&fixed_size_list_array, &mut ctx)?, None); + + Ok(()) + } + use crate::dtype::half::f16; #[test] diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index a3719e8e17e..96154e69571 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -179,13 +179,13 @@ impl ZoneMap { return Ok(lit(0u64)); } - // When the aggregate function does not support the column dtype the stat is - // not computable on this layout, so treat it the same as a missing stat and - // lower to a nullable "unknown" rather than failing the whole scan. Min/Max - // share their input dtype, so falling back to `input_dtype.as_nullable()` - // keeps the rewrite well-typed for the most common case. - let Some(return_dtype) = options.aggregate_fn().return_dtype(&input_dtype) else { - return Ok(null_expr(input_dtype.as_nullable())); + let return_dtype = match options.aggregate_fn().return_dtype(&input_dtype) { + Some(return_dtype) => return_dtype, + None => vortex_bail!( + "Aggregate function {} does not support input dtype {}", + options.aggregate_fn(), + input_dtype + ), }; if !input_is_root { @@ -277,6 +277,7 @@ mod tests { use vortex_array::arrays::StructArray; use vortex_array::assert_arrays_eq; use vortex_array::dtype::DType; + use vortex_array::dtype::DecimalDType; use vortex_array::dtype::FieldNames; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; @@ -606,14 +607,14 @@ mod tests { } #[test] - fn unsupported_aggregate_input_dtype_lowers_to_unknown() { - // Regression test for issue #8189: a pruning predicate that contains a - // `StatFn(Max, $)` (or `Min`) over a column dtype that the aggregate - // function does not support (e.g. `FixedSizeList`) used to - // bail out of `lower_stat_fn` with "Aggregate function vortex.max() does - // not support input dtype ...", panicking the scan instead of treating - // the stat as unknown. - let elem_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)); + fn fixed_size_list_min_max_stat_fn_lowers_to_unknown_mask() { + // Regression test for issue #8189: Min/Max is defined for FixedSizeList + // when T is orderable. If the zone map does not carry the requested stat, + // lowering should produce an unknown typed null rather than rejecting the dtype. + let elem_dtype = Arc::new(DType::Decimal( + DecimalDType::new(10, 2), + Nullability::Nullable, + )); let column_dtype = DType::FixedSizeList(elem_dtype, 1, Nullability::Nullable); let zone_map = ZoneMap::try_new( @@ -630,12 +631,36 @@ mod tests { .expect("max should have an aggregate function"); let predicate = is_null(vortex_array::stats::stat(root(), max_fn)); - // Must not panic; the unsupported StatFn lowers to a nullable null - // literal, so `is_null(...)` is true for every zone. + // Missing StatFn lowers to a nullable null literal, so `is_null(...)` is true for every zone. let mask = zone_map.prune(&predicate, &SESSION).unwrap(); assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([true, true, true])); } + #[test] + fn unsupported_aggregate_input_dtype_errors() { + let zone_map = ZoneMap::try_new( + DType::Null, + StructArray::try_new(FieldNames::empty(), vec![], 3, Validity::NonNullable).unwrap(), + Arc::new([]), + 4, + 10, + ) + .unwrap(); + + let max_fn = Stat::Max + .aggregate_fn() + .expect("max should have an aggregate function"); + let predicate = is_null(vortex_array::stats::stat(root(), max_fn)); + let error = zone_map.prune(&predicate, &SESSION).unwrap_err(); + + assert!( + error + .to_string() + .contains("Aggregate function vortex.max() does not support input dtype null"), + "{error}" + ); + } + #[test] fn row_count_prunes_all_null_uniform_zones() { let zone_map = ZoneMap::try_new(