Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::Cosine

pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity

impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity

pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>

pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity

pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions
Expand Down Expand Up @@ -284,6 +290,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::inner_product::InnerProdu

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::clone(&self) -> vortex_tensor::scalar_fns::inner_product::InnerProduct

impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::inner_product::InnerProduct

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::inner_product::InnerProduct

pub type vortex_tensor::scalar_fns::inner_product::InnerProduct::Options = vortex_array::scalar_fn::vtable::EmptyOptions
Expand Down Expand Up @@ -322,6 +334,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::l2_denorm::L2Denorm

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::clone(&self) -> vortex_tensor::scalar_fns::l2_denorm::L2Denorm

impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::l2_denorm::L2Denorm

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_denorm::L2Denorm

pub type vortex_tensor::scalar_fns::l2_denorm::L2Denorm::Options = vortex_array::scalar_fn::vtable::EmptyOptions
Expand Down Expand Up @@ -362,6 +380,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::l2_norm::L2Norm

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor::scalar_fns::l2_norm::L2Norm

impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm

pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_array::scalar_fn::vtable::EmptyOptions
Expand Down Expand Up @@ -444,6 +468,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::sorf_transform::SorfTrans

pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::clone(&self) -> vortex_tensor::scalar_fns::sorf_transform::SorfTransform

impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform

pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>

pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform

pub type vortex_tensor::scalar_fns::sorf_transform::SorfTransform::Options = vortex_tensor::scalar_fns::sorf_transform::SorfOptions
Expand Down
8 changes: 1 addition & 7 deletions vortex-tensor/src/encodings/turboquant/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ mod nullable;
mod roundtrip;
mod structural;

use std::sync::LazyLock;

use rand::SeedableRng;
use rand::rngs::StdRng;
use rand_distr::Distribution;
Expand All @@ -29,22 +27,18 @@ use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::extension::EmptyMetadata;
use vortex_array::session::ArraySession;
use vortex_array::validity::Validity;
use vortex_buffer::BufferMut;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_session::VortexSession;

use crate::encodings::turboquant::TurboQuantConfig;
use crate::encodings::turboquant::turboquant_encode_unchecked;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
use crate::tests::SESSION;
use crate::vector::Vector;

static SESSION: LazyLock<VortexSession> =
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());

/// Create a FixedSizeListArray of random f32 vectors with the given validity.
fn make_fsl_with_validity(
num_rows: usize,
Expand Down
35 changes: 30 additions & 5 deletions vortex-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
//! including unit vectors, spherical coordinates, and similarity measures such as cosine
//! similarity.

use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin;
use vortex_array::dtype::session::DTypeSessionExt;
use vortex_array::scalar_fn::session::ScalarFnSessionExt;
use vortex_array::session::ArraySessionExt;
use vortex_session::VortexSession;

use crate::fixed_shape::FixedShapeTensor;
Expand Down Expand Up @@ -34,9 +36,32 @@ pub fn initialize(session: &VortexSession) {
session.dtypes().register(Vector);
session.dtypes().register(FixedShapeTensor);

session.scalar_fns().register(CosineSimilarity);
session.scalar_fns().register(InnerProduct);
session.scalar_fns().register(L2Denorm);
session.scalar_fns().register(L2Norm);
session.scalar_fns().register(SorfTransform);
let session_fns = session.scalar_fns();
let session_arrays = session.arrays();

session_fns.register(CosineSimilarity);
session_fns.register(InnerProduct);
session_fns.register(L2Denorm);
session_fns.register(L2Norm);
session_fns.register(SorfTransform);

session_arrays.register(ScalarFnArrayPlugin::new(CosineSimilarity));
session_arrays.register(ScalarFnArrayPlugin::new(InnerProduct));
session_arrays.register(ScalarFnArrayPlugin::new(L2Denorm));
session_arrays.register(ScalarFnArrayPlugin::new(L2Norm));
session_arrays.register(ScalarFnArrayPlugin::new(SorfTransform));
}

#[cfg(test)]
mod tests {
use std::sync::LazyLock;

use vortex_array::session::ArraySession;
use vortex_session::VortexSession;

pub static SESSION: LazyLock<VortexSession> = LazyLock::new(|| {
let session = VortexSession::empty().with::<ArraySession>();
crate::initialize(&session);
session
});
}
85 changes: 79 additions & 6 deletions vortex-tensor/src/scalar_fns/cosine_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ use vortex_array::IntoArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::scalar_fn::ExactScalarFn;
use vortex_array::arrays::scalar_fn::ScalarFnArrayView;
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts;
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
Expand All @@ -25,11 +28,14 @@ use vortex_array::scalar_fn::ExecutionArgs;
use vortex_array::scalar_fn::ScalarFn;
use vortex_array::scalar_fn::ScalarFnId;
use vortex_array::scalar_fn::ScalarFnVTable;
use vortex_array::serde::ArrayChildren;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_session::VortexSession;

use crate::scalar_fns::inner_product::BinaryTensorOpMetadata;
use crate::scalar_fns::inner_product::InnerProduct;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm;
Expand Down Expand Up @@ -221,6 +227,37 @@ impl ScalarFnVTable for CosineSimilarity {
}
}

impl ScalarFnArrayVTable for CosineSimilarity {
fn serialize(
&self,
view: &ScalarFnArrayView<Self>,
_session: &VortexSession,
) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(BinaryTensorOpMetadata::encode_from_view(view)?))
}

fn deserialize(
&self,
_dtype: &DType,
len: usize,
metadata: &[u8],
children: &dyn ArrayChildren,
session: &VortexSession,
) -> VortexResult<ScalarFnArrayParts<Self>> {
let reconstructed = BinaryTensorOpMetadata::decode_children(
metadata,
len,
children,
session,
"CosineSimilarity",
)?;
Ok(ScalarFnArrayParts {
options: EmptyOptions,
children: reconstructed,
})
}
}

impl CosineSimilarity {
/// Both sides are `L2Denorm`: treat the normalized children as authoritative, so
/// `cosine_similarity = dot(n_l, n_r)`.
Expand Down Expand Up @@ -292,31 +329,28 @@ impl CosineSimilarity {

#[cfg(test)]
mod tests {
use std::sync::LazyLock;

use rstest::rstest;
use vortex_array::ArrayPlugin;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::MaskedArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::session::ArraySession;
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;
use vortex_session::VortexSession;

use crate::scalar_fns::cosine_similarity::CosineSimilarity;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::tests::SESSION;
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::constant_tensor_array;
use crate::utils::test_helpers::constant_vector_array;
use crate::utils::test_helpers::tensor_array;
use crate::utils::test_helpers::vector_array;

static SESSION: LazyLock<VortexSession> =
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());

/// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
let scalar_fn = CosineSimilarity::new().erased();
Expand Down Expand Up @@ -693,4 +727,43 @@ mod tests {
);
Ok(())
}

#[rstest]
#[case::vector(
vector_array(3, &[1.0, 0.0, 0.0, 3.0, 4.0, 0.0]).unwrap(),
vector_array(3, &[0.0, 1.0, 0.0, 3.0, 4.0, 0.0]).unwrap(),
2,
)]
#[case::fixed_shape_tensor(
tensor_array(&[2], &[1.0, 0.0, 3.0, 4.0]).unwrap(),
tensor_array(&[2], &[0.0, 1.0, 3.0, 4.0]).unwrap(),
2,
)]
fn serde_round_trip(
#[case] lhs: ArrayRef,
#[case] rhs: ArrayRef,
#[case] len: usize,
) -> VortexResult<()> {
let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone(), len)?.into_array();

let plugin = ScalarFnArrayPlugin::new(CosineSimilarity);
let metadata = plugin
.serialize(&original, &SESSION)?
.expect("CosineSimilarity serialize must produce metadata");

let children = vec![lhs, rhs];
let recovered = plugin.deserialize(
original.dtype(),
original.len(),
&metadata,
&[],
&children,
&SESSION,
)?;

assert_eq!(recovered.dtype(), original.dtype());
assert_eq!(recovered.len(), original.len());
assert_eq!(recovered.encoding_id(), original.encoding_id());
Ok(())
}
}
Loading
Loading