Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 85 additions & 18 deletions vortex-tensor/src/scalar_fns/cosine_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::scalar_fn::ScalarFnArrayView;
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts;
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::expr::Expression;
Expand All @@ -28,7 +27,6 @@ use vortex_array::scalar_fn::ScalarFnId;
use vortex_array::scalar_fn::ScalarFnVTable;
use vortex_array::scalar_fn::TypedScalarFnInstance;
use vortex_array::serde::ArrayChildren;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_error::VortexResult;
use vortex_session::VortexSession;
Expand Down Expand Up @@ -144,7 +142,7 @@ impl ScalarFnVTable for CosineSimilarity {
// Take any L2Denorm-wrapped fast path that applies.
match DenormOrientation::classify(&lhs_ref, &rhs_ref) {
DenormOrientation::Both { lhs, rhs } => {
return self.execute_both_denorm(lhs, rhs, len);
return self.execute_both_denorm(lhs, rhs, len, ctx);
}
DenormOrientation::One { denorm, plain } => {
return self.execute_one_denorm(denorm, plain, len, ctx);
Expand Down Expand Up @@ -244,22 +242,39 @@ impl CosineSimilarity {
lhs_ref: &ArrayRef,
rhs_ref: &ArrayRef,
len: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;

let (normalized_l, _) = extract_l2_denorm_children(lhs_ref);
let (normalized_r, _) = extract_l2_denorm_children(rhs_ref);
let (normalized_l, norms_l) = extract_l2_denorm_children(lhs_ref);
let (normalized_r, norms_r) = extract_l2_denorm_children(rhs_ref);

// `L2Denorm` makes the normalized children authoritative, so their dot product is the
// cosine similarity even for lossy storage wrappers.
let dot = InnerProduct::try_new_array(normalized_l, normalized_r, len)?.into_array();

if !matches!(validity, Validity::NonNullable) {
// Masking always changes the nullability to nullable.
dot.mask(validity.to_array(len))
} else {
Ok(dot)
}
// cosine similarity even for lossy storage wrappers, except that a zero stored norm still
// represents a zero vector.
let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r, len)?
.into_array()
.execute(ctx)?;
let norms_l: PrimitiveArray = norms_l.execute(ctx)?;
let norms_r: PrimitiveArray = norms_r.execute(ctx)?;

match_each_float_ptype!(dot.ptype(), |T| {
let dots = dot.as_slice::<T>();
let norms_l = norms_l.as_slice::<T>();
let norms_r = norms_r.as_slice::<T>();
let buffer: Buffer<T> = (0..len)
.map(|i| {
if norms_l[i] == T::zero() || norms_r[i] == T::zero() {
T::zero()
} else {
dots[i]
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely premature optimization, but I wonder if rustc is able to see that this can be written in a branchless way.

let either_is_zero = norms_l[i] == T::zero() || norms_r[i] == T::zero();
T::from(!either_is_zero) * dots[i]

})
.collect();

// SAFETY: The buffer length equals `len`, which matches the source validity length.
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
})
}

/// One side is `L2Denorm`: treat the normalized child as authoritative, so
Expand All @@ -275,25 +290,28 @@ impl CosineSimilarity {
) -> VortexResult<ArrayRef> {
let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?;

let (normalized, _) = extract_l2_denorm_children(denorm_ref);
let (normalized, denorm_norms) = extract_l2_denorm_children(denorm_ref);

let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)?;
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;

let denorm_norms: PrimitiveArray = denorm_norms.execute(ctx)?;

let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?;
let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?;

// TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
// TODO(connor): This can be written in a more SIMD-friendly manner.
match_each_float_ptype!(dot.ptype(), |T| {
let dots = dot.as_slice::<T>();
let norms = plain_norm.as_slice::<T>();
let denorm_norms = denorm_norms.as_slice::<T>();
let plain_norms = plain_norm.as_slice::<T>();
let buffer: Buffer<T> = (0..len)
.map(|i| {
if norms[i] == T::zero() {
if denorm_norms[i] == T::zero() || plain_norms[i] == T::zero() {
T::zero()
} else {
dots[i] / norms[i]
dots[i] / plain_norms[i]
}
})
.collect();
Expand Down Expand Up @@ -596,6 +614,55 @@ mod tests {
Ok(())
}

#[test]
fn both_denorm_lossy_zero_stored_norm_returns_zero() -> VortexResult<()> {
// Mimics a lossy encoding (e.g. TurboQuant) where the stored norm is authoritative but
// the decoded normalized child is physically nonzero. With a stored norm of `0.0`, cosine
// similarity for that row must be `0.0` even though the dot product of the normalized
// children is nonzero.
let normalized_l = tensor_array(&[2], &[0.6, 0.8])?;
let norms_l = PrimitiveArray::from_iter([0.0f64]).into_array();
// SAFETY: This is a focused test that intentionally violates the unit-norm invariant by
// pairing a nonzero normalized row with a stored norm of `0.0`, mimicking lossy storage.
let lhs = unsafe { L2Denorm::new_array_unchecked(normalized_l, norms_l, 1)? }.into_array();

let normalized_r = tensor_array(&[2], &[0.6, 0.8])?;
let norms_r = PrimitiveArray::from_iter([0.0f64]).into_array();
// SAFETY: Same as above for the rhs operand.
let rhs = unsafe { L2Denorm::new_array_unchecked(normalized_r, norms_r, 1)? }.into_array();

// `dot(normalized_l, normalized_r) = 1.0`, but the authoritative stored norms are both
// `0.0`, so cosine similarity must be `0.0`.
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]);
Ok(())
}

#[test]
fn one_side_denorm_lossy_zero_stored_norm_returns_zero() -> VortexResult<()> {
// Mimics a lossy encoding (e.g. TurboQuant) where the stored norm is authoritative but
// the decoded normalized child is physically nonzero. The plain side is a normal nonzero
// tensor with positive norm. cosine similarity must still be `0.0` because the
// authoritative stored norm on the denorm side is `0.0`.
let normalized = tensor_array(&[2], &[0.6, 0.8])?;
let norms = PrimitiveArray::from_iter([0.0f64]).into_array();
// SAFETY: This is a focused test that intentionally pairs a nonzero normalized row with a
// stored norm of `0.0`, mimicking lossy storage where the stored norm is authoritative.
let denorm = unsafe { L2Denorm::new_array_unchecked(normalized, norms, 1)? }.into_array();

let plain = tensor_array(&[2], &[1.0, 0.0])?;

// Denorm on the lhs: `One { denorm: lhs, plain: rhs }`.
assert_close(
&eval_cosine_similarity(denorm.clone(), plain.clone(), 1)?,
&[0.0],
);

// Denorm on the rhs: `One { denorm: rhs, plain: lhs }`. The same zero-norm guard must
// fire regardless of operand order.
assert_close(&eval_cosine_similarity(plain, denorm, 1)?, &[0.0]);
Ok(())
}

#[test]
fn constant_lhs_matches_plain_tensor() -> VortexResult<()> {
// The constant query `[1, 2, 2]` has norm 3, so its normalized form is `[1/3, 2/3, 2/3]`.
Expand Down
Loading