Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub vortex_tensor::encodings::turboquant::TurboQuantConfig::bit_width: u8

pub vortex_tensor::encodings::turboquant::TurboQuantConfig::num_rounds: u8

pub vortex_tensor::encodings::turboquant::TurboQuantConfig::seed: core::option::Option<u64>
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::seed: u64

impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantConfig

Expand Down Expand Up @@ -440,11 +440,11 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::padded_dim(&self)

pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::rotate(&self, input: &[f32], output: &mut [f32])

pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::try_new(seed: u64, dimension: usize, num_rounds: usize) -> vortex_error::VortexResult<Self>
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::try_new(seed: u64, dimensions: usize, num_rounds: usize) -> vortex_error::VortexResult<Self>

pub struct vortex_tensor::scalar_fns::sorf_transform::SorfOptions

pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::dimension: u32
pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::dimensions: u32

pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::element_ptype: vortex_array::dtype::ptype::PType

Expand Down Expand Up @@ -490,7 +490,7 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::clone(&self) ->

impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform

pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>

pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

Expand Down
24 changes: 12 additions & 12 deletions vortex-tensor/src/encodings/turboquant/centroids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Buffer<f32>>> = LazyLock::new
/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar
/// quantization levels for the coordinate distribution after random rotation in
/// `dimension`-dimensional space.
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
pub fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
vortex_ensure!(
(1..=MAX_BIT_WIDTH).contains(&bit_width),
"TurboQuant bit_width must be 1-{}, got {bit_width}",
Expand Down Expand Up @@ -239,7 +239,7 @@ mod tests {
#[case] bits: u8,
#[case] expected: usize,
) -> VortexResult<()> {
let centroids = get_centroids(dim, bits)?;
let centroids = compute_or_get_centroids(dim, bits)?;
assert_eq!(centroids.len(), expected);
Ok(())
}
Expand All @@ -251,7 +251,7 @@ mod tests {
#[case(128, 4)]
#[case(768, 2)]
fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
let centroids = get_centroids(dim, bits)?;
let centroids = compute_or_get_centroids(dim, bits)?;
for window in centroids.windows(2) {
assert!(
window[0] < window[1],
Expand All @@ -268,7 +268,7 @@ mod tests {
#[case(256, 2)]
#[case(768, 2)]
fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
let centroids = get_centroids(dim, bits)?;
let centroids = compute_or_get_centroids(dim, bits)?;
let count = centroids.len();
for idx in 0..count / 2 {
let diff = (centroids[idx] + centroids[count - 1 - idx]).abs();
Expand All @@ -287,7 +287,7 @@ mod tests {
#[case(128, 1)]
#[case(128, 4)]
fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
let centroids = get_centroids(dim, bits)?;
let centroids = compute_or_get_centroids(dim, bits)?;
for &val in centroids.iter() {
assert!(
(-1.0..=1.0).contains(&val),
Expand All @@ -299,15 +299,15 @@ mod tests {

#[test]
fn centroids_cached() -> VortexResult<()> {
let c1 = get_centroids(128, 2)?;
let c2 = get_centroids(128, 2)?;
let c1 = compute_or_get_centroids(128, 2)?;
let c2 = compute_or_get_centroids(128, 2)?;
assert_eq!(c1, c2);
Ok(())
}

#[test]
fn find_nearest_basic() -> VortexResult<()> {
let centroids = get_centroids(128, 2)?;
let centroids = compute_or_get_centroids(128, 2)?;
let boundaries = compute_centroid_boundaries(&centroids);
assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0);

Expand All @@ -324,9 +324,9 @@ mod tests {

#[test]
fn rejects_invalid_params() {
assert!(get_centroids(128, 0).is_err());
assert!(get_centroids(128, 9).is_err());
assert!(get_centroids(1, 2).is_err());
assert!(get_centroids(127, 2).is_err());
assert!(compute_or_get_centroids(128, 0).is_err());
assert!(compute_or_get_centroids(128, 9).is_err());
assert!(compute_or_get_centroids(1, 2).is_err());
assert!(compute_or_get_centroids(127, 2).is_err());
}
}
16 changes: 8 additions & 8 deletions vortex-tensor/src/encodings/turboquant/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use vortex_error::vortex_ensure;
use crate::encodings::turboquant::MAX_BIT_WIDTH;
use crate::encodings::turboquant::MIN_DIMENSION;
use crate::encodings::turboquant::centroids::compute_centroid_boundaries;
use crate::encodings::turboquant::centroids::compute_or_get_centroids;
use crate::encodings::turboquant::centroids::find_nearest_centroid;
use crate::encodings::turboquant::centroids::get_centroids;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
use crate::scalar_fns::sorf_transform::SorfMatrix;
Expand All @@ -48,8 +48,8 @@ use crate::utils::cast_to_f32;
pub struct TurboQuantConfig {
/// Bits per coordinate (1-8).
pub bit_width: u8,
/// Optional seed for the rotation matrix. If None, the default seed is used.
pub seed: Option<u64>,
/// Seed for the rotation matrix.
pub seed: u64,
/// Number of sign-diagonal + WHT rounds in the structured rotation (default 3).
pub num_rounds: u8,
}
Expand All @@ -58,7 +58,7 @@ impl Default for TurboQuantConfig {
fn default() -> Self {
Self {
bit_width: MAX_BIT_WIDTH,
seed: Some(42),
seed: 42,
num_rounds: 3,
}
}
Expand Down Expand Up @@ -141,7 +141,7 @@ pub unsafe fn turboquant_encode_unchecked(
let vector_metadata = ext_dtype.as_extension().metadata::<AnyVector>();
let element_ptype = vector_metadata.element_ptype();

let seed = config.seed.unwrap_or(42);
let seed = config.seed;
let num_rows = fsl.len();

if fsl.is_empty() {
Expand All @@ -161,7 +161,7 @@ pub unsafe fn turboquant_encode_unchecked(
let sorf_options = SorfOptions {
seed,
num_rounds: config.num_rounds,
dimension,
dimensions: dimension,
element_ptype,
};
return Ok(
Expand All @@ -177,7 +177,7 @@ pub unsafe fn turboquant_encode_unchecked(
let sorf_options = SorfOptions {
seed,
num_rounds: config.num_rounds,
dimension,
dimensions: dimension,
element_ptype,
};
Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array())
Expand Down Expand Up @@ -213,7 +213,7 @@ fn turboquant_quantize_core(
let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
let f32_elements = cast_to_f32(elements_prim)?;

let centroids = get_centroids(padded_dim_u32, bit_width)?;
let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?;
let boundaries = compute_centroid_boundaries(&centroids);

let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);
Expand Down
2 changes: 1 addition & 1 deletion vortex-tensor/src/encodings/turboquant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
//! // Normalize and quantize at 2 bits per coordinate in one pass.
//! let session = VortexSession::empty().with::<ArraySession>();
//! let mut ctx = session.create_execution_ctx();
//! let config = TurboQuantConfig { bit_width: 2, seed: Some(42), num_rounds: 3 };
//! let config = TurboQuantConfig { bit_width: 2, seed: 42, num_rounds: 3 };
//! let tq = turboquant_encode(vector, &config, &mut ctx).unwrap();
//!
//! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input.
Expand Down
18 changes: 11 additions & 7 deletions vortex-tensor/src/encodings/turboquant/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,25 @@ impl Scheme for TurboQuantScheme {
fn estimate_compression_ratio(element_bit_width: u8, dimensions: u32, num_vectors: usize) -> f64 {
let config = TurboQuantConfig::default();
let padded_dim = dimensions.next_power_of_two() as usize;
let element_bits = usize::from(element_bit_width);

// Per-vector: MSE codes per padded coordinate, plus one stored norm in the input element
// float width.
let compressed_bits_per_vector =
usize::from(element_bit_width) + usize::from(config.bit_width) * padded_dim;
// Get the size of the fully uncompressed vector data.
let uncompressed_size_bits = element_bits * dimensions as usize * num_vectors;

// Per-vector: MSE codes per padded coordinate, plus one stored norm in the input element float
// width.
let norm_bits = element_bits;
let compressed_bits_per_vector = usize::from(config.bit_width) * padded_dim;
let total_bits_per_vector = norm_bits + compressed_bits_per_vector;

// Shared overhead: codebook centroids (2^bit_width f32 values).
// Note: rotation signs are no longer stored — rotation is deterministic from seed.
let num_centroids = 1usize << config.bit_width;
debug_assert!(num_centroids <= MAX_CENTROIDS);
let overhead_bits = num_centroids * 32; // centroids are always f32

let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits;
// This includes the quantized vectors, norms, and centroid codebook.
let compressed_size_bits = total_bits_per_vector * num_vectors + overhead_bits;

let uncompressed_size_bits = usize::from(element_bit_width) * dimensions as usize * num_vectors;
uncompressed_size_bits as f64 / compressed_size_bits as f64
}

Expand Down
10 changes: 5 additions & 5 deletions vortex-tensor/src/encodings/turboquant/tests/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn slice_preserves_data() -> VortexResult<()> {
let ext = make_vector_ext(&fsl);
let config = TurboQuantConfig {
bit_width: 3,
seed: Some(123),
seed: 123,
num_rounds: 4,
};
let mut ctx = SESSION.create_execution_ctx();
Expand Down Expand Up @@ -85,7 +85,7 @@ fn scalar_at_matches_decompress() -> VortexResult<()> {
let ext = make_vector_ext(&fsl);
let config = TurboQuantConfig {
bit_width: 3,
seed: Some(123),
seed: 123,
num_rounds: 2,
};
let mut ctx = SESSION.create_execution_ctx();
Expand All @@ -108,7 +108,7 @@ fn l2_norm_readthrough() -> VortexResult<()> {
let ext = make_vector_ext(&fsl);
let config = TurboQuantConfig {
bit_width: 3,
seed: Some(123),
seed: 123,
num_rounds: 5,
};
let mut ctx = SESSION.create_execution_ctx();
Expand Down Expand Up @@ -146,7 +146,7 @@ fn l2_norm_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()>
let ext = make_vector_ext(&fsl);
let config = TurboQuantConfig {
bit_width: 1,
seed: Some(123),
seed: 123,
num_rounds: 3,
};
let mut ctx = SESSION.create_execution_ctx();
Expand Down Expand Up @@ -183,7 +183,7 @@ fn cosine_similarity_readthrough_is_authoritative_for_lossy_storage() -> VortexR
let ext = make_vector_ext(&fsl);
let config = TurboQuantConfig {
bit_width: 1,
seed: Some(123),
seed: 123,
num_rounds: 3,
};
let mut ctx = SESSION.create_execution_ctx();
Expand Down
8 changes: 4 additions & 4 deletions vortex-tensor/src/encodings/turboquant/tests/nullable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn nullable_vectors_roundtrip() -> VortexResult<()> {

let config = TurboQuantConfig {
bit_width: 3,
seed: Some(123),
seed: 123,
num_rounds: 4,
};
let mut ctx = SESSION.create_execution_ctx();
Expand Down Expand Up @@ -84,7 +84,7 @@ fn nullable_norms_match_validity() -> VortexResult<()> {

let config = TurboQuantConfig {
bit_width: 2,
seed: Some(123),
seed: 123,
num_rounds: 3,
};
let mut ctx = SESSION.create_execution_ctx();
Expand Down Expand Up @@ -114,7 +114,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> {

let config = TurboQuantConfig {
bit_width: 3,
seed: Some(123),
seed: 123,
num_rounds: 3,
};
let mut ctx = SESSION.create_execution_ctx();
Expand Down Expand Up @@ -156,7 +156,7 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> {

let config = TurboQuantConfig {
bit_width: 3,
seed: Some(123),
seed: 123,
num_rounds: 2,
};
let mut ctx = SESSION.create_execution_ctx();
Expand Down
Loading
Loading