Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vortex-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ name = "expr_case_when"
path = "benches/expr/case_when_bench.rs"
harness = false

[[bench]]
name = "expr_optimize"
path = "benches/expr/optimize_bench.rs"
harness = false

[[bench]]
name = "chunked_dict_builder"
harness = false
Expand Down
45 changes: 45 additions & 0 deletions vortex-array/benches/expr/optimize_bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

#![expect(clippy::unwrap_used)]
#![expect(clippy::cast_possible_truncation)]

use divan::Bencher;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::dtype::StructFields;
use vortex_array::expr::Expression;
use vortex_array::expr::eq;
use vortex_array::expr::get_item;
use vortex_array::expr::lit;
use vortex_array::expr::or;
use vortex_array::expr::root;

fn main() {
divan::main();
}

fn struct_scope() -> DType {
DType::Struct(
StructFields::new(
["x"].into(),
vec![DType::Primitive(PType::I32, Nullability::NonNullable)],
),
Nullability::NonNullable,
)
}

fn build_or_chain(n: usize) -> Expression {
let base = eq(get_item("x", root()), lit(0i32));
(1..n).fold(base, |acc, i| {
or(acc, eq(get_item("x", root()), lit(i as i32)))
})
}

#[divan::bench(args = [200])]
fn optimize_or_chain(bencher: Bencher, n: usize) {
let expr = build_or_chain(n);
let scope = struct_scope();
bencher.bench(|| expr.optimize_recursive(&scope).unwrap());
}
2 changes: 0 additions & 2 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12812,8 +12812,6 @@ pub fn vortex_array::expr::Expression::simplify(&self, &vortex_array::dtype::DTy

pub fn vortex_array::expr::Expression::simplify_untyped(&self) -> vortex_error::VortexResult<vortex_array::expr::Expression>

pub fn vortex_array::expr::Expression::try_optimize(&self, &vortex_array::dtype::DType) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub fn vortex_array::expr::Expression::try_optimize_recursive(&self, &vortex_array::dtype::DType) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

impl core::clone::Clone for vortex_array::expr::Expression
Expand Down
82 changes: 68 additions & 14 deletions vortex-array/src/expr/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,22 @@ impl Expression {
/// 2. `simplify` - type-aware simplifications
/// 3. `reduce` - abstract reduction rules via `ReduceNode`/`ReduceCtx`
pub fn optimize(&self, scope: &DType) -> VortexResult<Expression> {
let cache = SimplifyCache {
scope,
dtype_cache: RefCell::new(HashMap::new()),
};
Ok(self
.clone()
.try_optimize(scope)?
.try_optimize(scope, &cache)?
.unwrap_or_else(|| self.clone()))
}

/// Try to optimize the root expression node only, returning None if no optimizations applied.
pub fn try_optimize(&self, scope: &DType) -> VortexResult<Option<Expression>> {
let cache = SimplifyCache {
scope,
dtype_cache: RefCell::new(HashMap::new()),
};
fn try_optimize(
&self,
scope: &DType,
cache: &SimplifyCache<'_>,
) -> VortexResult<Option<Expression>> {
let reduce_ctx = ExpressionReduceCtx {
scope: scope.clone(),
};
Expand All @@ -67,7 +71,7 @@ impl Expression {
}

// Try simplify (typed)
if let Some(simplified) = current.scalar_fn().simplify(&current, &cache)? {
if let Some(simplified) = current.scalar_fn().simplify(&current, cache)? {
current = simplified;
changed = true;
any_optimizations = true;
Expand Down Expand Up @@ -114,11 +118,28 @@ impl Expression {

/// Try to optimize the entire expression tree recursively.
pub fn try_optimize_recursive(&self, scope: &DType) -> VortexResult<Option<Expression>> {
let cache = SimplifyCache {
scope,
dtype_cache: RefCell::new(HashMap::new()),
};
let result = self.try_optimize_recursive_inner(scope, &cache)?;

// Apply the between optimization once at the top level only.
// TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
// to CNF?
Ok(Some(find_between(result.unwrap_or_else(|| self.clone()))))
}

fn try_optimize_recursive_inner(
&self,
scope: &DType,
cache: &SimplifyCache<'_>,
) -> VortexResult<Option<Expression>> {
let mut current = self.clone();
let mut any_optimizations = false;

// First optimize the root
if let Some(optimized) = current.clone().try_optimize(scope)? {
if let Some(optimized) = current.clone().try_optimize(scope, cache)? {
current = optimized;
any_optimizations = true;
}
Expand All @@ -127,7 +148,7 @@ impl Expression {
let mut new_children = Vec::with_capacity(current.children().len());
let mut any_child_optimized = false;
for child in current.children().iter() {
if let Some(optimized) = child.try_optimize_recursive(scope)? {
if let Some(optimized) = child.try_optimize_recursive_inner(scope, cache)? {
new_children.push(optimized);
any_child_optimized = true;
} else {
Expand All @@ -140,15 +161,11 @@ impl Expression {
any_optimizations = true;

// After updating children, try to optimize root again
if let Some(optimized) = current.clone().try_optimize(scope)? {
if let Some(optimized) = current.clone().try_optimize(scope, cache)? {
current = optimized;
}
}

// TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
// to CNF?
let current = find_between(current);

if any_optimizations {
Ok(Some(current))
} else {
Expand Down Expand Up @@ -294,3 +311,40 @@ impl ReduceCtx for ExpressionReduceCtx {
}))
}
}

#[cfg(test)]
mod tests {
use vortex_error::VortexResult;

use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::dtype::StructFields;
use crate::expr::eq;
use crate::expr::get_item;
use crate::expr::lit;
use crate::expr::or;
use crate::expr::root;

#[test]
fn optimize_or_chain_correctness() -> VortexResult<()> {
let expr = or(
eq(get_item("x", root()), lit(1i32)),
eq(get_item("x", root()), lit(2i32)),
);
let scope = DType::Struct(
StructFields::new(
["x"].into(),
vec![DType::Primitive(PType::I32, Nullability::NonNullable)],
),
Nullability::NonNullable,
);
let optimized = expr.optimize_recursive(&scope)?;

let s = optimized.to_string();
assert!(s.contains("$.x"), "expected $.x in {s}");
assert!(s.contains("1i32") || s.contains('1'), "expected 1 in {s}");
assert!(s.contains("2i32") || s.contains('2'), "expected 2 in {s}");
Ok(())
}
}
Loading