Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vortex-turboquant/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ vortex-utils = { workspace = true, features = ["dashmap"] }
divan = { workspace = true }
rand = { workspace = true }
rstest = { workspace = true }
vortex-array = { workspace = true, features = ["_test-harness"] }
vortex-file = { workspace = true }
vortex-io = { workspace = true }
vortex-layout = { workspace = true }
Expand Down
28 changes: 17 additions & 11 deletions vortex-turboquant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,39 @@
//! The [`TQEncode`] scalar function first computes and stores the original L2 norm for each vector
//! row, then normalizes each valid nonzero row internally before SORF transform and scalar
//! quantization. The [`TQDecode`] scalar function dequantizes through deterministic centroids,
//! applies the inverse SORF transform, truncates back to the original dimension, and re-applies the
//! stored norm.
//! applies the inverse SORF transform, truncates back to the original dimension, and applies a
//! stored inverse direction-norm correction before re-applying the stored norm.
//!
//! The encoded storage is a row-aligned extension tree:
//!
//! ```text
//! Extension<TurboQuant>(
//! Struct {
//! norms: Primitive<element_ptype, vector_validity>,
//! inv_direction_norms: Primitive<f32, vector_validity>,
//! codes: FixedSizeList<Primitive<u8>, padded_dim, vector_validity>,
//! }
//! )
//! ```
//!
//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Decoded quantized
//! directions are not guaranteed to have unit norm after scalar quantization and inverse transform.
//! Stored norms are authoritative for future TurboQuant-aware scalar functions. The rationale
//! for the `inv_direction_norms` correction field lives next to the storage layout; see
//! `vector/storage.rs`.
//!
//! # Source map
//!
//! Implementation details are documented next to the code that owns them:
//!
//! - `vector/storage.rs`: physical storage shape, full-length child arrays, and field-level
//! validity for null vectors.
//! - `vector/normalize.rs`: TurboQuant-local normalization and how it differs from the tensor
//! crate's null-row zeroing helper.
//! - `vector/quantize.rs`: SORF transform, centroid lookup, and why invalid rows are skipped rather
//! than quantized.
//! - `vector/storage.rs`: physical storage shape and parsing.
//! - `vector/normalize.rs`: TurboQuant-local normalization and the encode-time finite-norm
//! guard.
//! - `vector/quantize.rs`: SORF transform, centroid lookup, and the per-row
//! `inv_direction_norm` computation.
//! - `scalar_fns/compute/`: session-scoped optimizer kernels that intercept canonical scalar
//! functions over TurboQuant inputs (currently `L2Norm(TQDecode(_))`).
//! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching.
//! - `sorf/`: the Walsh-Hadamard-based structured transform and the stable SplitMix64 sign stream.
//! - `sorf/`: Walsh-Hadamard-based structured transform plus the stable SplitMix64 sign
//! stream.
//!
//! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL
//! residual correction for unbiased inner-product estimation, and it still uses internal
Expand Down Expand Up @@ -75,6 +79,8 @@ pub fn initialize(session: &vortex_session::VortexSession) {

session.scalar_fns().register(TQEncode);
session.scalar_fns().register(TQDecode);

scalar_fns::compute::register_kernels(session);
}

#[cfg(test)]
Expand Down
85 changes: 85 additions & 0 deletions vortex-turboquant/src/scalar_fns/compute/l2_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! `L2Norm` execute-parent kernel that intercepts `L2Norm(TQDecode(tq))` and returns the stored
//! per-row norms directly instead of decoding and recomputing.

use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFn;
use vortex_array::arrays::scalar_fn::ExactScalarFn;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_array::dtype::Nullability;
use vortex_array::optimizer::kernels::ArrayKernelsExt;
use vortex_array::optimizer::kernels::ExecuteParentFn;
use vortex_array::scalar_fn::ScalarFnVTable;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure_eq;
use vortex_session::VortexSession;
use vortex_tensor::scalar_fns::l2_norm::L2Norm;

use crate::TQDecode;
use crate::vector::storage::parse_storage_norms_only;

/// Register the `L2Norm(TQDecode(_))` execute-parent kernel on the session.
pub(super) fn register(session: &VortexSession) {
session.kernels().register_execute_parent(
L2Norm.id(),
TQDecode.id(),
&[l2_norm_tq_decode_execute_parent as ExecuteParentFn],
);
}

/// Intercepts `L2Norm(TQDecode(tq_arr))` and returns the stored TurboQuant `norms` field.
///
/// Semantically valid because [`TQDecode`] renormalizes the lossy quantized direction with the
/// stored inverse direction-norm before re-applying the original row norm, so decoded rows
/// preserve the stored L2 norm. The kernel returns `Ok(None)` for any non-matching parent /
/// child pair so the canonical `L2Norm` path runs unchanged.
///
/// The result's nullability is coerced to the parent's expected dtype because the stored
/// `norms` child may be wider than the outer struct (a shape [`parse_storage_norms_only`]
/// accepts).
fn l2_norm_tq_decode_execute_parent(
child: &ArrayRef,
parent: &ArrayRef,
_child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
if !parent.is::<ExactScalarFn<L2Norm>>() {
return Ok(None);
}
if !child.is::<ExactScalarFn<TQDecode>>() {
return Ok(None);
}

let tq_array = child.as_::<ScalarFn>().child_at(0).clone();
let parsed = parse_storage_norms_only(tq_array, ctx)?;

let norms_validity = match parent.dtype().nullability() {
Nullability::NonNullable => Validity::NonNullable,
Nullability::Nullable => parsed.vector_validity,
};
let norms = PrimitiveArray::from_buffer_handle(
parsed.norms.buffer_handle().clone(),
parsed.norms.ptype(),
norms_validity,
)
.into_array();

vortex_ensure_eq!(
norms.dtype(),
parent.dtype(),
"TurboQuant norms field dtype must match L2Norm output dtype"
);
vortex_ensure_eq!(
norms.len(),
parent.len(),
"TurboQuant norms field length must match L2Norm output length"
);

Ok(Some(norms))
}
22 changes: 22 additions & 0 deletions vortex-turboquant/src/scalar_fns/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! TurboQuant-specific session-scoped optimizer kernels.
//!
//! Each kernel module owns its own
//! [`register_execute_parent`](vortex_array::optimizer::kernels::ArrayKernelsExt::register_execute_parent)
//! call. New kernels (for example `InnerProduct` or `CosineSimilarity`) should be added as
//! sibling modules and threaded through [`register_kernels`] with a single line.

mod l2_norm;

use vortex_session::VortexSession;

/// Register every TurboQuant-specific optimizer kernel on `session`.
///
/// Called from the crate-level [`crate::initialize`] after the TurboQuant extension type and
/// the `TQEncode` / `TQDecode` scalar functions are registered, so kernels can resolve the
/// scalar-fn ids they intercept.
pub(crate) fn register_kernels(session: &VortexSession) {
l2_norm::register(session);
}
15 changes: 11 additions & 4 deletions vortex-turboquant/src/scalar_fns/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ impl ScalarFnVTable for TQDecode {

/// Decode a `TurboQuant` extension array back into a `Vector` extension array.
///
/// The decoded directions are inverse-transformed, truncated to the original dimension, and
/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with
/// [`TQEncode`](crate::TQEncode).
/// The decoded directions are inverse-transformed, truncated to the original dimension, normalized
/// by the stored inverse direction norms, and multiplied by the stored row norms. The conversion is
/// lossy and does not roundtrip with [`TQEncode`](crate::TQEncode), but valid nonzero decoded rows
/// preserve the original stored L2 norm.
pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
let parsed = parse_storage(input, ctx)?;
let metadata = parsed.metadata;
Expand All @@ -177,6 +178,7 @@ pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexRe
sorf_matrix: &transform,
centroids: &centroids,
norms: &parsed.norms,
inv_direction_norms: &parsed.inv_direction_norms,
codes: &parsed.codes,
},
parsed.vector_validity,
Expand Down Expand Up @@ -217,6 +219,9 @@ struct DecodeInputs<'a> {
centroids: &'a [f32],
/// Per-row stored L2 norm of the original input vector, in the element ptype.
norms: &'a PrimitiveArray,
/// Per-row reciprocal of the decoded direction's L2 norm, always in f32. See
/// [`crate::vector::storage`] for the sentinel semantics.
inv_direction_norms: &'a PrimitiveArray,
/// Flat per-row centroid indices, `num_vectors * padded_dim` bytes.
codes: &'a PrimitiveArray,
}
Expand All @@ -236,6 +241,7 @@ where
let padded_dim = decode.sorf_matrix.padded_dim();
let centroids = decode.centroids;
let norms = decode.norms.as_slice::<T>();
let inv_direction_norms = decode.inv_direction_norms.as_slice::<f32>();
let codes = decode.codes.as_slice::<u8>();
let mask = vector_validity.execute_mask(num_vectors, ctx)?;

Expand All @@ -259,11 +265,12 @@ where
decode.sorf_matrix.inverse_transform(&decoded, &mut inverse);

let norm = norms[i];
let inv_direction_norm = inv_direction_norms[i];
for &value in inverse.iter().take(dimensions) {
// `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`,
// `f64`): values outside `f16` range saturate to `±inf` rather than returning
// `None`.
let value = T::from_f32(value)
let value = T::from_f32(value * inv_direction_norm)
.vortex_expect("from_f32 is infallible for supported float types");

// SAFETY: total pushes across all match arms equal `output_len`.
Expand Down
18 changes: 16 additions & 2 deletions vortex-turboquant/src/scalar_fns/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use vortex_array::IntoArray;
use vortex_array::arrays::Extension;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeListArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::extension::ExtensionArrayExt;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
Expand Down Expand Up @@ -209,7 +210,14 @@ pub(crate) fn encode_vector(
// SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child.
unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? }
};
let codes = build_codes_child(num_vectors, core, vector_validity.clone())?;
let inv_direction_norms =
PrimitiveArray::new::<f32>(core.inv_direction_norms, vector_validity.clone()).into_array();
let codes = build_codes_child(
num_vectors,
core.all_indices,
core.padded_dim,
vector_validity.clone(),
)?;

let metadata = TurboQuantMetadata {
element_ptype,
Expand All @@ -218,7 +226,13 @@ pub(crate) fn encode_vector(
seed: config.seed(),
num_rounds: config.num_rounds(),
};
let storage = build_storage(norms, codes, num_vectors, vector_validity)?;
let storage = build_storage(
norms,
inv_direction_norms,
codes,
num_vectors,
vector_validity,
)?;

Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array())
}
2 changes: 2 additions & 0 deletions vortex-turboquant/src/scalar_fns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

//! Scalar functions for lazy TurboQuant vector encode and decode operations.

pub(crate) mod compute;

mod decode;
mod encode;
mod metadata;
Expand Down
Loading
Loading