Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions vortex-array/src/compute/conformance/binary_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
//! - Subtraction (`-`)
//! - Multiplication (`*`)
//! - Division (`/`)
//! - Safe division (`//`) — like division but returns 0 on a zero divisor

use itertools::Itertools;
use num_traits::Num;
Expand Down Expand Up @@ -104,11 +105,12 @@ where
.cast(array.dtype())
.vortex_expect("operation should succeed in conformance test");

let operators: [NumericOperator; 4] = [
let operators: [NumericOperator; 5] = [
NumericOperator::Add,
NumericOperator::Sub,
NumericOperator::Mul,
NumericOperator::Div,
NumericOperator::SafeDiv,
];

for operator in operators {
Expand Down Expand Up @@ -338,20 +340,23 @@ where
.cast(array.dtype())
.vortex_expect("operation should succeed in conformance test");

// Only test operators that make sense for the given scalar
// Only test operators that make sense for the given scalar.
// Regular Div is skipped when the scalar is zero (it would error); SafeDiv is always
// exercised because it is specifically defined to yield zero on a zero divisor.
let operators = if scalar_value == T::zero() {
// Skip division by zero
vec![
NumericOperator::Add,
NumericOperator::Sub,
NumericOperator::Mul,
NumericOperator::SafeDiv,
]
} else {
vec![
NumericOperator::Add,
NumericOperator::Sub,
NumericOperator::Mul,
NumericOperator::Div,
NumericOperator::SafeDiv,
]
};

Expand Down
6 changes: 5 additions & 1 deletion vortex-array/src/scalar/typed_view/decimal/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,17 @@ impl<'a> DecimalScalar<'a> {
// Handle null cases using SQL semantics
let result_value = match (self.decimal_value, other.decimal_value) {
(None, _) | (_, None) => None,
(Some(_), Some(rhs)) if op == NumericOperator::SafeDiv && rhs.is_zero() => {
// Non-null zero divisor: return zero rather than erroring.
Some(DecimalValue::zero(&self.decimal_type))
}
(Some(lhs), Some(rhs)) => {
// Perform the operation
let operation_result = match op {
NumericOperator::Add => lhs.checked_add(&rhs),
NumericOperator::Sub => lhs.checked_sub(&rhs),
NumericOperator::Mul => lhs.checked_mul(&rhs),
NumericOperator::Div => lhs.checked_div(&rhs),
NumericOperator::Div | NumericOperator::SafeDiv => lhs.checked_div(&rhs),
}?;

// Check if the result fits within the precision constraints
Expand Down
25 changes: 25 additions & 0 deletions vortex-array/src/scalar/typed_view/decimal/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,31 @@ fn test_decimal_scalar_checked_div_by_zero() {
assert_eq!(result, None);
}

#[test]
fn test_decimal_scalar_safe_div_by_zero() {
use crate::scalar::NumericOperator;

let decimal1 = Scalar::decimal(
DecimalValue::I64(1000),
DecimalDType::new(10, 2),
Nullability::NonNullable,
);
let scalar1 = decimal1.as_decimal();

let decimal2 = Scalar::decimal(
DecimalValue::I64(0),
DecimalDType::new(10, 2),
Nullability::NonNullable,
);
let scalar2 = decimal2.as_decimal();

// SafeDiv returns zero (not None) when the non-null divisor is zero.
let result = scalar1
.checked_binary_numeric(&scalar2, NumericOperator::SafeDiv)
.unwrap();
assert!(result.decimal_value().unwrap().is_zero());
}

#[test]
fn test_decimal_scalar_null_handling() {
use crate::scalar::NumericOperator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ pub enum NumericOperator {
Mul,
/// Binary element-wise division of two arrays or of two scalars.
Div,
/// Binary element-wise division that yields 0 instead of erroring when the right-hand-side
/// is a non-null zero. Integer overflow still errs, matching [`NumericOperator::Div`].
SafeDiv,
}

impl fmt::Display for NumericOperator {
Expand All @@ -33,6 +36,7 @@ impl From<NumericOperator> for crate::scalar_fn::fns::operators::Operator {
NumericOperator::Sub => crate::scalar_fn::fns::operators::Operator::Sub,
NumericOperator::Mul => crate::scalar_fn::fns::operators::Operator::Mul,
NumericOperator::Div => crate::scalar_fn::fns::operators::Operator::Div,
NumericOperator::SafeDiv => crate::scalar_fn::fns::operators::Operator::SafeDiv,
}
}
}
11 changes: 11 additions & 0 deletions vortex-array/src/scalar/typed_view/primitive/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use num_traits::CheckedAdd;
use num_traits::CheckedDiv;
use num_traits::CheckedMul;
use num_traits::CheckedSub;
use num_traits::Zero;
use vortex_error::VortexError;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
Expand Down Expand Up @@ -339,6 +340,9 @@ impl<'a> PrimitiveScalar<'a> {
NumericOperator::Sub => Some(lhs - rhs),
NumericOperator::Mul => Some(lhs * rhs),
NumericOperator::Div => Some(lhs / rhs),
NumericOperator::SafeDiv => {
Some(if rhs == P::zero() { P::zero() } else { lhs / rhs })
}
}
};
Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) })
Expand Down Expand Up @@ -373,6 +377,13 @@ impl<'a> PrimitiveScalar<'a> {
NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some),
NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some),
NumericOperator::Div => lhs.checked_div(&rhs).map(Some),
NumericOperator::SafeDiv => {
if rhs == P::zero() {
Some(Some(P::zero()))
} else {
lhs.checked_div(&rhs).map(Some)
}
}
},
};

Expand Down
70 changes: 70 additions & 0 deletions vortex-array/src/scalar/typed_view/primitive/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,73 @@ fn test_f16_nans_equal() {
let nan3 = f16::from_f16(nan1).unwrap();
assert_eq!(nan1.to_bits(), nan3.to_bits(),);
}

#[test]
fn test_checked_binary_numeric_safe_div_integer() {
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let value1 = ScalarValue::Primitive(PValue::I32(20));
let value2 = ScalarValue::Primitive(PValue::I32(4));
let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap();
let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap();

let result = scalar1
.checked_binary_numeric(&scalar2, NumericOperator::SafeDiv)
.unwrap();
assert_eq!(result.typed_value::<i32>(), Some(5));
}

#[test]
fn test_checked_binary_numeric_safe_div_integer_by_zero() {
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let value1 = ScalarValue::Primitive(PValue::I32(10));
let value2 = ScalarValue::Primitive(PValue::I32(0));
let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap();
let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap();

let result = scalar1
.checked_binary_numeric(&scalar2, NumericOperator::SafeDiv)
.unwrap();
assert_eq!(result.typed_value::<i32>(), Some(0));
}

#[test]
fn test_checked_binary_numeric_safe_div_integer_overflow_still_errs() {
// i32::MIN / -1 overflows; SafeDiv preserves that behavior (returns None).
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let value1 = ScalarValue::Primitive(PValue::I32(i32::MIN));
let value2 = ScalarValue::Primitive(PValue::I32(-1));
let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap();
let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap();

let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::SafeDiv);
assert!(result.is_none());
}

#[test]
fn test_checked_binary_numeric_safe_div_float_by_zero() {
// Float SafeDiv by zero yields 0.0, not Inf or NaN.
let dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
let value1 = ScalarValue::Primitive(PValue::F32(1.0));
let value2 = ScalarValue::Primitive(PValue::F32(0.0));
let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap();
let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap();

let result = scalar1
.checked_binary_numeric(&scalar2, NumericOperator::SafeDiv)
.unwrap();
assert_eq!(result.typed_value::<f32>(), Some(0.0));
}

#[test]
fn test_checked_binary_numeric_safe_div_null_propagates() {
// Null on either side yields null.
let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
let null_scalar = PrimitiveScalar::try_new(&dtype, None).unwrap();
let value = ScalarValue::Primitive(PValue::I32(0));
let zero_scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap();

let result = null_scalar
.checked_binary_numeric(&zero_scalar, NumericOperator::SafeDiv)
.unwrap();
assert_eq!(result.typed_value::<i32>(), None);
}
15 changes: 9 additions & 6 deletions vortex-array/src/scalar_fn/fns/binary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl ScalarFnVTable for Binary {
&self,
op: &Operator,
args: &dyn ExecutionArgs,
_ctx: &mut ExecutionCtx,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let lhs = args.get(0)?;
let rhs = args.get(1)?;
Expand All @@ -156,10 +156,11 @@ impl ScalarFnVTable for Binary {
Operator::Gte => execute_compare(&lhs, &rhs, CompareOperator::Gte),
Operator::And => execute_boolean(&lhs, &rhs, Operator::And),
Operator::Or => execute_boolean(&lhs, &rhs, Operator::Or),
Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add),
Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub),
Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul),
Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div),
Operator::Add => execute_numeric(&lhs, &rhs, NumericOperator::Add, ctx),
Operator::Sub => execute_numeric(&lhs, &rhs, NumericOperator::Sub, ctx),
Operator::Mul => execute_numeric(&lhs, &rhs, NumericOperator::Mul, ctx),
Operator::Div => execute_numeric(&lhs, &rhs, NumericOperator::Div, ctx),
Operator::SafeDiv => execute_numeric(&lhs, &rhs, NumericOperator::SafeDiv, ctx),
}
}

Expand Down Expand Up @@ -263,7 +264,9 @@ impl ScalarFnVTable for Binary {
lhs.stat_falsification(catalog)?,
rhs.stat_falsification(catalog)?,
)),
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div | Operator::SafeDiv => {
None
}
}
}

Expand Down
Loading
Loading