Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vortex-tensor/src/encodings/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

pub mod norm;
// mod spherical;
211 changes: 211 additions & 0 deletions vortex-tensor/src/encodings/norm/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use num_traits::Float;
use vortex::array::ArrayRef;
use vortex::array::ExecutionCtx;
use vortex::array::IntoArray;
use vortex::array::ToCanonical;
use vortex::array::arrays::ExtensionArray;
use vortex::array::arrays::FixedSizeListArray;
use vortex::array::arrays::PrimitiveArray;
use vortex::array::arrays::ScalarFnArray;
use vortex::array::match_each_float_ptype;
use vortex::array::validity::Validity;
use vortex::dtype::DType;
use vortex::dtype::Nullability;
use vortex::dtype::extension::ExtDType;
use vortex::error::VortexResult;
use vortex::error::vortex_ensure;
use vortex::error::vortex_ensure_eq;
use vortex::error::vortex_err;
use vortex::extension::EmptyMetadata;
use vortex::scalar_fn::EmptyOptions;
use vortex::scalar_fn::ScalarFn;

use crate::scalar_fns::l2_norm::L2Norm;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::extension_storage;
use crate::utils::extract_flat_elements;
use crate::vector::Vector;

/// A normalized array that stores unit-normalized vectors alongside their original L2 norms.
///
/// Each vector in the array is divided by its L2 norm, producing a unit-normalized vector. The
/// original norms are stored separately so that the original vectors can be reconstructed.
#[derive(Debug, Clone)]
pub struct NormVectorArray {
/// The backing vector array that has been unit normalized.
///
/// The underlying elements of the vector array must be floating-point.
pub(crate) vector_array: ArrayRef,

/// The L2 (Frobenius) norms of each vector.
///
/// This must have the same dtype as the elements of the vector array.
pub(crate) norms: ArrayRef,
}

impl NormVectorArray {
/// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms.
///
/// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and
/// `norms` must be a primitive array of the same float type with the same length.
pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult<Self> {
let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| {
vortex_err!(
"vector_array dtype must be an extension type, got {}",
vector_array.dtype()
)
})?;

vortex_ensure!(
ext.is::<Vector>(),
"vector_array must have the Vector extension type, got {}",
vector_array.dtype()
);

let element_ptype = extension_element_ptype(ext)?;

let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable);
vortex_ensure_eq!(
*norms.dtype(),
expected_norms_dtype,
"norms dtype must match vector element type"
);

vortex_ensure_eq!(
vector_array.len(),
norms.len(),
"vector_array and norms must have the same length"
);

Ok(Self {
vector_array,
norms,
})
}

/// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and
/// dividing each vector by its norm.
///
/// The input must be a [`Vector`] extension array with floating-point elements.
pub fn compress(vector_array: ArrayRef) -> VortexResult<Self> {
let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| {
vortex_err!(
"vector_array dtype must be an extension type, got {}",
vector_array.dtype()
)
})?;

vortex_ensure!(
ext.is::<Vector>(),
"vector_array must have the Vector extension type, got {}",
vector_array.dtype()
);

let list_size = extension_list_size(ext)?;
let row_count = vector_array.len();

// Compute L2 norms using the scalar function.
let l2_norm_fn = ScalarFn::new(L2Norm, EmptyOptions).erased();
let norms = ScalarFnArray::try_new(l2_norm_fn, vec![vector_array.clone()], row_count)?
.to_primitive()
.into_array();

// Divide each vector element by its corresponding norm.
let storage = extension_storage(&vector_array)?;
let flat = extract_flat_elements(&storage, list_size)?;
let norms_prim = norms.to_canonical()?.into_primitive();

match_each_float_ptype!(flat.ptype(), |T| {
let norms_slice = norms_prim.as_slice::<T>();

let normalized_elems: PrimitiveArray = (0..row_count)
.flat_map(|i| {
let inv_norm = safe_inv_norm(norms_slice[i]);
flat.row::<T>(i).iter().map(move |&v| v * inv_norm)
})
.collect();

let fsl = FixedSizeListArray::new(
normalized_elems.into_array(),
u32::try_from(list_size)?,
Validity::NonNullable,
row_count,
);

let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
let normalized_vector = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array();

Self::try_new(normalized_vector, norms)
})
}

/// Returns a reference to the backing vector array that has been unit normalized.
pub fn vector_array(&self) -> &ArrayRef {
&self.vector_array
}

/// Returns a reference to the L2 (Frobenius) norms of each vector.
pub fn norms(&self) -> &ArrayRef {
&self.norms
}

/// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm.
pub fn decompress(&self, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
let ext_dtype = self
.vector_array
.dtype()
.as_extension_opt()
.ok_or_else(|| {
vortex_err!(
"expected Vector extension dtype, got {}",
self.vector_array.dtype()
)
})?;

let list_size = extension_list_size(ext_dtype)?;
let row_count = self.vector_array.len();

let storage = extension_storage(&self.vector_array)?;
let flat = extract_flat_elements(&storage, list_size)?;

let norms_prim = self.norms.to_canonical()?.into_primitive();

match_each_float_ptype!(flat.ptype(), |T| {
let norms_slice = norms_prim.as_slice::<T>();

let result_elems: PrimitiveArray = (0..row_count)
.flat_map(|i| {
let norm = norms_slice[i];
flat.row::<T>(i).iter().map(move |&v| v * norm)
})
.collect();

let fsl = FixedSizeListArray::new(
result_elems.into_array(),
u32::try_from(list_size)?,
Validity::NonNullable,
row_count,
);

let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
})
}
}

/// Returns `1 / norm` if the norm is non-zero, or zero otherwise.
///
/// This avoids division by zero for zero-length or all-zero vectors.
fn safe_inv_norm<T: Float>(norm: T) -> T {
if norm == T::zero() {
T::zero()
} else {
T::one() / norm
}
}
13 changes: 13 additions & 0 deletions vortex-tensor/src/encodings/norm/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

mod array;
pub use array::NormVectorArray;

// pub(crate) mod compute;

mod vtable;
pub use vtable::NormVector;

#[cfg(test)]
mod tests;
135 changes: 135 additions & 0 deletions vortex-tensor/src/encodings/norm/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex::array::IntoArray;
use vortex::array::VortexSessionExecute;
use vortex::array::arrays::Extension;
use vortex::error::VortexResult;

use crate::encodings::norm::NormVectorArray;
use crate::utils::extension_list_size;
use crate::utils::extension_storage;
use crate::utils::extract_flat_elements;
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::vector_array;

#[test]
fn encode_unit_vectors() -> VortexResult<()> {
// Already unit-length vectors: norms should be 1.0 and vectors unchanged.
let arr = vector_array(
3,
&[
1.0, 0.0, 0.0, // norm = 1.0
0.0, 1.0, 0.0, // norm = 1.0
],
)?;

let norm = NormVectorArray::compress(arr)?;
let norms = norm.norms().to_canonical()?.into_primitive();
assert_close(norms.as_slice::<f64>(), &[1.0, 1.0]);

let vectors = norm.vector_array();
let ext = vectors.dtype().as_extension_opt().unwrap();
let list_size = extension_list_size(ext)?;
let storage = extension_storage(vectors)?;
let flat = extract_flat_elements(&storage, list_size)?;
assert_close(flat.row::<f64>(0), &[1.0, 0.0, 0.0]);
assert_close(flat.row::<f64>(1), &[0.0, 1.0, 0.0]);

Ok(())
}

#[test]
fn encode_non_unit_vectors() -> VortexResult<()> {
let arr = vector_array(
2,
&[
3.0, 4.0, // norm = 5.0
0.0, 0.0, // norm = 0.0 (zero vector)
],
)?;

let norm = NormVectorArray::compress(arr)?;
let norms = norm.norms().to_canonical()?.into_primitive();
assert_close(norms.as_slice::<f64>(), &[5.0, 0.0]);

let vectors = norm.vector_array();
let ext = vectors.dtype().as_extension_opt().unwrap();
let list_size = extension_list_size(ext)?;
let storage = extension_storage(vectors)?;
let flat = extract_flat_elements(&storage, list_size)?;
assert_close(flat.row::<f64>(0), &[3.0 / 5.0, 4.0 / 5.0]);
assert_close(flat.row::<f64>(1), &[0.0, 0.0]);

Ok(())
}

#[test]
fn execute_round_trip() -> VortexResult<()> {
let original_elements = &[
3.0, 4.0, // norm = 5.0
6.0, 8.0, // norm = 10.0
];
let arr = vector_array(2, original_elements)?;

let norm = NormVectorArray::compress(arr)?;

// Execute to reconstruct the original vectors.
let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx();
let reconstructed = norm.decompress(&mut ctx)?;

// The reconstructed array should be a Vector extension array.
assert!(reconstructed.as_opt::<Extension>().is_some());

let ext = reconstructed.dtype().as_extension_opt().unwrap();
let list_size = extension_list_size(ext)?;
let storage = extension_storage(&reconstructed)?;
let flat = extract_flat_elements(&storage, list_size)?;
assert_close(flat.row::<f64>(0), &[3.0, 4.0]);
assert_close(flat.row::<f64>(1), &[6.0, 8.0]);

Ok(())
}

#[test]
fn execute_round_trip_zero_vector() -> VortexResult<()> {
let arr = vector_array(2, &[0.0, 0.0])?;

let norm = NormVectorArray::compress(arr)?;

let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx();
let reconstructed = norm.decompress(&mut ctx)?;

let ext = reconstructed.dtype().as_extension_opt().unwrap();
let list_size = extension_list_size(ext)?;
let storage = extension_storage(&reconstructed)?;
let flat = extract_flat_elements(&storage, list_size)?;
// Zero vector should remain zero after round-trip.
assert_close(flat.row::<f64>(0), &[0.0, 0.0]);

Ok(())
}

#[test]
fn scalar_at_returns_original_vector() -> VortexResult<()> {
let arr = vector_array(
2,
&[
3.0, 4.0, // norm = 5.0
6.0, 8.0, // norm = 10.0
],
)?;

let encoded = NormVectorArray::compress(arr)?;

// `scalar_at` on the NormVectorArray should match `scalar_at` on the decompressed result.
let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx();
let decompressed = encoded.decompress(&mut ctx)?;

let norm_array = encoded.into_array();
for i in 0..2 {
assert_eq!(norm_array.scalar_at(i)?, decompressed.scalar_at(i)?);
}

Ok(())
}
Loading
Loading