Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8756,9 +8756,9 @@ impl<V: vortex_array::dtype::extension::ExtVTable> vortex_array::dtype::extensio

pub type V::Match<'a> = &'a <V as vortex_array::dtype::extension::ExtVTable>::Metadata

pub fn V::matches(item: &vortex_array::dtype::extension::ExtDTypeRef) -> bool
pub fn V::matches(ext_dtype: &vortex_array::dtype::extension::ExtDTypeRef) -> bool

pub fn V::try_match<'a>(item: &'a vortex_array::dtype::extension::ExtDTypeRef) -> core::option::Option<<V as vortex_array::dtype::extension::Matcher>::Match>
pub fn V::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::ExtDTypeRef) -> core::option::Option<<V as vortex_array::dtype::extension::Matcher>::Match>

pub type vortex_array::dtype::extension::ExtDTypePluginRef = alloc::sync::Arc<dyn vortex_array::dtype::extension::ExtDTypePlugin>

Expand Down
9 changes: 5 additions & 4 deletions vortex-array/src/dtype/extension/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ pub trait Matcher {
impl<V: ExtVTable> Matcher for V {
type Match<'a> = &'a V::Metadata;

fn matches(item: &ExtDTypeRef) -> bool {
item.0.as_any().is::<ExtDType<V>>()
fn matches(ext_dtype: &ExtDTypeRef) -> bool {
ext_dtype.0.as_any().is::<ExtDType<V>>()
}

fn try_match<'a>(item: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
item.0
fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
ext_dtype
.0
.as_any()
.downcast_ref::<ExtDType<V>>()
.map(|inner| inner.metadata())
Expand Down
96 changes: 92 additions & 4 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub const vortex_tensor::encodings::turboquant::TurboQuant::MIN_DIMENSION: u32

pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantArray>

pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<&vortex_array::dtype::extension::erased::ExtDTypeRef>
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_tensor::vector::VectorMatcherMetadata>

impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuant

Expand Down Expand Up @@ -188,6 +188,14 @@ pub type vortex_tensor::encodings::turboquant::TurboQuantArray = vortex_array::a

pub mod vortex_tensor::fixed_shape

pub struct vortex_tensor::fixed_shape::AnyFixedShapeTensor

impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::fixed_shape::AnyFixedShapeTensor

pub type vortex_tensor::fixed_shape::AnyFixedShapeTensor::Match<'a> = vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>

pub fn vortex_tensor::fixed_shape::AnyFixedShapeTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>

pub struct vortex_tensor::fixed_shape::FixedShapeTensor

impl core::clone::Clone for vortex_tensor::fixed_shape::FixedShapeTensor
Expand Down Expand Up @@ -230,6 +238,34 @@ pub fn vortex_tensor::fixed_shape::FixedShapeTensor::unpack_native<'a>(_ext_dtyp

pub fn vortex_tensor::fixed_shape::FixedShapeTensor::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>

pub struct vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>

impl vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>

pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::element_ptype(&self) -> vortex_array::dtype::ptype::PType

pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::list_size(&self) -> usize

pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::metadata(&self) -> &vortex_tensor::fixed_shape::FixedShapeTensorMetadata

impl<'a> core::clone::Clone for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>

pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>::clone(&self) -> vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>

impl<'a> core::cmp::Eq for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>

impl<'a> core::cmp::PartialEq for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>

pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>::eq(&self, other: &vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>) -> bool

impl<'a> core::fmt::Debug for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>

pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl<'a> core::marker::Copy for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>

impl<'a> core::marker::StructuralPartialEq for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>

pub struct vortex_tensor::fixed_shape::FixedShapeTensorMetadata

impl vortex_tensor::fixed_shape::FixedShapeTensorMetadata
Expand Down Expand Up @@ -280,9 +316,19 @@ pub mod vortex_tensor::matcher

pub enum vortex_tensor::matcher::TensorMatch<'a>

pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(&'a vortex_tensor::fixed_shape::FixedShapeTensorMetadata)
pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>)

pub vortex_tensor::matcher::TensorMatch::Vector(vortex_tensor::vector::VectorMatcherMetadata)

impl vortex_tensor::matcher::TensorMatch<'_>

pub fn vortex_tensor::matcher::TensorMatch<'_>::element_ptype(self) -> vortex_array::dtype::ptype::PType

pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> usize

pub vortex_tensor::matcher::TensorMatch::Vector
impl<'a> core::clone::Clone for vortex_tensor::matcher::TensorMatch<'a>

pub fn vortex_tensor::matcher::TensorMatch<'a>::clone(&self) -> vortex_tensor::matcher::TensorMatch<'a>

impl<'a> core::cmp::Eq for vortex_tensor::matcher::TensorMatch<'a>

Expand All @@ -294,6 +340,8 @@ impl<'a> core::fmt::Debug for vortex_tensor::matcher::TensorMatch<'a>

pub fn vortex_tensor::matcher::TensorMatch<'a>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl<'a> core::marker::Copy for vortex_tensor::matcher::TensorMatch<'a>

impl<'a> core::marker::StructuralPartialEq for vortex_tensor::matcher::TensorMatch<'a>

pub struct vortex_tensor::matcher::AnyTensor
Expand All @@ -302,7 +350,7 @@ impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::matcher

pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::TensorMatch<'a>

pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(item: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>

pub mod vortex_tensor::scalar_fns

Expand Down Expand Up @@ -456,6 +504,14 @@ impl core::marker::StructuralPartialEq for vortex_tensor::scalar_fns::ApproxOpti

pub mod vortex_tensor::vector

pub struct vortex_tensor::vector::AnyVector

impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::vector::AnyVector

pub type vortex_tensor::vector::AnyVector::Match<'a> = vortex_tensor::vector::VectorMatcherMetadata

pub fn vortex_tensor::vector::AnyVector::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>

pub struct vortex_tensor::vector::Vector

impl core::clone::Clone for vortex_tensor::vector::Vector
Expand Down Expand Up @@ -498,4 +554,36 @@ pub fn vortex_tensor::vector::Vector::unpack_native<'a>(_ext_dtype: &'a vortex_a

pub fn vortex_tensor::vector::Vector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>

pub struct vortex_tensor::vector::VectorMatcherMetadata

impl vortex_tensor::vector::VectorMatcherMetadata

pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32

pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType

pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult<Self>

impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata

pub fn vortex_tensor::vector::VectorMatcherMetadata::clone(&self) -> vortex_tensor::vector::VectorMatcherMetadata

impl core::cmp::Eq for vortex_tensor::vector::VectorMatcherMetadata

impl core::cmp::PartialEq for vortex_tensor::vector::VectorMatcherMetadata

pub fn vortex_tensor::vector::VectorMatcherMetadata::eq(&self, other: &vortex_tensor::vector::VectorMatcherMetadata) -> bool

impl core::fmt::Debug for vortex_tensor::vector::VectorMatcherMetadata

pub fn vortex_tensor::vector::VectorMatcherMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl core::hash::Hash for vortex_tensor::vector::VectorMatcherMetadata

pub fn vortex_tensor::vector::VectorMatcherMetadata::hash<__H: core::hash::Hasher>(&self, state: &mut __H)

impl core::marker::Copy for vortex_tensor::vector::VectorMatcherMetadata

impl core::marker::StructuralPartialEq for vortex_tensor::vector::VectorMatcherMetadata

pub fn vortex_tensor::initialize(session: &vortex_session::VortexSession)
34 changes: 10 additions & 24 deletions vortex-tensor/src/encodings/turboquant/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ use vortex_error::vortex_ensure_eq;

use crate::encodings::turboquant::array::slots::Slot;
use crate::encodings::turboquant::vtable::TurboQuant;
use crate::utils::tensor_element_ptype;
use crate::utils::tensor_list_size;

/// TurboQuant array data.
///
Expand All @@ -41,16 +39,13 @@ pub struct TurboQuantData {
}

impl TurboQuantData {
/// Build a TurboQuant array with validation.
///
/// The `dimension` and `bit_width` are derived from the inputs:
/// - `dimension` from the `dtype`'s `FixedSizeList` storage list size.
/// - `bit_width` from `log2(centroids.len())` (0 for degenerate empty arrays).
/// Build a `TurboQuantData` with validation.
///
/// # Errors
///
/// Returns an error if the provided components do not satisfy the invariants documented
/// in [`new_unchecked`](Self::new_unchecked).
/// Returns an error if:
/// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
/// - `bit_width` is greater than 8.
pub fn try_new(dimension: u32, bit_width: u8) -> VortexResult<Self> {
vortex_ensure!(
dimension >= TurboQuant::MIN_DIMENSION,
Expand All @@ -67,23 +62,14 @@ impl TurboQuantData {
})
}

/// Build a TurboQuant array without validation.
/// Build a `TurboQuantData` without validation.
///
/// # Safety
///
/// The caller must ensure:
///
/// - `dtype` is a [`Vector`](crate::vector::Vector) extension type whose storage list size
/// is >= [`MIN_DIMENSION`](crate::encodings::turboquant::TurboQuant::MIN_DIMENSION).
/// - `codes` is a non-nullable `FixedSizeListArray<u8>` with `list_size == padded_dim` and
/// `codes.len() == norms.len()`. Null vectors are represented by all-zero codes.
/// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage
/// dtype. The nullability must match `dtype.nullability()`. Norms carry the validity of the
/// entire array, since null vectors have null norms.
/// - `centroids` is a non-nullable `PrimitiveArray<f32>` whose length is a power of 2 in
/// `[2, 256]` (i.e., `2^bit_width` for bit_width 1-8), or empty for degenerate arrays.
/// - `rotation_signs` has `3 * padded_dim` elements, or is empty for degenerate arrays.
/// - For degenerate (empty) arrays: all children must be empty.
/// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
/// - `bit_width` is in the range `[0, 8]`.
///
/// Violating these invariants may produce incorrect results during decompression.
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8) -> Self {
Expand All @@ -103,8 +89,8 @@ impl TurboQuantData {
centroids: &ArrayRef,
rotation_signs: &ArrayRef,
) -> VortexResult<()> {
let ext = TurboQuant::validate_dtype(dtype)?;
let dimension = tensor_list_size(ext)?;
let vector_metadata = TurboQuant::validate_dtype(dtype)?;
let dimension = vector_metadata.dimensions();
let padded_dim = dimension.next_power_of_two();

// Codes must be a non-nullable FixedSizeList<u8> with list_size == padded_dim.
Expand Down Expand Up @@ -159,7 +145,7 @@ impl TurboQuantData {

// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
// Norms carry the validity of the entire TurboQuant array.
let element_ptype = tensor_element_ptype(ext)?;
let element_ptype = vector_metadata.element_ptype();
let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability());
vortex_ensure_eq!(
*norms.dtype(),
Expand Down
28 changes: 14 additions & 14 deletions vortex-tensor/src/encodings/turboquant/array/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ use vortex_compressor::CascadingCompressor;
use vortex_compressor::ctx::CompressorContext;
use vortex_compressor::scheme::Scheme;
use vortex_compressor::stats::ArrayAndStats;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;

use crate::encodings::turboquant::TurboQuant;
use crate::encodings::turboquant::TurboQuantConfig;
use crate::encodings::turboquant::turboquant_encode;
use crate::utils::tensor_element_ptype;
use crate::utils::tensor_list_size;

/// TurboQuant compression scheme for [`Vector`] extension types.
///
Expand Down Expand Up @@ -58,15 +57,16 @@ impl Scheme for TurboQuantScheme {
let dtype = data.array().dtype();
let len = data.array().len();

let ext = TurboQuant::validate_dtype(dtype)?;
let element_ptype = tensor_element_ptype(ext)?;
let dimension = tensor_list_size(ext)?;
let vector_metadata =
TurboQuant::validate_dtype(dtype).vortex_expect("invalid dtype for TurboQuant");
let element_ptype = vector_metadata.element_ptype();
let bit_width: u8 = element_ptype
.bit_width()
.try_into()
.vortex_expect("invalid bit width for TurboQuant");
let dimension = vector_metadata.dimensions();

Ok(estimate_compression_ratio(
element_ptype.bit_width(),
dimension,
len,
))
Ok(estimate_compression_ratio(bit_width, dimension, len))
}

fn compress(
Expand All @@ -84,7 +84,7 @@ impl Scheme for TurboQuantScheme {
}

/// Estimate the compression ratio for TurboQuant MSE encoding with the default config.
fn estimate_compression_ratio(bits_per_element: usize, dimensions: u32, num_vectors: usize) -> f64 {
fn estimate_compression_ratio(bits_per_element: u8, dimensions: u32, num_vectors: usize) -> f64 {
let config = TurboQuantConfig::default();
let padded_dim = dimensions.next_power_of_two() as usize;

Expand All @@ -99,7 +99,7 @@ fn estimate_compression_ratio(bits_per_element: usize, dimensions: u32, num_vect
+ 3 * padded_dim; // rotation signs, 1 bit each

let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits;
let uncompressed_size_bits = bits_per_element * num_vectors * dimensions as usize;
let uncompressed_size_bits = bits_per_element as usize * dimensions as usize * num_vectors;
uncompressed_size_bits as f64 / compressed_size_bits as f64
}

Expand All @@ -121,7 +121,7 @@ mod tests {
#[case::f64_768d(64, 768, 1000, 5.0, 7.0)]
#[case::f16_768d(16, 768, 1000, 1.2, 2.0)]
fn compression_ratio_in_expected_range(
#[case] bits_per_element: usize,
#[case] bits_per_element: u8,
#[case] dim: u32,
#[case] num_vectors: usize,
#[case] min_ratio: f64,
Expand All @@ -142,7 +142,7 @@ mod tests {
#[case(32, 768, 10)]
#[case(64, 256, 50)]
fn ratio_always_greater_than_one(
#[case] bits_per_element: usize,
#[case] bits_per_element: u8,
#[case] dim: u32,
#[case] num_vectors: usize,
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use vortex_error::vortex_ensure_eq;
use crate::encodings::turboquant::TurboQuant;
use crate::encodings::turboquant::TurboQuantArrayExt;
use crate::encodings::turboquant::array::float_from_f32;
use crate::utils::tensor_element_ptype;
use crate::vector::AnyVector;

/// Compute the per-row unit-norm dot products in f32 (centroids are always f32).
///
Expand Down Expand Up @@ -109,7 +109,11 @@ pub fn cosine_similarity_quantized_column(
"TurboQuant quantized dot product requires matching dimensions",
);

let element_ptype = tensor_element_ptype(lhs.dtype().as_extension())?;
let element_ptype = lhs
.dtype()
.as_extension()
.metadata::<AnyVector>()
.element_ptype();
let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?;
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;

Expand Down Expand Up @@ -147,7 +151,11 @@ pub fn dot_product_quantized_column(
"TurboQuant quantized dot product requires matching dimensions",
);

let element_ptype = tensor_element_ptype(lhs.dtype().as_extension())?;
let element_ptype = lhs
.dtype()
.as_extension()
.metadata::<AnyVector>()
.element_ptype();
let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?;
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;
let num_rows = lhs.norms().len();
Expand Down
6 changes: 3 additions & 3 deletions vortex-tensor/src/encodings/turboquant/decompress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ use crate::encodings::turboquant::TurboQuant;
use crate::encodings::turboquant::TurboQuantArrayExt;
use crate::encodings::turboquant::array::float_from_f32;
use crate::encodings::turboquant::array::rotation::RotationMatrix;
use crate::utils::tensor_element_ptype;
use crate::vector::AnyVector;

/// Decompress a `TurboQuantArray` into a [`Vector`] extension array.
///
/// The returned array is an [`ExtensionArray`] with the original Vector dtype wrapping a
/// `FixedSizeListArray` of f32 elements.
/// `FixedSizeListArray` of the original vector element type.
///
/// [`Vector`]: crate::vector::Vector
pub fn execute_decompress(
Expand All @@ -40,7 +40,7 @@ pub fn execute_decompress(
let padded_dim = array.padded_dim() as usize;
let num_rows = array.norms().len();
let ext_dtype = array.dtype().as_extension().clone();
let element_ptype = tensor_element_ptype(&ext_dtype)?;
let element_ptype = ext_dtype.metadata::<AnyVector>().element_ptype();

if num_rows == 0 {
let fsl_validity = Validity::from(ext_dtype.storage_dtype().nullability());
Expand Down
Loading
Loading