From d455af07706533ca784c42f301ff23a8d4608a34 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 24 Apr 2026 10:07:30 -0400 Subject: [PATCH 1/3] feat(vortex-array): add `is_refinement` method to ExtVTable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces a refinement-type mechanism for extension types. An extension vtable declares itself a refinement of its storage dtype by returning `true` from the new defaulted method: fn is_refinement(&self) -> bool { false } Semantically this asserts that every value of the extension type is also a value of its storage dtype, with additional invariants carried by the vtable. Plain extension types (e.g. `Uuid` over `FSL`) remain `false`: their storage is just a physical encoding, not substitutable with the logical type. The flag is surfaced through `ExtDTypeRef::is_refinement()` via a forwarder on the object-safe `DynExtDType` trait. Also adds `ExtDTypeRef::as_typed()` — a borrowing downcast that does not consume the ref — which is useful for any `ExtVTable` caller, refinement or not. The existing `DivisibleInt` and `EvenDivisibleInt` test extensions are rewired as plain `ExtVTable` impls that return `true` from `is_refinement`. Their divisibility checks now run in `unpack_native` (invoked by `validate_scalar_value`). `EvenDivisibleInt` composes with `DivisibleInt` by drilling through its storage via `as_typed::()` and calling the inner vtable's `unpack_native` before applying its own evenness check. No behavioural impact on non-refinement extension types: the new method is defaulted to `false` and nothing consumes it yet. A follow-up commit adds the blanket scalar-fn pushdown that uses this flag. Signed-off-by: Connor Tsui --- vortex-array/public-api.lock | 46 +++-- vortex-array/src/dtype/extension/erased.rs | 20 ++ vortex-array/src/dtype/extension/typed.rs | 5 + vortex-array/src/dtype/extension/vtable.rs | 20 ++ .../src/extension/tests/divisible_int.rs | 62 +++--- .../src/extension/tests/even_divisible_int.rs | 177 ++++++++++++++++++ vortex-array/src/extension/tests/mod.rs | 1 + 7 files changed, 296 insertions(+), 35 deletions(-) create mode 100644 vortex-array/src/extension/tests/even_divisible_int.rs diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 8ee3c97f17b..47366b941c5 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -8562,6 +8562,20 @@ pub struct vortex_array::dtype::extension::ExtDTypeRef(_) impl vortex_array::dtype::extension::ExtDTypeRef +pub fn vortex_array::dtype::extension::ExtDTypeRef::as_typed(&self) -> core::option::Option<&vortex_array::dtype::extension::ExtDType> + +pub fn vortex_array::dtype::extension::ExtDTypeRef::downcast(self) -> alloc::sync::Arc> + +pub fn vortex_array::dtype::extension::ExtDTypeRef::is(&self) -> bool + +pub fn vortex_array::dtype::extension::ExtDTypeRef::metadata(&self) -> ::Match + +pub fn vortex_array::dtype::extension::ExtDTypeRef::metadata_opt(&self) -> core::option::Option<::Match> + +pub fn vortex_array::dtype::extension::ExtDTypeRef::try_downcast(self) -> core::result::Result>, vortex_array::dtype::extension::ExtDTypeRef> + +impl vortex_array::dtype::extension::ExtDTypeRef + pub fn vortex_array::dtype::extension::ExtDTypeRef::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::dtype::extension::ExtDTypeRef::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool @@ -8574,6 +8588,8 @@ pub fn vortex_array::dtype::extension::ExtDTypeRef::id(&self) -> vortex_array::d pub fn vortex_array::dtype::extension::ExtDTypeRef::is_nullable(&self) -> bool +pub fn vortex_array::dtype::extension::ExtDTypeRef::is_refinement(&self) -> bool + pub fn vortex_array::dtype::extension::ExtDTypeRef::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::dtype::extension::ExtDTypeRef::nullability(&self) -> vortex_array::dtype::Nullability @@ -8584,18 +8600,6 @@ pub fn vortex_array::dtype::extension::ExtDTypeRef::storage_dtype(&self) -> &vor pub fn vortex_array::dtype::extension::ExtDTypeRef::with_nullability(&self, nullability: vortex_array::dtype::Nullability) -> Self -impl vortex_array::dtype::extension::ExtDTypeRef - -pub fn vortex_array::dtype::extension::ExtDTypeRef::downcast(self) -> alloc::sync::Arc> - -pub fn vortex_array::dtype::extension::ExtDTypeRef::is(&self) -> bool - -pub fn vortex_array::dtype::extension::ExtDTypeRef::metadata(&self) -> ::Match - -pub fn vortex_array::dtype::extension::ExtDTypeRef::metadata_opt(&self) -> core::option::Option<::Match> - -pub fn vortex_array::dtype::extension::ExtDTypeRef::try_downcast(self) -> core::result::Result>, vortex_array::dtype::extension::ExtDTypeRef> - impl core::clone::Clone for vortex_array::dtype::extension::ExtDTypeRef pub fn vortex_array::dtype::extension::ExtDTypeRef::clone(&self) -> vortex_array::dtype::extension::ExtDTypeRef @@ -8644,6 +8648,8 @@ pub fn vortex_array::dtype::extension::ExtVTable::deserialize_metadata(&self, me pub fn vortex_array::dtype::extension::ExtVTable::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::dtype::extension::ExtVTable::is_refinement(&self) -> bool + pub fn vortex_array::dtype::extension::ExtVTable::least_supertype(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::dtype::extension::ExtVTable::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -8668,6 +8674,8 @@ pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, meta pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Date::is_refinement(&self) -> bool + pub fn vortex_array::extension::datetime::Date::least_supertype(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -8692,6 +8700,8 @@ pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Time::is_refinement(&self) -> bool + pub fn vortex_array::extension::datetime::Time::least_supertype(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -8716,6 +8726,8 @@ pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Timestamp::is_refinement(&self) -> bool + pub fn vortex_array::extension::datetime::Timestamp::least_supertype(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -8740,6 +8752,8 @@ pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::uuid::Uuid::is_refinement(&self) -> bool + pub fn vortex_array::extension::uuid::Uuid::least_supertype(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -12602,6 +12616,8 @@ pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, meta pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Date::is_refinement(&self) -> bool + pub fn vortex_array::extension::datetime::Date::least_supertype(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -12658,6 +12674,8 @@ pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Time::is_refinement(&self) -> bool + pub fn vortex_array::extension::datetime::Time::least_supertype(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -12716,6 +12734,8 @@ pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Timestamp::is_refinement(&self) -> bool + pub fn vortex_array::extension::datetime::Timestamp::least_supertype(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -12798,6 +12818,8 @@ pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::uuid::Uuid::is_refinement(&self) -> bool + pub fn vortex_array::extension::uuid::Uuid::least_supertype(ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> diff --git a/vortex-array/src/dtype/extension/erased.rs b/vortex-array/src/dtype/extension/erased.rs index c84b198516c..cc0e3749e63 100644 --- a/vortex-array/src/dtype/extension/erased.rs +++ b/vortex-array/src/dtype/extension/erased.rs @@ -42,6 +42,12 @@ impl ExtDTypeRef { self.0.id() } + /// Returns `true` when this extension type's vtable declares itself a refinement of its + /// storage dtype. See [`ExtVTable::is_refinement`] for the full semantics. + pub fn is_refinement(&self) -> bool { + self.0.is_refinement() + } + /// Returns the storage dtype of the extension type. pub fn storage_dtype(&self) -> &DType { self.0.storage_dtype() @@ -138,6 +144,20 @@ impl ExtDTypeRef { .vortex_expect("Failed to downcast ExtDTypeRef") } + /// Borrow the erased dtype as a concrete [`ExtDType`]. + /// + /// Unlike [`try_downcast()`], this does not consume the [`ExtDTypeRef`] or its backing + /// [`Arc`], so the returned reference inherits the lifetime of `&self`. Useful when the + /// enclosing borrow must be preserved (for example, when unpacking a native value that + /// borrows from the enclosing storage `DType`). + /// + /// Returns `None` if the concrete type is not `V`. + /// + /// [`try_downcast()`]: Self::try_downcast + pub fn as_typed(&self) -> Option<&ExtDType> { + self.0.as_any().downcast_ref::>() + } + /// Downcast to the concrete [`ExtDType`]. /// /// Returns `Err(self)` if the downcast fails. diff --git a/vortex-array/src/dtype/extension/typed.rs b/vortex-array/src/dtype/extension/typed.rs index 5b176bbb34d..18d4ec6636f 100644 --- a/vortex-array/src/dtype/extension/typed.rs +++ b/vortex-array/src/dtype/extension/typed.rs @@ -137,6 +137,7 @@ impl ExtDType { pub(super) trait DynExtDType: 'static + Send + Sync + super::sealed::Sealed { fn as_any(&self) -> &dyn Any; fn id(&self) -> ExtId; + fn is_refinement(&self) -> bool; fn storage_dtype(&self) -> &DType; fn metadata_any(&self) -> &dyn Any; fn metadata_debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result; @@ -166,6 +167,10 @@ impl DynExtDType for ExtDType { self.id() } + fn is_refinement(&self) -> bool { + self.vtable.is_refinement() + } + fn storage_dtype(&self) -> &DType { self.storage_dtype() } diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index d4a00fbdec4..54e721e2101 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -28,6 +28,26 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Returns the ID for this extension type. fn id(&self) -> ExtId; + /// Returns `true` when this extension type logically refines its storage dtype. + /// + /// A refinement asserts that every value in this extension's logical domain is also a + /// value of the storage dtype's domain, with additional invariants enforced by this + /// vtable. The canonical examples are `NormalizedVector` (refines `Vector` by + /// requiring unit L2 norm) and a hypothetical `PositiveInt` (refines `Primitive(U64)` + /// by excluding zero). + /// + /// The default is `false`: plain extension types (e.g. `Uuid` over + /// `FixedSizeList`) encode a distinct logical type in their storage that is + /// not substitutable with it. + /// + /// This flag is consumed by `ScalarFnRefinementPeelRule` (in + /// `arrays::scalar_fn::rules`) to decide whether to transparently peel a refinement + /// input and retry a scalar fn on the source type. See that rule for the full + /// semantics; the `TODO(connor)` there also tracks refinement-closure preservation. + fn is_refinement(&self) -> bool { + false + } + // Methods related to the extension `DType`. /// Serialize the metadata into a byte vector. diff --git a/vortex-array/src/extension/tests/divisible_int.rs b/vortex-array/src/extension/tests/divisible_int.rs index 4ab08f7b9f5..6bf821d1641 100644 --- a/vortex-array/src/extension/tests/divisible_int.rs +++ b/vortex-array/src/extension/tests/divisible_int.rs @@ -2,6 +2,11 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors //! A test extension type representing unsigned integers divisible by a given divisor. +//! +//! `DivisibleInt` is a refinement of `Primitive(U64)`: every valid value is a `u64`, with the +//! additional invariant that it is divisible by the metadata-provided [`Divisor`]. Its +//! `ExtVTable::is_refinement` returns `true` so that generic scalar-fn dispatch can peel it +//! to its storage dtype automatically. use std::fmt; @@ -26,7 +31,7 @@ impl fmt::Display for Divisor { } } -/// Extension type for unsigned integers that must be divisible by the metadata divisor. +/// Refinement type for unsigned integers that must be divisible by the metadata divisor. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct DivisibleInt; @@ -38,6 +43,33 @@ impl ExtVTable for DivisibleInt { ExtId::new("test.divisible_int") } + fn is_refinement(&self) -> bool { + true + } + + fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { + match ext_dtype.storage_dtype() { + DType::Primitive(PType::U64, _) => Ok(()), + other => vortex_bail!("`DivisibleInt` requires `U64` storage, got {other}"), + } + } + + fn unpack_native<'a>( + ext_dtype: &'a ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + let ScalarValue::Primitive(pv) = storage_value else { + vortex_bail!("`DivisibleInt` expected a primitive scalar, got {storage_value:?}"); + }; + let n = pv.cast::()?; + let divisor = ext_dtype.metadata().0; + vortex_ensure!( + n.is_multiple_of(divisor), + "{n} is not divisible by {divisor}", + ); + Ok(n) + } + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { Ok(metadata.0.to_le_bytes().to_vec()) } @@ -51,26 +83,6 @@ impl ExtVTable for DivisibleInt { vortex_ensure!(n > 0, "divisor must be greater than 0"); Ok(Divisor(n)) } - - fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { - vortex_ensure!( - matches!(ext_dtype.storage_dtype(), DType::Primitive(PType::U64, _)), - "divisible int storage dtype must be u64" - ); - Ok(()) - } - - fn unpack_native<'a>( - ext_dtype: &'a ExtDType, - storage_value: &'a ScalarValue, - ) -> VortexResult> { - let value = storage_value.as_primitive().cast::()?; - let metadata = ext_dtype.metadata(); - if value % metadata.0 != 0 { - vortex_bail!("{} is not divisible by {}", value, metadata.0); - } - Ok(value) - } } #[cfg(test)] @@ -99,9 +111,8 @@ mod tests { #[test] fn rejects_zero_divisor() { - let vtable = DivisibleInt; let bytes = 0u64.to_le_bytes(); - assert!(vtable.deserialize_metadata(&bytes).is_err()); + assert!(DivisibleInt.deserialize_metadata(&bytes).is_err()); } #[test] @@ -127,4 +138,9 @@ mod tests { .is_ok() ); } + + #[test] + fn is_refinement_is_true() { + assert!(DivisibleInt.is_refinement()); + } } diff --git a/vortex-array/src/extension/tests/even_divisible_int.rs b/vortex-array/src/extension/tests/even_divisible_int.rs new file mode 100644 index 00000000000..bd78b42510b --- /dev/null +++ b/vortex-array/src/extension/tests/even_divisible_int.rs @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! A test extension type layering a refinement on top of another refinement. +//! +//! [`EvenDivisibleInt`] refines [`DivisibleInt`] with the additional requirement that the +//! value is even. Its storage `DType` is therefore `DType::Extension(DivisibleInt)`, and its +//! validation chain transitively inherits `DivisibleInt`'s divisibility check: when the +//! outer `ExtDType` is constructed, the inner `DivisibleInt` extension already ran its own +//! `validate_dtype`. + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; + +use super::divisible_int::DivisibleInt; +use crate::dtype::DType; +use crate::dtype::extension::ExtDType; +use crate::dtype::extension::ExtId; +use crate::dtype::extension::ExtVTable; +use crate::extension::EmptyMetadata; +use crate::scalar::ScalarValue; + +/// Refinement of [`DivisibleInt`] requiring the stored value to additionally be even. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct EvenDivisibleInt; + +impl ExtVTable for EvenDivisibleInt { + type Metadata = EmptyMetadata; + type NativeValue<'a> = u64; + + fn id(&self) -> ExtId { + ExtId::new("test.even_divisible_int") + } + + fn is_refinement(&self) -> bool { + true + } + + fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { + let DType::Extension(inner) = ext_dtype.storage_dtype() else { + vortex_bail!( + "`EvenDivisibleInt` requires extension storage, got {}", + ext_dtype.storage_dtype(), + ); + }; + vortex_ensure!( + inner.is::(), + "`EvenDivisibleInt` requires `DivisibleInt` storage, got {}", + inner.id(), + ); + Ok(()) + } + + fn unpack_native<'a>( + ext_dtype: &'a ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + // Compose with `DivisibleInt::unpack_native`: the inner refinement's divisibility + // check runs first; only values that pass reach the even-ness check here. + let DType::Extension(inner) = ext_dtype.storage_dtype() else { + vortex_bail!("unreachable: validate_dtype rejects non-extension storage"); + }; + let inner_typed = inner.as_typed::().ok_or_else(|| { + vortex_err!("unreachable: validate_dtype rejects non-`DivisibleInt` inner extension") + })?; + let n = DivisibleInt::unpack_native(inner_typed, storage_value)?; + vortex_ensure!(n.is_multiple_of(2), "{n} is not even"); + Ok(n) + } + + fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { + Ok(Vec::new()) + } + + fn deserialize_metadata(&self, _bytes: &[u8]) -> VortexResult { + Ok(EmptyMetadata) + } +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use super::super::divisible_int::DivisibleInt; + use super::super::divisible_int::Divisor; + use super::EvenDivisibleInt; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::dtype::extension::ExtDType; + use crate::dtype::extension::ExtVTable; + use crate::extension::EmptyMetadata; + use crate::scalar::PValue; + use crate::scalar::ScalarValue; + + fn even_dtype(divisor: u64) -> VortexResult> { + let inner = ExtDType::::try_new( + Divisor(divisor), + DType::Primitive(PType::U64, Nullability::NonNullable), + )?; + ExtDType::::try_new(EmptyMetadata, DType::Extension(inner.erased())) + } + + #[test] + fn accepts_valid_ext_over_ext_storage() -> VortexResult<()> { + let _dtype = even_dtype(4)?; + Ok(()) + } + + #[test] + fn rejects_non_extension_storage() { + let built = ExtDType::::try_new( + EmptyMetadata, + DType::Primitive(PType::U64, Nullability::NonNullable), + ); + assert!(built.is_err(), "must reject non-extension storage"); + } + + #[test] + fn rejects_mismatched_inner_extension() -> VortexResult<()> { + use crate::extension::uuid::Uuid; + let uuid = ExtDType::::try_new( + Default::default(), + DType::FixedSizeList( + std::sync::Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), + 16, + Nullability::NonNullable, + ), + )?; + + let built = + ExtDType::::try_new(EmptyMetadata, DType::Extension(uuid.erased())); + assert!( + built.is_err(), + "must reject non-DivisibleInt inner extension" + ); + Ok(()) + } + + #[test] + fn unpack_accepts_even_divisible_value() -> VortexResult<()> { + let dtype = even_dtype(3)?; + // 12 is divisible by 3 and even, so both the inner DivisibleInt predicate and the + // outer EvenDivisibleInt predicate succeed. + let storage = ScalarValue::Primitive(PValue::U64(12)); + let value = EvenDivisibleInt::unpack_native(&dtype, &storage)?; + assert_eq!(value, 12); + Ok(()) + } + + #[test] + fn unpack_rejects_odd_divisible_value() -> VortexResult<()> { + let dtype = even_dtype(3)?; + // 9 is divisible by 3 (inner predicate succeeds) but odd (outer predicate fails). + let storage = ScalarValue::Primitive(PValue::U64(9)); + assert!(EvenDivisibleInt::unpack_native(&dtype, &storage).is_err()); + Ok(()) + } + + #[test] + fn unpack_rejects_not_divisible_value() -> VortexResult<()> { + let dtype = even_dtype(3)?; + // 8 is even but not divisible by 3. The inner `DivisibleInt` predicate fires before + // we ever reach the outer even-ness check, proving that refinements compose via + // nested `unpack_native` calls. + let storage = ScalarValue::Primitive(PValue::U64(8)); + assert!(EvenDivisibleInt::unpack_native(&dtype, &storage).is_err()); + Ok(()) + } + + #[test] + fn is_refinement_is_true() { + assert!(EvenDivisibleInt.is_refinement()); + } +} diff --git a/vortex-array/src/extension/tests/mod.rs b/vortex-array/src/extension/tests/mod.rs index 31df677e61d..702e4c144bf 100644 --- a/vortex-array/src/extension/tests/mod.rs +++ b/vortex-array/src/extension/tests/mod.rs @@ -4,3 +4,4 @@ //! Test extension types for exercising the [`ExtVTable`] contract. mod divisible_int; +mod even_divisible_int; From 33571c2e8e46bfe7feb405e07a92ee91a19652c1 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 24 Apr 2026 10:08:34 -0400 Subject: [PATCH 2/3] feat(vortex-array): blanket refinement pushdown at ScalarFnArray construction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a scalar fn is given a refinement-typed input that its `return_dtype` rejects, the framework now transparently peels the refinement one level at a time until the fn accepts the shape. This lands in `ScalarFnArray::try_new` rather than as a reduce rule because `try_new` already calls `return_dtype` and aborts on error — no ScalarFnArray tree is ever constructed to reduce when a refinement input is rejected. Algorithm (see `peel_refinements_and_resolve_dtype` for details): 1. Compute `scalar_fn.return_dtype(arg_dtypes)` on the current children. 2. If it succeeds, done. This covers both non-refinement inputs and fns that explicitly accept the refinement (category B / C / D from the plan — specialization path). 3. If it errors, peel one level from every child whose dtype is an extension dtype with `is_refinement() == true`. Replace each with its storage array. 4. If no children were peeled, return the original error. 5. Otherwise, loop back to step 1 with the peeled children. Multi-level refinement chains (e.g. EvenDivisibleInt → DivisibleInt → U64) unwind one level per iteration. Implements category A (refinement-transparent scalar fns): when a generic fn doesn't know about a refinement, the refinement is lost and the fn operates on the source storage. Refinement-preserving semantics (categories C and D) are deferred; the TODO(connor) in `peel_refinements_and_resolve_dtype` documents the intended direction — an inverted-control hook on the refinement vtable itself, rather than per-fn specialization, which is blocked by the vortex-array → downstream crate dependency direction. Exposes `crate::extension::tests::{divisible_int, even_divisible_int}` as `pub(crate)` so the scalar-fn tests can reuse them. No user-facing public API changes beyond what the prior commit landed. Four unit tests: - peels_single_level_refinement_through_strict_add: `Binary(Add)` over `DivisibleInt(U64)` succeeds and returns `U64`. - peels_two_level_refinement_chain_through_strict_add: `EvenDivisibleInt(DivisibleInt(U64))` unwinds both layers. - does_not_peel_non_refinement_extension: `Uuid` is not peeled; the fn's original error surfaces. - does_not_peel_when_scalar_fn_accepts_refinement: `Binary(Eq)` accepts extension inputs directly, so children retain their refinement dtypes. Signed-off-by: Connor Tsui --- vortex-array/src/arrays/scalar_fn/array.rs | 342 +++++++++++++++++- .../src/arrays/scalar_fn/vtable/mod.rs | 6 +- vortex-array/src/extension/mod.rs | 2 +- vortex-array/src/extension/tests/mod.rs | 4 +- 4 files changed, 345 insertions(+), 9 deletions(-) diff --git a/vortex-array/src/arrays/scalar_fn/array.rs b/vortex-array/src/arrays/scalar_fn/array.rs index 4a82951ebdb..00e66bf4113 100644 --- a/vortex-array/src/arrays/scalar_fn/array.rs +++ b/vortex-array/src/arrays/scalar_fn/array.rs @@ -9,10 +9,16 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use crate::ArrayRef; +use crate::IntoArray; use crate::array::Array; use crate::array::ArrayParts; use crate::array::TypedArrayRef; +use crate::arrays::Constant; +use crate::arrays::ConstantArray; +use crate::arrays::Extension; use crate::arrays::ScalarFn; +use crate::arrays::extension::ExtensionArrayExt; +use crate::dtype::DType; use crate::scalar_fn::ScalarFnRef; // ScalarFnArray has a variable number of slots (one per child) @@ -84,13 +90,18 @@ impl> ScalarFnArrayExt for T {} impl Array { /// Create a new ScalarFnArray from a scalar function and its children. + /// + /// When a child has a refinement extension dtype + /// (`ExtVTable::is_refinement() == true`) and the scalar function rejects the original + /// input shape, refinement children are transparently peeled one level at a time until + /// the fn accepts the shape or no refinement children remain to peel. pub fn try_new( scalar_fn: ScalarFnRef, - children: Vec, + mut children: Vec, len: usize, ) -> VortexResult { - let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect(); - let dtype = scalar_fn.return_dtype(&arg_dtypes)?; + let dtype = peel_refinements_and_resolve_dtype(&scalar_fn, &mut children)?; + let data = ScalarFnData::build(scalar_fn.clone(), children.clone(), len)?; let vtable = ScalarFn { id: scalar_fn.id() }; Ok(unsafe { @@ -101,3 +112,328 @@ impl Array { }) } } + +// TODO(connor): Refinement-preserving (e.g. `add(PositiveInt, PositiveInt) -> PositiveInt`) and +// refinement-changing (e.g. `negate(PositiveInt) -> NegativeInt`) semantics require an +// inverted-control hook like `ExtVTable::closure_result(&self, scalar_fn, arg_dtypes) -> +// Option` so the refinement can inspect the fn and preserve/rewrite itself. Until then, +// refinements are lost through any scalar fn whose `return_dtype` doesn't explicitly acceptthem +// (the peel loop strips the refinement and operates on storage instead). +/// Resolves the scalar function's return dtype against `children`, transparently peeling +/// refinement-typed children when the fn doesn't accept the original shape. +/// +/// Why this exists: +/// +/// Refinement extensions are logically stricter views of a storage dtype. For example, a +/// `DivisibleInt` is still represented as `u64`, and many scalar fns are semantically valid on +/// the storage values even though their `return_dtype` implementation only understands `u64`. +/// That means `add(DivisibleInt, DivisibleInt)` should be allowed to fall back to +/// `add(u64, u64)` when the fn does not explicitly accept the refinement. +/// +/// Why this happens during construction rather than as a normal optimizer rewrite: +/// +/// Public array construction goes through `ScalarFnFactoryExt::try_new_array`, which must call +/// `return_dtype` before a `ScalarFnArray` exists. A post-construction reduce rule would fire too +/// late. So refinement peeling has to be a construction-time fallback for dtype resolution. +/// +/// Why peeling is intentionally narrow: +/// +/// Only arrays that are *representation wrappers* can be peeled safely. Today that means: +/// - `ExtensionArray`, which can peel to `storage_array()` +/// - `ConstantArray` holding an extension scalar, which can peel to a constant storage scalar +/// +/// Other encodings may preserve a refinement dtype without exposing storage as child slot 0. For +/// example, `mask(refined, mask)` still has a refinement dtype, but child 0 is the masked input +/// expression, not "the refinement storage". Peeling such arrays structurally via `nth_child(0)` +/// silently drops semantics and can produce wrong results. In those cases we must reject +/// construction until we have a semantics-preserving rewrite. +/// +/// The algorithm is a fixpoint: +/// +/// 1. Ask the scalar fn for its return dtype on the current children's dtypes. +/// 2. If it succeeds, return the result. This covers both the all-non-refinement case and the case +/// where the fn explicitly accepts a refinement type as a child. +/// 3. If it errors, try peeling one level from every child whose dtype is an extension dtype with +/// [`is_refinement`] set. Extension arrays peel to their storage arrays; constant extension +/// literals peel to constant storage scalars. +/// 4. If no children were peeled, return the original error. +/// 5. Otherwise, loop back to step 1 with the peeled children. +/// +/// [`is_refinement`]: crate::dtype::extension::ExtVTable::is_refinement +pub(crate) fn peel_refinements_and_resolve_dtype( + scalar_fn: &ScalarFnRef, + children: &mut [ArrayRef], +) -> VortexResult { + loop { + let children_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect(); + + match scalar_fn.return_dtype(&children_dtypes) { + Ok(dtype) => return Ok(dtype), + Err(err) => { + let any_peeled = peel_refinement_layers(children); + if !any_peeled { + return Err(err); + } + } + } + } +} + +// TODO(connor): Is it correct/ok to peel away refinement on all children at once? Do we need +// small-step semantics here? +/// Peels one layer of refinement extensions from all `children`. +/// +/// Returns a flag indicating whether any child was actually peeled. +fn peel_refinement_layers(children: &mut [ArrayRef]) -> bool { + let mut any_peeled = false; + + for child in children.iter_mut() { + if let Some(peeled) = peel_refinement_child(child) { + *child = peeled; + any_peeled = true; + } + } + + any_peeled +} + +/// Peel exactly one refinement wrapper when the array is a real representation wrapper. +/// +/// This must not inspect arbitrary structural children. Only `ExtensionArray` and refinement +/// literals in `ConstantArray` are safe to unwrap here. +fn peel_refinement_child(child: &ArrayRef) -> Option { + let DType::Extension(ext_dtype) = child.dtype() else { + return None; + }; + if !ext_dtype.is_refinement() { + return None; + } + + if let Some(ext_array) = child.as_opt::() { + return Some(ext_array.storage_array().clone()); + } + + if let Some(const_array) = child.as_opt::() { + let constant = const_array.scalar(); + let ext_scalar = constant.as_extension_opt()?; + return Some(ConstantArray::new(ext_scalar.to_storage_scalar(), child.len()).into_array()); + } + + None +} + +#[cfg(test)] +mod tests { + use vortex_buffer::Buffer; + use vortex_error::VortexResult; + + use super::*; + use crate::IntoArray; + use crate::arrays::ConstantArray; + use crate::arrays::ExtensionArray; + use crate::arrays::FixedSizeListArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::scalar_fn::ScalarFnArrayExt; + use crate::builtins::ArrayBuiltins; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::dtype::extension::ExtDType; + use crate::extension::EmptyMetadata; + use crate::extension::tests::divisible_int::DivisibleInt; + use crate::extension::tests::divisible_int::Divisor; + use crate::extension::tests::even_divisible_int::EvenDivisibleInt; + use crate::extension::uuid::Uuid; + use crate::scalar::Scalar; + use crate::scalar_fn::EmptyOptions; + use crate::scalar_fn::ScalarFnVTableExt; + use crate::scalar_fn::fns::binary::Binary; + use crate::scalar_fn::fns::mask::Mask; + use crate::scalar_fn::fns::operators::Operator; + use crate::validity::Validity; + + fn divisible_int_dtype(divisor: u64) -> VortexResult { + ExtDType::::try_new( + Divisor(divisor), + DType::Primitive(PType::U64, Nullability::NonNullable), + ) + .map(|dtype| dtype.erased()) + } + + fn divisible_int_array(divisor: u64, values: Vec) -> VortexResult { + let ext_dtype = divisible_int_dtype(divisor)?; + let storage = + PrimitiveArray::new::(Buffer::::copy_from(&values), Validity::NonNullable) + .into_array(); + Ok(ExtensionArray::try_new(ext_dtype, storage)?.into_array()) + } + + fn even_divisible_int_array(divisor: u64, values: Vec) -> VortexResult { + let inner_dtype = ExtDType::::try_new( + Divisor(divisor), + DType::Primitive(PType::U64, Nullability::NonNullable), + )? + .erased(); + let outer_dtype = ExtDType::::try_new( + EmptyMetadata, + DType::Extension(inner_dtype.clone()), + )? + .erased(); + let primitive = + PrimitiveArray::new::(Buffer::::copy_from(&values), Validity::NonNullable) + .into_array(); + let inner_ext = ExtensionArray::try_new(inner_dtype, primitive)?.into_array(); + Ok(ExtensionArray::try_new(outer_dtype, inner_ext)?.into_array()) + } + + fn uuid_array(row_count: usize) -> VortexResult { + let fsl_element = DType::Primitive(PType::U8, Nullability::NonNullable); + let fsl = DType::FixedSizeList( + std::sync::Arc::new(fsl_element), + 16, + Nullability::NonNullable, + ); + let uuid_dtype = ExtDType::::try_new(Default::default(), fsl)?.erased(); + let bytes: Buffer = Buffer::copy_from(vec![0u8; row_count * 16]); + let primitive = PrimitiveArray::new::(bytes, Validity::NonNullable).into_array(); + let storage = FixedSizeListArray::try_new(primitive, 16, Validity::NonNullable, row_count)? + .into_array(); + Ok(ExtensionArray::try_new(uuid_dtype, storage)?.into_array()) + } + + /// `Binary(Add)` applied to two `DivisibleInt` children must peel one level, producing + /// a ScalarFnArray whose dtype is `Primitive(U64)`. The refinement is lost by design + /// (category A). + #[test] + fn peels_single_level_refinement_through_strict_add() -> VortexResult<()> { + let lhs = divisible_int_array(3, vec![0u64, 3, 6])?; + let rhs = divisible_int_array(3, vec![0u64, 3, 6])?; + + let sfn = Binary.bind(Operator::Add); + let arr = Array::::try_new(sfn, vec![lhs, rhs], 3)?; + + assert_eq!( + arr.dtype(), + &DType::Primitive(PType::U64, Nullability::NonNullable), + ); + Ok(()) + } + + /// `Binary(Add)` applied to two `EvenDivisibleInt` children must peel through both + /// refinement layers (EvenDivisibleInt → DivisibleInt → U64) via the fixpoint loop. + #[test] + fn peels_two_level_refinement_chain_through_strict_add() -> VortexResult<()> { + let lhs = even_divisible_int_array(3, vec![0u64, 6, 12])?; + let rhs = even_divisible_int_array(3, vec![0u64, 6, 12])?; + + let sfn = Binary.bind(Operator::Add); + let arr = Array::::try_new(sfn, vec![lhs, rhs], 3)?; + + assert_eq!( + arr.dtype(), + &DType::Primitive(PType::U64, Nullability::NonNullable), + ); + Ok(()) + } + + #[test] + fn peels_refinement_constant_literal_through_strict_add() -> VortexResult<()> { + let ext_dtype = divisible_int_dtype(3)?; + let lhs = divisible_int_array(3, vec![0u64, 3, 6])?; + let rhs = ConstantArray::new( + Scalar::extension_ref(ext_dtype, Scalar::from(3u64)), + lhs.len(), + ) + .into_array(); + + let arr = Array::::try_new(Binary.bind(Operator::Add), vec![lhs, rhs], 3)?; + + assert_eq!( + arr.dtype(), + &DType::Primitive(PType::U64, Nullability::NonNullable), + ); + Ok(()) + } + + #[test] + fn array_builtins_binary_reuses_refinement_peeling() -> VortexResult<()> { + let lhs = divisible_int_array(3, vec![0u64, 3, 6])?; + let rhs = divisible_int_array(3, vec![0u64, 3, 6])?; + + let arr = lhs.binary(rhs, Operator::Add)?; + + assert_eq!( + arr.dtype(), + &DType::Primitive(PType::U64, Nullability::NonNullable), + ); + Ok(()) + } + + /// Regression for an easy-to-make `nth_child(0)` peel bug. + /// + /// `mask(refined, false)` preserves the refinement dtype, but it is not an `ExtensionArray`. + /// Peeling it via child 0 would discard the `mask(...)` node entirely, after which the retry + /// loop would eventually type-check `add(lhs_storage, rhs_storage)` and silently change the + /// program. The correct behavior today is to reject the construction. + #[test] + fn does_not_drop_masked_refinement_via_child_zero_peel() -> VortexResult<()> { + let lhs = divisible_int_array(3, vec![0u64, 3, 6])?; + let rhs = divisible_int_array(3, vec![0u64, 3, 6])?; + let mask = ConstantArray::new(false, lhs.len()).into_array(); + let masked_lhs = + Array::::try_new(Mask.bind(EmptyOptions), vec![lhs, mask], 3)?.into_array(); + + let result = + Array::::try_new(Binary.bind(Operator::Add), vec![masked_lhs, rhs], 3); + + assert!( + result.is_err(), + "non-Extension refinement producers must not be peeled via structural slot access" + ); + Ok(()) + } + + /// `Uuid` is a non-refinement extension (`is_refinement() == false`). When fed to a + /// strict primitive-typed scalar fn, peeling must NOT happen — the fn's original error + /// is surfaced as-is. + #[test] + fn does_not_peel_non_refinement_extension() -> VortexResult<()> { + let lhs = uuid_array(2)?; + let rhs = uuid_array(2)?; + + let sfn = Binary.bind(Operator::Add); + let result = Array::::try_new(sfn, vec![lhs, rhs], 2); + + assert!( + result.is_err(), + "Uuid is not a refinement; peel must not fire" + ); + Ok(()) + } + + /// When the scalar fn's `return_dtype` already accepts the refinement (category B / C / + /// D — specialization path), the peel loop must short-circuit without touching + /// children. + #[test] + fn does_not_peel_when_scalar_fn_accepts_refinement() -> VortexResult<()> { + // `Binary(Eq)` is a comparison, and its return_dtype accepts any pair of matching + // extension dtypes (see `binary::mod::return_dtype` — comparisons allow extensions + // as long as the two sides share the same dtype). So `Eq(DivisibleInt, + // DivisibleInt)` succeeds at return_dtype time with no peel. + let lhs = divisible_int_array(3, vec![0u64, 3, 6])?; + let rhs = divisible_int_array(3, vec![0u64, 3, 6])?; + + let sfn = Binary.bind(Operator::Eq); + let arr = Array::::try_new(sfn, vec![lhs, rhs], 3)?; + + // Comparison returns Bool; children retain their refinement dtypes. + assert_eq!(arr.dtype(), &DType::Bool(Nullability::NonNullable)); + let child0 = arr.child_at(0); + assert!( + matches!(child0.dtype(), DType::Extension(ext) if ext.is::()), + "child 0 retained its DivisibleInt refinement (got {})", + child0.dtype(), + ); + Ok(()) + } +} diff --git a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs index 34c472eac07..5f9f641af97 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs @@ -28,6 +28,7 @@ use crate::array::ArrayView; use crate::array::VTable; use crate::arrays::scalar_fn::array::ScalarFnArrayExt; use crate::arrays::scalar_fn::array::ScalarFnData; +use crate::arrays::scalar_fn::array::peel_refinements_and_resolve_dtype; use crate::arrays::scalar_fn::rules::PARENT_RULES; use crate::arrays::scalar_fn::rules::RULES; use crate::buffer::BufferHandle; @@ -175,14 +176,13 @@ pub trait ScalarFnFactoryExt: scalar_fn::ScalarFnVTable { ) -> VortexResult { let scalar_fn = scalar_fn::TypedScalarFnInstance::new(self.clone(), options).erased(); - let children = children.into(); + let mut children = children.into(); vortex_ensure!( children.iter().all(|c| c.len() == len), "All child arrays must have the same length as the scalar function array" ); - let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec(); - let dtype = scalar_fn.return_dtype(&child_dtypes)?; + let dtype = peel_refinements_and_resolve_dtype(&scalar_fn, &mut children)?; let data = ScalarFnData { scalar_fn: scalar_fn.clone(), diff --git a/vortex-array/src/extension/mod.rs b/vortex-array/src/extension/mod.rs index 9f81e7fb310..077af4a8337 100644 --- a/vortex-array/src/extension/mod.rs +++ b/vortex-array/src/extension/mod.rs @@ -9,7 +9,7 @@ pub mod datetime; pub mod uuid; #[cfg(test)] -mod tests; +pub(crate) mod tests; /// An empty metadata struct for extension dtypes that do not require any metadata. #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/vortex-array/src/extension/tests/mod.rs b/vortex-array/src/extension/tests/mod.rs index 702e4c144bf..1eb7f8d14d7 100644 --- a/vortex-array/src/extension/tests/mod.rs +++ b/vortex-array/src/extension/tests/mod.rs @@ -3,5 +3,5 @@ //! Test extension types for exercising the [`ExtVTable`] contract. -mod divisible_int; -mod even_divisible_int; +pub(crate) mod divisible_int; +pub(crate) mod even_divisible_int; From fe47467d419bf8f608ed1ea53af3f239286cdfbd Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 24 Apr 2026 16:15:59 -0400 Subject: [PATCH 3/3] fix expression pushdown Signed-off-by: Connor Tsui --- vortex-array/src/arrays/scalar_fn/array.rs | 40 ++++------ vortex-array/src/expr/expression.rs | 85 ++++++++++++++++++++- vortex-array/src/expr/mod.rs | 34 +++++++++ vortex-array/src/expr/optimize.rs | 73 +++++++++++++++--- vortex-array/src/scalar_fn/mod.rs | 3 + vortex-array/src/scalar_fn/refinement.rs | 39 ++++++++++ vortex-layout/src/scan/scan_builder.rs | 87 +++++++++++++++++++++- 7 files changed, 320 insertions(+), 41 deletions(-) create mode 100644 vortex-array/src/scalar_fn/refinement.rs diff --git a/vortex-array/src/arrays/scalar_fn/array.rs b/vortex-array/src/arrays/scalar_fn/array.rs index 00e66bf4113..34258a78ab0 100644 --- a/vortex-array/src/arrays/scalar_fn/array.rs +++ b/vortex-array/src/arrays/scalar_fn/array.rs @@ -19,7 +19,9 @@ use crate::arrays::Extension; use crate::arrays::ScalarFn; use crate::arrays::extension::ExtensionArrayExt; use crate::dtype::DType; +use crate::scalar_fn::RefinementFallbackArg; use crate::scalar_fn::ScalarFnRef; +use crate::scalar_fn::resolve_return_dtype_with_refinement_fallback; // ScalarFnArray has a variable number of slots (one per child) @@ -164,37 +166,21 @@ pub(crate) fn peel_refinements_and_resolve_dtype( scalar_fn: &ScalarFnRef, children: &mut [ArrayRef], ) -> VortexResult { - loop { - let children_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect(); - - match scalar_fn.return_dtype(&children_dtypes) { - Ok(dtype) => return Ok(dtype), - Err(err) => { - let any_peeled = peel_refinement_layers(children); - if !any_peeled { - return Err(err); - } - } - } - } + resolve_return_dtype_with_refinement_fallback(scalar_fn, children) } -// TODO(connor): Is it correct/ok to peel away refinement on all children at once? Do we need -// small-step semantics here? -/// Peels one layer of refinement extensions from all `children`. -/// -/// Returns a flag indicating whether any child was actually peeled. -fn peel_refinement_layers(children: &mut [ArrayRef]) -> bool { - let mut any_peeled = false; - - for child in children.iter_mut() { - if let Some(peeled) = peel_refinement_child(child) { - *child = peeled; - any_peeled = true; - } +impl RefinementFallbackArg for ArrayRef { + fn current_dtype(&self) -> &DType { + self.dtype() } - any_peeled + fn peel_one_refinement_layer(&mut self) -> bool { + let Some(peeled) = peel_refinement_child(self) else { + return false; + }; + *self = peeled; + true + } } /// Peel exactly one refinement wrapper when the array is a real representation wrapper. diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index 8dae54ffa36..869c8a9a499 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -17,8 +17,11 @@ use crate::dtype::DType; use crate::expr::StatsCatalog; use crate::expr::display::DisplayTreeExpr; use crate::expr::stats::Stat; +use crate::scalar_fn::RefinementFallbackArg; use crate::scalar_fn::ScalarFnRef; +use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::root::Root; +use crate::scalar_fn::resolve_return_dtype_with_refinement_fallback; /// A node in a Vortex expression tree. /// @@ -94,16 +97,27 @@ impl Expression { /// Computes the return dtype of this expression given the input dtype. pub fn return_dtype(&self, scope: &DType) -> VortexResult { + Ok(self.return_dtype_info(scope)?.dtype) + } + + pub(crate) fn return_dtype_info(&self, scope: &DType) -> VortexResult { if self.is::() { - return Ok(scope.clone()); + return Ok(ExprReturnInfo::storage_chain_repr_wrapper(scope.clone())); + } + + if let Some(literal) = self.as_opt::() { + return Ok(ExprReturnInfo::storage_chain_repr_wrapper( + literal.dtype().clone(), + )); } - let dtypes: Vec<_> = self + let mut children: Vec<_> = self .children .iter() - .map(|c| c.return_dtype(scope)) + .map(|c| c.return_dtype_info(scope)) .try_collect()?; - self.scalar_fn.return_dtype(&dtypes) + + expr_return_dtype_info(&self.scalar_fn, &mut children) } /// Returns a new expression representing the validity mask output of this expression. @@ -243,3 +257,66 @@ impl Drop for Expression { } } } + +#[derive(Clone, Debug)] +pub(crate) struct ExprReturnInfo { + pub(crate) dtype: DType, + peel_strategy: ExprPeelStrategy, +} + +#[derive(Clone, Copy, Debug, Default)] +enum ExprPeelStrategy { + #[default] + None, + StorageChainRepresentationWrapper, +} + +impl ExprReturnInfo { + pub(crate) fn opaque(dtype: DType) -> Self { + Self { + dtype, + peel_strategy: ExprPeelStrategy::None, + } + } + + pub(crate) fn storage_chain_repr_wrapper(dtype: DType) -> Self { + Self { + dtype, + peel_strategy: ExprPeelStrategy::StorageChainRepresentationWrapper, + } + } +} + +impl RefinementFallbackArg for ExprReturnInfo { + fn current_dtype(&self) -> &DType { + &self.dtype + } + + fn peel_one_refinement_layer(&mut self) -> bool { + if !matches!( + self.peel_strategy, + ExprPeelStrategy::StorageChainRepresentationWrapper + ) { + return false; + } + + let DType::Extension(ext_dtype) = &self.dtype else { + return false; + }; + if !ext_dtype.is_refinement() { + return false; + } + + self.dtype = ext_dtype.storage_dtype().clone(); + true + } +} + +pub(crate) fn expr_return_dtype_info( + scalar_fn: &ScalarFnRef, + children: &mut [ExprReturnInfo], +) -> VortexResult { + Ok(ExprReturnInfo::opaque( + resolve_return_dtype_with_refinement_fallback(scalar_fn, children)?, + )) +} diff --git a/vortex-array/src/expr/mod.rs b/vortex-array/src/expr/mod.rs index 9d973a01cae..0513e260d81 100644 --- a/vortex-array/src/expr/mod.rs +++ b/vortex-array/src/expr/mod.rs @@ -127,7 +127,9 @@ mod tests { use crate::dtype::Nullability; use crate::dtype::PType; use crate::dtype::StructFields; + use crate::dtype::extension::ExtDType; use crate::expr::and; + use crate::expr::checked_add; use crate::expr::col; use crate::expr::eq; use crate::expr::get_item; @@ -142,6 +144,8 @@ mod tests { use crate::expr::root; use crate::expr::select; use crate::expr::select_exclude; + use crate::extension::tests::divisible_int::DivisibleInt; + use crate::extension::tests::divisible_int::Divisor; use crate::scalar::Scalar; #[test] @@ -260,4 +264,34 @@ mod tests { "{dog: 32u32, cat: \"rufus\"}" ); } + + fn divisible_int_dtype(divisor: u64) -> DType { + DType::Extension( + ExtDType::::try_new( + Divisor(divisor), + DType::Primitive(PType::U64, Nullability::NonNullable), + ) + .unwrap() + .erased(), + ) + } + + #[test] + fn return_dtype_peels_refinement_root_and_literal() { + let scope = divisible_int_dtype(3); + + assert_eq!( + checked_add(root(), root()).return_dtype(&scope).unwrap(), + DType::Primitive(PType::U64, Nullability::NonNullable), + ); + + let ext_dtype = scope.as_extension().clone(); + let refined_literal = Scalar::extension_ref(ext_dtype, Scalar::from(3u64)); + assert_eq!( + checked_add(root(), lit(refined_literal)) + .return_dtype(&scope) + .unwrap(), + DType::Primitive(PType::U64, Nullability::NonNullable), + ); + } } diff --git a/vortex-array/src/expr/optimize.rs b/vortex-array/src/expr/optimize.rs index 55b7e6e8df6..470a4689dd2 100644 --- a/vortex-array/src/expr/optimize.rs +++ b/vortex-array/src/expr/optimize.rs @@ -3,7 +3,6 @@ use std::any::Any; use std::cell::RefCell; -use std::ops::Deref; use std::sync::Arc; use itertools::Itertools; @@ -13,12 +12,15 @@ use vortex_utils::aliases::hash_map::HashMap; use crate::dtype::DType; use crate::expr::Expression; +use crate::expr::expression::ExprReturnInfo; +use crate::expr::expression::expr_return_dtype_info; use crate::expr::transform::match_between::find_between; use crate::scalar_fn::ReduceCtx; use crate::scalar_fn::ReduceNode; use crate::scalar_fn::ReduceNodeRef; use crate::scalar_fn::ScalarFnRef; use crate::scalar_fn::SimplifyCtx; +use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::root::Root; impl Expression { @@ -206,14 +208,21 @@ impl Expression { struct SimplifyCache<'a> { scope: &'a DType, - dtype_cache: RefCell>, + dtype_cache: RefCell>, } -impl SimplifyCtx for SimplifyCache<'_> { - fn return_dtype(&self, expr: &Expression) -> VortexResult { - // If the expression is "root", return the scope dtype +impl SimplifyCache<'_> { + fn return_dtype_info(&self, expr: &Expression) -> VortexResult { if expr.is::() { - return Ok(self.scope.clone()); + return Ok(ExprReturnInfo::storage_chain_repr_wrapper( + self.scope.clone(), + )); + } + + if let Some(literal) = expr.as_opt::() { + return Ok(ExprReturnInfo::storage_chain_repr_wrapper( + literal.dtype().clone(), + )); } if let Some(dtype) = self.dtype_cache.borrow().get(expr) { @@ -221,12 +230,12 @@ impl SimplifyCtx for SimplifyCache<'_> { } // Otherwise, compute dtype from children - let input_dtypes: Vec<_> = expr + let mut input_dtypes: Vec<_> = expr .children() .iter() - .map(|c| self.return_dtype(c)) + .map(|c| self.return_dtype_info(c)) .try_collect()?; - let dtype = expr.deref().return_dtype(&input_dtypes)?; + let dtype = expr_return_dtype_info(expr.scalar_fn(), &mut input_dtypes)?; self.dtype_cache .borrow_mut() .insert(expr.clone(), dtype.clone()); @@ -235,6 +244,12 @@ impl SimplifyCtx for SimplifyCache<'_> { } } +impl SimplifyCtx for SimplifyCache<'_> { + fn return_dtype(&self, expr: &Expression) -> VortexResult { + Ok(self.return_dtype_info(expr)?.dtype) + } +} + struct ExpressionReduceNode { expression: Expression, scope: DType, @@ -294,3 +309,43 @@ impl ReduceCtx for ExpressionReduceCtx { })) } } + +#[cfg(test)] +mod tests { + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::dtype::extension::ExtDType; + use crate::expr::checked_add; + use crate::expr::fill_null; + use crate::expr::lit; + use crate::expr::root; + use crate::extension::tests::divisible_int::DivisibleInt; + use crate::extension::tests::divisible_int::Divisor; + + fn divisible_int_dtype(divisor: u64) -> DType { + DType::Extension( + ExtDType::::try_new( + Divisor(divisor), + DType::Primitive(PType::U64, Nullability::NonNullable), + ) + .unwrap() + .erased(), + ) + } + + #[test] + fn optimize_recursive_uses_refinement_fallback_for_typed_simplify() { + let scope = divisible_int_dtype(3); + let add = checked_add(root(), root()); + let expr = fill_null(add.clone(), lit(0u64)); + + let optimized = expr.optimize_recursive(&scope).unwrap(); + + assert_eq!(optimized, add); + assert_eq!( + optimized.return_dtype(&scope).unwrap(), + DType::Primitive(PType::U64, Nullability::NonNullable), + ); + } +} diff --git a/vortex-array/src/scalar_fn/mod.rs b/vortex-array/src/scalar_fn/mod.rs index 6f8ba1b6544..305add2ea6d 100644 --- a/vortex-array/src/scalar_fn/mod.rs +++ b/vortex-array/src/scalar_fn/mod.rs @@ -24,6 +24,9 @@ pub use typed::*; mod erased; pub use erased::*; +mod refinement; +pub(crate) use refinement::*; + mod options; pub use options::*; diff --git a/vortex-array/src/scalar_fn/refinement.rs b/vortex-array/src/scalar_fn/refinement.rs new file mode 100644 index 00000000000..f7ca9d33dfd --- /dev/null +++ b/vortex-array/src/scalar_fn/refinement.rs @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::dtype::DType; +use crate::scalar_fn::ScalarFnRef; + +/// An argument that may be able to peel one refinement layer and retry scalar-fn dtype +/// resolution on its storage dtype. +pub(crate) trait RefinementFallbackArg { + fn current_dtype(&self) -> &DType; + + fn peel_one_refinement_layer(&mut self) -> bool; +} + +/// Resolve a scalar function's return dtype, retrying after peeling one refinement layer from all +/// currently peelable children when the original argument dtypes are rejected. +pub(crate) fn resolve_return_dtype_with_refinement_fallback( + scalar_fn: &ScalarFnRef, + args: &mut [A], +) -> VortexResult { + loop { + let arg_dtypes: Vec<_> = args.iter().map(|arg| arg.current_dtype().clone()).collect(); + + match scalar_fn.return_dtype(&arg_dtypes) { + Ok(dtype) => return Ok(dtype), + Err(err) => { + let mut any_peeled = false; + for arg in args.iter_mut() { + any_peeled |= arg.peel_one_refinement_layer(); + } + if !any_peeled { + return Err(err); + } + } + } + } +} diff --git a/vortex-layout/src/scan/scan_builder.rs b/vortex-layout/src/scan/scan_builder.rs index bdf5d0bfb11..c1d3ea71bfd 100644 --- a/vortex-layout/src/scan/scan_builder.rs +++ b/vortex-layout/src/scan/scan_builder.rs @@ -472,7 +472,16 @@ mod test { use vortex_array::dtype::FieldMask; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::dtype::extension::ExtId; + use vortex_array::dtype::extension::ExtVTable; use vortex_array::expr::Expression; + use vortex_array::expr::checked_add; + use vortex_array::expr::fill_null; + use vortex_array::expr::lit; + use vortex_array::expr::root; + use vortex_array::extension::EmptyMetadata; + use vortex_array::scalar::ScalarValue; use vortex_error::VortexResult; use vortex_error::vortex_err; use vortex_io::runtime::BlockingRuntime; @@ -483,6 +492,59 @@ mod test { use crate::ArrayFuture; use crate::LayoutReader; + #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] + struct TestRefinement; + + impl ExtVTable for TestRefinement { + type Metadata = EmptyMetadata; + type NativeValue<'a> = u64; + + fn id(&self) -> ExtId { + ExtId::new("test.scan_builder_refinement") + } + + fn is_refinement(&self) -> bool { + true + } + + fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { + Ok(vec![]) + } + + fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyMetadata) + } + + fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { + assert_eq!( + ext_dtype.storage_dtype(), + &DType::Primitive(PType::U64, Nullability::NonNullable) + ); + Ok(()) + } + + fn unpack_native<'a>( + _ext_dtype: &'a ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + let ScalarValue::Primitive(value) = storage_value else { + unreachable!("storage dtype is validated to primitive u64"); + }; + value.cast::() + } + } + + fn test_refinement_dtype() -> DType { + DType::Extension( + ExtDType::::try_new( + EmptyMetadata, + DType::Primitive(PType::U64, Nullability::NonNullable), + ) + .unwrap() + .erased(), + ) + } + #[derive(Debug)] struct CountingLayoutReader { name: Arc, @@ -493,9 +555,16 @@ mod test { impl CountingLayoutReader { fn new(register_splits_calls: Arc) -> Self { + Self::new_with_dtype( + DType::Primitive(PType::I32, Nullability::NonNullable), + register_splits_calls, + ) + } + + fn new_with_dtype(dtype: DType, register_splits_calls: Arc) -> Self { Self { name: Arc::from("counting"), - dtype: DType::Primitive(PType::I32, Nullability::NonNullable), + dtype, row_count: 1, register_splits_calls, } @@ -572,6 +641,22 @@ mod test { assert_eq!(calls.load(Ordering::Relaxed), 0); } + #[test] + fn prepare_accepts_refinement_projection_in_planning_paths() -> VortexResult<()> { + let calls = Arc::new(AtomicUsize::new(0)); + let reader = Arc::new(CountingLayoutReader::new_with_dtype( + test_refinement_dtype(), + Arc::clone(&calls), + )); + let session = crate::scan::test::SCAN_SESSION.clone(); + + ScanBuilder::new(session, reader) + .with_projection(fill_null(checked_add(root(), root()), lit(0u64))) + .prepare()?; + + Ok(()) + } + #[derive(Debug)] struct SplittingLayoutReader { name: Arc,