Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13514,13 +13514,13 @@ pub fn vortex_array::expr::Expression::children(&self) -> &alloc::sync::Arc<allo

pub fn vortex_array::expr::Expression::display_tree(&self) -> impl core::fmt::Display

pub fn vortex_array::expr::Expression::falsify(&self, &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>
pub fn vortex_array::expr::Expression::falsify(&self, &vortex_array::dtype::DType, &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub fn vortex_array::expr::Expression::fmt_sql(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result

pub fn vortex_array::expr::Expression::return_dtype(&self, &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::expr::Expression::satisfy(&self, &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>
pub fn vortex_array::expr::Expression::satisfy(&self, &vortex_array::dtype::DType, &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub fn vortex_array::expr::Expression::scalar_fn(&self) -> &vortex_array::scalar_fn::ScalarFnRef

Expand Down Expand Up @@ -20524,21 +20524,21 @@ pub mod vortex_array::stats::flatbuffers

pub mod vortex_array::stats::session

pub struct vortex_array::stats::session::StatsRewriteSession
pub struct vortex_array::stats::session::StatsSession

impl core::default::Default for vortex_array::stats::StatsRewriteSession
impl core::default::Default for vortex_array::stats::StatsSession

pub fn vortex_array::stats::StatsRewriteSession::default() -> vortex_array::stats::StatsRewriteSession
pub fn vortex_array::stats::StatsSession::default() -> vortex_array::stats::StatsSession

impl core::fmt::Debug for vortex_array::stats::StatsRewriteSession
impl core::fmt::Debug for vortex_array::stats::StatsSession

pub fn vortex_array::stats::StatsRewriteSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result
pub fn vortex_array::stats::StatsSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl vortex_session::SessionVar for vortex_array::stats::StatsRewriteSession
impl vortex_session::SessionVar for vortex_array::stats::StatsSession

pub fn vortex_array::stats::StatsRewriteSession::as_any(&self) -> &dyn core::any::Any
pub fn vortex_array::stats::StatsSession::as_any(&self) -> &dyn core::any::Any

pub fn vortex_array::stats::StatsRewriteSession::as_any_mut(&mut self) -> &mut dyn core::any::Any
pub fn vortex_array::stats::StatsSession::as_any_mut(&mut self) -> &mut dyn core::any::Any

pub struct vortex_array::stats::ArrayStats

Expand Down Expand Up @@ -20600,21 +20600,21 @@ pub fn vortex_array::stats::MutTypedStatsSetRef<'_, '_>::is_empty(&self) -> bool

pub fn vortex_array::stats::MutTypedStatsSetRef<'_, '_>::len(&self) -> usize

pub struct vortex_array::stats::StatsRewriteSession
pub struct vortex_array::stats::StatsSession

impl core::default::Default for vortex_array::stats::StatsRewriteSession
impl core::default::Default for vortex_array::stats::StatsSession

pub fn vortex_array::stats::StatsRewriteSession::default() -> vortex_array::stats::StatsRewriteSession
pub fn vortex_array::stats::StatsSession::default() -> vortex_array::stats::StatsSession

impl core::fmt::Debug for vortex_array::stats::StatsRewriteSession
impl core::fmt::Debug for vortex_array::stats::StatsSession

pub fn vortex_array::stats::StatsRewriteSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result
pub fn vortex_array::stats::StatsSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl vortex_session::SessionVar for vortex_array::stats::StatsRewriteSession
impl vortex_session::SessionVar for vortex_array::stats::StatsSession

pub fn vortex_array::stats::StatsRewriteSession::as_any(&self) -> &dyn core::any::Any
pub fn vortex_array::stats::StatsSession::as_any(&self) -> &dyn core::any::Any

pub fn vortex_array::stats::StatsRewriteSession::as_any_mut(&mut self) -> &mut dyn core::any::Any
pub fn vortex_array::stats::StatsSession::as_any_mut(&mut self) -> &mut dyn core::any::Any

pub struct vortex_array::stats::StatsSet

Expand Down
20 changes: 16 additions & 4 deletions vortex-array/src/expr/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,30 @@ impl Expression {

/// Returns an expression that proves this predicate is definitely false from stats.
///
/// `scope` is the dtype of the row this expression evaluates over.
///
/// If the returned expression evaluates to `true` for a stats scope, this expression is
/// guaranteed to be false for every row in that scope. `false` and `null` are unknown.
pub fn falsify(&self, session: &VortexSession) -> VortexResult<Option<Expression>> {
crate::stats::rewrite::StatsRewriteCtx::new(session).falsify(self)
pub fn falsify(
&self,
scope: &DType,
session: &VortexSession,
) -> VortexResult<Option<Expression>> {
crate::stats::rewrite::StatsRewriteCtx::new(session, scope).falsify(self)
}

/// Returns an expression that proves this predicate is definitely true from stats.
///
/// `scope` is the dtype of the row this expression evaluates over.
///
/// If the returned expression evaluates to `true` for a stats scope, this expression is
/// guaranteed to be true for every row in that scope. `false` and `null` are unknown.
pub fn satisfy(&self, session: &VortexSession) -> VortexResult<Option<Expression>> {
crate::stats::rewrite::StatsRewriteCtx::new(session).satisfy(self)
pub fn satisfy(
&self,
scope: &DType,
session: &VortexSession,
) -> VortexResult<Option<Expression>> {
crate::stats::rewrite::StatsRewriteCtx::new(session, scope).satisfy(self)
}

/// Returns an expression representing the zoned statistic for the given stat, if available.
Expand Down
76 changes: 58 additions & 18 deletions vortex-array/src/stats/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ use std::fmt::Debug;
use std::sync::Arc;

use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_session::VortexSession;

use crate::dtype::DType;
use crate::expr::Expression;
use crate::expr::or_collect;
use crate::scalar_fn::ScalarFnId;
use crate::stats::session::StatsRewriteSessionExt;
use crate::stats::session::StatsSessionExt;

/// Shared reference to a stats rewrite rule.
pub(crate) type StatsRewriteRuleRef = Arc<dyn StatsRewriteRule>;
Expand Down Expand Up @@ -54,28 +56,45 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static {
/// Context passed to stats rewrite rules.
pub(crate) struct StatsRewriteCtx<'a> {
session: &'a VortexSession,
scope: &'a DType,
}

impl<'a> StatsRewriteCtx<'a> {
/// Create a rewrite context for `session`.
pub(crate) fn new(session: &'a VortexSession) -> Self {
Self { session }
pub(crate) fn new(session: &'a VortexSession, scope: &'a DType) -> Self {
Self { session, scope }
}

/// Returns the session that owns the rewrite registry.
pub(crate) fn session(&self) -> &'a VortexSession {
self.session
}

/// Return the dtype of `expr` within this rewrite scope.
pub(crate) fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
expr.return_dtype(self.scope)
}

/// Rewrite `expr` into a stats-backed falsifier.
pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
self.ensure_predicate(expr)?;
rewrite(expr, self, StatsRewriteRule::falsify)
}

/// Rewrite `expr` into a stats-backed satisfier.
pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
self.ensure_predicate(expr)?;
rewrite(expr, self, StatsRewriteRule::satisfy)
}

fn ensure_predicate(&self, expr: &Expression) -> VortexResult<()> {
let dtype = self.return_dtype(expr)?;
vortex_ensure!(
matches!(dtype, DType::Bool(_)),
"Stats rewrites require a boolean predicate, got {dtype}",
);
Ok(())
}
}

fn rewrite(
Expand All @@ -89,8 +108,8 @@ fn rewrite(
) -> VortexResult<Option<Expression>> {
let rules = ctx
.session()
.stats_rewrites()
.rules_for(expr.scalar_fn().id());
.stats()
.rewrite_rules_for(expr.scalar_fn().id());
let Some(rules) = rules else {
return Ok(None);
};
Expand All @@ -112,14 +131,17 @@ mod tests {

use super::StatsRewriteCtx;
use super::StatsRewriteRule;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::expr::Expression;
use crate::expr::lit;
use crate::expr::or;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnVTable;
use crate::scalar_fn::fns::literal::Literal;
use crate::stats::session::StatsRewriteSession;
use crate::stats::session::StatsRewriteSessionExt;
use crate::stats::session::StatsSession;
use crate::stats::session::StatsSessionExt;

#[derive(Debug)]
struct StaticLiteralRule {
Expand Down Expand Up @@ -151,42 +173,60 @@ mod tests {

#[test]
fn combines_multiple_falsifiers_with_or() -> VortexResult<()> {
let session = VortexSession::empty().with::<StatsRewriteSession>();
session.stats_rewrites().register(StaticLiteralRule {
let session = VortexSession::empty().with::<StatsSession>();
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
session.stats().register_rewrite(StaticLiteralRule {
falsifier: Some(lit(false)),
satisfier: None,
});
session.stats_rewrites().register(StaticLiteralRule {
session.stats().register_rewrite(StaticLiteralRule {
falsifier: Some(lit(true)),
satisfier: None,
});

assert_eq!(lit(7).falsify(&session)?, Some(or(lit(false), lit(true))));
assert_eq!(
lit(true).falsify(&dtype, &session)?,
Some(or(lit(false), lit(true)))
);
Ok(())
}

#[test]
fn combines_multiple_satisfiers_with_or() -> VortexResult<()> {
let session = VortexSession::empty().with::<StatsRewriteSession>();
session.stats_rewrites().register(StaticLiteralRule {
let session = VortexSession::empty().with::<StatsSession>();
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
session.stats().register_rewrite(StaticLiteralRule {
falsifier: None,
satisfier: Some(lit(false)),
});
session.stats_rewrites().register(StaticLiteralRule {
session.stats().register_rewrite(StaticLiteralRule {
falsifier: None,
satisfier: Some(lit(true)),
});

assert_eq!(lit(7).satisfy(&session)?, Some(or(lit(false), lit(true))));
assert_eq!(
lit(true).satisfy(&dtype, &session)?,
Some(or(lit(false), lit(true)))
);
Ok(())
}

#[test]
fn unregistered_expression_has_no_rewrite() -> VortexResult<()> {
let session = VortexSession::empty().with::<StatsRewriteSession>();
let session = VortexSession::empty().with::<StatsSession>();
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);

assert_eq!(lit(7).falsify(&session)?, None);
assert_eq!(lit(7).satisfy(&session)?, None);
assert_eq!(lit(true).falsify(&dtype, &session)?, None);
assert_eq!(lit(true).satisfy(&dtype, &session)?, None);
Ok(())
}

#[test]
fn non_predicate_expression_errors() {
let session = VortexSession::empty().with::<StatsSession>();
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);

assert!(lit(7).falsify(&dtype, &session).is_err());
assert!(lit(7).satisfy(&dtype, &session).is_err());
}
}
39 changes: 21 additions & 18 deletions vortex-array/src/stats/session.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Session state for stats rewrite rules.
//! Session state for stats APIs.

use std::any::Any;
use std::sync::Arc;
Expand All @@ -18,23 +18,23 @@ use crate::stats::rewrite::StatsRewriteRuleRef;

type StatsRewriteRuleSet = Arc<[StatsRewriteRuleRef]>;

/// Session state for stats rewrite rules.
/// Session state for stats APIs.
#[derive(Debug, Default)]
pub struct StatsRewriteSession {
rules: RwLock<HashMap<ScalarFnId, StatsRewriteRuleSet>>,
pub struct StatsSession {
rewrite_rules: RwLock<HashMap<ScalarFnId, StatsRewriteRuleSet>>,
}

impl StatsRewriteSession {
impl StatsSession {
/// Register a stats rewrite rule.
#[allow(dead_code)]
pub(crate) fn register<R: StatsRewriteRule>(&self, rule: R) {
self.register_ref(Arc::new(rule));
pub(crate) fn register_rewrite<R: StatsRewriteRule>(&self, rule: R) {
self.register_rewrite_ref(Arc::new(rule));
}

/// Register a shared stats rewrite rule.
#[allow(dead_code)]
pub(crate) fn register_ref(&self, rule: StatsRewriteRuleRef) {
let mut rules = self.rules.write();
pub(crate) fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) {
let mut rules = self.rewrite_rules.write();
let rule_id = rule.scalar_fn_id();
let mut updated_rules = rules
.get(&rule_id)
Expand All @@ -45,12 +45,15 @@ impl StatsRewriteSession {
}

/// Return the rewrite rules registered for `scalar_fn_id`.
pub(crate) fn rules_for(&self, scalar_fn_id: ScalarFnId) -> Option<StatsRewriteRuleSet> {
self.rules.read().get(&scalar_fn_id).cloned()
pub(crate) fn rewrite_rules_for(
&self,
scalar_fn_id: ScalarFnId,
) -> Option<StatsRewriteRuleSet> {
self.rewrite_rules.read().get(&scalar_fn_id).cloned()
}
}

impl SessionVar for StatsRewriteSession {
impl SessionVar for StatsSession {
fn as_any(&self) -> &dyn Any {
self
}
Expand All @@ -60,11 +63,11 @@ impl SessionVar for StatsRewriteSession {
}
}

/// Extension trait for accessing stats rewrite session data.
pub(crate) trait StatsRewriteSessionExt: SessionExt {
/// Returns the stats rewrite rule registry.
fn stats_rewrites(&self) -> Ref<'_, StatsRewriteSession> {
self.get::<StatsRewriteSession>()
/// Extension trait for accessing stats session data.
pub(crate) trait StatsSessionExt: SessionExt {
/// Returns the stats session state.
fn stats(&self) -> Ref<'_, StatsSession> {
self.get::<StatsSession>()
}
}
impl<S: SessionExt> StatsRewriteSessionExt for S {}
impl<S: SessionExt> StatsSessionExt for S {}
4 changes: 2 additions & 2 deletions vortex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use vortex_array::optimizer::kernels::ArrayKernels;
pub use vortex_array::scalar_fn;
use vortex_array::scalar_fn::session::ScalarFnSession;
use vortex_array::session::ArraySession;
use vortex_array::stats::session::StatsRewriteSession;
use vortex_array::stats::session::StatsSession;
use vortex_io::session::RuntimeSession;
use vortex_layout::session::LayoutSession;
use vortex_session::VortexSession;
Expand Down Expand Up @@ -168,7 +168,7 @@ impl VortexSessionDefault for VortexSession {
.with::<ArraySession>()
.with::<LayoutSession>()
.with::<ScalarFnSession>()
.with::<StatsRewriteSession>()
.with::<StatsSession>()
.with::<ArrayKernels>()
.with::<AggregateFnSession>()
.with::<RuntimeSession>();
Expand Down
Loading