diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index ff897840a3c..e792882dc54 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -143,6 +143,11 @@ name = "expr_case_when" path = "benches/expr/case_when_bench.rs" harness = false +[[bench]] +name = "expr_optimize" +path = "benches/expr/optimize_bench.rs" +harness = false + [[bench]] name = "chunked_dict_builder" harness = false diff --git a/vortex-array/benches/expr/optimize_bench.rs b/vortex-array/benches/expr/optimize_bench.rs new file mode 100644 index 00000000000..16a3f8765ba --- /dev/null +++ b/vortex-array/benches/expr/optimize_bench.rs @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![expect(clippy::unwrap_used)] +#![expect(clippy::cast_possible_truncation)] + +use divan::Bencher; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::StructFields; +use vortex_array::expr::Expression; +use vortex_array::expr::eq; +use vortex_array::expr::get_item; +use vortex_array::expr::lit; +use vortex_array::expr::or; +use vortex_array::expr::root; + +fn main() { + divan::main(); +} + +fn struct_scope() -> DType { + DType::Struct( + StructFields::new( + ["x"].into(), + vec![DType::Primitive(PType::I32, Nullability::NonNullable)], + ), + Nullability::NonNullable, + ) +} + +fn build_or_chain(n: usize) -> Expression { + let base = eq(get_item("x", root()), lit(0i32)); + (1..n).fold(base, |acc, i| { + or(acc, eq(get_item("x", root()), lit(i as i32))) + }) +} + +#[divan::bench(args = [200])] +fn optimize_or_chain(bencher: Bencher, n: usize) { + let expr = build_or_chain(n); + let scope = struct_scope(); + bencher.bench(|| expr.optimize_recursive(&scope).unwrap()); +} diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index a790196e149..e1cd3077409 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -12812,8 +12812,6 @@ pub fn vortex_array::expr::Expression::simplify(&self, &vortex_array::dtype::DTy pub fn vortex_array::expr::Expression::simplify_untyped(&self) -> vortex_error::VortexResult -pub fn vortex_array::expr::Expression::try_optimize(&self, &vortex_array::dtype::DType) -> vortex_error::VortexResult> - pub fn vortex_array::expr::Expression::try_optimize_recursive(&self, &vortex_array::dtype::DType) -> vortex_error::VortexResult> impl core::clone::Clone for vortex_array::expr::Expression diff --git a/vortex-array/src/expr/optimize.rs b/vortex-array/src/expr/optimize.rs index 55b7e6e8df6..27959a96070 100644 --- a/vortex-array/src/expr/optimize.rs +++ b/vortex-array/src/expr/optimize.rs @@ -29,18 +29,22 @@ impl Expression { /// 2. `simplify` - type-aware simplifications /// 3. `reduce` - abstract reduction rules via `ReduceNode`/`ReduceCtx` pub fn optimize(&self, scope: &DType) -> VortexResult { + let cache = SimplifyCache { + scope, + dtype_cache: RefCell::new(HashMap::new()), + }; Ok(self .clone() - .try_optimize(scope)? + .try_optimize(scope, &cache)? .unwrap_or_else(|| self.clone())) } /// Try to optimize the root expression node only, returning None if no optimizations applied. - pub fn try_optimize(&self, scope: &DType) -> VortexResult> { - let cache = SimplifyCache { - scope, - dtype_cache: RefCell::new(HashMap::new()), - }; + fn try_optimize( + &self, + scope: &DType, + cache: &SimplifyCache<'_>, + ) -> VortexResult> { let reduce_ctx = ExpressionReduceCtx { scope: scope.clone(), }; @@ -67,7 +71,7 @@ impl Expression { } // Try simplify (typed) - if let Some(simplified) = current.scalar_fn().simplify(¤t, &cache)? { + if let Some(simplified) = current.scalar_fn().simplify(¤t, cache)? { current = simplified; changed = true; any_optimizations = true; @@ -114,11 +118,28 @@ impl Expression { /// Try to optimize the entire expression tree recursively. pub fn try_optimize_recursive(&self, scope: &DType) -> VortexResult> { + let cache = SimplifyCache { + scope, + dtype_cache: RefCell::new(HashMap::new()), + }; + let result = self.try_optimize_recursive_inner(scope, &cache)?; + + // Apply the between optimization once at the top level only. + // TODO(ngates): remove the "between" optimization, or rewrite it to not always convert + // to CNF? + Ok(Some(find_between(result.unwrap_or_else(|| self.clone())))) + } + + fn try_optimize_recursive_inner( + &self, + scope: &DType, + cache: &SimplifyCache<'_>, + ) -> VortexResult> { let mut current = self.clone(); let mut any_optimizations = false; // First optimize the root - if let Some(optimized) = current.clone().try_optimize(scope)? { + if let Some(optimized) = current.clone().try_optimize(scope, cache)? { current = optimized; any_optimizations = true; } @@ -127,7 +148,7 @@ impl Expression { let mut new_children = Vec::with_capacity(current.children().len()); let mut any_child_optimized = false; for child in current.children().iter() { - if let Some(optimized) = child.try_optimize_recursive(scope)? { + if let Some(optimized) = child.try_optimize_recursive_inner(scope, cache)? { new_children.push(optimized); any_child_optimized = true; } else { @@ -140,15 +161,11 @@ impl Expression { any_optimizations = true; // After updating children, try to optimize root again - if let Some(optimized) = current.clone().try_optimize(scope)? { + if let Some(optimized) = current.clone().try_optimize(scope, cache)? { current = optimized; } } - // TODO(ngates): remove the "between" optimization, or rewrite it to not always convert - // to CNF? - let current = find_between(current); - if any_optimizations { Ok(Some(current)) } else { @@ -294,3 +311,40 @@ impl ReduceCtx for ExpressionReduceCtx { })) } } + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::dtype::StructFields; + use crate::expr::eq; + use crate::expr::get_item; + use crate::expr::lit; + use crate::expr::or; + use crate::expr::root; + + #[test] + fn optimize_or_chain_correctness() -> VortexResult<()> { + let expr = or( + eq(get_item("x", root()), lit(1i32)), + eq(get_item("x", root()), lit(2i32)), + ); + let scope = DType::Struct( + StructFields::new( + ["x"].into(), + vec![DType::Primitive(PType::I32, Nullability::NonNullable)], + ), + Nullability::NonNullable, + ); + let optimized = expr.optimize_recursive(&scope)?; + + let s = optimized.to_string(); + assert!(s.contains("$.x"), "expected $.x in {s}"); + assert!(s.contains("1i32") || s.contains('1'), "expected 1 in {s}"); + assert!(s.contains("2i32") || s.contains('2'), "expected 2 in {s}"); + Ok(()) + } +}