diff --git a/vortex/examples/turboquant_vector_search.rs b/vortex/examples/turboquant_vector_search.rs index c4302e0c5ad..47e1e56ba3a 100644 --- a/vortex/examples/turboquant_vector_search.rs +++ b/vortex/examples/turboquant_vector_search.rs @@ -77,7 +77,7 @@ use vortex_tensor::vector::Vector; /// Cosine threshold for the demo filter. The query comes from the test split, so it may or may not /// have nearby rows in the train split. -const COSINE_THRESHOLD: f32 = 0.90; +const COSINE_THRESHOLD: f32 = 0.85; /// Slack for checking decoded rows against a predicate that was evaluated on TurboQuant's lossy /// readthrough representation. @@ -99,13 +99,14 @@ async fn main() -> Result<()> { let session = VortexSession::default().with_tokio(); vortex_tensor::initialize(&session); + let mut ctx = session.create_execution_ctx(); println!("session initialized with tensor plugins"); let dataset = VectorDataset::CohereSmall100k; // This is one of the smaller datasets. // Download the source parquet files. let dataset_paths = vector_dataset::download(dataset, TrainLayout::Single).await?; - let (_id, query_vector) = get_query_vector(dataset_paths.test).await?; + let (_id, query_vector) = get_query_vector(dataset_paths.test, &mut ctx).await?; println!( "query vector selected (id = {_id}, dim = {})", query_vector.len() @@ -142,18 +143,21 @@ async fn main() -> Result<()> { verify_roundtrip(&session, &bytes, struct_array.clone()).await?; println!("verifying filter pushdown with cosine similarity..."); - verify_filter_pushdown(&session, &bytes, &query_vector, struct_array).await?; + verify_filter_pushdown(&session, &bytes, &query_vector, struct_array, &mut ctx).await?; println!("all checks passed!"); Ok(()) } -async fn get_query_vector(query_vectors_path: PathBuf) -> Result<(usize, Vec)> { +async fn get_query_vector( + query_vectors_path: PathBuf, + ctx: &mut ExecutionCtx, +) -> Result<(usize, Vec)> { let test_vectors = parquet_to_vortex_chunks(query_vectors_path).await?; // Get a random query vector. let idx = rand::random_range(0..test_vectors.len()); - let struct_scalar = test_vectors.scalar_at(idx)?; + let struct_scalar = test_vectors.execute_scalar(idx, ctx)?; let id_scalar = struct_scalar .as_struct() .field("id") @@ -260,6 +264,7 @@ async fn verify_filter_pushdown( bytes: &ByteBuffer, query: &[f32], original: ArrayRef, + ctx: &mut ExecutionCtx, ) -> Result<()> { // Build the filter as `cosine_similarity(emb, ) > threshold`. The RHS of // `CosineSimilarity` is a `lit(...)` wrapping a `Vector` scalar; during scan @@ -297,13 +302,12 @@ async fn verify_filter_pushdown( // Materialize the matching rows and dump each `emb` vector so the reader can see what the // pushed-down filter actually selected. Vectors are truncated to the first few elements since // DIM is typically large. - let mut ctx = session.create_execution_ctx(); let filtered: StructArray = ChunkedArray::try_new(chunks, original.dtype().clone())? .into_array() - .execute(&mut ctx)?; + .execute(ctx)?; let emb = filtered.unmasked_field_by_name("emb")?.clone(); - let flat = flatten_vector_column(emb, &mut ctx)?; + let flat = flatten_vector_column(emb, ctx)?; let dim = query.len(); for (i, row) in flat.chunks_exact(dim).enumerate() {