diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml index ab3f63583d3..708abf5948c 100644 --- a/vortex-turboquant/Cargo.toml +++ b/vortex-turboquant/Cargo.toml @@ -32,6 +32,7 @@ vortex-utils = { workspace = true, features = ["dashmap"] } divan = { workspace = true } rand = { workspace = true } rstest = { workspace = true } +vortex-array = { workspace = true, features = ["_test-harness"] } vortex-file = { workspace = true } vortex-io = { workspace = true } vortex-layout = { workspace = true } diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs index 7aeb60368dd..e2f3c9329dc 100644 --- a/vortex-turboquant/src/lib.rs +++ b/vortex-turboquant/src/lib.rs @@ -19,8 +19,8 @@ //! The [`TQEncode`] scalar function first computes and stores the original L2 norm for each vector //! row, then normalizes each valid nonzero row internally before SORF transform and scalar //! quantization. The [`TQDecode`] scalar function dequantizes through deterministic centroids, -//! applies the inverse SORF transform, truncates back to the original dimension, and re-applies the -//! stored norm. +//! applies the inverse SORF transform, truncates back to the original dimension, and applies a +//! stored inverse direction-norm correction before re-applying the stored norm. //! //! The encoded storage is a row-aligned extension tree: //! @@ -28,26 +28,30 @@ //! Extension( //! Struct { //! norms: Primitive, +//! inv_direction_norms: Primitive, //! codes: FixedSizeList, padded_dim, vector_validity>, //! } //! ) //! ``` //! -//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Decoded quantized -//! directions are not guaranteed to have unit norm after scalar quantization and inverse transform. +//! Stored norms are authoritative for future TurboQuant-aware scalar functions. The rationale +//! for the `inv_direction_norms` correction field lives next to the storage layout; see +//! `vector/storage.rs`. //! //! # Source map //! //! Implementation details are documented next to the code that owns them: //! -//! - `vector/storage.rs`: physical storage shape, full-length child arrays, and field-level -//! validity for null vectors. -//! - `vector/normalize.rs`: TurboQuant-local normalization and how it differs from the tensor -//! crate's null-row zeroing helper. -//! - `vector/quantize.rs`: SORF transform, centroid lookup, and why invalid rows are skipped rather -//! than quantized. +//! - `vector/storage.rs`: physical storage shape and parsing. +//! - `vector/normalize.rs`: TurboQuant-local normalization and the encode-time finite-norm +//! guard. +//! - `vector/quantize.rs`: SORF transform, centroid lookup, and the per-row +//! `inv_direction_norm` computation. +//! - `scalar_fns/compute/`: session-scoped optimizer kernels that intercept canonical scalar +//! functions over TurboQuant inputs (currently `L2Norm(TQDecode(_))`). //! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching. -//! - `sorf/`: the Walsh-Hadamard-based structured transform and the stable SplitMix64 sign stream. +//! - `sorf/`: Walsh-Hadamard-based structured transform plus the stable SplitMix64 sign +//! stream. //! //! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL //! residual correction for unbiased inner-product estimation, and it still uses internal @@ -75,6 +79,8 @@ pub fn initialize(session: &vortex_session::VortexSession) { session.scalar_fns().register(TQEncode); session.scalar_fns().register(TQDecode); + + scalar_fns::compute::register_kernels(session); } #[cfg(test)] diff --git a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs new file mode 100644 index 00000000000..0948e200fac --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! `L2Norm` execute-parent kernel that intercepts `L2Norm(TQDecode(tq))` and returns the stored +//! per-row norms directly instead of decoding and recomputing. + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFn; +use vortex_array::arrays::scalar_fn::ExactScalarFn; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::dtype::Nullability; +use vortex_array::optimizer::kernels::ArrayKernelsExt; +use vortex_array::optimizer::kernels::ExecuteParentFn; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure_eq; +use vortex_session::VortexSession; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; + +use crate::TQDecode; +use crate::vector::storage::parse_storage_norms_only; + +/// Register the `L2Norm(TQDecode(_))` execute-parent kernel on the session. +pub(super) fn register(session: &VortexSession) { + session.kernels().register_execute_parent( + L2Norm.id(), + TQDecode.id(), + &[l2_norm_tq_decode_execute_parent as ExecuteParentFn], + ); +} + +/// Intercepts `L2Norm(TQDecode(tq_arr))` and returns the stored TurboQuant `norms` field. +/// +/// Semantically valid because [`TQDecode`] renormalizes the lossy quantized direction with the +/// stored inverse direction-norm before re-applying the original row norm, so decoded rows +/// preserve the stored L2 norm. The kernel returns `Ok(None)` for any non-matching parent / +/// child pair so the canonical `L2Norm` path runs unchanged. +/// +/// The result's nullability is coerced to the parent's expected dtype because the stored +/// `norms` child may be wider than the outer struct (a shape [`parse_storage_norms_only`] +/// accepts). +fn l2_norm_tq_decode_execute_parent( + child: &ArrayRef, + parent: &ArrayRef, + _child_idx: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + if !parent.is::>() { + return Ok(None); + } + if !child.is::>() { + return Ok(None); + } + + let tq_array = child.as_::().child_at(0).clone(); + let parsed = parse_storage_norms_only(tq_array, ctx)?; + + let norms_validity = match parent.dtype().nullability() { + Nullability::NonNullable => Validity::NonNullable, + Nullability::Nullable => parsed.vector_validity, + }; + let norms = PrimitiveArray::from_buffer_handle( + parsed.norms.buffer_handle().clone(), + parsed.norms.ptype(), + norms_validity, + ) + .into_array(); + + vortex_ensure_eq!( + norms.dtype(), + parent.dtype(), + "TurboQuant norms field dtype must match L2Norm output dtype" + ); + vortex_ensure_eq!( + norms.len(), + parent.len(), + "TurboQuant norms field length must match L2Norm output length" + ); + + Ok(Some(norms)) +} diff --git a/vortex-turboquant/src/scalar_fns/compute/mod.rs b/vortex-turboquant/src/scalar_fns/compute/mod.rs new file mode 100644 index 00000000000..7dbc9fb9412 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/compute/mod.rs @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant-specific session-scoped optimizer kernels. +//! +//! Each kernel module owns its own +//! [`register_execute_parent`](vortex_array::optimizer::kernels::ArrayKernelsExt::register_execute_parent) +//! call. New kernels (for example `InnerProduct` or `CosineSimilarity`) should be added as +//! sibling modules and threaded through [`register_kernels`] with a single line. + +mod l2_norm; + +use vortex_session::VortexSession; + +/// Register every TurboQuant-specific optimizer kernel on `session`. +/// +/// Called from the crate-level [`crate::initialize`] after the TurboQuant extension type and +/// the `TQEncode` / `TQDecode` scalar functions are registered, so kernels can resolve the +/// scalar-fn ids they intercept. +pub(crate) fn register_kernels(session: &VortexSession) { + l2_norm::register(session); +} diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs index 6791a1aef61..332aaca1a36 100644 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -153,9 +153,10 @@ impl ScalarFnVTable for TQDecode { /// Decode a `TurboQuant` extension array back into a `Vector` extension array. /// -/// The decoded directions are inverse-transformed, truncated to the original dimension, and -/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with -/// [`TQEncode`](crate::TQEncode). +/// The decoded directions are inverse-transformed, truncated to the original dimension, normalized +/// by the stored inverse direction norms, and multiplied by the stored row norms. The conversion is +/// lossy and does not roundtrip with [`TQEncode`](crate::TQEncode), but valid nonzero decoded rows +/// preserve the original stored L2 norm. pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { let parsed = parse_storage(input, ctx)?; let metadata = parsed.metadata; @@ -177,6 +178,7 @@ pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexRe sorf_matrix: &transform, centroids: ¢roids, norms: &parsed.norms, + inv_direction_norms: &parsed.inv_direction_norms, codes: &parsed.codes, }, parsed.vector_validity, @@ -217,6 +219,9 @@ struct DecodeInputs<'a> { centroids: &'a [f32], /// Per-row stored L2 norm of the original input vector, in the element ptype. norms: &'a PrimitiveArray, + /// Per-row reciprocal of the decoded direction's L2 norm, always in f32. See + /// [`crate::vector::storage`] for the sentinel semantics. + inv_direction_norms: &'a PrimitiveArray, /// Flat per-row centroid indices, `num_vectors * padded_dim` bytes. codes: &'a PrimitiveArray, } @@ -236,6 +241,7 @@ where let padded_dim = decode.sorf_matrix.padded_dim(); let centroids = decode.centroids; let norms = decode.norms.as_slice::(); + let inv_direction_norms = decode.inv_direction_norms.as_slice::(); let codes = decode.codes.as_slice::(); let mask = vector_validity.execute_mask(num_vectors, ctx)?; @@ -259,11 +265,12 @@ where decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); let norm = norms[i]; + let inv_direction_norm = inv_direction_norms[i]; for &value in inverse.iter().take(dimensions) { // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, // `f64`): values outside `f16` range saturate to `±inf` rather than returning // `None`. - let value = T::from_f32(value) + let value = T::from_f32(value * inv_direction_norm) .vortex_expect("from_f32 is infallible for supported float types"); // SAFETY: total pushes across all match arms equal `output_len`. diff --git a/vortex-turboquant/src/scalar_fns/encode.rs b/vortex-turboquant/src/scalar_fns/encode.rs index 29ce7cc580a..7c2ad7ea30f 100644 --- a/vortex-turboquant/src/scalar_fns/encode.rs +++ b/vortex-turboquant/src/scalar_fns/encode.rs @@ -12,6 +12,7 @@ use vortex_array::IntoArray; use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; @@ -209,7 +210,14 @@ pub(crate) fn encode_vector( // SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child. unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? } }; - let codes = build_codes_child(num_vectors, core, vector_validity.clone())?; + let inv_direction_norms = + PrimitiveArray::new::(core.inv_direction_norms, vector_validity.clone()).into_array(); + let codes = build_codes_child( + num_vectors, + core.all_indices, + core.padded_dim, + vector_validity.clone(), + )?; let metadata = TurboQuantMetadata { element_ptype, @@ -218,7 +226,13 @@ pub(crate) fn encode_vector( seed: config.seed(), num_rounds: config.num_rounds(), }; - let storage = build_storage(norms, codes, num_vectors, vector_validity)?; + let storage = build_storage( + norms, + inv_direction_norms, + codes, + num_vectors, + vector_validity, + )?; Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array()) } diff --git a/vortex-turboquant/src/scalar_fns/mod.rs b/vortex-turboquant/src/scalar_fns/mod.rs index 1acea9f70f3..3f233659d60 100644 --- a/vortex-turboquant/src/scalar_fns/mod.rs +++ b/vortex-turboquant/src/scalar_fns/mod.rs @@ -3,6 +3,8 @@ //! Scalar functions for lazy TurboQuant vector encode and decode operations. +pub(crate) mod compute; + mod decode; mod encode; mod metadata; diff --git a/vortex-turboquant/src/tests/encode_decode.rs b/vortex-turboquant/src/tests/encode_decode.rs index ed5aab190aa..36d691a3f3a 100644 --- a/vortex-turboquant/src/tests/encode_decode.rs +++ b/vortex-turboquant/src/tests/encode_decode.rs @@ -16,10 +16,12 @@ use vortex_array::dtype::PType; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; use super::execute_tq_decode; use super::execute_tq_encode; use super::f32_vector_array; +use super::tensor_test_session; use super::test_session; use super::turboquant_storage; use super::vector_array; @@ -29,6 +31,7 @@ use super::vector_values_f32; use crate::TurboQuantConfig; use crate::centroids::compute_or_get_centroids; use crate::vector::normalize::tq_normalize_as_l2_denorm; +use crate::vector::storage::parse_storage; #[rstest] #[case::zero_bits(0, 42, 3)] @@ -105,6 +108,10 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { .unmasked_field_by_name("norms")? .clone() .execute(&mut ctx)?; + let inv_direction_norms: PrimitiveArray = storage + .unmasked_field_by_name("inv_direction_norms")? + .clone() + .execute(&mut ctx)?; let codes: FixedSizeListArray = storage .unmasked_field_by_name("codes")? .clone() @@ -114,13 +121,21 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { assert!(!mask.value(1)); assert!(mask.value(2)); assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); + assert_eq!( + inv_direction_norms.validity()?.nullability(), + Nullability::Nullable + ); assert_eq!(codes.validity()?.nullability(), Nullability::Nullable); let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?; + let inv_direction_norms_validity = inv_direction_norms.validity()?.execute_mask(3, &mut ctx)?; let codes_validity = codes.validity()?.execute_mask(3, &mut ctx)?; assert!(norms_validity.value(0)); assert!(!norms_validity.value(1)); assert!(norms_validity.value(2)); + assert!(inv_direction_norms_validity.value(0)); + assert!(!inv_direction_norms_validity.value(1)); + assert!(inv_direction_norms_validity.value(2)); assert!(codes_validity.value(0)); assert!(!codes_validity.value(1)); assert!(codes_validity.value(2)); @@ -134,6 +149,181 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { Ok(()) } +#[test] +fn encode_stores_zero_inv_direction_norm_for_zero_rows() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let mut values = vec![0.0f32; 3 * 128]; + values[0] = 3.0; + values[1] = 4.0; + values[256] = 1.0; + let input = vector_array(128, &values, Validity::NonNullable)?; + + let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; + let storage = turboquant_storage(encoded, &mut ctx)?; + let inv_direction_norms: PrimitiveArray = storage + .unmasked_field_by_name("inv_direction_norms")? + .clone() + .execute(&mut ctx)?; + + let values = inv_direction_norms.as_slice::(); + assert!(values[0].is_finite() && values[0] > 0.0); + assert_eq!(values[1], 0.0); + assert!(values[2].is_finite() && values[2] > 0.0); + Ok(()) +} + +#[test] +fn decode_preserves_original_l2_norms_for_non_power_of_two_dimensions() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(129, 3, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let expected_norms = parse_storage(encoded.clone(), &mut ctx)?.norms; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let decoded_norms: PrimitiveArray = L2Norm::try_new_array(decoded, 3)? + .into_array() + .execute(&mut ctx)?; + + for (actual, expected) in decoded_norms + .as_slice::() + .iter() + .zip(expected_norms.as_slice::()) + { + assert!( + (*actual - *expected).abs() <= 1e-4 * expected.max(1.0), + "decoded norm {actual} did not match stored norm {expected}" + ); + } + Ok(()) +} + +/// Encode rejects rows whose L2 norm is non-finite. Without this guard, a row whose squared +/// sum overflows would normalize to all-zero placeholders and decode-vs-kernel would silently +/// diverge (`NaN` vs `+inf`). +#[test] +fn encode_rejects_non_finite_norms() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + + // A row of `1e30` repeated `dim=128` times has squared sum `128 * 1e60 ≈ 1.28e62`, which + // overflows `f32` (max ≈ 3.4e38) and produces `+inf` when `L2Norm` runs in `f32`. + let values = vec![1e30f32; 128]; + let input = vector_array(128, &values, Validity::NonNullable)?; + + let result = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx); + assert!( + result.is_err(), + "encode must reject non-finite norms (overflow case)" + ); + let error = result.err().unwrap().to_string(); + assert!( + error.contains("non-finite"), + "expected non-finite error, got: {error}" + ); + Ok(()) +} + +/// Encode rejects rows containing `NaN` values, which propagate through `L2Norm` to produce +/// a `NaN` stored norm. +#[test] +fn encode_rejects_nan_input() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + + let mut values = vec![1.0f32; 128]; + values[0] = f32::NAN; + let input = vector_array(128, &values, Validity::NonNullable)?; + + let result = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx); + assert!(result.is_err(), "encode must reject NaN input rows"); + Ok(()) +} + +/// Decode preserves stored L2 norms across element ptypes and padded/unpadded dimensions. +#[rstest] +#[case::f16_dim_128(PType::F16, 128_u32, 1e-2_f32)] +#[case::f16_dim_129(PType::F16, 129_u32, 1e-2_f32)] +#[case::f32_dim_128(PType::F32, 128_u32, 1e-4_f32)] +#[case::f32_dim_129(PType::F32, 129_u32, 1e-4_f32)] +#[case::f32_dim_257(PType::F32, 257_u32, 1e-4_f32)] +#[case::f64_dim_128(PType::F64, 128_u32, 1e-4_f32)] +#[case::f64_dim_129(PType::F64, 129_u32, 1e-4_f32)] +fn decode_preserves_original_l2_norms_across_ptypes_and_dims( + #[case] ptype: PType, + #[case] dim: u32, + #[case] tolerance: f32, +) -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let rows = 3; + let raw = (0..rows * dim as usize) + .map(|i| (i % 17) as f32 - 8.0) + .map(|v| v * 0.25) + .collect::>(); + let input = match ptype { + PType::F16 => { + let values: Vec = raw.iter().copied().map(half::f16::from_f32).collect(); + vector_array(dim, &values, Validity::NonNullable)? + } + PType::F32 => vector_array(dim, &raw, Validity::NonNullable)?, + PType::F64 => { + let values: Vec = raw.iter().copied().map(f64::from).collect(); + vector_array(dim, &values, Validity::NonNullable)? + } + _ => unreachable!("ptype must be float"), + }; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let decoded_norms: PrimitiveArray = L2Norm::try_new_array(decoded, rows)? + .into_array() + .execute(&mut ctx)?; + + // L2Norm returns the element ptype; widen to f32 for comparison. + let actuals: Vec = match ptype { + PType::F16 => decoded_norms + .as_slice::() + .iter() + .map(|v| f32::from(*v)) + .collect(), + PType::F32 => decoded_norms.as_slice::().to_vec(), + PType::F64 => decoded_norms + .as_slice::() + .iter() + .map(|v| { + #[expect( + clippy::cast_possible_truncation, + reason = "norms are bounded by the test's input magnitudes (~|raw| * dim^0.5), \ + well within f32 range" + )] + let widened = *v as f32; + widened + }) + .collect(), + _ => unreachable!(), + }; + + // Recompute expected from the raw f32 input to avoid coupling to internal storage. + let expected: Vec = (0..rows) + .map(|i| { + let row = &raw[i * dim as usize..][..dim as usize]; + row.iter().map(|v| v * v).sum::().sqrt() + }) + .collect(); + + for (actual, exp) in actuals.iter().zip(expected.iter()) { + assert!( + (*actual - *exp).abs() <= tolerance * exp.max(1.0), + "decoded norm {actual} did not match expected {exp} (ptype {ptype:?}, dim {dim})" + ); + } + Ok(()) +} + #[test] fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> { let session = test_session(); diff --git a/vortex-turboquant/src/tests/file.rs b/vortex-turboquant/src/tests/file.rs index e59b7a95c75..2c6b0184460 100644 --- a/vortex-turboquant/src/tests/file.rs +++ b/vortex-turboquant/src/tests/file.rs @@ -3,6 +3,7 @@ use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::arrays::PrimitiveArray; use vortex_array::stream::ArrayStreamExt; use vortex_array::validity::Validity; use vortex_error::VortexResult; @@ -10,6 +11,7 @@ use vortex_file::OpenOptionsSessionExt; use vortex_file::VortexWriteOptions; use vortex_io::runtime::BlockingRuntime; use vortex_io::runtime::single::SingleThreadRuntime; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; use vortex_tensor::vector::Vector; use super::execute_tq_decode_from_metadata; @@ -19,6 +21,7 @@ use super::file_session; use super::vector_validity; use crate::TQDecode; use crate::TurboQuantConfig; +use crate::vector::storage::parse_storage; use crate::vtable::tq_metadata; #[test] @@ -46,6 +49,54 @@ fn file_roundtrip_with_initialize_session() -> VortexResult<()> { Ok(()) } +/// File-roundtrip preserves `inv_direction_norms` and the `L2Norm(TQDecode(_))` fast-path +/// invariant. A regression that silently dropped the field at serialization would only show +/// up downstream as norm divergence; this test surfaces it at the IO layer. +#[test] +fn file_roundtrip_preserves_inv_direction_norms_and_l2_norm_invariant() -> VortexResult<()> { + let runtime = SingleThreadRuntime::default(); + let session = file_session(&runtime); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let original_norms: PrimitiveArray = parse_storage(encoded.clone(), &mut ctx)?.norms; + + let mut file_bytes = Vec::new(); + VortexWriteOptions::new(session.clone()) + .blocking(&runtime) + .write(&mut file_bytes, encoded.to_array_iterator())?; + + let file = session.open_options().open_buffer(file_bytes)?; + let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; + + // The inv_direction_norms field must survive serialization with finite-positive values for + // every valid row. + let parsed = parse_storage(read.clone(), &mut ctx)?; + let inv_direction_norms = parsed.inv_direction_norms.as_slice::(); + assert_eq!(inv_direction_norms.len(), 4); + for &v in inv_direction_norms { + assert!( + v.is_finite() && v > 0.0, + "inv_direction_norm {v} after file roundtrip is not finite-positive" + ); + } + + // Fast-path `L2Norm(TQDecode(_))` must still return the originally stored row norms after + // the file roundtrip. If the kernel or the `inv_direction_norms` field had silently broken + // at serialization, this is where it would surface. + let decoded = TQDecode::try_new_array(read)?.into_array(); + let kernel_norms: PrimitiveArray = L2Norm::try_new_array(decoded, 4)? + .into_array() + .execute(&mut ctx)?; + assert_eq!( + kernel_norms.as_slice::(), + original_norms.as_slice::(), + "L2Norm(TQDecode(read_back)) must equal the originally stored row norms" + ); + Ok(()) +} + #[test] fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResult<()> { let runtime = SingleThreadRuntime::default(); diff --git a/vortex-turboquant/src/tests/kernels.rs b/vortex-turboquant/src/tests/kernels.rs new file mode 100644 index 00000000000..cf58bab352b --- /dev/null +++ b/vortex-turboquant/src/tests/kernels.rs @@ -0,0 +1,251 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Tests for TurboQuant-specific session-scoped optimizer kernels. + +use rstest::rstest; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::assert_arrays_eq; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::dtype::DType; +use vortex_array::dtype::FieldNames; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; + +use super::execute_tq_decode; +use super::execute_tq_encode; +use super::f32_vector_array; +use super::tensor_test_session; +use super::vector_array; +use crate::TQDecode; +use crate::TurboQuant; +use crate::TurboQuantConfig; +use crate::TurboQuantMetadata; +use crate::vector::storage::parse_storage; + +const DIM: u32 = 128; + +/// Fast path: `L2Norm(TQDecode(tq_arr))` returns the storage `norms` field bit-for-bit. +/// +/// `TQDecode` applies the stored inverse direction-norm correction, so decoded vectors preserve +/// these norms. Bit-exact equality is the strongest invariant that confirms the session-registered +/// kernel fired instead of recomputing. +#[test] +fn l2_norm_over_tq_decode_returns_stored_norms() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(DIM, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let expected_norms = parse_storage(encoded.clone(), &mut ctx)?.norms.into_array(); + + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + let row_count = decoded.len(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, row_count)? + .into_array() + .execute(&mut ctx)?; + + assert_arrays_eq!(result, expected_norms); + Ok(()) +} + +/// Negative: directly wrapping a `Vector` (no `TQDecode`) must hit the canonical `L2Norm` path. +/// +/// Proves the kernel only intercepts the matched `(L2Norm, TQDecode)` pair and does not affect +/// the standard tensor scalar-function flow. +#[test] +fn l2_norm_over_plain_vector_uses_canonical_path() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + + let input = vector_array( + 3, + &[ + 3.0f32, 4.0, 0.0, // row 0, norm = 5.0 + 0.0, 0.0, 0.0, // row 1, norm = 0.0 + 1.0, 0.0, 0.0, // row 2, norm = 1.0 + ], + Validity::NonNullable, + )?; + + let row_count = input.len(); + let result: PrimitiveArray = L2Norm::try_new_array(input, row_count)? + .into_array() + .execute(&mut ctx)?; + + let expected = + PrimitiveArray::new::(Buffer::copy_from([5.0f32, 0.0, 1.0]), Validity::NonNullable); + assert_arrays_eq!(result, expected); + Ok(()) +} + +/// Empty input: zero-length TurboQuant array still produces a zero-length norms array of the +/// matching primitive dtype. +#[test] +fn l2_norm_over_empty_tq_decode_is_empty_norms() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = vector_array::(DIM, &[], Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 0)? + .into_array() + .execute(&mut ctx)?; + + assert_eq!(result.len(), 0); + assert_eq!( + result.dtype(), + &DType::Primitive(PType::F32, Nullability::NonNullable) + ); + Ok(()) +} + +/// Null rows: the kernel must preserve the input's row-level validity and produce correct norms +/// for the non-null rows. +#[rstest] +#[case::leading_null(Validity::from_iter([false, true, true]))] +#[case::trailing_null(Validity::from_iter([true, true, false]))] +#[case::interior_null(Validity::from_iter([true, false, true]))] +fn l2_norm_over_tq_decode_preserves_nulls(#[case] validity: Validity) -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(DIM, 3, 0.25, validity)?; + let config = TurboQuantConfig::try_new(4, 7, 2)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let expected_norms = parse_storage(encoded.clone(), &mut ctx)?.norms.into_array(); + + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 3)? + .into_array() + .execute(&mut ctx)?; + + assert_arrays_eq!(result, expected_norms); + Ok(()) +} + +/// Masked input: generic masks narrow the TurboQuant storage struct validity without rewriting the +/// `norms` child, so the kernel must apply the authoritative struct validity before returning. +#[test] +fn l2_norm_over_masked_tq_decode_uses_storage_validity() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(DIM, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let masked = encoded.mask(BoolArray::from_iter([true, false, true, false]).into_array())?; + + let decoded = TQDecode::try_new_array(masked)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 4)? + .into_array() + .execute(&mut ctx)?; + let validity = result.validity()?.execute_mask(4, &mut ctx)?; + + assert!(validity.value(0)); + assert!(!validity.value(1)); + assert!(validity.value(2)); + assert!(!validity.value(3)); + assert_eq!( + result.dtype(), + &DType::Primitive(PType::F32, Nullability::Nullable) + ); + Ok(()) +} + +/// Regression for the wider-child-nullability shape (`Nullable` `norms` with `AllValid` under +/// a `NonNullable` struct). `parse_storage` accepts it; the kernel must return a `NonNullable` +/// result rather than reusing the wider child validity. See `malformed.rs` for the matching +/// decode-side cases. +#[test] +fn l2_norm_over_tq_decode_nullable_norms_under_nonnullable_struct() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: DIM, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + + let norms = + PrimitiveArray::new::(Buffer::copy_from([5.0f32]), Validity::AllValid).into_array(); + let inv_direction_norms = + PrimitiveArray::new::(Buffer::copy_from([1.0f32]), Validity::AllValid).into_array(); + let codes = PrimitiveArray::new::(vec![0u8; DIM as usize], Validity::NonNullable); + let codes = + FixedSizeListArray::try_new(codes.into_array(), DIM, Validity::AllValid, 1)?.into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], + 1, + Validity::NonNullable, + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + let decoded = TQDecode::try_new_array(tq)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 1)? + .into_array() + .execute(&mut ctx)?; + + assert_eq!( + result.dtype(), + &DType::Primitive(PType::F32, Nullability::NonNullable), + "kernel result dtype must match parent (NonNullable), not the wider child validity" + ); + assert_eq!(result.as_slice::(), &[5.0f32]); + Ok(()) +} + +/// Cross-check the kernel result against the canonical `L2Norm(execute(TQDecode))` path. +/// Materializing the decoded vector first breaks the `(L2Norm, TQDecode)` pattern so `L2Norm` +/// runs through the canonical scalar-function flow. +#[rstest] +#[case::dim_128(128_u32)] +#[case::dim_129(129_u32)] +#[case::dim_257(257_u32)] +fn l2_norm_over_tq_decode_matches_canonical(#[case] dim: u32) -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(dim, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + + let kernel_result: PrimitiveArray = + L2Norm::try_new_array(TQDecode::try_new_array(encoded.clone())?.into_array(), 4)? + .into_array() + .execute(&mut ctx)?; + + // Materialize the decoded vector first so `L2Norm` cannot match `(L2Norm, TQDecode)`. The + // resulting `L2Norm(Vector)` flows through the canonical scalar-function path. + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let canonical_result: PrimitiveArray = L2Norm::try_new_array(decoded, 4)? + .into_array() + .execute(&mut ctx)?; + + let kernel = kernel_result.as_slice::(); + let canonical = canonical_result.as_slice::(); + for (k, c) in kernel.iter().zip(canonical.iter()) { + assert!( + (*k - *c).abs() <= 1e-4 * c.max(1.0), + "kernel result {k} disagrees with canonical {c} (dim {dim})" + ); + } + Ok(()) +} diff --git a/vortex-turboquant/src/tests/malformed.rs b/vortex-turboquant/src/tests/malformed.rs index f99f0ee5105..6145cc4e5b2 100644 --- a/vortex-turboquant/src/tests/malformed.rs +++ b/vortex-turboquant/src/tests/malformed.rs @@ -23,21 +23,37 @@ use crate::TurboQuantMetadata; #[rstest] #[case::nullable_norms_under_nonnullable_struct( + Nullability::NonNullable, + Nullability::Nullable, + Nullability::NonNullable, + Nullability::NonNullable +)] +#[case::nullable_inv_direction_norms_under_nonnullable_struct( + Nullability::NonNullable, Nullability::NonNullable, Nullability::Nullable, Nullability::NonNullable )] #[case::nullable_codes_under_nonnullable_struct( + Nullability::NonNullable, Nullability::NonNullable, Nullability::NonNullable, Nullability::Nullable )] #[case::nonnullable_norms_under_nullable_struct( + Nullability::Nullable, + Nullability::NonNullable, + Nullability::Nullable, + Nullability::Nullable +)] +#[case::nonnullable_inv_direction_norms_under_nullable_struct( + Nullability::Nullable, Nullability::Nullable, Nullability::NonNullable, Nullability::Nullable )] #[case::nonnullable_codes_under_nullable_struct( + Nullability::Nullable, Nullability::Nullable, Nullability::Nullable, Nullability::NonNullable @@ -45,6 +61,7 @@ use crate::TurboQuantMetadata; fn decode_accepts_child_nullability_that_covers_struct_validity( #[case] struct_nullability: Nullability, #[case] norms_nullability: Nullability, + #[case] inv_direction_norms_nullability: Nullability, #[case] codes_nullability: Nullability, ) -> VortexResult<()> { let session = test_session(); @@ -59,6 +76,11 @@ fn decode_accepts_child_nullability_that_covers_struct_validity( let norms = PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::from(norms_nullability)) .into_array(); + let inv_direction_norms = PrimitiveArray::new::( + Buffer::copy_from([1.0]), + Validity::from(inv_direction_norms_nullability), + ) + .into_array(); let codes = PrimitiveArray::new::(vec![0u8; 128], Validity::NonNullable); let codes = FixedSizeListArray::try_new( codes.into_array(), @@ -69,8 +91,8 @@ fn decode_accepts_child_nullability_that_covers_struct_validity( .unwrap() .into_array(); let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], 1, Validity::from(struct_nullability), ) @@ -97,12 +119,15 @@ fn decode_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { let norms = PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) .into_array(); + let inv_direction_norms = + PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) + .into_array(); let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 3)? .into_array(); let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], 3, Validity::from_iter([true, false, true]), )?; @@ -133,6 +158,9 @@ fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResu Validity::from_iter([true, true, false]), ) .into_array(); + let inv_direction_norms = + PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) + .into_array(); let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); let codes = FixedSizeListArray::try_new( codes.into_array(), @@ -142,8 +170,44 @@ fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResu )? .into_array(); let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], + 3, + Validity::from_iter([true, false, true]), + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); + Ok(()) +} + +#[test] +fn decode_rejects_inv_direction_norm_masks_that_disagree_with_struct_validity() -> VortexResult<()> +{ + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) + .into_array(); + let inv_direction_norms = PrimitiveArray::new::( + Buffer::copy_from([1.0, 1.0, 1.0]), + Validity::from_iter([true, true, false]), + ) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 3)? + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], 3, Validity::from_iter([true, false, true]), )?; @@ -168,6 +232,8 @@ fn decode_panics_on_codes_outside_centroid_table() { }; let norms = PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::NonNullable).into_array(); + let inv_direction_norms = + PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::NonNullable).into_array(); let mut codes = vec![0u8; 128]; codes[0] = 2; let codes = PrimitiveArray::new::(codes, Validity::NonNullable); @@ -175,8 +241,8 @@ fn decode_panics_on_codes_outside_centroid_table() { .unwrap() .into_array(); let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], 1, Validity::NonNullable, ) diff --git a/vortex-turboquant/src/tests/metadata.rs b/vortex-turboquant/src/tests/metadata.rs index e0d1042f02f..db199bc12fb 100644 --- a/vortex-turboquant/src/tests/metadata.rs +++ b/vortex-turboquant/src/tests/metadata.rs @@ -18,6 +18,7 @@ use vortex_error::vortex_err; use crate::TurboQuant; use crate::TurboQuantMetadata; use crate::vector::storage::CODES_FIELD; +use crate::vector::storage::INV_DIRECTION_NORMS_FIELD; use crate::vector::storage::NORMS_FIELD; use crate::vector::tq_padded_dim; @@ -43,9 +44,10 @@ fn tq_storage_dtype( .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; Ok(DType::Struct( StructFields::new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), + FieldNames::from([NORMS_FIELD, INV_DIRECTION_NORMS_FIELD, CODES_FIELD]), vec![ DType::Primitive(metadata.element_ptype, row_nullability), + DType::Primitive(PType::F32, row_nullability), DType::FixedSizeList( Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), padded_dim, @@ -156,9 +158,10 @@ fn dtype_validation_rejects_malformed_storage() { }; let storage = DType::Struct( StructFields::new( - FieldNames::from(["norms", "codes"]), + FieldNames::from(["norms", "inv_direction_norms", "codes"]), vec![ DType::Primitive(PType::F32, Nullability::Nullable), + DType::Primitive(PType::F64, Nullability::Nullable), DType::FixedSizeList( DType::Primitive(PType::U8, Nullability::Nullable).into(), 128, diff --git a/vortex-turboquant/src/tests/mod.rs b/vortex-turboquant/src/tests/mod.rs index ffa1db175a7..eb0f5da89b4 100644 --- a/vortex-turboquant/src/tests/mod.rs +++ b/vortex-turboquant/src/tests/mod.rs @@ -39,6 +39,7 @@ use crate::initialize; mod encode_decode; mod file; +mod kernels; mod malformed; mod metadata; mod parity; @@ -50,6 +51,17 @@ fn test_session() -> VortexSession { session } +/// In-memory session with both `vortex_tensor` and `vortex_turboquant` initialized. +/// +/// Tests that exercise tensor scalar functions over TurboQuant inputs need `L2Norm` registered +/// alongside the TurboQuant extension type and kernels. +fn tensor_test_session() -> VortexSession { + let session = VortexSession::empty().with::(); + vortex_tensor::initialize(&session); + initialize(&session); + session +} + fn file_session(runtime: &SingleThreadRuntime) -> VortexSession { let session = VortexSession::empty() .with::() diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs index 4360d90849d..e65f132ab57 100644 --- a/vortex-turboquant/src/tests/parity.rs +++ b/vortex-turboquant/src/tests/parity.rs @@ -13,15 +13,24 @@ use super::f32_vector_array; use super::test_session; use super::vector_values_f32; use crate::TurboQuantConfig; +use crate::vector::storage::parse_storage; +/// Pins down the exact relationship between new and legacy TurboQuant decode: for each row, +/// `new_value[i] == old_value[i] * inv_direction_norm[row]`. The centroid table and SORF +/// transform are identical between the two encoders, so the inverse-transformed direction is +/// the same; the only mathematical difference is the per-row scalar correction. #[test] -fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { +fn encode_decode_applies_direction_norm_correction_after_old_turboquant_decode() -> VortexResult<()> +{ let session = test_session(); let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 2, 0.125, Validity::NonNullable)?; + let input = f32_vector_array(129, 2, 0.125, Validity::NonNullable)?; let config = TurboQuantConfig::try_new(3, 42, 3)?; let new_encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; + let parsed = parse_storage(new_encoded.clone(), &mut ctx)?; + let inv_direction_norms = parsed.inv_direction_norms.as_slice::().to_vec(); + let new_decoded = execute_tq_decode(new_encoded, &mut ctx)?; let old_config = OldTurboQuantConfig { bit_width: config.bit_width(), @@ -32,7 +41,30 @@ fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { let new_values = vector_values_f32(new_decoded, &mut ctx)?; let old_values = vector_values_f32(old_decoded, &mut ctx)?; + let dim = new_values.len() / inv_direction_norms.len(); - assert_eq!(new_values, old_values); + // Every coordinate of the new decode should equal the corresponding coordinate of the old + // decode, multiplied by the row's stored `inv_direction_norm`. Tolerance is per-element, + // scaled by the larger of the two values to handle exact zeros gracefully. + for (row, &correction) in inv_direction_norms.iter().enumerate() { + for col in 0..dim { + let idx = row * dim + col; + let new_v = new_values[idx]; + let old_v = old_values[idx]; + let expected = old_v * correction; + let scale = new_v.abs().max(expected.abs()).max(1.0); + assert!( + (new_v - expected).abs() <= 1e-4 * scale, + "row {row} col {col}: new {new_v} != old {old_v} * inv_direction_norm \ + {correction} (= {expected})" + ); + } + } + // Sanity: the correction is meaningfully non-trivial for at least one row (verifies the + // direction-norm field is actually doing work, not a no-op). + assert!( + inv_direction_norms.iter().any(|&c| (c - 1.0).abs() > 1e-4), + "inv_direction_norms should differ from 1.0 for at least one row" + ); Ok(()) } diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs index 642949eecf6..05eadb3792c 100644 --- a/vortex-turboquant/src/vector/normalize.rs +++ b/vortex-turboquant/src/vector/normalize.rs @@ -24,6 +24,7 @@ use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; use vortex_mask::Mask; @@ -102,6 +103,23 @@ where let values = elements.as_slice::(); let norm_values = norms.as_slice::(); + // Reject non-finite norms up front. A `+inf` or `NaN` norm would either come from an input + // row whose sum of squares overflowed `T` or from a pre-existing `NaN` in the data. In + // either case the f32-precision SORF transform downstream cannot represent the row, and + // letting the division `value / norm` proceed silently corrupts encoded data: the + // normalized row becomes all zeros, encode infers a zero-norm row, and the stored + // `inv_direction_norm` sentinel disagrees with the non-finite stored norm at decode time. + // Fail fast instead. Invalid-row norms can carry arbitrary placeholders so they are + // excluded from the check via the row-validity mask. + let mask_for_check = mask; + for (i, &norm) in norm_values.iter().enumerate() { + if mask_for_check.value(i) && !norm.is_finite() { + vortex_bail!( + "TurboQuant input row {i} has non-finite L2 norm; encode requires finite norms" + ); + } + } + let output_len = num_vectors .checked_mul(dimensions) .ok_or_else(|| vortex_err!("TurboQuant normalized vector length overflow"))?; diff --git a/vortex-turboquant/src/vector/quantize.rs b/vortex-turboquant/src/vector/quantize.rs index 0861b9f6805..0521a6021df 100644 --- a/vortex-turboquant/src/vector/quantize.rs +++ b/vortex-turboquant/src/vector/quantize.rs @@ -32,15 +32,24 @@ use crate::centroids::compute_or_get_centroids; use crate::centroids::find_nearest_centroid; use crate::sorf::SorfMatrix; -/// Shared intermediate results from the quantization loop. +/// Intermediate output from the quantization loop, consumed by `encode_vector` to assemble +/// the storage struct. Invalid rows hold zero placeholders in both buffers. pub(crate) struct QuantizationResult { + /// Flat `padded_dim`-strided centroid indices, `num_vectors * padded_dim` entries. pub(crate) all_indices: Buffer, + /// Per-row reciprocal L2 norm of the decoded quantized direction. See the comment inside + /// [`turboquant_quantize_core`] for the `0.0` sentinel cases. + pub(crate) inv_direction_norms: Buffer, + /// SORF padded dimension, `next_power_of_two(dimensions)`. pub(crate) padded_dim: usize, } +/// Build an empty [`QuantizationResult`] for a zero-row input, so the SORF machinery does not +/// run with a zero-length elements buffer. pub(crate) fn empty_quantization(padded_dim: usize) -> QuantizationResult { QuantizationResult { all_indices: Buffer::empty(), + inv_direction_norms: Buffer::empty(), padded_dim, } } @@ -79,29 +88,61 @@ pub(crate) unsafe fn turboquant_quantize_core( .checked_mul(padded_dim) .ok_or_else(|| vortex_err!("TurboQuant codes length overflow"))?; let mut all_indices = BufferMut::::with_capacity(codes_len); + let mut inv_direction_norms = BufferMut::::with_capacity(num_vectors); let mut padded = vec![0.0f32; padded_dim]; let mut transformed = vec![0.0f32; padded_dim]; + let mut dequantized = vec![0.0f32; padded_dim]; + let mut inverse = vec![0.0f32; padded_dim]; // Pad, SORF-transform, and quantize a single row, pushing `padded_dim` codes into - // `all_indices`. Captures the read-only inputs and the scratch buffers so each call site - // only needs to pass `all_indices` and the row index. + // `all_indices` and one inverse direction norm into `inv_direction_norms`. Captures the + // read-only inputs and scratch buffers so each call site only needs to pass the output buffers + // and the row index. // // NB: `all_indices` cannot be captured here: the `Values` arm interleaves the closure call // with direct `all_indices.push_n_unchecked` calls. let f32_slice = f32_elements.as_slice(); let dimension = dimension as usize; - let mut quantize_row = |all_indices: &mut BufferMut, row: usize| { - // Reuse `padded` and `transformed` from the outer scope. - padded[..dimension].copy_from_slice(&f32_slice[row * dimension..][..dimension]); - padded[dimension..].fill(0.0); - sorf_transform.transform(&padded, &mut transformed); - - for &value in &transformed { - // SAFETY: total pushes across all match arms equal `codes_len`. - unsafe { all_indices.push_unchecked(find_nearest_centroid(value, &boundaries)) }; - } - }; + let mut quantize_row = + |all_indices: &mut BufferMut, inv_direction_norms: &mut BufferMut, row: usize| { + // Reuse `padded` and `transformed` from the outer scope. + let row_values = &f32_slice[row * dimension..][..dimension]; + padded[..dimension].copy_from_slice(row_values); + padded[dimension..].fill(0.0); + sorf_transform.transform(&padded, &mut transformed); + + for (&value, dst) in transformed.iter().zip(dequantized.iter_mut()) { + // SAFETY: total pushes across all match arms equal `codes_len`. + let code = find_nearest_centroid(value, &boundaries); + unsafe { all_indices.push_unchecked(code) }; + *dst = centroids[usize::from(code)]; + } + + // The all-zero `row_values` check fires only for valid zero-norm rows (the + // normalize step pushes zero placeholders for those; non-finite input norms are + // rejected earlier). The `is_normal` guard handles the remaining numerical edge: + // a denormal `norm_squared` would produce a huge-or-infinite `recip` that decode + // would propagate as `+inf` / `NaN`. Both cases store `0.0` so decode emits a + // zero row, matching the stored norm. + let inv_direction_norm = if row_values.iter().all(|&value| value == 0.0) { + 0.0 + } else { + sorf_transform.inverse_transform(&dequantized, &mut inverse); + let norm_squared = inverse[..dimension] + .iter() + .map(|value| value * value) + .sum::(); + if norm_squared.is_normal() { + norm_squared.sqrt().recip() + } else { + 0.0 + } + }; + + // SAFETY: total pushes across all match arms equal `num_vectors`. + unsafe { inv_direction_norms.push_unchecked(inv_direction_norm) }; + }; // The total number of pushes is always exactly `num_vectors * padded_dim == codes_len` // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. @@ -112,10 +153,13 @@ pub(crate) unsafe fn turboquant_quantize_core( // SAFETY: `all_indices` was allocated with capacity `codes_len`, and this push // writes exactly `codes_len` zero codes. unsafe { all_indices.push_n_unchecked(0, codes_len) }; + // SAFETY: `inv_direction_norms` was allocated with capacity `num_vectors`, and this + // writes exactly `num_vectors` zero placeholders. + unsafe { inv_direction_norms.push_n_unchecked(0.0, num_vectors) }; } Mask::AllTrue(_) => { for row in 0..num_vectors { - quantize_row(&mut all_indices, row); + quantize_row(&mut all_indices, &mut inv_direction_norms, row); } } Mask::Values(values_mask) => { @@ -125,10 +169,12 @@ pub(crate) unsafe fn turboquant_quantize_core( if start > cursor { // SAFETY: total pushes across all arms equal `codes_len`. unsafe { all_indices.push_n_unchecked(0, (start - cursor) * padded_dim) }; + // SAFETY: total pushes across all arms equal `num_vectors`. + unsafe { inv_direction_norms.push_n_unchecked(0.0, start - cursor) }; } for row in start..end { - quantize_row(&mut all_indices, row); + quantize_row(&mut all_indices, &mut inv_direction_norms, row); } cursor = end; @@ -137,12 +183,15 @@ pub(crate) unsafe fn turboquant_quantize_core( if cursor < num_vectors { // SAFETY: total pushes across all arms equal `codes_len`. unsafe { all_indices.push_n_unchecked(0, (num_vectors - cursor) * padded_dim) }; + // SAFETY: total pushes across all arms equal `num_vectors`. + unsafe { inv_direction_norms.push_n_unchecked(0.0, num_vectors - cursor) }; } } } Ok(QuantizationResult { all_indices: all_indices.freeze(), + inv_direction_norms: inv_direction_norms.freeze(), padded_dim, }) } diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs index d1b4f06cc05..ad24916a759 100644 --- a/vortex-turboquant/src/vector/storage.rs +++ b/vortex-turboquant/src/vector/storage.rs @@ -8,14 +8,20 @@ //! ```text //! Struct { //! norms: Primitive, +//! inv_direction_norms: Primitive, //! codes: FixedSizeList, padded_dim, vector_validity>, //! } //! ``` //! -//! Row nullability is carried on the outer struct and on the `norms` and `codes` field arrays. -//! This is deliberate duplication: null vectors remain null throughout encode/decode instead of being -//! converted into zero vectors. The code bytes for invalid rows are physical placeholders only; the -//! field-level validity records that those rows were not quantized. +//! `inv_direction_norms` is pinned to `f32` regardless of `element_ptype` because the SORF +//! transform and the centroid codebook are both `f32`; storing it wider would add precision the +//! underlying computation does not have. +//! +//! Row nullability is carried on the outer struct AND on every row-aligned field array. This is +//! deliberate duplication: null vectors remain null throughout encode/decode instead of being +//! converted into zero vectors. The code bytes and inverse direction norms for invalid rows are +//! physical placeholders only; the field-level validity records that those rows were not +//! quantized. //! //! Parsing treats the outer struct validity as authoritative. Child validity may be wider than //! the struct validity (for example after a generic mask only updates the struct validity), but @@ -33,18 +39,21 @@ use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::struct_::StructArrayExt; use vortex_array::dtype::FieldNames; use vortex_array::validity::Validity; +use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_mask::Mask; -use super::quantize::QuantizationResult; use crate::vtable::TurboQuantMetadata; use crate::vtable::tq_metadata; /// Name of the stored row-norm child. pub(crate) const NORMS_FIELD: &str = "norms"; +/// Name of the stored inverse quantized-direction norm child. +pub(crate) const INV_DIRECTION_NORMS_FIELD: &str = "inv_direction_norms"; + /// Name of the stored quantized-code child. pub(crate) const CODES_FIELD: &str = "codes"; @@ -58,23 +67,41 @@ pub(crate) struct TurboQuantParsedStorage { pub(crate) vector_validity: Validity, /// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`. pub(crate) norms: PrimitiveArray, + /// Per-row reciprocal L2 norm of the decoded direction (always `f32`). Multiplied through + /// in `TQDecode` so that `L2Norm(TQDecode(_))` preserves the stored row norm. A stored + /// `0.0` is a sentinel telling decode to emit an all-zero row; it pairs with a stored + /// norm of `0.0` for valid zero-norm input rows and for the rare denormal-cancellation + /// case (encode rejects non-finite input norms up front, so those cannot reach this + /// field). + pub(crate) inv_direction_norms: PrimitiveArray, /// Flat `u8` per-row centroid indices, `num_vectors * padded_dim` entries long. pub(crate) codes: PrimitiveArray, /// Row count. pub(crate) len: usize, } +/// Subset of [`TurboQuantParsedStorage`] containing only the `norms` child plus the outer +/// struct validity. Used by the `L2Norm(TQDecode(_))` execute-parent kernel, which has no need +/// for the `codes` or `inv_direction_norms` children. +pub(crate) struct TurboQuantParsedNorms { + /// Authoritative row validity for the quantized vectors. + pub(crate) vector_validity: Validity, + /// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`. + pub(crate) norms: PrimitiveArray, +} + /// Build the `codes: FixedSizeList, padded_dim>` storage child. /// /// Each row of `padded_dim` u8 codes indexes into the deterministic centroid codebook derived /// from `(padded_dim, bit_width)`. The centroid values are intentionally not stored in the array. pub(crate) fn build_codes_child( num_vectors: usize, - quantization: QuantizationResult, + all_indices: Buffer, + padded_dim: usize, vector_validity: Validity, ) -> VortexResult { - let codes = PrimitiveArray::new::(quantization.all_indices, Validity::NonNullable); - let padded_dim_u32 = u32::try_from(quantization.padded_dim) + let codes = PrimitiveArray::new::(all_indices, Validity::NonNullable); + let padded_dim_u32 = u32::try_from(padded_dim) .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; Ok(FixedSizeListArray::try_new( @@ -86,16 +113,17 @@ pub(crate) fn build_codes_child( .into_array()) } -/// Build the TurboQuant `Struct { norms, codes }` storage array. +/// Build the TurboQuant `Struct { norms, inv_direction_norms, codes }` storage array. pub(crate) fn build_storage( norms: ArrayRef, + inv_direction_norms: ArrayRef, codes: ArrayRef, len: usize, vector_validity: Validity, ) -> VortexResult { Ok(StructArray::try_new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), - vec![norms, codes], + FieldNames::from([NORMS_FIELD, INV_DIRECTION_NORMS_FIELD, CODES_FIELD]), + vec![norms, inv_direction_norms, codes], len, vector_validity, )? @@ -103,6 +131,11 @@ pub(crate) fn build_storage( } /// Parse a TurboQuant extension array into executed storage children. +/// +/// Executes all three storage children, validates that every child's row validity covers the +/// outer struct validity, and returns the parsed result. Used by `TQDecode`, which needs every +/// child. Kernels that only need a subset should use a narrower helper (for example +/// [`parse_storage_norms_only`]) to avoid executing the children they will not consume. pub(crate) fn parse_storage( input: ArrayRef, ctx: &mut ExecutionCtx, @@ -116,6 +149,11 @@ pub(crate) fn parse_storage( .clone() .execute(ctx)?; + let inv_direction_norms: PrimitiveArray = storage + .unmasked_field_by_name(INV_DIRECTION_NORMS_FIELD)? + .clone() + .execute(ctx)?; + let codes_fsl: FixedSizeListArray = storage .unmasked_field_by_name(CODES_FIELD)? .clone() @@ -125,22 +163,63 @@ pub(crate) fn parse_storage( let len = storage.len(); let struct_validity = storage.struct_validity(); let norms_validity = norms.validity()?; + let inv_direction_norms_validity = inv_direction_norms.validity()?; let codes_validity = codes_fsl.validity()?; let struct_mask = struct_validity.execute_mask(len, ctx)?; let norms_mask = norms_validity.execute_mask(len, ctx)?; + let inv_direction_norms_mask = inv_direction_norms_validity.execute_mask(len, ctx)?; let codes_mask = codes_validity.execute_mask(len, ctx)?; - validate_child_validity_covers_struct(&struct_mask, &norms_mask, &codes_mask)?; + validate_child_validity_covers_struct( + &struct_mask, + &norms_mask, + &inv_direction_norms_mask, + &codes_mask, + )?; Ok(TurboQuantParsedStorage { metadata, vector_validity: struct_validity, norms, + inv_direction_norms, codes, len, }) } +/// Narrow form of [`parse_storage`] that returns only the `norms` child plus the outer struct +/// validity. Used by the `L2Norm(TQDecode(_))` kernel so the fast path does not execute the +/// `codes` and `inv_direction_norms` children it has no use for. The `norms` child's validity +/// is still validated against the struct's; the other children are not. +pub(crate) fn parse_storage_norms_only( + input: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let ext: ExtensionArray = input.execute(ctx)?; + let storage: StructArray = ext.storage_array().clone().execute(ctx)?; + + let norms: PrimitiveArray = storage + .unmasked_field_by_name(NORMS_FIELD)? + .clone() + .execute(ctx)?; + + let len = storage.len(); + let struct_validity = storage.struct_validity(); + let norms_validity = norms.validity()?; + + let struct_mask = struct_validity.execute_mask(len, ctx)?; + let norms_mask = norms_validity.execute_mask(len, ctx)?; + vortex_ensure!( + struct_mask.bitand_not(&norms_mask).all_false(), + "TurboQuant {NORMS_FIELD} row validity must cover storage validity" + ); + + Ok(TurboQuantParsedNorms { + vector_validity: struct_validity, + norms, + }) +} + /// Validate that both child masks cover the struct mask: every row that the struct considers /// valid must also be valid in the `norms` and `codes` children. /// @@ -150,12 +229,20 @@ pub(crate) fn parse_storage( fn validate_child_validity_covers_struct( struct_mask: &Mask, norms_mask: &Mask, + inv_direction_norms_mask: &Mask, codes_mask: &Mask, ) -> VortexResult<()> { vortex_ensure!( struct_mask.clone().bitand_not(norms_mask).all_false(), "TurboQuant {NORMS_FIELD} row validity must cover storage validity" ); + vortex_ensure!( + struct_mask + .clone() + .bitand_not(inv_direction_norms_mask) + .all_false(), + "TurboQuant {INV_DIRECTION_NORMS_FIELD} row validity must cover storage validity" + ); vortex_ensure!( struct_mask.clone().bitand_not(codes_mask).all_false(), "TurboQuant {CODES_FIELD} row validity must cover storage validity" diff --git a/vortex-turboquant/src/vtable.rs b/vortex-turboquant/src/vtable.rs index 854bcee6c70..4dc5370cc92 100644 --- a/vortex-turboquant/src/vtable.rs +++ b/vortex-turboquant/src/vtable.rs @@ -23,6 +23,7 @@ use vortex_error::vortex_err; use crate::TurboQuantConfig; use crate::config::MIN_DIMENSION; use crate::vector::storage::CODES_FIELD; +use crate::vector::storage::INV_DIRECTION_NORMS_FIELD; use crate::vector::storage::NORMS_FIELD; use crate::vector::tq_padded_dim; @@ -139,6 +140,9 @@ pub(crate) fn tq_metadata(dtype: &DType) -> VortexResult { Ok(*metadata) } +/// Compute the storage [`DType`] for a TurboQuant array with the given `metadata` and outer +/// `row_nullability`. Returns the `Struct { norms, inv_direction_norms, codes }` shape +/// documented in [`crate::vector::storage`]. pub(crate) fn tq_storage_dtype( metadata: &TurboQuantMetadata, row_nullability: Nullability, @@ -148,9 +152,10 @@ pub(crate) fn tq_storage_dtype( Ok(DType::Struct( StructFields::new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), + FieldNames::from([NORMS_FIELD, INV_DIRECTION_NORMS_FIELD, CODES_FIELD]), vec![ DType::Primitive(metadata.element_ptype, row_nullability), + DType::Primitive(PType::F32, row_nullability), DType::FixedSizeList( Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), padded_dim, @@ -187,7 +192,7 @@ fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> Vo let DType::Struct(fields, _) = dtype else { vortex_bail!("TurboQuant storage dtype must be a Struct, got {dtype}"); }; - let expected_names = FieldNames::from([NORMS_FIELD, CODES_FIELD]); + let expected_names = FieldNames::from([NORMS_FIELD, INV_DIRECTION_NORMS_FIELD, CODES_FIELD]); vortex_ensure_eq!( fields.names(), &expected_names, @@ -208,6 +213,20 @@ fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> Vo metadata.element_ptype ); + let Some(inv_direction_norms_dtype) = fields.field(INV_DIRECTION_NORMS_FIELD) else { + vortex_bail!("TurboQuant storage missing {INV_DIRECTION_NORMS_FIELD} field"); + }; + let DType::Primitive(inv_direction_norms_ptype, _) = inv_direction_norms_dtype else { + vortex_bail!( + "TurboQuant {INV_DIRECTION_NORMS_FIELD} field must be primitive, got {inv_direction_norms_dtype}" + ); + }; + vortex_ensure_eq!( + inv_direction_norms_ptype, + PType::F32, + "TurboQuant {INV_DIRECTION_NORMS_FIELD} ptype must be f32, got {inv_direction_norms_ptype}" + ); + let Some(codes_dtype) = fields.field(CODES_FIELD) else { vortex_bail!("TurboQuant storage missing {CODES_FIELD} field"); };