From d1fee4077e1142ade7832789c4f01b40db6e4520 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 13 May 2026 14:42:52 -0400 Subject: [PATCH 1/3] Add TurboQuant L2Norm decode readthrough Signed-off-by: "Connor Tsui" --- vortex-turboquant/Cargo.toml | 1 + vortex-turboquant/src/lib.rs | 2 + .../src/scalar_fns/compute/l2_norm.rs | 95 +++++++++++ .../src/scalar_fns/compute/mod.rs | 17 ++ vortex-turboquant/src/scalar_fns/mod.rs | 2 + vortex-turboquant/src/tests/kernels.rs | 160 ++++++++++++++++++ vortex-turboquant/src/tests/mod.rs | 12 ++ 7 files changed, 289 insertions(+) create mode 100644 vortex-turboquant/src/scalar_fns/compute/l2_norm.rs create mode 100644 vortex-turboquant/src/scalar_fns/compute/mod.rs create mode 100644 vortex-turboquant/src/tests/kernels.rs diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml index ab3f63583d3..55ada130236 100644 --- a/vortex-turboquant/Cargo.toml +++ b/vortex-turboquant/Cargo.toml @@ -32,6 +32,7 @@ vortex-utils = { workspace = true, features = ["dashmap"] } divan = { workspace = true } rand = { workspace = true } rstest = { workspace = true } +vortex-array = { path = "../vortex-array", features = ["_test-harness"] } vortex-file = { workspace = true } vortex-io = { workspace = true } vortex-layout = { workspace = true } diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs index 7aeb60368dd..4aabe08291c 100644 --- a/vortex-turboquant/src/lib.rs +++ b/vortex-turboquant/src/lib.rs @@ -75,6 +75,8 @@ pub fn initialize(session: &vortex_session::VortexSession) { session.scalar_fns().register(TQEncode); session.scalar_fns().register(TQDecode); + + scalar_fns::compute::register_kernels(session); } #[cfg(test)] diff --git a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs new file mode 100644 index 00000000000..9f9773cfe74 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! `L2Norm` execute-parent kernel that intercepts `L2Norm(TQDecode(tq))` and returns the stored +//! per-row norms directly instead of decoding and recomputing. + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFn; +use vortex_array::arrays::scalar_fn::ExactScalarFn; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::optimizer::kernels::ArrayKernelsExt; +use vortex_array::optimizer::kernels::ExecuteParentFn; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure_eq; +use vortex_session::VortexSession; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; + +use crate::TQDecode; +use crate::vector::storage::parse_storage; +use crate::vtable::TurboQuant; + +/// Register the `L2Norm(TQDecode(_))` execute-parent kernel on the session. +pub(super) fn register(session: &VortexSession) { + session.kernels().register_execute_parent( + L2Norm.id(), + TQDecode.id(), + &[l2_norm_tq_decode_execute_parent as ExecuteParentFn], + ); +} + +/// Intercepts `L2Norm(TQDecode(tq_arr))` and returns the stored TurboQuant `norms` field. +/// +/// The kernel only fires when both the parent matches `ExactScalarFn` and the child +/// matches `ExactScalarFn`. Returns `Ok(None)` for any other shape so the canonical +/// `L2Norm` path runs unchanged. +// +// TODO(vortex-data/vortex#TODO): The TurboQuant storage `norms` field is pre-quantization — it +// is the L2 norm of each original vector before SORF transform and scalar quantization. The +// lossy contract (see `vortex-turboquant/src/lib.rs`) means decoded vectors are not guaranteed +// to be unit-norm, so strictly `l2_norm(tq_decode(x))` may differ slightly from the stored +// norm. We treat the stored norms as authoritative here for parity with the `L2Denorm` fast +// path in `vortex-tensor/src/scalar_fns/l2_norm.rs`. A future fix should recompute norms +// post-quantization. +fn l2_norm_tq_decode_execute_parent( + child: &ArrayRef, + parent: &ArrayRef, + _child_idx: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + if !parent.is::>() { + return Ok(None); + } + if !child.is::>() { + return Ok(None); + } + + let tq_array = child.as_::().child_at(0).clone(); + + // Defensive: TQDecode's signature already guarantees this, but a misregistration or a + // future TQDecode that takes a wrapped child should fall back to the canonical path. + if tq_array + .dtype() + .as_extension_opt() + .and_then(|d| d.metadata_opt::()) + .is_none() + { + return Ok(None); + } + + let parsed = parse_storage(tq_array, ctx)?; + let norms_validity = parsed.norms.validity()?; + let norms = PrimitiveArray::from_buffer_handle( + parsed.norms.buffer_handle().clone(), + parsed.norms.ptype(), + norms_validity.and(parsed.vector_validity)?, + ) + .into_array(); + + vortex_ensure_eq!( + norms.dtype(), + parent.dtype(), + "TurboQuant norms field dtype must match L2Norm output dtype" + ); + vortex_ensure_eq!( + norms.len(), + parent.len(), + "TurboQuant norms field length must match L2Norm output length" + ); + + Ok(Some(norms)) +} diff --git a/vortex-turboquant/src/scalar_fns/compute/mod.rs b/vortex-turboquant/src/scalar_fns/compute/mod.rs new file mode 100644 index 00000000000..5f560f97eb5 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/compute/mod.rs @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant-specific session-scoped optimizer kernels. +//! +//! Each kernel module owns its own [`ArrayKernelsExt::register_execute_parent`] call. New +//! kernels (e.g. for `InnerProduct` or `CosineSimilarity`) should be added as sibling modules +//! and threaded through [`register_kernels`] with a single line. + +mod l2_norm; + +use vortex_session::VortexSession; + +/// Register every TurboQuant kernel on `session`. +pub(crate) fn register_kernels(session: &VortexSession) { + l2_norm::register(session); +} diff --git a/vortex-turboquant/src/scalar_fns/mod.rs b/vortex-turboquant/src/scalar_fns/mod.rs index 1acea9f70f3..3f233659d60 100644 --- a/vortex-turboquant/src/scalar_fns/mod.rs +++ b/vortex-turboquant/src/scalar_fns/mod.rs @@ -3,6 +3,8 @@ //! Scalar functions for lazy TurboQuant vector encode and decode operations. +pub(crate) mod compute; + mod decode; mod encode; mod metadata; diff --git a/vortex-turboquant/src/tests/kernels.rs b/vortex-turboquant/src/tests/kernels.rs new file mode 100644 index 00000000000..9cb33e97074 --- /dev/null +++ b/vortex-turboquant/src/tests/kernels.rs @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Tests for TurboQuant-specific session-scoped optimizer kernels. + +use rstest::rstest; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::assert_arrays_eq; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; + +use super::execute_tq_encode; +use super::f32_vector_array; +use super::tensor_test_session; +use super::vector_array; +use crate::TQDecode; +use crate::TurboQuantConfig; +use crate::vector::storage::parse_storage; + +const DIM: u32 = 128; + +/// Fast path: `L2Norm(TQDecode(tq_arr))` returns the storage `norms` field bit-for-bit. +/// +/// The slow path would recompute norms from lossily decoded vectors, which only approximately +/// match the stored norms. Bit-exact equality is the strongest invariant that confirms the +/// session-registered kernel fired. +#[test] +fn l2_norm_over_tq_decode_returns_stored_norms() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(DIM, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let expected_norms = parse_storage(encoded.clone(), &mut ctx)?.norms.into_array(); + + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + let row_count = decoded.len(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, row_count)? + .into_array() + .execute(&mut ctx)?; + + assert_arrays_eq!(result, expected_norms); + Ok(()) +} + +/// Negative: directly wrapping a `Vector` (no `TQDecode`) must hit the canonical `L2Norm` path. +/// +/// Proves the kernel only intercepts the matched `(L2Norm, TQDecode)` pair and does not affect +/// the standard tensor scalar-function flow. +#[test] +fn l2_norm_over_plain_vector_uses_canonical_path() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + + let input = vector_array( + 3, + &[ + 3.0f32, 4.0, 0.0, // row 0, norm = 5.0 + 0.0, 0.0, 0.0, // row 1, norm = 0.0 + 1.0, 0.0, 0.0, // row 2, norm = 1.0 + ], + Validity::NonNullable, + )?; + + let row_count = input.len(); + let result: PrimitiveArray = L2Norm::try_new_array(input, row_count)? + .into_array() + .execute(&mut ctx)?; + + let expected = + PrimitiveArray::new::(Buffer::copy_from([5.0f32, 0.0, 1.0]), Validity::NonNullable); + assert_arrays_eq!(result, expected); + Ok(()) +} + +/// Empty input: zero-length TurboQuant array still produces a zero-length norms array of the +/// matching primitive dtype. +#[test] +fn l2_norm_over_empty_tq_decode_is_empty_norms() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = vector_array::(DIM, &[], Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 0)? + .into_array() + .execute(&mut ctx)?; + + assert_eq!(result.len(), 0); + assert_eq!( + result.dtype(), + &DType::Primitive(PType::F32, Nullability::NonNullable) + ); + Ok(()) +} + +/// Null rows: the kernel must preserve the input's row-level validity and produce correct norms +/// for the non-null rows. +#[rstest] +#[case::leading_null(Validity::from_iter([false, true, true]))] +#[case::trailing_null(Validity::from_iter([true, true, false]))] +#[case::interior_null(Validity::from_iter([true, false, true]))] +fn l2_norm_over_tq_decode_preserves_nulls(#[case] validity: Validity) -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(DIM, 3, 0.25, validity)?; + let config = TurboQuantConfig::try_new(4, 7, 2)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let expected_norms = parse_storage(encoded.clone(), &mut ctx)?.norms.into_array(); + + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 3)? + .into_array() + .execute(&mut ctx)?; + + assert_arrays_eq!(result, expected_norms); + Ok(()) +} + +/// Masked input: generic masks narrow the TurboQuant storage struct validity without rewriting the +/// `norms` child, so the kernel must apply the authoritative struct validity before returning. +#[test] +fn l2_norm_over_masked_tq_decode_uses_storage_validity() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(DIM, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let masked = encoded.mask(BoolArray::from_iter([true, false, true, false]).into_array())?; + + let decoded = TQDecode::try_new_array(masked)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 4)? + .into_array() + .execute(&mut ctx)?; + let validity = result.validity()?.execute_mask(4, &mut ctx)?; + + assert!(validity.value(0)); + assert!(!validity.value(1)); + assert!(validity.value(2)); + assert!(!validity.value(3)); + assert_eq!( + result.dtype(), + &DType::Primitive(PType::F32, Nullability::Nullable) + ); + Ok(()) +} diff --git a/vortex-turboquant/src/tests/mod.rs b/vortex-turboquant/src/tests/mod.rs index ffa1db175a7..eb0f5da89b4 100644 --- a/vortex-turboquant/src/tests/mod.rs +++ b/vortex-turboquant/src/tests/mod.rs @@ -39,6 +39,7 @@ use crate::initialize; mod encode_decode; mod file; +mod kernels; mod malformed; mod metadata; mod parity; @@ -50,6 +51,17 @@ fn test_session() -> VortexSession { session } +/// In-memory session with both `vortex_tensor` and `vortex_turboquant` initialized. +/// +/// Tests that exercise tensor scalar functions over TurboQuant inputs need `L2Norm` registered +/// alongside the TurboQuant extension type and kernels. +fn tensor_test_session() -> VortexSession { + let session = VortexSession::empty().with::(); + vortex_tensor::initialize(&session); + initialize(&session); + session +} + fn file_session(runtime: &SingleThreadRuntime) -> VortexSession { let session = VortexSession::empty() .with::() From c49096c82e3dbeb7d0c46cd34fba9f82917e7710 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 21 May 2026 09:48:08 -0400 Subject: [PATCH 2/3] Preserve TurboQuant decoded vector norms Signed-off-by: "Connor Tsui" --- vortex-turboquant/src/lib.rs | 19 ++++- .../src/scalar_fns/compute/l2_norm.rs | 10 +-- vortex-turboquant/src/scalar_fns/decode.rs | 13 ++- vortex-turboquant/src/scalar_fns/encode.rs | 18 +++- vortex-turboquant/src/tests/encode_decode.rs | 66 +++++++++++++++ vortex-turboquant/src/tests/kernels.rs | 6 +- vortex-turboquant/src/tests/malformed.rs | 82 +++++++++++++++++-- vortex-turboquant/src/tests/metadata.rs | 7 +- vortex-turboquant/src/tests/parity.rs | 13 ++- vortex-turboquant/src/vector/quantize.rs | 71 ++++++++++++---- vortex-turboquant/src/vector/storage.rs | 60 +++++++++++--- vortex-turboquant/src/vtable.rs | 20 ++++- 12 files changed, 323 insertions(+), 62 deletions(-) diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs index 4aabe08291c..87a663bbfa9 100644 --- a/vortex-turboquant/src/lib.rs +++ b/vortex-turboquant/src/lib.rs @@ -19,8 +19,8 @@ //! The [`TQEncode`] scalar function first computes and stores the original L2 norm for each vector //! row, then normalizes each valid nonzero row internally before SORF transform and scalar //! quantization. The [`TQDecode`] scalar function dequantizes through deterministic centroids, -//! applies the inverse SORF transform, truncates back to the original dimension, and re-applies the -//! stored norm. +//! applies the inverse SORF transform, truncates back to the original dimension, and applies a +//! stored inverse direction-norm correction before re-applying the stored norm. //! //! The encoded storage is a row-aligned extension tree: //! @@ -28,13 +28,24 @@ //! Extension( //! Struct { //! norms: Primitive, +//! inv_direction_norms: Primitive, //! codes: FixedSizeList, padded_dim, vector_validity>, //! } //! ) //! ``` //! -//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Decoded quantized -//! directions are not guaranteed to have unit norm after scalar quantization and inverse transform. +//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Scalar quantization +//! perturbs the transformed unit vector, and inverse SORF plus truncation can leave the decoded +//! quantized direction with norm different from `1.0`. If decode only multiplied that direction by +//! the original row norm, `L2Norm(TQDecode(_))` would not equal the norm of the vector returned by +//! `TQDecode`. TurboQuant therefore stores `inv_direction_norms = 1 / ||decoded_direction||` so +//! decode can first renormalize the lossy quantized direction and then apply the original norm. +//! +//! Storing the correction also keeps future query kernels cheap. Inner product and cosine kernels can +//! rotate a query once and gather against centroids directly; the per-row scale they need is already +//! available as `norms * inv_direction_norms` for inner product and `inv_direction_norms` for cosine. +//! Without this field, those kernels would have to recompute the inverse SORF/truncated norm per row +//! or give up the `TQDecode` norm-preservation invariant. //! //! # Source map //! diff --git a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs index 9f9773cfe74..e770565c514 100644 --- a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs +++ b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs @@ -38,13 +38,9 @@ pub(super) fn register(session: &VortexSession) { /// matches `ExactScalarFn`. Returns `Ok(None)` for any other shape so the canonical /// `L2Norm` path runs unchanged. // -// TODO(vortex-data/vortex#TODO): The TurboQuant storage `norms` field is pre-quantization — it -// is the L2 norm of each original vector before SORF transform and scalar quantization. The -// lossy contract (see `vortex-turboquant/src/lib.rs`) means decoded vectors are not guaranteed -// to be unit-norm, so strictly `l2_norm(tq_decode(x))` may differ slightly from the stored -// norm. We treat the stored norms as authoritative here for parity with the `L2Denorm` fast -// path in `vortex-tensor/src/scalar_fns/l2_norm.rs`. A future fix should recompute norms -// post-quantization. +// This is semantically correct because TurboQuant stores per-row inverse direction norms and +// `TQDecode` applies that correction before re-applying the original row norm. In other words, +// valid nonzero decoded rows preserve the stored L2 norm even though coordinates are lossy. fn l2_norm_tq_decode_execute_parent( child: &ArrayRef, parent: &ArrayRef, diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs index 6791a1aef61..72193d31262 100644 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -153,9 +153,10 @@ impl ScalarFnVTable for TQDecode { /// Decode a `TurboQuant` extension array back into a `Vector` extension array. /// -/// The decoded directions are inverse-transformed, truncated to the original dimension, and -/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with -/// [`TQEncode`](crate::TQEncode). +/// The decoded directions are inverse-transformed, truncated to the original dimension, normalized +/// by the stored inverse direction norms, and multiplied by the stored row norms. The conversion is +/// lossy and does not roundtrip with [`TQEncode`](crate::TQEncode), but valid nonzero decoded rows +/// preserve the original stored L2 norm. pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { let parsed = parse_storage(input, ctx)?; let metadata = parsed.metadata; @@ -177,6 +178,7 @@ pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexRe sorf_matrix: &transform, centroids: ¢roids, norms: &parsed.norms, + inv_direction_norms: &parsed.inv_direction_norms, codes: &parsed.codes, }, parsed.vector_validity, @@ -217,6 +219,7 @@ struct DecodeInputs<'a> { centroids: &'a [f32], /// Per-row stored L2 norm of the original input vector, in the element ptype. norms: &'a PrimitiveArray, + inv_direction_norms: &'a PrimitiveArray, /// Flat per-row centroid indices, `num_vectors * padded_dim` bytes. codes: &'a PrimitiveArray, } @@ -236,6 +239,7 @@ where let padded_dim = decode.sorf_matrix.padded_dim(); let centroids = decode.centroids; let norms = decode.norms.as_slice::(); + let inv_direction_norms = decode.inv_direction_norms.as_slice::(); let codes = decode.codes.as_slice::(); let mask = vector_validity.execute_mask(num_vectors, ctx)?; @@ -259,11 +263,12 @@ where decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); let norm = norms[i]; + let inv_direction_norm = inv_direction_norms[i]; for &value in inverse.iter().take(dimensions) { // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, // `f64`): values outside `f16` range saturate to `±inf` rather than returning // `None`. - let value = T::from_f32(value) + let value = T::from_f32(value * inv_direction_norm) .vortex_expect("from_f32 is infallible for supported float types"); // SAFETY: total pushes across all match arms equal `output_len`. diff --git a/vortex-turboquant/src/scalar_fns/encode.rs b/vortex-turboquant/src/scalar_fns/encode.rs index 29ce7cc580a..7c2ad7ea30f 100644 --- a/vortex-turboquant/src/scalar_fns/encode.rs +++ b/vortex-turboquant/src/scalar_fns/encode.rs @@ -12,6 +12,7 @@ use vortex_array::IntoArray; use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; @@ -209,7 +210,14 @@ pub(crate) fn encode_vector( // SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child. unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? } }; - let codes = build_codes_child(num_vectors, core, vector_validity.clone())?; + let inv_direction_norms = + PrimitiveArray::new::(core.inv_direction_norms, vector_validity.clone()).into_array(); + let codes = build_codes_child( + num_vectors, + core.all_indices, + core.padded_dim, + vector_validity.clone(), + )?; let metadata = TurboQuantMetadata { element_ptype, @@ -218,7 +226,13 @@ pub(crate) fn encode_vector( seed: config.seed(), num_rounds: config.num_rounds(), }; - let storage = build_storage(norms, codes, num_vectors, vector_validity)?; + let storage = build_storage( + norms, + inv_direction_norms, + codes, + num_vectors, + vector_validity, + )?; Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array()) } diff --git a/vortex-turboquant/src/tests/encode_decode.rs b/vortex-turboquant/src/tests/encode_decode.rs index ed5aab190aa..73e6e3eedbf 100644 --- a/vortex-turboquant/src/tests/encode_decode.rs +++ b/vortex-turboquant/src/tests/encode_decode.rs @@ -16,10 +16,12 @@ use vortex_array::dtype::PType; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; use super::execute_tq_decode; use super::execute_tq_encode; use super::f32_vector_array; +use super::tensor_test_session; use super::test_session; use super::turboquant_storage; use super::vector_array; @@ -29,6 +31,7 @@ use super::vector_values_f32; use crate::TurboQuantConfig; use crate::centroids::compute_or_get_centroids; use crate::vector::normalize::tq_normalize_as_l2_denorm; +use crate::vector::storage::parse_storage; #[rstest] #[case::zero_bits(0, 42, 3)] @@ -105,6 +108,10 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { .unmasked_field_by_name("norms")? .clone() .execute(&mut ctx)?; + let inv_direction_norms: PrimitiveArray = storage + .unmasked_field_by_name("inv_direction_norms")? + .clone() + .execute(&mut ctx)?; let codes: FixedSizeListArray = storage .unmasked_field_by_name("codes")? .clone() @@ -114,13 +121,21 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { assert!(!mask.value(1)); assert!(mask.value(2)); assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); + assert_eq!( + inv_direction_norms.validity()?.nullability(), + Nullability::Nullable + ); assert_eq!(codes.validity()?.nullability(), Nullability::Nullable); let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?; + let inv_direction_norms_validity = inv_direction_norms.validity()?.execute_mask(3, &mut ctx)?; let codes_validity = codes.validity()?.execute_mask(3, &mut ctx)?; assert!(norms_validity.value(0)); assert!(!norms_validity.value(1)); assert!(norms_validity.value(2)); + assert!(inv_direction_norms_validity.value(0)); + assert!(!inv_direction_norms_validity.value(1)); + assert!(inv_direction_norms_validity.value(2)); assert!(codes_validity.value(0)); assert!(!codes_validity.value(1)); assert!(codes_validity.value(2)); @@ -134,6 +149,57 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { Ok(()) } +#[test] +fn encode_stores_zero_inv_direction_norm_for_zero_rows() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let mut values = vec![0.0f32; 3 * 128]; + values[0] = 3.0; + values[1] = 4.0; + values[256] = 1.0; + let input = vector_array(128, &values, Validity::NonNullable)?; + + let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; + let storage = turboquant_storage(encoded, &mut ctx)?; + let inv_direction_norms: PrimitiveArray = storage + .unmasked_field_by_name("inv_direction_norms")? + .clone() + .execute(&mut ctx)?; + + let values = inv_direction_norms.as_slice::(); + assert!(values[0].is_finite() && values[0] > 0.0); + assert_eq!(values[1], 0.0); + assert!(values[2].is_finite() && values[2] > 0.0); + Ok(()) +} + +#[test] +fn decode_preserves_original_l2_norms_for_non_power_of_two_dimensions() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(129, 3, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let expected_norms = parse_storage(encoded.clone(), &mut ctx)?.norms; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let decoded_norms: PrimitiveArray = L2Norm::try_new_array(decoded, 3)? + .into_array() + .execute(&mut ctx)?; + + for (actual, expected) in decoded_norms + .as_slice::() + .iter() + .zip(expected_norms.as_slice::()) + { + assert!( + (*actual - *expected).abs() <= 1e-4 * expected.max(1.0), + "decoded norm {actual} did not match stored norm {expected}" + ); + } + Ok(()) +} + #[test] fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> { let session = test_session(); diff --git a/vortex-turboquant/src/tests/kernels.rs b/vortex-turboquant/src/tests/kernels.rs index 9cb33e97074..c1120b64b83 100644 --- a/vortex-turboquant/src/tests/kernels.rs +++ b/vortex-turboquant/src/tests/kernels.rs @@ -30,9 +30,9 @@ const DIM: u32 = 128; /// Fast path: `L2Norm(TQDecode(tq_arr))` returns the storage `norms` field bit-for-bit. /// -/// The slow path would recompute norms from lossily decoded vectors, which only approximately -/// match the stored norms. Bit-exact equality is the strongest invariant that confirms the -/// session-registered kernel fired. +/// `TQDecode` applies the stored inverse direction-norm correction, so decoded vectors preserve +/// these norms. Bit-exact equality is the strongest invariant that confirms the session-registered +/// kernel fired instead of recomputing. #[test] fn l2_norm_over_tq_decode_returns_stored_norms() -> VortexResult<()> { let session = tensor_test_session(); diff --git a/vortex-turboquant/src/tests/malformed.rs b/vortex-turboquant/src/tests/malformed.rs index f99f0ee5105..6145cc4e5b2 100644 --- a/vortex-turboquant/src/tests/malformed.rs +++ b/vortex-turboquant/src/tests/malformed.rs @@ -23,21 +23,37 @@ use crate::TurboQuantMetadata; #[rstest] #[case::nullable_norms_under_nonnullable_struct( + Nullability::NonNullable, + Nullability::Nullable, + Nullability::NonNullable, + Nullability::NonNullable +)] +#[case::nullable_inv_direction_norms_under_nonnullable_struct( + Nullability::NonNullable, Nullability::NonNullable, Nullability::Nullable, Nullability::NonNullable )] #[case::nullable_codes_under_nonnullable_struct( + Nullability::NonNullable, Nullability::NonNullable, Nullability::NonNullable, Nullability::Nullable )] #[case::nonnullable_norms_under_nullable_struct( + Nullability::Nullable, + Nullability::NonNullable, + Nullability::Nullable, + Nullability::Nullable +)] +#[case::nonnullable_inv_direction_norms_under_nullable_struct( + Nullability::Nullable, Nullability::Nullable, Nullability::NonNullable, Nullability::Nullable )] #[case::nonnullable_codes_under_nullable_struct( + Nullability::Nullable, Nullability::Nullable, Nullability::Nullable, Nullability::NonNullable @@ -45,6 +61,7 @@ use crate::TurboQuantMetadata; fn decode_accepts_child_nullability_that_covers_struct_validity( #[case] struct_nullability: Nullability, #[case] norms_nullability: Nullability, + #[case] inv_direction_norms_nullability: Nullability, #[case] codes_nullability: Nullability, ) -> VortexResult<()> { let session = test_session(); @@ -59,6 +76,11 @@ fn decode_accepts_child_nullability_that_covers_struct_validity( let norms = PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::from(norms_nullability)) .into_array(); + let inv_direction_norms = PrimitiveArray::new::( + Buffer::copy_from([1.0]), + Validity::from(inv_direction_norms_nullability), + ) + .into_array(); let codes = PrimitiveArray::new::(vec![0u8; 128], Validity::NonNullable); let codes = FixedSizeListArray::try_new( codes.into_array(), @@ -69,8 +91,8 @@ fn decode_accepts_child_nullability_that_covers_struct_validity( .unwrap() .into_array(); let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], 1, Validity::from(struct_nullability), ) @@ -97,12 +119,15 @@ fn decode_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { let norms = PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) .into_array(); + let inv_direction_norms = + PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) + .into_array(); let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 3)? .into_array(); let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], 3, Validity::from_iter([true, false, true]), )?; @@ -133,6 +158,9 @@ fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResu Validity::from_iter([true, true, false]), ) .into_array(); + let inv_direction_norms = + PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) + .into_array(); let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); let codes = FixedSizeListArray::try_new( codes.into_array(), @@ -142,8 +170,44 @@ fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResu )? .into_array(); let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], + 3, + Validity::from_iter([true, false, true]), + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); + Ok(()) +} + +#[test] +fn decode_rejects_inv_direction_norm_masks_that_disagree_with_struct_validity() -> VortexResult<()> +{ + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) + .into_array(); + let inv_direction_norms = PrimitiveArray::new::( + Buffer::copy_from([1.0, 1.0, 1.0]), + Validity::from_iter([true, true, false]), + ) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 3)? + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], 3, Validity::from_iter([true, false, true]), )?; @@ -168,6 +232,8 @@ fn decode_panics_on_codes_outside_centroid_table() { }; let norms = PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::NonNullable).into_array(); + let inv_direction_norms = + PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::NonNullable).into_array(); let mut codes = vec![0u8; 128]; codes[0] = 2; let codes = PrimitiveArray::new::(codes, Validity::NonNullable); @@ -175,8 +241,8 @@ fn decode_panics_on_codes_outside_centroid_table() { .unwrap() .into_array(); let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], 1, Validity::NonNullable, ) diff --git a/vortex-turboquant/src/tests/metadata.rs b/vortex-turboquant/src/tests/metadata.rs index e0d1042f02f..db199bc12fb 100644 --- a/vortex-turboquant/src/tests/metadata.rs +++ b/vortex-turboquant/src/tests/metadata.rs @@ -18,6 +18,7 @@ use vortex_error::vortex_err; use crate::TurboQuant; use crate::TurboQuantMetadata; use crate::vector::storage::CODES_FIELD; +use crate::vector::storage::INV_DIRECTION_NORMS_FIELD; use crate::vector::storage::NORMS_FIELD; use crate::vector::tq_padded_dim; @@ -43,9 +44,10 @@ fn tq_storage_dtype( .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; Ok(DType::Struct( StructFields::new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), + FieldNames::from([NORMS_FIELD, INV_DIRECTION_NORMS_FIELD, CODES_FIELD]), vec![ DType::Primitive(metadata.element_ptype, row_nullability), + DType::Primitive(PType::F32, row_nullability), DType::FixedSizeList( Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), padded_dim, @@ -156,9 +158,10 @@ fn dtype_validation_rejects_malformed_storage() { }; let storage = DType::Struct( StructFields::new( - FieldNames::from(["norms", "codes"]), + FieldNames::from(["norms", "inv_direction_norms", "codes"]), vec![ DType::Primitive(PType::F32, Nullability::Nullable), + DType::Primitive(PType::F64, Nullability::Nullable), DType::FixedSizeList( DType::Primitive(PType::U8, Nullability::Nullable).into(), 128, diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs index 4360d90849d..305b96606d5 100644 --- a/vortex-turboquant/src/tests/parity.rs +++ b/vortex-turboquant/src/tests/parity.rs @@ -15,10 +15,11 @@ use super::vector_values_f32; use crate::TurboQuantConfig; #[test] -fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { +fn encode_decode_applies_direction_norm_correction_after_old_turboquant_decode() -> VortexResult<()> +{ let session = test_session(); let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 2, 0.125, Validity::NonNullable)?; + let input = f32_vector_array(129, 2, 0.125, Validity::NonNullable)?; let config = TurboQuantConfig::try_new(3, 42, 3)?; let new_encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; @@ -33,6 +34,12 @@ fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { let new_values = vector_values_f32(new_decoded, &mut ctx)?; let old_values = vector_values_f32(old_decoded, &mut ctx)?; - assert_eq!(new_values, old_values); + assert!( + new_values + .iter() + .zip(old_values.iter()) + .any(|(new, old)| (*new - *old).abs() > 1e-6), + "direction-norm correction should intentionally change decoded values" + ); Ok(()) } diff --git a/vortex-turboquant/src/vector/quantize.rs b/vortex-turboquant/src/vector/quantize.rs index 0861b9f6805..b7fbc8e1935 100644 --- a/vortex-turboquant/src/vector/quantize.rs +++ b/vortex-turboquant/src/vector/quantize.rs @@ -35,12 +35,14 @@ use crate::sorf::SorfMatrix; /// Shared intermediate results from the quantization loop. pub(crate) struct QuantizationResult { pub(crate) all_indices: Buffer, + pub(crate) inv_direction_norms: Buffer, pub(crate) padded_dim: usize, } pub(crate) fn empty_quantization(padded_dim: usize) -> QuantizationResult { QuantizationResult { all_indices: Buffer::empty(), + inv_direction_norms: Buffer::empty(), padded_dim, } } @@ -79,29 +81,60 @@ pub(crate) unsafe fn turboquant_quantize_core( .checked_mul(padded_dim) .ok_or_else(|| vortex_err!("TurboQuant codes length overflow"))?; let mut all_indices = BufferMut::::with_capacity(codes_len); + let mut inv_direction_norms = BufferMut::::with_capacity(num_vectors); let mut padded = vec![0.0f32; padded_dim]; let mut transformed = vec![0.0f32; padded_dim]; + let mut dequantized = vec![0.0f32; padded_dim]; + let mut inverse = vec![0.0f32; padded_dim]; // Pad, SORF-transform, and quantize a single row, pushing `padded_dim` codes into - // `all_indices`. Captures the read-only inputs and the scratch buffers so each call site - // only needs to pass `all_indices` and the row index. + // `all_indices` and one inverse direction norm into `inv_direction_norms`. Captures the + // read-only inputs and scratch buffers so each call site only needs to pass the output buffers + // and the row index. // // NB: `all_indices` cannot be captured here: the `Values` arm interleaves the closure call // with direct `all_indices.push_n_unchecked` calls. let f32_slice = f32_elements.as_slice(); let dimension = dimension as usize; - let mut quantize_row = |all_indices: &mut BufferMut, row: usize| { - // Reuse `padded` and `transformed` from the outer scope. - padded[..dimension].copy_from_slice(&f32_slice[row * dimension..][..dimension]); - padded[dimension..].fill(0.0); - sorf_transform.transform(&padded, &mut transformed); - - for &value in &transformed { - // SAFETY: total pushes across all match arms equal `codes_len`. - unsafe { all_indices.push_unchecked(find_nearest_centroid(value, &boundaries)) }; - } - }; + let mut quantize_row = + |all_indices: &mut BufferMut, inv_direction_norms: &mut BufferMut, row: usize| { + // Reuse `padded` and `transformed` from the outer scope. + let row_values = &f32_slice[row * dimension..][..dimension]; + padded[..dimension].copy_from_slice(row_values); + padded[dimension..].fill(0.0); + sorf_transform.transform(&padded, &mut transformed); + + for (&value, dst) in transformed.iter().zip(dequantized.iter_mut()) { + // SAFETY: total pushes across all match arms equal `codes_len`. + let code = find_nearest_centroid(value, &boundaries); + unsafe { all_indices.push_unchecked(code) }; + *dst = centroids[usize::from(code)]; + } + + // Quantization perturbs the unit direction. The exact norm that decode will return is + // the norm after dequantizing centroids, inverse-transforming, and truncating away any + // padded dimensions. Precomputing its reciprocal here lets `TQDecode` preserve the + // original row norm and lets future query kernels reuse the same per-row correction + // without repeating this inverse transform for every query. + let inv_direction_norm = if row_values.iter().all(|&value| value == 0.0) { + 0.0 + } else { + sorf_transform.inverse_transform(&dequantized, &mut inverse); + let norm_squared = inverse[..dimension] + .iter() + .map(|value| value * value) + .sum::(); + if norm_squared == 0.0 { + 0.0 + } else { + norm_squared.sqrt().recip() + } + }; + + // SAFETY: total pushes across all match arms equal `num_vectors`. + unsafe { inv_direction_norms.push_unchecked(inv_direction_norm) }; + }; // The total number of pushes is always exactly `num_vectors * padded_dim == codes_len` // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. @@ -112,10 +145,13 @@ pub(crate) unsafe fn turboquant_quantize_core( // SAFETY: `all_indices` was allocated with capacity `codes_len`, and this push // writes exactly `codes_len` zero codes. unsafe { all_indices.push_n_unchecked(0, codes_len) }; + // SAFETY: `inv_direction_norms` was allocated with capacity `num_vectors`, and this + // writes exactly `num_vectors` zero placeholders. + unsafe { inv_direction_norms.push_n_unchecked(0.0, num_vectors) }; } Mask::AllTrue(_) => { for row in 0..num_vectors { - quantize_row(&mut all_indices, row); + quantize_row(&mut all_indices, &mut inv_direction_norms, row); } } Mask::Values(values_mask) => { @@ -125,10 +161,12 @@ pub(crate) unsafe fn turboquant_quantize_core( if start > cursor { // SAFETY: total pushes across all arms equal `codes_len`. unsafe { all_indices.push_n_unchecked(0, (start - cursor) * padded_dim) }; + // SAFETY: total pushes across all arms equal `num_vectors`. + unsafe { inv_direction_norms.push_n_unchecked(0.0, start - cursor) }; } for row in start..end { - quantize_row(&mut all_indices, row); + quantize_row(&mut all_indices, &mut inv_direction_norms, row); } cursor = end; @@ -137,12 +175,15 @@ pub(crate) unsafe fn turboquant_quantize_core( if cursor < num_vectors { // SAFETY: total pushes across all arms equal `codes_len`. unsafe { all_indices.push_n_unchecked(0, (num_vectors - cursor) * padded_dim) }; + // SAFETY: total pushes across all arms equal `num_vectors`. + unsafe { inv_direction_norms.push_n_unchecked(0.0, num_vectors - cursor) }; } } } Ok(QuantizationResult { all_indices: all_indices.freeze(), + inv_direction_norms: inv_direction_norms.freeze(), padded_dim, }) } diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs index d1b4f06cc05..3461f389a97 100644 --- a/vortex-turboquant/src/vector/storage.rs +++ b/vortex-turboquant/src/vector/storage.rs @@ -8,14 +8,23 @@ //! ```text //! Struct { //! norms: Primitive, +//! inv_direction_norms: Primitive, //! codes: FixedSizeList, padded_dim, vector_validity>, //! } //! ``` //! -//! Row nullability is carried on the outer struct and on the `norms` and `codes` field arrays. -//! This is deliberate duplication: null vectors remain null throughout encode/decode instead of being -//! converted into zero vectors. The code bytes for invalid rows are physical placeholders only; the -//! field-level validity records that those rows were not quantized. +//! Row nullability is carried on the outer struct and on every row-aligned field array. This is +//! deliberate duplication: null vectors remain null throughout encode/decode instead of being +//! converted into zero vectors. The code bytes and inverse direction norms for invalid rows are +//! physical placeholders only; the field-level validity records that those rows were not quantized. +//! +//! `inv_direction_norms` is stored instead of recomputed by decode or every future query kernel +//! because it is a property of the quantized row, not of a particular operation. Computing it exactly +//! requires dequantizing the row's centroids, applying the inverse SORF, truncating back to the +//! original dimension, and taking the f32 L2 norm of that decoded direction. That is the same +//! per-row work decode already needs, and it would be repeated for every scan/query if we did not +//! persist the result. Storing one f32 per row makes the norm-preserving decode contract explicit +//! and gives query kernels a reusable scale factor. //! //! Parsing treats the outer struct validity as authoritative. Child validity may be wider than //! the struct validity (for example after a generic mask only updates the struct validity), but @@ -33,18 +42,21 @@ use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::struct_::StructArrayExt; use vortex_array::dtype::FieldNames; use vortex_array::validity::Validity; +use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_mask::Mask; -use super::quantize::QuantizationResult; use crate::vtable::TurboQuantMetadata; use crate::vtable::tq_metadata; /// Name of the stored row-norm child. pub(crate) const NORMS_FIELD: &str = "norms"; +/// Name of the stored inverse quantized-direction norm child. +pub(crate) const INV_DIRECTION_NORMS_FIELD: &str = "inv_direction_norms"; + /// Name of the stored quantized-code child. pub(crate) const CODES_FIELD: &str = "codes"; @@ -58,6 +70,7 @@ pub(crate) struct TurboQuantParsedStorage { pub(crate) vector_validity: Validity, /// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`. pub(crate) norms: PrimitiveArray, + pub(crate) inv_direction_norms: PrimitiveArray, /// Flat `u8` per-row centroid indices, `num_vectors * padded_dim` entries long. pub(crate) codes: PrimitiveArray, /// Row count. @@ -70,11 +83,12 @@ pub(crate) struct TurboQuantParsedStorage { /// from `(padded_dim, bit_width)`. The centroid values are intentionally not stored in the array. pub(crate) fn build_codes_child( num_vectors: usize, - quantization: QuantizationResult, + all_indices: Buffer, + padded_dim: usize, vector_validity: Validity, ) -> VortexResult { - let codes = PrimitiveArray::new::(quantization.all_indices, Validity::NonNullable); - let padded_dim_u32 = u32::try_from(quantization.padded_dim) + let codes = PrimitiveArray::new::(all_indices, Validity::NonNullable); + let padded_dim_u32 = u32::try_from(padded_dim) .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; Ok(FixedSizeListArray::try_new( @@ -86,16 +100,17 @@ pub(crate) fn build_codes_child( .into_array()) } -/// Build the TurboQuant `Struct { norms, codes }` storage array. +/// Build the TurboQuant `Struct { norms, inv_direction_norms, codes }` storage array. pub(crate) fn build_storage( norms: ArrayRef, + inv_direction_norms: ArrayRef, codes: ArrayRef, len: usize, vector_validity: Validity, ) -> VortexResult { Ok(StructArray::try_new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), - vec![norms, codes], + FieldNames::from([NORMS_FIELD, INV_DIRECTION_NORMS_FIELD, CODES_FIELD]), + vec![norms, inv_direction_norms, codes], len, vector_validity, )? @@ -116,6 +131,11 @@ pub(crate) fn parse_storage( .clone() .execute(ctx)?; + let inv_direction_norms: PrimitiveArray = storage + .unmasked_field_by_name(INV_DIRECTION_NORMS_FIELD)? + .clone() + .execute(ctx)?; + let codes_fsl: FixedSizeListArray = storage .unmasked_field_by_name(CODES_FIELD)? .clone() @@ -125,17 +145,25 @@ pub(crate) fn parse_storage( let len = storage.len(); let struct_validity = storage.struct_validity(); let norms_validity = norms.validity()?; + let inv_direction_norms_validity = inv_direction_norms.validity()?; let codes_validity = codes_fsl.validity()?; let struct_mask = struct_validity.execute_mask(len, ctx)?; let norms_mask = norms_validity.execute_mask(len, ctx)?; + let inv_direction_norms_mask = inv_direction_norms_validity.execute_mask(len, ctx)?; let codes_mask = codes_validity.execute_mask(len, ctx)?; - validate_child_validity_covers_struct(&struct_mask, &norms_mask, &codes_mask)?; + validate_child_validity_covers_struct( + &struct_mask, + &norms_mask, + &inv_direction_norms_mask, + &codes_mask, + )?; Ok(TurboQuantParsedStorage { metadata, vector_validity: struct_validity, norms, + inv_direction_norms, codes, len, }) @@ -150,12 +178,20 @@ pub(crate) fn parse_storage( fn validate_child_validity_covers_struct( struct_mask: &Mask, norms_mask: &Mask, + inv_direction_norms_mask: &Mask, codes_mask: &Mask, ) -> VortexResult<()> { vortex_ensure!( struct_mask.clone().bitand_not(norms_mask).all_false(), "TurboQuant {NORMS_FIELD} row validity must cover storage validity" ); + vortex_ensure!( + struct_mask + .clone() + .bitand_not(inv_direction_norms_mask) + .all_false(), + "TurboQuant {INV_DIRECTION_NORMS_FIELD} row validity must cover storage validity" + ); vortex_ensure!( struct_mask.clone().bitand_not(codes_mask).all_false(), "TurboQuant {CODES_FIELD} row validity must cover storage validity" diff --git a/vortex-turboquant/src/vtable.rs b/vortex-turboquant/src/vtable.rs index 854bcee6c70..34896cf4396 100644 --- a/vortex-turboquant/src/vtable.rs +++ b/vortex-turboquant/src/vtable.rs @@ -23,6 +23,7 @@ use vortex_error::vortex_err; use crate::TurboQuantConfig; use crate::config::MIN_DIMENSION; use crate::vector::storage::CODES_FIELD; +use crate::vector::storage::INV_DIRECTION_NORMS_FIELD; use crate::vector::storage::NORMS_FIELD; use crate::vector::tq_padded_dim; @@ -148,9 +149,10 @@ pub(crate) fn tq_storage_dtype( Ok(DType::Struct( StructFields::new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), + FieldNames::from([NORMS_FIELD, INV_DIRECTION_NORMS_FIELD, CODES_FIELD]), vec![ DType::Primitive(metadata.element_ptype, row_nullability), + DType::Primitive(PType::F32, row_nullability), DType::FixedSizeList( Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), padded_dim, @@ -187,7 +189,7 @@ fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> Vo let DType::Struct(fields, _) = dtype else { vortex_bail!("TurboQuant storage dtype must be a Struct, got {dtype}"); }; - let expected_names = FieldNames::from([NORMS_FIELD, CODES_FIELD]); + let expected_names = FieldNames::from([NORMS_FIELD, INV_DIRECTION_NORMS_FIELD, CODES_FIELD]); vortex_ensure_eq!( fields.names(), &expected_names, @@ -208,6 +210,20 @@ fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> Vo metadata.element_ptype ); + let Some(inv_direction_norms_dtype) = fields.field(INV_DIRECTION_NORMS_FIELD) else { + vortex_bail!("TurboQuant storage missing {INV_DIRECTION_NORMS_FIELD} field"); + }; + let DType::Primitive(inv_direction_norms_ptype, _) = inv_direction_norms_dtype else { + vortex_bail!( + "TurboQuant {INV_DIRECTION_NORMS_FIELD} field must be primitive, got {inv_direction_norms_dtype}" + ); + }; + vortex_ensure_eq!( + inv_direction_norms_ptype, + PType::F32, + "TurboQuant {INV_DIRECTION_NORMS_FIELD} ptype must be f32, got {inv_direction_norms_ptype}" + ); + let Some(codes_dtype) = fields.field(CODES_FIELD) else { vortex_bail!("TurboQuant storage missing {CODES_FIELD} field"); }; From 353a0ee8cf5d667a3f23b5f2a89dd484567e0b26 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 21 May 2026 15:21:29 -0400 Subject: [PATCH 3/3] Harden TurboQuant L2Norm fast path and document the crate Fix two correctness bugs in the L2Norm(TQDecode(_)) fast path. (1) The kernel coerced the returned norms to the child's nullability rather than the parent's, so wider-child-validity storage shapes that parse_storage accepts errored out at the dtype invariant. The kernel now coerces to parent.dtype().nullability() and a new test mirrors the malformed.rs shape. (2) The per-row inv_direction_norm computation could store a 0.0 sentinel for finite rows whose squared sum overflowed to +inf in f32 (or a +inf for denormal norm_squared), making decode emit zeros while the kernel returned the nonzero stored norm. Encode now rejects non-finite input norms up front and the denormal recip is guarded by is_normal(); regression tests cover both cases. Several should-fix items go with the must-fix: parse_storage_norms_only lets the kernel skip executing the codes and inv_direction_norms children it does not consume; the parity test pins down the exact new = old * inv_direction_norm[row] relationship rather than asserting "the values differ"; file roundtrip now asserts the new field survives serialization and the kernel still preserves stored norms; tests are parameterized over f16/f32/f64 and across padded vs unpadded dimensions; the kernel result is cross-checked against canonical L2Norm of the materialized decode. The hypothetical defensive metadata check on the kernel is dropped (registry key plus TQDecode signature already enforce shape). The dev-dep on vortex-array switches to workspace = true to match sibling encodings. Over-long doc lines are reflowed. Every type in the crate now has a doc comment, emphasizing the new inv_direction_norms storage child and the 0.0 sentinel semantics. Module docs single-source the storage schema rationale in storage.rs; lib.rs and the scalar-fn modules defer to it. Verified: cargo check, cargo clippy --all-targets --all-features, cargo +nightly fmt --all --check, cargo doc --no-deps, and cargo nextest run (102 tests, +14 new) all clean. Signed-off-by: "Connor Tsui" --- vortex-turboquant/Cargo.toml | 2 +- vortex-turboquant/src/lib.rs | 31 ++--- .../src/scalar_fns/compute/l2_norm.rs | 40 +++--- .../src/scalar_fns/compute/mod.rs | 13 +- vortex-turboquant/src/scalar_fns/decode.rs | 2 + vortex-turboquant/src/tests/encode_decode.rs | 124 ++++++++++++++++++ vortex-turboquant/src/tests/file.rs | 51 +++++++ vortex-turboquant/src/tests/kernels.rs | 91 +++++++++++++ vortex-turboquant/src/tests/parity.rs | 35 ++++- vortex-turboquant/src/vector/normalize.rs | 18 +++ vortex-turboquant/src/vector/quantize.rs | 26 ++-- vortex-turboquant/src/vector/storage.rs | 71 ++++++++-- vortex-turboquant/src/vtable.rs | 3 + 13 files changed, 436 insertions(+), 71 deletions(-) diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml index 55ada130236..708abf5948c 100644 --- a/vortex-turboquant/Cargo.toml +++ b/vortex-turboquant/Cargo.toml @@ -32,7 +32,7 @@ vortex-utils = { workspace = true, features = ["dashmap"] } divan = { workspace = true } rand = { workspace = true } rstest = { workspace = true } -vortex-array = { path = "../vortex-array", features = ["_test-harness"] } +vortex-array = { workspace = true, features = ["_test-harness"] } vortex-file = { workspace = true } vortex-io = { workspace = true } vortex-layout = { workspace = true } diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs index 87a663bbfa9..e2f3c9329dc 100644 --- a/vortex-turboquant/src/lib.rs +++ b/vortex-turboquant/src/lib.rs @@ -34,31 +34,24 @@ //! ) //! ``` //! -//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Scalar quantization -//! perturbs the transformed unit vector, and inverse SORF plus truncation can leave the decoded -//! quantized direction with norm different from `1.0`. If decode only multiplied that direction by -//! the original row norm, `L2Norm(TQDecode(_))` would not equal the norm of the vector returned by -//! `TQDecode`. TurboQuant therefore stores `inv_direction_norms = 1 / ||decoded_direction||` so -//! decode can first renormalize the lossy quantized direction and then apply the original norm. -//! -//! Storing the correction also keeps future query kernels cheap. Inner product and cosine kernels can -//! rotate a query once and gather against centroids directly; the per-row scale they need is already -//! available as `norms * inv_direction_norms` for inner product and `inv_direction_norms` for cosine. -//! Without this field, those kernels would have to recompute the inverse SORF/truncated norm per row -//! or give up the `TQDecode` norm-preservation invariant. +//! Stored norms are authoritative for future TurboQuant-aware scalar functions. The rationale +//! for the `inv_direction_norms` correction field lives next to the storage layout; see +//! `vector/storage.rs`. //! //! # Source map //! //! Implementation details are documented next to the code that owns them: //! -//! - `vector/storage.rs`: physical storage shape, full-length child arrays, and field-level -//! validity for null vectors. -//! - `vector/normalize.rs`: TurboQuant-local normalization and how it differs from the tensor -//! crate's null-row zeroing helper. -//! - `vector/quantize.rs`: SORF transform, centroid lookup, and why invalid rows are skipped rather -//! than quantized. +//! - `vector/storage.rs`: physical storage shape and parsing. +//! - `vector/normalize.rs`: TurboQuant-local normalization and the encode-time finite-norm +//! guard. +//! - `vector/quantize.rs`: SORF transform, centroid lookup, and the per-row +//! `inv_direction_norm` computation. +//! - `scalar_fns/compute/`: session-scoped optimizer kernels that intercept canonical scalar +//! functions over TurboQuant inputs (currently `L2Norm(TQDecode(_))`). //! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching. -//! - `sorf/`: the Walsh-Hadamard-based structured transform and the stable SplitMix64 sign stream. +//! - `sorf/`: Walsh-Hadamard-based structured transform plus the stable SplitMix64 sign +//! stream. //! //! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL //! residual correction for unbiased inner-product estimation, and it still uses internal diff --git a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs index e770565c514..0948e200fac 100644 --- a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs +++ b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs @@ -11,17 +11,18 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFn; use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::dtype::Nullability; use vortex_array::optimizer::kernels::ArrayKernelsExt; use vortex_array::optimizer::kernels::ExecuteParentFn; use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::validity::Validity; use vortex_error::VortexResult; use vortex_error::vortex_ensure_eq; use vortex_session::VortexSession; use vortex_tensor::scalar_fns::l2_norm::L2Norm; use crate::TQDecode; -use crate::vector::storage::parse_storage; -use crate::vtable::TurboQuant; +use crate::vector::storage::parse_storage_norms_only; /// Register the `L2Norm(TQDecode(_))` execute-parent kernel on the session. pub(super) fn register(session: &VortexSession) { @@ -34,13 +35,14 @@ pub(super) fn register(session: &VortexSession) { /// Intercepts `L2Norm(TQDecode(tq_arr))` and returns the stored TurboQuant `norms` field. /// -/// The kernel only fires when both the parent matches `ExactScalarFn` and the child -/// matches `ExactScalarFn`. Returns `Ok(None)` for any other shape so the canonical -/// `L2Norm` path runs unchanged. -// -// This is semantically correct because TurboQuant stores per-row inverse direction norms and -// `TQDecode` applies that correction before re-applying the original row norm. In other words, -// valid nonzero decoded rows preserve the stored L2 norm even though coordinates are lossy. +/// Semantically valid because [`TQDecode`] renormalizes the lossy quantized direction with the +/// stored inverse direction-norm before re-applying the original row norm, so decoded rows +/// preserve the stored L2 norm. The kernel returns `Ok(None)` for any non-matching parent / +/// child pair so the canonical `L2Norm` path runs unchanged. +/// +/// The result's nullability is coerced to the parent's expected dtype because the stored +/// `norms` child may be wider than the outer struct (a shape [`parse_storage_norms_only`] +/// accepts). fn l2_norm_tq_decode_execute_parent( child: &ArrayRef, parent: &ArrayRef, @@ -55,24 +57,16 @@ fn l2_norm_tq_decode_execute_parent( } let tq_array = child.as_::().child_at(0).clone(); + let parsed = parse_storage_norms_only(tq_array, ctx)?; - // Defensive: TQDecode's signature already guarantees this, but a misregistration or a - // future TQDecode that takes a wrapped child should fall back to the canonical path. - if tq_array - .dtype() - .as_extension_opt() - .and_then(|d| d.metadata_opt::()) - .is_none() - { - return Ok(None); - } - - let parsed = parse_storage(tq_array, ctx)?; - let norms_validity = parsed.norms.validity()?; + let norms_validity = match parent.dtype().nullability() { + Nullability::NonNullable => Validity::NonNullable, + Nullability::Nullable => parsed.vector_validity, + }; let norms = PrimitiveArray::from_buffer_handle( parsed.norms.buffer_handle().clone(), parsed.norms.ptype(), - norms_validity.and(parsed.vector_validity)?, + norms_validity, ) .into_array(); diff --git a/vortex-turboquant/src/scalar_fns/compute/mod.rs b/vortex-turboquant/src/scalar_fns/compute/mod.rs index 5f560f97eb5..7dbc9fb9412 100644 --- a/vortex-turboquant/src/scalar_fns/compute/mod.rs +++ b/vortex-turboquant/src/scalar_fns/compute/mod.rs @@ -3,15 +3,20 @@ //! TurboQuant-specific session-scoped optimizer kernels. //! -//! Each kernel module owns its own [`ArrayKernelsExt::register_execute_parent`] call. New -//! kernels (e.g. for `InnerProduct` or `CosineSimilarity`) should be added as sibling modules -//! and threaded through [`register_kernels`] with a single line. +//! Each kernel module owns its own +//! [`register_execute_parent`](vortex_array::optimizer::kernels::ArrayKernelsExt::register_execute_parent) +//! call. New kernels (for example `InnerProduct` or `CosineSimilarity`) should be added as +//! sibling modules and threaded through [`register_kernels`] with a single line. mod l2_norm; use vortex_session::VortexSession; -/// Register every TurboQuant kernel on `session`. +/// Register every TurboQuant-specific optimizer kernel on `session`. +/// +/// Called from the crate-level [`crate::initialize`] after the TurboQuant extension type and +/// the `TQEncode` / `TQDecode` scalar functions are registered, so kernels can resolve the +/// scalar-fn ids they intercept. pub(crate) fn register_kernels(session: &VortexSession) { l2_norm::register(session); } diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs index 72193d31262..332aaca1a36 100644 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -219,6 +219,8 @@ struct DecodeInputs<'a> { centroids: &'a [f32], /// Per-row stored L2 norm of the original input vector, in the element ptype. norms: &'a PrimitiveArray, + /// Per-row reciprocal of the decoded direction's L2 norm, always in f32. See + /// [`crate::vector::storage`] for the sentinel semantics. inv_direction_norms: &'a PrimitiveArray, /// Flat per-row centroid indices, `num_vectors * padded_dim` bytes. codes: &'a PrimitiveArray, diff --git a/vortex-turboquant/src/tests/encode_decode.rs b/vortex-turboquant/src/tests/encode_decode.rs index 73e6e3eedbf..36d691a3f3a 100644 --- a/vortex-turboquant/src/tests/encode_decode.rs +++ b/vortex-turboquant/src/tests/encode_decode.rs @@ -200,6 +200,130 @@ fn decode_preserves_original_l2_norms_for_non_power_of_two_dimensions() -> Vorte Ok(()) } +/// Encode rejects rows whose L2 norm is non-finite. Without this guard, a row whose squared +/// sum overflows would normalize to all-zero placeholders and decode-vs-kernel would silently +/// diverge (`NaN` vs `+inf`). +#[test] +fn encode_rejects_non_finite_norms() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + + // A row of `1e30` repeated `dim=128` times has squared sum `128 * 1e60 ≈ 1.28e62`, which + // overflows `f32` (max ≈ 3.4e38) and produces `+inf` when `L2Norm` runs in `f32`. + let values = vec![1e30f32; 128]; + let input = vector_array(128, &values, Validity::NonNullable)?; + + let result = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx); + assert!( + result.is_err(), + "encode must reject non-finite norms (overflow case)" + ); + let error = result.err().unwrap().to_string(); + assert!( + error.contains("non-finite"), + "expected non-finite error, got: {error}" + ); + Ok(()) +} + +/// Encode rejects rows containing `NaN` values, which propagate through `L2Norm` to produce +/// a `NaN` stored norm. +#[test] +fn encode_rejects_nan_input() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + + let mut values = vec![1.0f32; 128]; + values[0] = f32::NAN; + let input = vector_array(128, &values, Validity::NonNullable)?; + + let result = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx); + assert!(result.is_err(), "encode must reject NaN input rows"); + Ok(()) +} + +/// Decode preserves stored L2 norms across element ptypes and padded/unpadded dimensions. +#[rstest] +#[case::f16_dim_128(PType::F16, 128_u32, 1e-2_f32)] +#[case::f16_dim_129(PType::F16, 129_u32, 1e-2_f32)] +#[case::f32_dim_128(PType::F32, 128_u32, 1e-4_f32)] +#[case::f32_dim_129(PType::F32, 129_u32, 1e-4_f32)] +#[case::f32_dim_257(PType::F32, 257_u32, 1e-4_f32)] +#[case::f64_dim_128(PType::F64, 128_u32, 1e-4_f32)] +#[case::f64_dim_129(PType::F64, 129_u32, 1e-4_f32)] +fn decode_preserves_original_l2_norms_across_ptypes_and_dims( + #[case] ptype: PType, + #[case] dim: u32, + #[case] tolerance: f32, +) -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let rows = 3; + let raw = (0..rows * dim as usize) + .map(|i| (i % 17) as f32 - 8.0) + .map(|v| v * 0.25) + .collect::>(); + let input = match ptype { + PType::F16 => { + let values: Vec = raw.iter().copied().map(half::f16::from_f32).collect(); + vector_array(dim, &values, Validity::NonNullable)? + } + PType::F32 => vector_array(dim, &raw, Validity::NonNullable)?, + PType::F64 => { + let values: Vec = raw.iter().copied().map(f64::from).collect(); + vector_array(dim, &values, Validity::NonNullable)? + } + _ => unreachable!("ptype must be float"), + }; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let decoded_norms: PrimitiveArray = L2Norm::try_new_array(decoded, rows)? + .into_array() + .execute(&mut ctx)?; + + // L2Norm returns the element ptype; widen to f32 for comparison. + let actuals: Vec = match ptype { + PType::F16 => decoded_norms + .as_slice::() + .iter() + .map(|v| f32::from(*v)) + .collect(), + PType::F32 => decoded_norms.as_slice::().to_vec(), + PType::F64 => decoded_norms + .as_slice::() + .iter() + .map(|v| { + #[expect( + clippy::cast_possible_truncation, + reason = "norms are bounded by the test's input magnitudes (~|raw| * dim^0.5), \ + well within f32 range" + )] + let widened = *v as f32; + widened + }) + .collect(), + _ => unreachable!(), + }; + + // Recompute expected from the raw f32 input to avoid coupling to internal storage. + let expected: Vec = (0..rows) + .map(|i| { + let row = &raw[i * dim as usize..][..dim as usize]; + row.iter().map(|v| v * v).sum::().sqrt() + }) + .collect(); + + for (actual, exp) in actuals.iter().zip(expected.iter()) { + assert!( + (*actual - *exp).abs() <= tolerance * exp.max(1.0), + "decoded norm {actual} did not match expected {exp} (ptype {ptype:?}, dim {dim})" + ); + } + Ok(()) +} + #[test] fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> { let session = test_session(); diff --git a/vortex-turboquant/src/tests/file.rs b/vortex-turboquant/src/tests/file.rs index e59b7a95c75..2c6b0184460 100644 --- a/vortex-turboquant/src/tests/file.rs +++ b/vortex-turboquant/src/tests/file.rs @@ -3,6 +3,7 @@ use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::arrays::PrimitiveArray; use vortex_array::stream::ArrayStreamExt; use vortex_array::validity::Validity; use vortex_error::VortexResult; @@ -10,6 +11,7 @@ use vortex_file::OpenOptionsSessionExt; use vortex_file::VortexWriteOptions; use vortex_io::runtime::BlockingRuntime; use vortex_io::runtime::single::SingleThreadRuntime; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; use vortex_tensor::vector::Vector; use super::execute_tq_decode_from_metadata; @@ -19,6 +21,7 @@ use super::file_session; use super::vector_validity; use crate::TQDecode; use crate::TurboQuantConfig; +use crate::vector::storage::parse_storage; use crate::vtable::tq_metadata; #[test] @@ -46,6 +49,54 @@ fn file_roundtrip_with_initialize_session() -> VortexResult<()> { Ok(()) } +/// File-roundtrip preserves `inv_direction_norms` and the `L2Norm(TQDecode(_))` fast-path +/// invariant. A regression that silently dropped the field at serialization would only show +/// up downstream as norm divergence; this test surfaces it at the IO layer. +#[test] +fn file_roundtrip_preserves_inv_direction_norms_and_l2_norm_invariant() -> VortexResult<()> { + let runtime = SingleThreadRuntime::default(); + let session = file_session(&runtime); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let original_norms: PrimitiveArray = parse_storage(encoded.clone(), &mut ctx)?.norms; + + let mut file_bytes = Vec::new(); + VortexWriteOptions::new(session.clone()) + .blocking(&runtime) + .write(&mut file_bytes, encoded.to_array_iterator())?; + + let file = session.open_options().open_buffer(file_bytes)?; + let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; + + // The inv_direction_norms field must survive serialization with finite-positive values for + // every valid row. + let parsed = parse_storage(read.clone(), &mut ctx)?; + let inv_direction_norms = parsed.inv_direction_norms.as_slice::(); + assert_eq!(inv_direction_norms.len(), 4); + for &v in inv_direction_norms { + assert!( + v.is_finite() && v > 0.0, + "inv_direction_norm {v} after file roundtrip is not finite-positive" + ); + } + + // Fast-path `L2Norm(TQDecode(_))` must still return the originally stored row norms after + // the file roundtrip. If the kernel or the `inv_direction_norms` field had silently broken + // at serialization, this is where it would surface. + let decoded = TQDecode::try_new_array(read)?.into_array(); + let kernel_norms: PrimitiveArray = L2Norm::try_new_array(decoded, 4)? + .into_array() + .execute(&mut ctx)?; + assert_eq!( + kernel_norms.as_slice::(), + original_norms.as_slice::(), + "L2Norm(TQDecode(read_back)) must equal the originally stored row norms" + ); + Ok(()) +} + #[test] fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResult<()> { let runtime = SingleThreadRuntime::default(); diff --git a/vortex-turboquant/src/tests/kernels.rs b/vortex-turboquant/src/tests/kernels.rs index c1120b64b83..cf58bab352b 100644 --- a/vortex-turboquant/src/tests/kernels.rs +++ b/vortex-turboquant/src/tests/kernels.rs @@ -7,10 +7,14 @@ use rstest::rstest; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; use vortex_array::assert_arrays_eq; use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; +use vortex_array::dtype::FieldNames; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::validity::Validity; @@ -18,12 +22,15 @@ use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_tensor::scalar_fns::l2_norm::L2Norm; +use super::execute_tq_decode; use super::execute_tq_encode; use super::f32_vector_array; use super::tensor_test_session; use super::vector_array; use crate::TQDecode; +use crate::TurboQuant; use crate::TurboQuantConfig; +use crate::TurboQuantMetadata; use crate::vector::storage::parse_storage; const DIM: u32 = 128; @@ -158,3 +165,87 @@ fn l2_norm_over_masked_tq_decode_uses_storage_validity() -> VortexResult<()> { ); Ok(()) } + +/// Regression for the wider-child-nullability shape (`Nullable` `norms` with `AllValid` under +/// a `NonNullable` struct). `parse_storage` accepts it; the kernel must return a `NonNullable` +/// result rather than reusing the wider child validity. See `malformed.rs` for the matching +/// decode-side cases. +#[test] +fn l2_norm_over_tq_decode_nullable_norms_under_nonnullable_struct() -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: DIM, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + + let norms = + PrimitiveArray::new::(Buffer::copy_from([5.0f32]), Validity::AllValid).into_array(); + let inv_direction_norms = + PrimitiveArray::new::(Buffer::copy_from([1.0f32]), Validity::AllValid).into_array(); + let codes = PrimitiveArray::new::(vec![0u8; DIM as usize], Validity::NonNullable); + let codes = + FixedSizeListArray::try_new(codes.into_array(), DIM, Validity::AllValid, 1)?.into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "inv_direction_norms", "codes"]), + vec![norms, inv_direction_norms, codes], + 1, + Validity::NonNullable, + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + let decoded = TQDecode::try_new_array(tq)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 1)? + .into_array() + .execute(&mut ctx)?; + + assert_eq!( + result.dtype(), + &DType::Primitive(PType::F32, Nullability::NonNullable), + "kernel result dtype must match parent (NonNullable), not the wider child validity" + ); + assert_eq!(result.as_slice::(), &[5.0f32]); + Ok(()) +} + +/// Cross-check the kernel result against the canonical `L2Norm(execute(TQDecode))` path. +/// Materializing the decoded vector first breaks the `(L2Norm, TQDecode)` pattern so `L2Norm` +/// runs through the canonical scalar-function flow. +#[rstest] +#[case::dim_128(128_u32)] +#[case::dim_129(129_u32)] +#[case::dim_257(257_u32)] +fn l2_norm_over_tq_decode_matches_canonical(#[case] dim: u32) -> VortexResult<()> { + let session = tensor_test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(dim, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + + let kernel_result: PrimitiveArray = + L2Norm::try_new_array(TQDecode::try_new_array(encoded.clone())?.into_array(), 4)? + .into_array() + .execute(&mut ctx)?; + + // Materialize the decoded vector first so `L2Norm` cannot match `(L2Norm, TQDecode)`. The + // resulting `L2Norm(Vector)` flows through the canonical scalar-function path. + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let canonical_result: PrimitiveArray = L2Norm::try_new_array(decoded, 4)? + .into_array() + .execute(&mut ctx)?; + + let kernel = kernel_result.as_slice::(); + let canonical = canonical_result.as_slice::(); + for (k, c) in kernel.iter().zip(canonical.iter()) { + assert!( + (*k - *c).abs() <= 1e-4 * c.max(1.0), + "kernel result {k} disagrees with canonical {c} (dim {dim})" + ); + } + Ok(()) +} diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs index 305b96606d5..e65f132ab57 100644 --- a/vortex-turboquant/src/tests/parity.rs +++ b/vortex-turboquant/src/tests/parity.rs @@ -13,7 +13,12 @@ use super::f32_vector_array; use super::test_session; use super::vector_values_f32; use crate::TurboQuantConfig; +use crate::vector::storage::parse_storage; +/// Pins down the exact relationship between new and legacy TurboQuant decode: for each row, +/// `new_value[i] == old_value[i] * inv_direction_norm[row]`. The centroid table and SORF +/// transform are identical between the two encoders, so the inverse-transformed direction is +/// the same; the only mathematical difference is the per-row scalar correction. #[test] fn encode_decode_applies_direction_norm_correction_after_old_turboquant_decode() -> VortexResult<()> { @@ -23,6 +28,9 @@ fn encode_decode_applies_direction_norm_correction_after_old_turboquant_decode() let config = TurboQuantConfig::try_new(3, 42, 3)?; let new_encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; + let parsed = parse_storage(new_encoded.clone(), &mut ctx)?; + let inv_direction_norms = parsed.inv_direction_norms.as_slice::().to_vec(); + let new_decoded = execute_tq_decode(new_encoded, &mut ctx)?; let old_config = OldTurboQuantConfig { bit_width: config.bit_width(), @@ -33,13 +41,30 @@ fn encode_decode_applies_direction_norm_correction_after_old_turboquant_decode() let new_values = vector_values_f32(new_decoded, &mut ctx)?; let old_values = vector_values_f32(old_decoded, &mut ctx)?; + let dim = new_values.len() / inv_direction_norms.len(); + // Every coordinate of the new decode should equal the corresponding coordinate of the old + // decode, multiplied by the row's stored `inv_direction_norm`. Tolerance is per-element, + // scaled by the larger of the two values to handle exact zeros gracefully. + for (row, &correction) in inv_direction_norms.iter().enumerate() { + for col in 0..dim { + let idx = row * dim + col; + let new_v = new_values[idx]; + let old_v = old_values[idx]; + let expected = old_v * correction; + let scale = new_v.abs().max(expected.abs()).max(1.0); + assert!( + (new_v - expected).abs() <= 1e-4 * scale, + "row {row} col {col}: new {new_v} != old {old_v} * inv_direction_norm \ + {correction} (= {expected})" + ); + } + } + // Sanity: the correction is meaningfully non-trivial for at least one row (verifies the + // direction-norm field is actually doing work, not a no-op). assert!( - new_values - .iter() - .zip(old_values.iter()) - .any(|(new, old)| (*new - *old).abs() > 1e-6), - "direction-norm correction should intentionally change decoded values" + inv_direction_norms.iter().any(|&c| (c - 1.0).abs() > 1e-4), + "inv_direction_norms should differ from 1.0 for at least one row" ); Ok(()) } diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs index 642949eecf6..05eadb3792c 100644 --- a/vortex-turboquant/src/vector/normalize.rs +++ b/vortex-turboquant/src/vector/normalize.rs @@ -24,6 +24,7 @@ use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; use vortex_mask::Mask; @@ -102,6 +103,23 @@ where let values = elements.as_slice::(); let norm_values = norms.as_slice::(); + // Reject non-finite norms up front. A `+inf` or `NaN` norm would either come from an input + // row whose sum of squares overflowed `T` or from a pre-existing `NaN` in the data. In + // either case the f32-precision SORF transform downstream cannot represent the row, and + // letting the division `value / norm` proceed silently corrupts encoded data: the + // normalized row becomes all zeros, encode infers a zero-norm row, and the stored + // `inv_direction_norm` sentinel disagrees with the non-finite stored norm at decode time. + // Fail fast instead. Invalid-row norms can carry arbitrary placeholders so they are + // excluded from the check via the row-validity mask. + let mask_for_check = mask; + for (i, &norm) in norm_values.iter().enumerate() { + if mask_for_check.value(i) && !norm.is_finite() { + vortex_bail!( + "TurboQuant input row {i} has non-finite L2 norm; encode requires finite norms" + ); + } + } + let output_len = num_vectors .checked_mul(dimensions) .ok_or_else(|| vortex_err!("TurboQuant normalized vector length overflow"))?; diff --git a/vortex-turboquant/src/vector/quantize.rs b/vortex-turboquant/src/vector/quantize.rs index b7fbc8e1935..0521a6021df 100644 --- a/vortex-turboquant/src/vector/quantize.rs +++ b/vortex-turboquant/src/vector/quantize.rs @@ -32,13 +32,20 @@ use crate::centroids::compute_or_get_centroids; use crate::centroids::find_nearest_centroid; use crate::sorf::SorfMatrix; -/// Shared intermediate results from the quantization loop. +/// Intermediate output from the quantization loop, consumed by `encode_vector` to assemble +/// the storage struct. Invalid rows hold zero placeholders in both buffers. pub(crate) struct QuantizationResult { + /// Flat `padded_dim`-strided centroid indices, `num_vectors * padded_dim` entries. pub(crate) all_indices: Buffer, + /// Per-row reciprocal L2 norm of the decoded quantized direction. See the comment inside + /// [`turboquant_quantize_core`] for the `0.0` sentinel cases. pub(crate) inv_direction_norms: Buffer, + /// SORF padded dimension, `next_power_of_two(dimensions)`. pub(crate) padded_dim: usize, } +/// Build an empty [`QuantizationResult`] for a zero-row input, so the SORF machinery does not +/// run with a zero-length elements buffer. pub(crate) fn empty_quantization(padded_dim: usize) -> QuantizationResult { QuantizationResult { all_indices: Buffer::empty(), @@ -112,11 +119,12 @@ pub(crate) unsafe fn turboquant_quantize_core( *dst = centroids[usize::from(code)]; } - // Quantization perturbs the unit direction. The exact norm that decode will return is - // the norm after dequantizing centroids, inverse-transforming, and truncating away any - // padded dimensions. Precomputing its reciprocal here lets `TQDecode` preserve the - // original row norm and lets future query kernels reuse the same per-row correction - // without repeating this inverse transform for every query. + // The all-zero `row_values` check fires only for valid zero-norm rows (the + // normalize step pushes zero placeholders for those; non-finite input norms are + // rejected earlier). The `is_normal` guard handles the remaining numerical edge: + // a denormal `norm_squared` would produce a huge-or-infinite `recip` that decode + // would propagate as `+inf` / `NaN`. Both cases store `0.0` so decode emits a + // zero row, matching the stored norm. let inv_direction_norm = if row_values.iter().all(|&value| value == 0.0) { 0.0 } else { @@ -125,10 +133,10 @@ pub(crate) unsafe fn turboquant_quantize_core( .iter() .map(|value| value * value) .sum::(); - if norm_squared == 0.0 { - 0.0 - } else { + if norm_squared.is_normal() { norm_squared.sqrt().recip() + } else { + 0.0 } }; diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs index 3461f389a97..ad24916a759 100644 --- a/vortex-turboquant/src/vector/storage.rs +++ b/vortex-turboquant/src/vector/storage.rs @@ -13,18 +13,15 @@ //! } //! ``` //! -//! Row nullability is carried on the outer struct and on every row-aligned field array. This is +//! `inv_direction_norms` is pinned to `f32` regardless of `element_ptype` because the SORF +//! transform and the centroid codebook are both `f32`; storing it wider would add precision the +//! underlying computation does not have. +//! +//! Row nullability is carried on the outer struct AND on every row-aligned field array. This is //! deliberate duplication: null vectors remain null throughout encode/decode instead of being //! converted into zero vectors. The code bytes and inverse direction norms for invalid rows are -//! physical placeholders only; the field-level validity records that those rows were not quantized. -//! -//! `inv_direction_norms` is stored instead of recomputed by decode or every future query kernel -//! because it is a property of the quantized row, not of a particular operation. Computing it exactly -//! requires dequantizing the row's centroids, applying the inverse SORF, truncating back to the -//! original dimension, and taking the f32 L2 norm of that decoded direction. That is the same -//! per-row work decode already needs, and it would be repeated for every scan/query if we did not -//! persist the result. Storing one f32 per row makes the norm-preserving decode contract explicit -//! and gives query kernels a reusable scale factor. +//! physical placeholders only; the field-level validity records that those rows were not +//! quantized. //! //! Parsing treats the outer struct validity as authoritative. Child validity may be wider than //! the struct validity (for example after a generic mask only updates the struct validity), but @@ -70,6 +67,12 @@ pub(crate) struct TurboQuantParsedStorage { pub(crate) vector_validity: Validity, /// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`. pub(crate) norms: PrimitiveArray, + /// Per-row reciprocal L2 norm of the decoded direction (always `f32`). Multiplied through + /// in `TQDecode` so that `L2Norm(TQDecode(_))` preserves the stored row norm. A stored + /// `0.0` is a sentinel telling decode to emit an all-zero row; it pairs with a stored + /// norm of `0.0` for valid zero-norm input rows and for the rare denormal-cancellation + /// case (encode rejects non-finite input norms up front, so those cannot reach this + /// field). pub(crate) inv_direction_norms: PrimitiveArray, /// Flat `u8` per-row centroid indices, `num_vectors * padded_dim` entries long. pub(crate) codes: PrimitiveArray, @@ -77,6 +80,16 @@ pub(crate) struct TurboQuantParsedStorage { pub(crate) len: usize, } +/// Subset of [`TurboQuantParsedStorage`] containing only the `norms` child plus the outer +/// struct validity. Used by the `L2Norm(TQDecode(_))` execute-parent kernel, which has no need +/// for the `codes` or `inv_direction_norms` children. +pub(crate) struct TurboQuantParsedNorms { + /// Authoritative row validity for the quantized vectors. + pub(crate) vector_validity: Validity, + /// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`. + pub(crate) norms: PrimitiveArray, +} + /// Build the `codes: FixedSizeList, padded_dim>` storage child. /// /// Each row of `padded_dim` u8 codes indexes into the deterministic centroid codebook derived @@ -118,6 +131,11 @@ pub(crate) fn build_storage( } /// Parse a TurboQuant extension array into executed storage children. +/// +/// Executes all three storage children, validates that every child's row validity covers the +/// outer struct validity, and returns the parsed result. Used by `TQDecode`, which needs every +/// child. Kernels that only need a subset should use a narrower helper (for example +/// [`parse_storage_norms_only`]) to avoid executing the children they will not consume. pub(crate) fn parse_storage( input: ArrayRef, ctx: &mut ExecutionCtx, @@ -169,6 +187,39 @@ pub(crate) fn parse_storage( }) } +/// Narrow form of [`parse_storage`] that returns only the `norms` child plus the outer struct +/// validity. Used by the `L2Norm(TQDecode(_))` kernel so the fast path does not execute the +/// `codes` and `inv_direction_norms` children it has no use for. The `norms` child's validity +/// is still validated against the struct's; the other children are not. +pub(crate) fn parse_storage_norms_only( + input: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let ext: ExtensionArray = input.execute(ctx)?; + let storage: StructArray = ext.storage_array().clone().execute(ctx)?; + + let norms: PrimitiveArray = storage + .unmasked_field_by_name(NORMS_FIELD)? + .clone() + .execute(ctx)?; + + let len = storage.len(); + let struct_validity = storage.struct_validity(); + let norms_validity = norms.validity()?; + + let struct_mask = struct_validity.execute_mask(len, ctx)?; + let norms_mask = norms_validity.execute_mask(len, ctx)?; + vortex_ensure!( + struct_mask.bitand_not(&norms_mask).all_false(), + "TurboQuant {NORMS_FIELD} row validity must cover storage validity" + ); + + Ok(TurboQuantParsedNorms { + vector_validity: struct_validity, + norms, + }) +} + /// Validate that both child masks cover the struct mask: every row that the struct considers /// valid must also be valid in the `norms` and `codes` children. /// diff --git a/vortex-turboquant/src/vtable.rs b/vortex-turboquant/src/vtable.rs index 34896cf4396..4dc5370cc92 100644 --- a/vortex-turboquant/src/vtable.rs +++ b/vortex-turboquant/src/vtable.rs @@ -140,6 +140,9 @@ pub(crate) fn tq_metadata(dtype: &DType) -> VortexResult { Ok(*metadata) } +/// Compute the storage [`DType`] for a TurboQuant array with the given `metadata` and outer +/// `row_nullability`. Returns the `Struct { norms, inv_direction_norms, codes }` shape +/// documented in [`crate::vector::storage`]. pub(crate) fn tq_storage_dtype( metadata: &TurboQuantMetadata, row_nullability: Nullability,