From 9b2057cbba1d4d20cf39d2bd229185f42ef9f9e8 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Sun, 17 May 2026 17:40:53 +0100 Subject: [PATCH 1/2] Thread scope dtype through stats rewrites Signed-off-by: Nicholas Gates --- vortex-array/public-api.lock | 4 +-- vortex-array/src/expr/expression.rs | 20 ++++++++--- vortex-array/src/stats/rewrite.rs | 52 +++++++++++++++++++++++++---- 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 84014af52df..3920b90aa5d 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -13514,13 +13514,13 @@ pub fn vortex_array::expr::Expression::children(&self) -> &alloc::sync::Arc impl core::fmt::Display -pub fn vortex_array::expr::Expression::falsify(&self, &vortex_session::VortexSession) -> vortex_error::VortexResult> +pub fn vortex_array::expr::Expression::falsify(&self, &vortex_array::dtype::DType, &vortex_session::VortexSession) -> vortex_error::VortexResult> pub fn vortex_array::expr::Expression::fmt_sql(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result pub fn vortex_array::expr::Expression::return_dtype(&self, &vortex_array::dtype::DType) -> vortex_error::VortexResult -pub fn vortex_array::expr::Expression::satisfy(&self, &vortex_session::VortexSession) -> vortex_error::VortexResult> +pub fn vortex_array::expr::Expression::satisfy(&self, &vortex_array::dtype::DType, &vortex_session::VortexSession) -> vortex_error::VortexResult> pub fn vortex_array::expr::Expression::scalar_fn(&self) -> &vortex_array::scalar_fn::ScalarFnRef diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index 10b4389f6e8..cc21fb9a9a6 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -138,18 +138,30 @@ impl Expression { /// Returns an expression that proves this predicate is definitely false from stats. /// + /// `scope` is the dtype of the row this expression evaluates over. + /// /// If the returned expression evaluates to `true` for a stats scope, this expression is /// guaranteed to be false for every row in that scope. `false` and `null` are unknown. - pub fn falsify(&self, session: &VortexSession) -> VortexResult> { - crate::stats::rewrite::StatsRewriteCtx::new(session).falsify(self) + pub fn falsify( + &self, + scope: &DType, + session: &VortexSession, + ) -> VortexResult> { + crate::stats::rewrite::StatsRewriteCtx::new(session, scope).falsify(self) } /// Returns an expression that proves this predicate is definitely true from stats. /// + /// `scope` is the dtype of the row this expression evaluates over. + /// /// If the returned expression evaluates to `true` for a stats scope, this expression is /// guaranteed to be true for every row in that scope. `false` and `null` are unknown. - pub fn satisfy(&self, session: &VortexSession) -> VortexResult> { - crate::stats::rewrite::StatsRewriteCtx::new(session).satisfy(self) + pub fn satisfy( + &self, + scope: &DType, + session: &VortexSession, + ) -> VortexResult> { + crate::stats::rewrite::StatsRewriteCtx::new(session, scope).satisfy(self) } /// Returns an expression representing the zoned statistic for the given stat, if available. diff --git a/vortex-array/src/stats/rewrite.rs b/vortex-array/src/stats/rewrite.rs index 0eacc2d6629..4d829905d6f 100644 --- a/vortex-array/src/stats/rewrite.rs +++ b/vortex-array/src/stats/rewrite.rs @@ -7,8 +7,10 @@ use std::fmt::Debug; use std::sync::Arc; use vortex_error::VortexResult; +use vortex_error::vortex_ensure; use vortex_session::VortexSession; +use crate::dtype::DType; use crate::expr::Expression; use crate::expr::or_collect; use crate::scalar_fn::ScalarFnId; @@ -54,12 +56,13 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { /// Context passed to stats rewrite rules. pub(crate) struct StatsRewriteCtx<'a> { session: &'a VortexSession, + scope: &'a DType, } impl<'a> StatsRewriteCtx<'a> { /// Create a rewrite context for `session`. - pub(crate) fn new(session: &'a VortexSession) -> Self { - Self { session } + pub(crate) fn new(session: &'a VortexSession, scope: &'a DType) -> Self { + Self { session, scope } } /// Returns the session that owns the rewrite registry. @@ -67,15 +70,31 @@ impl<'a> StatsRewriteCtx<'a> { self.session } + /// Return the dtype of `expr` within this rewrite scope. + pub(crate) fn return_dtype(&self, expr: &Expression) -> VortexResult { + expr.return_dtype(self.scope) + } + /// Rewrite `expr` into a stats-backed falsifier. pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult> { + self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::falsify) } /// Rewrite `expr` into a stats-backed satisfier. pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult> { + self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::satisfy) } + + fn ensure_predicate(&self, expr: &Expression) -> VortexResult<()> { + let dtype = self.return_dtype(expr)?; + vortex_ensure!( + matches!(dtype, DType::Bool(_)), + "Stats rewrites require a boolean predicate, got {dtype}", + ); + Ok(()) + } } fn rewrite( @@ -112,6 +131,9 @@ mod tests { use super::StatsRewriteCtx; use super::StatsRewriteRule; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; use crate::expr::Expression; use crate::expr::lit; use crate::expr::or; @@ -152,6 +174,7 @@ mod tests { #[test] fn combines_multiple_falsifiers_with_or() -> VortexResult<()> { let session = VortexSession::empty().with::(); + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); session.stats_rewrites().register(StaticLiteralRule { falsifier: Some(lit(false)), satisfier: None, @@ -161,13 +184,17 @@ mod tests { satisfier: None, }); - assert_eq!(lit(7).falsify(&session)?, Some(or(lit(false), lit(true)))); + assert_eq!( + lit(true).falsify(&dtype, &session)?, + Some(or(lit(false), lit(true))) + ); Ok(()) } #[test] fn combines_multiple_satisfiers_with_or() -> VortexResult<()> { let session = VortexSession::empty().with::(); + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); session.stats_rewrites().register(StaticLiteralRule { falsifier: None, satisfier: Some(lit(false)), @@ -177,16 +204,29 @@ mod tests { satisfier: Some(lit(true)), }); - assert_eq!(lit(7).satisfy(&session)?, Some(or(lit(false), lit(true)))); + assert_eq!( + lit(true).satisfy(&dtype, &session)?, + Some(or(lit(false), lit(true))) + ); Ok(()) } #[test] fn unregistered_expression_has_no_rewrite() -> VortexResult<()> { let session = VortexSession::empty().with::(); + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - assert_eq!(lit(7).falsify(&session)?, None); - assert_eq!(lit(7).satisfy(&session)?, None); + assert_eq!(lit(true).falsify(&dtype, &session)?, None); + assert_eq!(lit(true).satisfy(&dtype, &session)?, None); Ok(()) } + + #[test] + fn non_predicate_expression_errors() { + let session = VortexSession::empty().with::(); + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + + assert!(lit(7).falsify(&dtype, &session).is_err()); + assert!(lit(7).satisfy(&dtype, &session).is_err()); + } } From dd555d9bfadc06fb9333958fe8eb021be60cb2f7 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Sun, 17 May 2026 17:47:27 +0100 Subject: [PATCH 2/2] Rename stats rewrite session Signed-off-by: Nicholas Gates --- vortex-array/public-api.lock | 32 ++++++++++++------------- vortex-array/src/stats/rewrite.rs | 26 ++++++++++----------- vortex-array/src/stats/session.rs | 39 +++++++++++++++++-------------- vortex/src/lib.rs | 4 ++-- 4 files changed, 52 insertions(+), 49 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 3920b90aa5d..eba7967f2c6 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -20524,21 +20524,21 @@ pub mod vortex_array::stats::flatbuffers pub mod vortex_array::stats::session -pub struct vortex_array::stats::session::StatsRewriteSession +pub struct vortex_array::stats::session::StatsSession -impl core::default::Default for vortex_array::stats::StatsRewriteSession +impl core::default::Default for vortex_array::stats::StatsSession -pub fn vortex_array::stats::StatsRewriteSession::default() -> vortex_array::stats::StatsRewriteSession +pub fn vortex_array::stats::StatsSession::default() -> vortex_array::stats::StatsSession -impl core::fmt::Debug for vortex_array::stats::StatsRewriteSession +impl core::fmt::Debug for vortex_array::stats::StatsSession -pub fn vortex_array::stats::StatsRewriteSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_array::stats::StatsSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl vortex_session::SessionVar for vortex_array::stats::StatsRewriteSession +impl vortex_session::SessionVar for vortex_array::stats::StatsSession -pub fn vortex_array::stats::StatsRewriteSession::as_any(&self) -> &dyn core::any::Any +pub fn vortex_array::stats::StatsSession::as_any(&self) -> &dyn core::any::Any -pub fn vortex_array::stats::StatsRewriteSession::as_any_mut(&mut self) -> &mut dyn core::any::Any +pub fn vortex_array::stats::StatsSession::as_any_mut(&mut self) -> &mut dyn core::any::Any pub struct vortex_array::stats::ArrayStats @@ -20600,21 +20600,21 @@ pub fn vortex_array::stats::MutTypedStatsSetRef<'_, '_>::is_empty(&self) -> bool pub fn vortex_array::stats::MutTypedStatsSetRef<'_, '_>::len(&self) -> usize -pub struct vortex_array::stats::StatsRewriteSession +pub struct vortex_array::stats::StatsSession -impl core::default::Default for vortex_array::stats::StatsRewriteSession +impl core::default::Default for vortex_array::stats::StatsSession -pub fn vortex_array::stats::StatsRewriteSession::default() -> vortex_array::stats::StatsRewriteSession +pub fn vortex_array::stats::StatsSession::default() -> vortex_array::stats::StatsSession -impl core::fmt::Debug for vortex_array::stats::StatsRewriteSession +impl core::fmt::Debug for vortex_array::stats::StatsSession -pub fn vortex_array::stats::StatsRewriteSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_array::stats::StatsSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl vortex_session::SessionVar for vortex_array::stats::StatsRewriteSession +impl vortex_session::SessionVar for vortex_array::stats::StatsSession -pub fn vortex_array::stats::StatsRewriteSession::as_any(&self) -> &dyn core::any::Any +pub fn vortex_array::stats::StatsSession::as_any(&self) -> &dyn core::any::Any -pub fn vortex_array::stats::StatsRewriteSession::as_any_mut(&mut self) -> &mut dyn core::any::Any +pub fn vortex_array::stats::StatsSession::as_any_mut(&mut self) -> &mut dyn core::any::Any pub struct vortex_array::stats::StatsSet diff --git a/vortex-array/src/stats/rewrite.rs b/vortex-array/src/stats/rewrite.rs index 4d829905d6f..98c9c01f894 100644 --- a/vortex-array/src/stats/rewrite.rs +++ b/vortex-array/src/stats/rewrite.rs @@ -14,7 +14,7 @@ use crate::dtype::DType; use crate::expr::Expression; use crate::expr::or_collect; use crate::scalar_fn::ScalarFnId; -use crate::stats::session::StatsRewriteSessionExt; +use crate::stats::session::StatsSessionExt; /// Shared reference to a stats rewrite rule. pub(crate) type StatsRewriteRuleRef = Arc; @@ -108,8 +108,8 @@ fn rewrite( ) -> VortexResult> { let rules = ctx .session() - .stats_rewrites() - .rules_for(expr.scalar_fn().id()); + .stats() + .rewrite_rules_for(expr.scalar_fn().id()); let Some(rules) = rules else { return Ok(None); }; @@ -140,8 +140,8 @@ mod tests { use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; use crate::scalar_fn::fns::literal::Literal; - use crate::stats::session::StatsRewriteSession; - use crate::stats::session::StatsRewriteSessionExt; + use crate::stats::session::StatsSession; + use crate::stats::session::StatsSessionExt; #[derive(Debug)] struct StaticLiteralRule { @@ -173,13 +173,13 @@ mod tests { #[test] fn combines_multiple_falsifiers_with_or() -> VortexResult<()> { - let session = VortexSession::empty().with::(); + let session = VortexSession::empty().with::(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - session.stats_rewrites().register(StaticLiteralRule { + session.stats().register_rewrite(StaticLiteralRule { falsifier: Some(lit(false)), satisfier: None, }); - session.stats_rewrites().register(StaticLiteralRule { + session.stats().register_rewrite(StaticLiteralRule { falsifier: Some(lit(true)), satisfier: None, }); @@ -193,13 +193,13 @@ mod tests { #[test] fn combines_multiple_satisfiers_with_or() -> VortexResult<()> { - let session = VortexSession::empty().with::(); + let session = VortexSession::empty().with::(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - session.stats_rewrites().register(StaticLiteralRule { + session.stats().register_rewrite(StaticLiteralRule { falsifier: None, satisfier: Some(lit(false)), }); - session.stats_rewrites().register(StaticLiteralRule { + session.stats().register_rewrite(StaticLiteralRule { falsifier: None, satisfier: Some(lit(true)), }); @@ -213,7 +213,7 @@ mod tests { #[test] fn unregistered_expression_has_no_rewrite() -> VortexResult<()> { - let session = VortexSession::empty().with::(); + let session = VortexSession::empty().with::(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); assert_eq!(lit(true).falsify(&dtype, &session)?, None); @@ -223,7 +223,7 @@ mod tests { #[test] fn non_predicate_expression_errors() { - let session = VortexSession::empty().with::(); + let session = VortexSession::empty().with::(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); assert!(lit(7).falsify(&dtype, &session).is_err()); diff --git a/vortex-array/src/stats/session.rs b/vortex-array/src/stats/session.rs index da9fe9dc786..8b1dc639441 100644 --- a/vortex-array/src/stats/session.rs +++ b/vortex-array/src/stats/session.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Session state for stats rewrite rules. +//! Session state for stats APIs. use std::any::Any; use std::sync::Arc; @@ -18,23 +18,23 @@ use crate::stats::rewrite::StatsRewriteRuleRef; type StatsRewriteRuleSet = Arc<[StatsRewriteRuleRef]>; -/// Session state for stats rewrite rules. +/// Session state for stats APIs. #[derive(Debug, Default)] -pub struct StatsRewriteSession { - rules: RwLock>, +pub struct StatsSession { + rewrite_rules: RwLock>, } -impl StatsRewriteSession { +impl StatsSession { /// Register a stats rewrite rule. #[allow(dead_code)] - pub(crate) fn register(&self, rule: R) { - self.register_ref(Arc::new(rule)); + pub(crate) fn register_rewrite(&self, rule: R) { + self.register_rewrite_ref(Arc::new(rule)); } /// Register a shared stats rewrite rule. #[allow(dead_code)] - pub(crate) fn register_ref(&self, rule: StatsRewriteRuleRef) { - let mut rules = self.rules.write(); + pub(crate) fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { + let mut rules = self.rewrite_rules.write(); let rule_id = rule.scalar_fn_id(); let mut updated_rules = rules .get(&rule_id) @@ -45,12 +45,15 @@ impl StatsRewriteSession { } /// Return the rewrite rules registered for `scalar_fn_id`. - pub(crate) fn rules_for(&self, scalar_fn_id: ScalarFnId) -> Option { - self.rules.read().get(&scalar_fn_id).cloned() + pub(crate) fn rewrite_rules_for( + &self, + scalar_fn_id: ScalarFnId, + ) -> Option { + self.rewrite_rules.read().get(&scalar_fn_id).cloned() } } -impl SessionVar for StatsRewriteSession { +impl SessionVar for StatsSession { fn as_any(&self) -> &dyn Any { self } @@ -60,11 +63,11 @@ impl SessionVar for StatsRewriteSession { } } -/// Extension trait for accessing stats rewrite session data. -pub(crate) trait StatsRewriteSessionExt: SessionExt { - /// Returns the stats rewrite rule registry. - fn stats_rewrites(&self) -> Ref<'_, StatsRewriteSession> { - self.get::() +/// Extension trait for accessing stats session data. +pub(crate) trait StatsSessionExt: SessionExt { + /// Returns the stats session state. + fn stats(&self) -> Ref<'_, StatsSession> { + self.get::() } } -impl StatsRewriteSessionExt for S {} +impl StatsSessionExt for S {} diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index 3da75c68a3b..8668de339cb 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -16,7 +16,7 @@ use vortex_array::optimizer::kernels::ArrayKernels; pub use vortex_array::scalar_fn; use vortex_array::scalar_fn::session::ScalarFnSession; use vortex_array::session::ArraySession; -use vortex_array::stats::session::StatsRewriteSession; +use vortex_array::stats::session::StatsSession; use vortex_io::session::RuntimeSession; use vortex_layout::session::LayoutSession; use vortex_session::VortexSession; @@ -168,7 +168,7 @@ impl VortexSessionDefault for VortexSession { .with::() .with::() .with::() - .with::() + .with::() .with::() .with::() .with::();