From 7b9846c98890b67d24930699f6845b7aa511e5e3 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 15 Apr 2026 14:10:42 -0400 Subject: [PATCH] demo tq basic search with serialization Signed-off-by: Connor Tsui --- Cargo.lock | 3 + vortex-bench/Cargo.toml | 14 +- vortex-bench/src/vector_dataset/convert.rs | 503 ++++++++++++++++++++ vortex-bench/src/vector_dataset/mod.rs | 2 + vortex/Cargo.toml | 6 + vortex/examples/tracing_vortex.rs | 7 +- vortex/examples/turboquant_vector_search.rs | 395 +++++++++++++++ 7 files changed, 922 insertions(+), 8 deletions(-) create mode 100644 vortex-bench/src/vector_dataset/convert.rs create mode 100644 vortex/examples/turboquant_vector_search.rs diff --git a/Cargo.lock b/Cargo.lock index 9a31e702e77..e5f2b5d254c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10056,6 +10056,7 @@ dependencies = [ "arrow-array 58.0.0", "codspeed-divan-compat", "fastlanes", + "futures", "mimalloc", "parquet 58.0.0", "paste", @@ -10068,6 +10069,7 @@ dependencies = [ "vortex", "vortex-alp", "vortex-array", + "vortex-bench", "vortex-btrblocks", "vortex-buffer", "vortex-bytebool", @@ -10230,6 +10232,7 @@ dependencies = [ "url", "uuid", "vortex", + "vortex-tensor", ] [[package]] diff --git a/vortex-bench/Cargo.toml b/vortex-bench/Cargo.toml index 8b2bb6efe25..e170a6552a8 100644 --- a/vortex-bench/Cargo.toml +++ b/vortex-bench/Cargo.toml @@ -17,6 +17,14 @@ version = { workspace = true } workspace = true [dependencies] +vortex = { workspace = true, features = [ + "object_store", + "files", + "tokio", + "zstd", +] } +vortex-tensor = { workspace = true } # TODO(connor): In the future, this might be inside vortex. + anyhow = { workspace = true } arrow-array = { workspace = true } arrow-schema = { workspace = true } @@ -57,12 +65,6 @@ tracing-subscriber = { workspace = true, features = [ ] } url = { workspace = true } uuid = { workspace = true, features = ["v4"] } -vortex = { workspace = true, features = [ - "object_store", - "files", - "tokio", - "zstd", -] } [features] unstable_encodings = ["vortex/unstable_encodings"] diff --git a/vortex-bench/src/vector_dataset/convert.rs b/vortex-bench/src/vector_dataset/convert.rs new file mode 100644 index 00000000000..f0c13ff33aa --- /dev/null +++ b/vortex-bench/src/vector_dataset/convert.rs @@ -0,0 +1,503 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +// TODO(connor): Should we re-export this through `conversions.rs`? + +use vortex::array::ArrayRef; +use vortex::array::IntoArray; +use vortex::array::arrays::Chunked; +use vortex::array::arrays::ChunkedArray; +use vortex::array::arrays::ExtensionArray; +use vortex::array::arrays::FixedSizeListArray; +use vortex::array::arrays::List; +use vortex::array::arrays::ListView; +use vortex::array::arrays::Primitive; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::arrays::chunked::ChunkedArrayExt; +use vortex::array::arrays::list::ListArrayExt; +use vortex::array::arrays::listview::recursive_list_from_list_view; +use vortex::array::validity::Validity; +use vortex::dtype::DType; +use vortex::dtype::extension::ExtDType; +use vortex::error::VortexExpect; +use vortex::error::VortexResult; +use vortex::error::vortex_bail; +use vortex::error::vortex_err; +use vortex::extension::EmptyMetadata; +use vortex_tensor::vector::Vector; + +/// Rewrap a list-of-float column as a [`vortex_tensor::vector::Vector`] extension array. +/// +/// Parquet has no fixed-size list logical type, so an embedding column ingested via +/// `parquet_to_vortex_chunks` arrives as `List` (or `List`) even when every row has the +/// same length. +/// +/// This helper validates that every list in `input` has the same length `D` and reconstructs the +/// column as an `Extension(FixedSizeList)` array, which is the type expected by the +/// vector search scalar functions in `vortex-tensor`. +/// +/// The input may be either a single [`ListView`] array or a [`Chunked`] array of lists (the common +/// case after `parquet_to_vortex_chunks`). Chunked inputs are converted chunk-by-chunk and +/// reassembled as a [`ChunkedArray`] of `Extension`. We also convert [`ListView`] to +/// [`List`] so that we know all elements are contiguous (this might be slow). +/// +/// # Errors +/// +/// Returns an error if: +/// - `input` is not a `ListView`, `List`, or `Chunked` array. +/// - The element type is not a float primitive (`f16`, `f32`, or `f64`). +/// - A nullable element dtype (`List`) is accepted as long as the runtime validity is +/// `NonNullable` or `AllValid` since parquet has no non-nullable-element list logical type, so +/// arrow-rs always marks list-of-float element fields as nullable on read regardless of whether +/// any element is actually missing. In that case the elements are rewrapped as non-nullable +/// before being embedded in the FSL. +/// - The element dtype is nullable *and* any element is actually null (i.e., `Validity::AllInvalid` +/// or any `Validity::Array` mask). Vector extension elements must be non-null, and that is +/// verified on construction. +/// - Any row has a different length than the first row. +/// - The list validity is nullable (vector elements cannot be null at the row level). +/// - The input has zero rows (the dimension cannot be inferred from empty input). +pub fn list_to_vector_ext(input: ArrayRef) -> VortexResult { + if let Some(chunked) = input.as_opt::() { + let converted: Vec = chunked + .iter_chunks() + .map(|chunk| list_to_vector_ext(chunk.clone())) + .collect::>()?; + + let Some(first) = converted.first() else { + vortex_bail!("list_to_vector_ext: chunked input has no chunks"); + }; + + let dtype = first.dtype().clone(); + return Ok(ChunkedArray::try_new(converted, dtype)?.into_array()); + } + + // `parquet_to_vortex_chunks` produces `ListView` arrays for list columns by default; + // materialize them into a flat `List` representation before we validate offsets. + if input.as_opt::().is_some() { + let flat = recursive_list_from_list_view(input)?; + return list_to_vector_ext(flat); + } + + let Some(list) = input.as_opt::() else { + vortex_bail!( + "list_to_vector_ext: expected a List array, got dtype {}", + input.dtype() + ); + }; + + if !matches!( + list.list_validity(), + Validity::NonNullable | Validity::AllValid + ) { + vortex_bail!( + "list_to_vector_ext: list rows must be non-nullable for Vector extension wrapping" + ); + } + + let element_dtype = list.element_dtype().clone(); + let DType::Primitive(ptype, elem_nullability) = &element_dtype else { + vortex_bail!( + "list_to_vector_ext: element dtype must be a primitive float, got {}", + element_dtype + ); + }; + if !ptype.is_float() { + vortex_bail!( + "list_to_vector_ext: element type must be float (f16/f32/f64), got {}", + ptype + ); + } + + // Extract the flat elements buffer up front: the nullable-handling branch below + // needs to inspect runtime validity before we can decide whether to rewrap it. + let raw_elements = list.sliced_elements()?; + + let num_rows = input.len(); + if num_rows == 0 { + vortex_bail!("list_to_vector_ext: cannot infer vector dimension from empty input"); + } + + // Walk the offsets array once, reusing the previous iteration's `end` as the + // next iteration's `start`. Each `offset_at` call goes through + // `ListArrayExt::offset_at`, which has a fast path when the offsets child is a + // `Primitive` array (direct slice index). That's the common case after + // `parquet_to_vortex_chunks`, so for a 100K-row column we do ~100K primitive + // slice indexes rather than 200K. The loop body is O(1) either way. + let mut prev_end = list.offset_at(0)?; + let first_end = list.offset_at(1)?; + + let dim = first_end.checked_sub(prev_end).ok_or_else(|| { + vortex_err!("list_to_vector_ext: offsets are not monotonically increasing") + })?; + if dim == 0 { + vortex_bail!("list_to_vector_ext: first row has zero elements"); + } + + prev_end = first_end; + + for i in 1..num_rows { + let end = list.offset_at(i + 1)?; + + let row_len = end + .checked_sub(prev_end) + .vortex_expect("list offsets must be monotonically increasing"); + if row_len != dim { + vortex_bail!( + "list_to_vector_ext: row {} has length {} but expected {}", + i, + row_len, + dim + ); + } + + prev_end = end; + } + + let expected_elements = num_rows + .checked_mul(dim) + .ok_or_else(|| vortex_err!("list_to_vector_ext: num_rows * dim overflows usize"))?; + if raw_elements.len() != expected_elements { + vortex_bail!( + "list_to_vector_ext: elements buffer has length {} but expected {}", + raw_elements.len(), + expected_elements + ); + } + + // Parquet has no non-nullable-element list logical type, so arrow-rs marks every + // `List`'s element field as nullable on read regardless of what the writer intended. + // That propagates through `DType::from_arrow`, so every real embedding parquet file arrives + // shaped as `List` even when every value is present. A nullable element dtype is + // losslessly convertible to a non-nullable FSL as long as the runtime validity is + // `NonNullable`/`AllValid`; we must only reject when a real null is present. + let elements = if elem_nullability.is_nullable() { + let primitive = raw_elements.as_opt::().ok_or_else(|| { + vortex_err!( + "list_to_vector_ext: expected nullable-float elements to downcast to \ + Primitive, got dtype {}", + raw_elements.dtype() + ) + })?; + match primitive.validity()? { + Validity::NonNullable | Validity::AllValid => { + // `to_host_sync` is a no-op for host-resident buffers, so this is a + // metadata change (rebuilding the array with a non-nullable dtype), + // not a data copy. + let byte_buffer = primitive.buffer_handle().to_host_sync(); + PrimitiveArray::from_byte_buffer(byte_buffer, *ptype, Validity::NonNullable) + .into_array() + } + Validity::AllInvalid => { + vortex_bail!( + "list_to_vector_ext: list has nullable element dtype with all-invalid \ + elements; Vector extension elements must be non-null" + ); + } + Validity::Array(_) => { + vortex_bail!( + "list_to_vector_ext: list has nullable element dtype with one or more \ + actual null elements; Vector extension elements must be non-null" + ); + } + } + } else { + raw_elements + }; + + let dim_u32 = u32::try_from(dim) + .map_err(|_| vortex_err!("list_to_vector_ext: dimension {dim} does not fit in u32"))?; + + // Finally, construct the `FixedSizeListArray` and wrap it in a Vector array. + let fsl = FixedSizeListArray::try_new(elements, dim_u32, Validity::NonNullable, num_rows)?; + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) +} + +#[cfg(test)] +mod tests { + use vortex::array::Array; + use vortex::array::ArrayRef; + use vortex::array::IntoArray; + use vortex::array::arrays::BoolArray; + use vortex::array::arrays::ChunkedArray; + use vortex::array::arrays::Extension; + use vortex::array::arrays::List; + use vortex::array::arrays::ListViewArray; + use vortex::array::arrays::PrimitiveArray; + use vortex::array::arrays::extension::ExtensionArrayExt; + use vortex::array::validity::Validity; + use vortex::buffer::BufferMut; + use vortex::dtype::DType; + + use super::list_to_vector_ext; + + /// Build a `List` whose elements carry the given [`Validity`]. Passing + /// `Validity::NonNullable` produces a `List`; any other variant produces + /// a `List`, matching the shape `parquet_to_vortex_chunks` produces for + /// embedding columns after arrow-rs' canonicalization. + fn list_f32_with_element_validity( + values: &[f32], + dim: usize, + element_validity: Validity, + ) -> ArrayRef { + assert_eq!( + values.len() % dim, + 0, + "values.len() must be a multiple of dim" + ); + let num_rows = values.len() / dim; + let elements = PrimitiveArray::new::( + BufferMut::::from_iter(values.iter().copied()).freeze(), + element_validity, + ) + .into_array(); + let mut offsets_buf = BufferMut::::with_capacity(num_rows + 1); + for i in 0..=num_rows { + offsets_buf.push(i32::try_from(i * dim).unwrap()); + } + let offsets = + PrimitiveArray::new::(offsets_buf.freeze(), Validity::NonNullable).into_array(); + Array::::new(elements, offsets, Validity::NonNullable).into_array() + } + + fn list_f32(rows: &[&[f32]]) -> ArrayRef { + let mut elements = BufferMut::::with_capacity(rows.iter().map(|r| r.len()).sum()); + let mut offsets = BufferMut::::with_capacity(rows.len() + 1); + offsets.push(0); + for row in rows { + for &v in row.iter() { + elements.push(v); + } + offsets.push(i32::try_from(elements.len()).unwrap()); + } + + let elements_array = + PrimitiveArray::new::(elements.freeze(), Validity::NonNullable).into_array(); + let offsets_array = + PrimitiveArray::new::(offsets.freeze(), Validity::NonNullable).into_array(); + Array::::new(elements_array, offsets_array, Validity::NonNullable).into_array() + } + + #[test] + fn uniform_list_becomes_vector_extension() { + let list = list_f32(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]); + let wrapped = list_to_vector_ext(list).unwrap(); + assert_eq!(wrapped.len(), 3); + let ext = wrapped.as_opt::().expect("returns Extension"); + assert!(matches!( + ext.storage_array().dtype(), + DType::FixedSizeList(_, 3, _) + )); + } + + #[test] + fn mismatched_row_length_is_rejected() { + let list = list_f32(&[&[1.0, 2.0, 3.0], &[4.0, 5.0]]); + let err = list_to_vector_ext(list).unwrap_err().to_string(); + assert!( + err.contains("row 1 has length 2 but expected 3"), + "unexpected error: {err}", + ); + } + + #[test] + fn non_list_input_is_rejected() { + let primitive = PrimitiveArray::new::( + BufferMut::::from_iter([1.0f32, 2.0, 3.0]).freeze(), + Validity::NonNullable, + ) + .into_array(); + let err = list_to_vector_ext(primitive).unwrap_err().to_string(); + assert!( + err.contains("expected a List array"), + "unexpected error: {err}" + ); + } + + #[test] + fn empty_input_is_rejected() { + let list = list_f32(&[]); + let err = list_to_vector_ext(list).unwrap_err().to_string(); + assert!( + err.contains("cannot infer vector dimension from empty input"), + "unexpected error: {err}", + ); + } + + /// Build a `ListView` whose every row is a length-`dim` slice of the flattened + /// `values` buffer. This shape matches what `parquet_to_vortex_chunks` produces for + /// embedding columns after arrow-rs' canonicalization, and exercises the + /// `list_to_vector_ext` fast-path that collapses `ListView` → `List` before + /// validating offsets. + fn list_view_f32(dim: usize, rows: &[&[f32]]) -> ArrayRef { + let mut values = BufferMut::::with_capacity(rows.len() * dim); + for row in rows { + assert_eq!(row.len(), dim); + for &v in row.iter() { + values.push(v); + } + } + let elements = + PrimitiveArray::new::(values.freeze(), Validity::NonNullable).into_array(); + + let dim_i32 = i32::try_from(dim).unwrap(); + let num_rows = rows.len(); + + let mut offsets_buf = BufferMut::::with_capacity(num_rows); + for i in 0..num_rows { + offsets_buf.push(i32::try_from(i).unwrap() * dim_i32); + } + let offsets = + PrimitiveArray::new::(offsets_buf.freeze(), Validity::NonNullable).into_array(); + + let mut sizes_buf = BufferMut::::with_capacity(num_rows); + for _ in 0..num_rows { + sizes_buf.push(dim_i32); + } + let sizes = + PrimitiveArray::new::(sizes_buf.freeze(), Validity::NonNullable).into_array(); + + ListViewArray::try_new(elements, offsets, sizes, Validity::NonNullable) + .unwrap() + .into_array() + } + + #[test] + fn list_view_input_is_rewrapped_as_vector_extension() { + // Simulates the post-parquet-ingest shape: the `emb` column arrives as a + // ListView, not a List. `list_to_vector_ext` must materialize it via + // `recursive_list_from_list_view` and then validate offsets on the flattened + // `List` form. + let list_view = list_view_f32(3, &[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]); + let wrapped = list_to_vector_ext(list_view).unwrap(); + assert_eq!(wrapped.len(), 2); + let ext = wrapped.as_opt::().expect("returns Extension"); + assert!(matches!( + ext.storage_array().dtype(), + DType::FixedSizeList(_, 3, _) + )); + } + + #[test] + fn all_invalid_list_validity_is_rejected() { + // A list with `Validity::AllInvalid` means every row is null. The Vector + // extension type requires non-nullable elements at the FSL level, so we + // must reject this input rather than silently dropping the validity mask. + let elements = PrimitiveArray::new::( + BufferMut::::from_iter([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).freeze(), + Validity::NonNullable, + ) + .into_array(); + let offsets = PrimitiveArray::new::( + BufferMut::::from_iter([0i32, 3, 6]).freeze(), + Validity::NonNullable, + ) + .into_array(); + let list = Array::::new(elements, offsets, Validity::AllInvalid).into_array(); + + let err = list_to_vector_ext(list).unwrap_err().to_string(); + assert!( + err.contains("list rows must be non-nullable"), + "unexpected error: {err}" + ); + } + + #[test] + fn non_float_element_type_is_rejected() { + // Build a List. + let elements = PrimitiveArray::new::( + BufferMut::::from_iter([1i32, 2, 3, 4]).freeze(), + Validity::NonNullable, + ) + .into_array(); + let offsets = PrimitiveArray::new::( + BufferMut::::from_iter([0i32, 2, 4]).freeze(), + Validity::NonNullable, + ) + .into_array(); + let list = Array::::new(elements, offsets, Validity::NonNullable).into_array(); + + let err = list_to_vector_ext(list).unwrap_err().to_string(); + assert!( + err.contains("element type must be float"), + "unexpected error: {err}", + ); + } + + #[test] + fn nullable_elements_with_real_nulls_are_rejected() { + // A `List` whose elements carry a real `Validity::Array` mask with + // at least one `false` bit has one or more actually-missing values. The + // rejection here is about runtime nulls, not dtype metadata: a nullable + // element dtype with all-valid runtime validity is accepted (see + // `nullable_element_dtype_with_all_valid_elements_is_accepted`), because + // parquet-ingested embeddings always arrive shaped that way even when + // every value is present. A real null, on the other hand, cannot be + // represented in the Vector extension FSL and must be rejected rather + // than silently dropped. + let element_validity = Validity::Array( + BoolArray::from_iter([true, true, false, true, true, true]).into_array(), + ); + let list = + list_f32_with_element_validity(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, element_validity); + + let err = list_to_vector_ext(list).unwrap_err().to_string(); + assert!( + err.contains("one or more actual null elements"), + "unexpected error: {err}" + ); + } + + #[test] + fn nullable_element_dtype_with_all_valid_elements_is_accepted() { + // This is the regression test for the Cohere parquet case: every real + // VectorDBBench parquet file arrives as `List` with + // `Validity::AllValid` elements because parquet has no non-nullable + // list-element logical type and arrow-rs propagates the nullable bit + // through `DType::from_arrow`. `list_to_vector_ext` must accept this + // shape by rewrapping the elements as non-nullable before building the + // FSL, rather than rejecting outright on the dtype metadata. + let list = + list_f32_with_element_validity(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, Validity::AllValid); + + let wrapped = list_to_vector_ext(list).unwrap(); + assert_eq!(wrapped.len(), 2); + let ext = wrapped.as_opt::().expect("returns Extension"); + assert!(matches!( + ext.storage_array().dtype(), + DType::FixedSizeList(_, 3, _) + )); + } + + #[test] + fn nullable_element_dtype_with_all_invalid_elements_is_rejected() { + // A `List` whose elements are `Validity::AllInvalid` means every + // value is missing. Rewrapping as non-nullable would silently drop the + // validity and produce bogus vectors, so this must be rejected. + let list = list_f32_with_element_validity( + &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + 3, + Validity::AllInvalid, + ); + + let err = list_to_vector_ext(list).unwrap_err().to_string(); + assert!( + err.contains("all-invalid elements"), + "unexpected error: {err}" + ); + } + + #[test] + fn chunked_input_with_mixed_dimensions_returns_error() { + let dim_three = list_f32(&[&[1.0, 2.0, 3.0]]); + let dim_two = list_f32(&[&[4.0, 5.0]]); + let chunked = + ChunkedArray::try_new(vec![dim_three.clone(), dim_two], dim_three.dtype().clone()) + .unwrap() + .into_array(); + + let err = list_to_vector_ext(chunked).unwrap_err().to_string(); + assert!(err.contains("Mismatched types"), "unexpected error: {err}"); + } +} diff --git a/vortex-bench/src/vector_dataset/mod.rs b/vortex-bench/src/vector_dataset/mod.rs index fe1e42a68d0..d826aa9fc5d 100644 --- a/vortex-bench/src/vector_dataset/mod.rs +++ b/vortex-bench/src/vector_dataset/mod.rs @@ -22,12 +22,14 @@ //! into per-flavor `.vortex` files, after which the scan driver re-opens those files per iteration. mod catalog; +mod convert; mod download; mod layout; mod paths; pub use catalog::ALL_VECTOR_DATASETS; pub use catalog::VectorDataset; +pub use convert::list_to_vector_ext; pub use download::DatasetPaths; pub use download::download; pub use download::neighbors_url; diff --git a/vortex/Cargo.toml b/vortex/Cargo.toml index f042f568c11..982127a4035 100644 --- a/vortex/Cargo.toml +++ b/vortex/Cargo.toml @@ -54,6 +54,7 @@ anyhow = { workspace = true } arrow-array = { workspace = true } divan = { workspace = true } fastlanes = { workspace = true } +futures = { workspace = true } mimalloc = { workspace = true } parquet = { workspace = true } paste = { workspace = true } @@ -64,6 +65,7 @@ tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } tracing-subscriber = { workspace = true } vortex = { path = ".", features = ["tokio"] } +vortex-bench = { workspace = true, features = ["unstable_encodings"] } vortex-tensor = { workspace = true } [features] @@ -83,6 +85,10 @@ unstable_encodings = [ "vortex-zstd?/unstable_encodings", ] +[[example]] +name = "turboquant_vector_search" +required-features = ["files", "tokio", "unstable_encodings"] + [[bench]] name = "single_encoding_throughput" harness = false diff --git a/vortex/examples/tracing_vortex.rs b/vortex/examples/tracing_vortex.rs index 7fd0267988e..f1227e2e017 100644 --- a/vortex/examples/tracing_vortex.rs +++ b/vortex/examples/tracing_vortex.rs @@ -91,8 +91,11 @@ async fn main() -> Result<(), Box> { Ok(()) } -/// Simulates application activity with various log levels and spans -#[expect(clippy::cognitive_complexity)] +/// Simulates application activity with various log levels and spans. +#[allow( + clippy::cognitive_complexity, + reason = "tracing sometimes triggers this" +)] async fn simulate_application_activity(user_id: u32) { // Simulate HTTP request handling let request_span = span!( diff --git a/vortex/examples/turboquant_vector_search.rs b/vortex/examples/turboquant_vector_search.rs new file mode 100644 index 00000000000..c4302e0c5ad --- /dev/null +++ b/vortex/examples/turboquant_vector_search.rs @@ -0,0 +1,395 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant vector-search roundtrip on a vector-embedding dataset. +//! +//! Load a parquet dataset (cohere-small), wrap the `emb` column as a `Vector` +//! extension, compress with BtrBlocks + TurboQuant, write to an in-memory Vortex file, then read +//! the file back twice: +//! +//! 1. plain scan — decode to canonical `FixedSizeList` and verify the per-element diff +//! against the original. TurboQuant is lossy, so we only check the reconstructed values are +//! within a tolerance. +//! 2. scan with a pushed-down cosine-similarity filter `cosine_similarity(emb, query) > thresh`. +//! The `CosineSimilarity` scalar fn is expressed directly as a filter `Expression`, so row +//! selection happens inside the scan rather than after materialization. +//! +//! The parquet file is cached under `vortex-bench/data//` after the first download. Run +//! with: +//! +//! ```sh +//! cargo run --example turboquant_vector_search \ +//! -p vortex --features unstable_encodings --release +//! ``` + +use std::path::PathBuf; +use std::time::Instant; + +use anyhow::Result; +use anyhow::bail; +use anyhow::ensure; +use futures::TryStreamExt; +use vortex::VortexSessionDefault; +use vortex::array::ArrayRef; +use vortex::array::IntoArray; +use vortex::array::VortexSessionExecute; +use vortex::array::arrays::ChunkedArray; +use vortex::array::arrays::ExtensionArray; +use vortex::array::arrays::FixedSizeListArray; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::arrays::StructArray; +use vortex::array::arrays::extension::ExtensionArrayExt; +use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex::array::arrays::struct_::StructArrayExt; +use vortex::array::expr::col; +use vortex::array::expr::gt; +use vortex::array::expr::lit; +use vortex::array::extension::EmptyMetadata; +use vortex::array::scalar::Scalar; +use vortex::array::scalar_fn::EmptyOptions; +use vortex::array::scalar_fn::ScalarFnVTable; +use vortex::array::scalar_fn::ScalarFnVTableExt; +use vortex::buffer::ByteBuffer; +use vortex::buffer::ByteBufferMut; +use vortex::dtype::DType; +use vortex::dtype::Nullability; +use vortex::dtype::PType; +use vortex::file::ALLOWED_ENCODINGS; +use vortex::file::OpenOptionsSessionExt; +use vortex::file::WriteOptionsSessionExt; +use vortex::file::WriteStrategyBuilder; +use vortex::io::session::RuntimeSessionExt; +use vortex::session::VortexSession; +use vortex_array::ExecutionCtx; +use vortex_array::builtins::ArrayBuiltins; +use vortex_bench::conversions::parquet_to_vortex_chunks; +use vortex_bench::vector_dataset; +use vortex_bench::vector_dataset::TrainLayout; +use vortex_bench::vector_dataset::VectorDataset; +use vortex_bench::vector_dataset::list_to_vector_ext; +use vortex_btrblocks::BtrBlocksCompressorBuilder; +use vortex_error::VortexExpect; +use vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity; +use vortex_tensor::scalar_fns::l2_denorm::L2Denorm; +use vortex_tensor::scalar_fns::sorf_transform::SorfTransform; +use vortex_tensor::vector::AnyVector; +use vortex_tensor::vector::Vector; + +/// Cosine threshold for the demo filter. The query comes from the test split, so it may or may not +/// have nearby rows in the train split. +const COSINE_THRESHOLD: f32 = 0.90; + +/// Slack for checking decoded rows against a predicate that was evaluated on TurboQuant's lossy +/// readthrough representation. +const COSINE_THRESHOLD_TOL: f32 = 0.02; + +/// Regression ceiling on the decoded vs original max-abs-diff for 8-bit TurboQuant on 768-dim f32 +/// embeddings. Observed on cohere-small: ~0.10. Pinned with slack so the check catches large +/// quality regressions without flapping on normal run-to-run variation. +const MAX_ABS_DIFF_TOL: f32 = 0.2; + +#[tokio::main] +async fn main() -> Result<()> { + // Opt in to registering the tensor scalar-fn array plugins before building the session. + // Without this, the TurboQuant-compressed `emb` column cannot be serialized into the Vortex + // file or deserialized on read. + // + // SAFETY: single-threaded setup before any other thread exists. + unsafe { std::env::set_var(vortex_tensor::SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV, "1") }; + + let session = VortexSession::default().with_tokio(); + vortex_tensor::initialize(&session); + println!("session initialized with tensor plugins"); + + let dataset = VectorDataset::CohereSmall100k; // This is one of the smaller datasets. + + // Download the source parquet files. + let dataset_paths = vector_dataset::download(dataset, TrainLayout::Single).await?; + let (_id, query_vector) = get_query_vector(dataset_paths.test).await?; + println!( + "query vector selected (id = {_id}, dim = {})", + query_vector.len() + ); + + // Bring the parquet file into memory so that we can write it as a vortex file (after prep). + let single_train_file = dataset_paths + .train_files + .first() + .vortex_expect("we know that there must be a file here") + .clone(); + + println!("reading parquet into chunked array..."); + let chunked_table = parquet_to_vortex_chunks(single_train_file) + .await? + .into_array(); + let len = chunked_table.len(); + println!("parquet loaded: {len} rows"); + + let id = chunked_table.get_item("id")?; + let emb = chunked_table.get_item("emb")?; + + println!("converting emb column to Vector extension type..."); + let vector_array = list_to_vector_ext(emb)?; + + let fields = [("id", id), ("emb", vector_array)]; + let struct_array = StructArray::from_fields(&fields)?.into_array(); + + println!("compressing with TurboQuant and writing to in-memory Vortex file..."); + let bytes = write_turboquant(&session, struct_array.clone().into_array()).await?; + println!("vortex file written: {} bytes", bytes.len()); + + println!("verifying roundtrip fidelity..."); + verify_roundtrip(&session, &bytes, struct_array.clone()).await?; + + println!("verifying filter pushdown with cosine similarity..."); + verify_filter_pushdown(&session, &bytes, &query_vector, struct_array).await?; + + println!("all checks passed!"); + Ok(()) +} + +async fn get_query_vector(query_vectors_path: PathBuf) -> Result<(usize, Vec)> { + let test_vectors = parquet_to_vortex_chunks(query_vectors_path).await?; + + // Get a random query vector. + let idx = rand::random_range(0..test_vectors.len()); + let struct_scalar = test_vectors.scalar_at(idx)?; + let id_scalar = struct_scalar + .as_struct() + .field("id") + .vortex_expect("test parquet file missing `id` field"); + + ensure!( + id_scalar + .as_primitive() + .as_::() + .vortex_expect("id was not a i64") + == idx as i64 + ); + + let emb_scalar = struct_scalar + .as_struct() + .field("emb") + .vortex_expect("test parquet file missing `emb` field"); + + // Pack into a `Vec`. + let query_vector: Vec = emb_scalar + .as_list() + .elements() + .vortex_expect("somehow had a null test vector") + .iter() + .map(|element| { + element + .as_primitive() + .as_::() + .vortex_expect("value was not a f32") + }) + .collect(); + + Ok((idx, query_vector)) +} + +async fn write_turboquant(session: &VortexSession, array: ArrayRef) -> Result { + let compressor = BtrBlocksCompressorBuilder::default() + .with_turboquant() + .build(); + + // TurboQuant produces `L2Denorm(SorfTransform(FSL(Dict(...))), norms)`. The default write + // allow-list only covers canonical/compressed array encodings, so the tensor scalar-fn + // encodings it emits get rejected during normalization. Extend the set with the two encoding + // IDs this scheme actually uses. + let mut allowed = ALLOWED_ENCODINGS.clone(); + allowed.insert(L2Denorm.id()); + allowed.insert(SorfTransform.id()); + + let strategy = WriteStrategyBuilder::default() + .with_compressor(compressor) + .with_allow_encodings(allowed) + .build(); + + let mut buf = ByteBufferMut::empty(); + session + .write_options() + .with_strategy(strategy) + .write(&mut buf, array.to_array_stream()) + .await?; + Ok(buf.freeze()) +} + +async fn verify_roundtrip( + session: &VortexSession, + bytes: &ByteBuffer, + original: ArrayRef, +) -> Result<()> { + let chunks: Vec = session + .open_options() + .open_buffer(bytes.clone())? + .scan()? + .into_array_stream()? + .try_collect() + .await?; + + let mut ctx = session.create_execution_ctx(); + + let read: StructArray = ChunkedArray::try_new(chunks, original.dtype().clone())? + .into_array() + .execute(&mut ctx)?; + let original: StructArray = original.execute(&mut ctx)?; + ensure!(read.len() == original.len()); + + let read_emb = read.unmasked_field_by_name("emb")?.clone(); + let original_emb = original.unmasked_field_by_name("emb")?.clone(); + + let decoded = flatten_vector_column(read_emb, &mut ctx)?; + let original_decoded = flatten_vector_column(original_emb, &mut ctx)?; + + let (max_abs, mean_abs) = diff_stats(&original_decoded, &decoded); + println!( + "roundtrip fidelity: max_abs_diff = {max_abs:.6}, mean_abs_diff = {mean_abs:.6} \ + (tol = {MAX_ABS_DIFF_TOL})" + ); + if max_abs > MAX_ABS_DIFF_TOL { + bail!("TurboQuant max_abs_diff {max_abs} exceeds tolerance {MAX_ABS_DIFF_TOL}"); + } + + Ok(()) +} + +async fn verify_filter_pushdown( + session: &VortexSession, + bytes: &ByteBuffer, + query: &[f32], + original: ArrayRef, +) -> Result<()> { + // Build the filter as `cosine_similarity(emb, ) > threshold`. The RHS of + // `CosineSimilarity` is a `lit(...)` wrapping a `Vector` scalar; during scan + // evaluation the Literal expands to a ConstantArray whose row count matches the current batch, + // satisfying `CosineSimilarity`'s same-length requirement. The entire expression is pushed + // through `with_filter`, so row selection happens inside the scan rather than after the whole + // column is materialized. + println!("query: {}", preview_vector(query)); + + let query_scalar = build_query_vector_scalar(query)?; + let cosine_expr = CosineSimilarity.new_expr(EmptyOptions, [col("emb"), lit(query_scalar)]); + let filter = gt(cosine_expr, lit(COSINE_THRESHOLD)); + + let scan_start = Instant::now(); + let chunks: Vec = session + .open_options() + .open_buffer(bytes.clone())? + .scan()? + .with_filter(filter) + .into_array_stream()? + .try_collect() + .await?; + let scan_ms = scan_start.elapsed().as_secs_f64() * 1e3; + + let hits: usize = chunks.iter().map(|c| c.len()).sum(); + println!( + "pushed down `cosine_similarity(emb, query) > {COSINE_THRESHOLD}`: {hits} rows survived \ + in {scan_ms:.2} ms" + ); + if hits == 0 { + println!(" no rows survived the filter for this random query"); + return Ok(()); + } + + // Materialize the matching rows and dump each `emb` vector so the reader can see what the + // pushed-down filter actually selected. Vectors are truncated to the first few elements since + // DIM is typically large. + let mut ctx = session.create_execution_ctx(); + let filtered: StructArray = ChunkedArray::try_new(chunks, original.dtype().clone())? + .into_array() + .execute(&mut ctx)?; + + let emb = filtered.unmasked_field_by_name("emb")?.clone(); + let flat = flatten_vector_column(emb, &mut ctx)?; + + let dim = query.len(); + for (i, row) in flat.chunks_exact(dim).enumerate() { + let cos = cosine_similarity(query, row); + ensure!( + cos >= COSINE_THRESHOLD - COSINE_THRESHOLD_TOL, + "filtered row {i} had decoded cosine {cos:+.6}, below threshold {COSINE_THRESHOLD} \ + by more than tolerance {COSINE_THRESHOLD_TOL}" + ); + println!(" match {i}: cos = {cos:+.6} {}", preview_vector(row)); + } + + Ok(()) +} + +/// Plain `dot(a, b) / (||a|| * ||b||)` over two equal-length f32 slices. Used purely for reporting +/// — the actual row selection is done inside the scan by the pushed-down `CosineSimilarity` +/// expression. This lets the reader cross-check that the surviving rows really do clear the +/// threshold once decoded. +fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len()); + let mut dot = 0.0f32; + let mut a_sq = 0.0f32; + let mut b_sq = 0.0f32; + for (&x, &y) in a.iter().zip(b) { + dot += x * y; + a_sq += x * x; + b_sq += y * y; + } + dot / (a_sq.sqrt() * b_sq.sqrt()) +} + +/// Render a vector as `[v0, v1, ..., vN-1, vN]` with the first 4 and last 1 elements at 4-decimal +/// precision. Keeps the output compact for high-dim embeddings while still giving the reader +/// something concrete to eyeball. +fn preview_vector(row: &[f32]) -> String { + let dim = row.len(); + if dim <= 5 { + return format!("[{}] (dim = {dim})", fmt_slice(row)); + } + format!( + "[{}, ..., {}] (dim = {dim})", + fmt_slice(&row[..4]), + fmt_slice(&row[dim - 1..]) + ) +} + +fn fmt_slice(s: &[f32]) -> String { + s.iter() + .map(|v| format!("{v:+.4}")) + .collect::>() + .join(", ") +} + +/// Wrap a query vector in a `Vector` extension scalar suitable for use as the RHS of a +/// `CosineSimilarity` filter expression via `lit(...)`. +fn build_query_vector_scalar(query: &[f32]) -> Result { + let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + let children: Vec = query + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let fsl_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + Ok(Scalar::extension::(EmptyMetadata, fsl_scalar)) +} + +/// Decode a `Vector` extension array's storage down to its flat f32 buffer. +fn flatten_vector_column(emb: ArrayRef, ctx: &mut ExecutionCtx) -> Result> { + let ext: ExtensionArray = emb.execute(ctx)?; + ensure!(ext.ext_dtype().is::()); + + let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; + let elements: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + Ok(elements.as_slice::().to_vec()) +} + +fn diff_stats(original: &[f32], decoded: &[f32]) -> (f32, f32) { + assert_eq!(original.len(), decoded.len()); + let (sum_abs, max_abs) = + original + .iter() + .zip(decoded) + .fold((0.0f32, 0.0f32), |(sum, peak), (&orig, &dec)| { + let diff = (orig - dec).abs(); + (sum + diff, peak.max(diff)) + }); + let mean_abs = sum_abs / original.len() as f32; + (max_abs, mean_abs) +}