Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 138 additions & 47 deletions vortex-array/src/arrays/primitive/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use num_traits::AsPrimitive;
use num_traits::NumCast;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_mask::Mask;

use crate::ArrayRef;
use crate::ExecutionCtx;
Expand All @@ -19,9 +22,13 @@ use crate::dtype::DType;
use crate::dtype::NativePType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::expr::stats::Precision;
use crate::expr::stats::Stat;
use crate::expr::stats::StatsProvider;
use crate::match_each_native_ptype;
use crate::scalar_fn::fns::cast::CastKernel;
use crate::scalar_fn::fns::cast::CastReduce;
use crate::validity::Validity;

impl CastReduce for Primitive {
fn cast(array: ArrayView<'_, Primitive>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
Expand Down Expand Up @@ -63,79 +70,163 @@ impl CastKernel for Primitive {
return Ok(None);
};
let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
let src_ptype = array.ptype();

// First, check that the cast is compatible with the source array's validity
let new_validity = array
.validity()?
.cast_nullability(new_nullability, array.len(), ctx)?;

// Same ptype: zero-copy, just update validity.
if array.ptype() == new_ptype {
// SAFETY: validity and data buffer still have same length
return Ok(Some(unsafe {
PrimitiveArray::new_unchecked_from_handle(
array.buffer_handle().clone(),
array.ptype(),
new_validity,
)
.into_array()
}));
// Same bit representation: either the same ptype (only the nullability changed) or two
// same-width integers (identical layout under 2's complement). The only non-trivial case
// is the sign change between same-width ints, which still needs a value-range check.
let same_rep = src_ptype == new_ptype
|| (src_ptype.is_int()
&& new_ptype.is_int()
&& src_ptype.byte_width() == new_ptype.byte_width());
if same_rep {
if !values_fit_in(array, new_ptype, ctx, true) {
vortex_bail!(
Compute: "Cannot cast {} to {} — values exceed target range",
src_ptype, new_ptype,
);
}
return Ok(Some(reinterpret(array, new_ptype, new_validity)));
}

if !values_fit_in(array, new_ptype, ctx) {
vortex_bail!(
Compute: "Cannot cast {} to {} — values exceed target range",
array.ptype(),
new_ptype,
);
}

// Same-width integers have identical bit representations due to 2's
// complement. If all values fit in the target range, reinterpret with
// no allocation.
if array.ptype().is_int()
&& new_ptype.is_int()
&& array.ptype().byte_width() == new_ptype.byte_width()
{
// SAFETY: both types are integers with the same size and alignment, and
// min/max confirm all valid values are representable in the target type.
return Ok(Some(unsafe {
PrimitiveArray::new_unchecked_from_handle(
array.buffer_handle().clone(),
new_ptype,
new_validity,
)
.into_array()
}));
}

// Otherwise, we need to cast the values one-by-one.
// Different bit rep: cast each element. `cast_values` picks a pure or checked loop based
// on whether the conversion is statically infallible.
Ok(Some(match_each_native_ptype!(new_ptype, |T| {
match_each_native_ptype!(array.ptype(), |F| {
PrimitiveArray::new(cast::<F, T>(array.as_slice()), new_validity).into_array()
match_each_native_ptype!(src_ptype, |F| {
cast_values::<F, T>(array, new_validity, ctx)?
})
})))
}
}

/// Cast values from `F` to `T`. For infallible casts this is a pure pass; for fallible casts
/// each valid value goes through a checked `NumCast::from` and the kernel bails if any of them
/// overflow `T`. Invalid positions use the wrapping `as` cast since their values are masked out.
fn cast_values<F, T>(
array: ArrayView<'_, Primitive>,
new_validity: Validity,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef>
where
F: NativePType + AsPrimitive<T>,
T: NativePType,
{
let values = array.as_slice::<F>();

// Fast path: statically infallible, or cached min/max prove every valid value fits in `T`.
// The cached check never triggers a stats computation — if the bounds aren't already known
// we fall through to the per-lane loop below.
if values_always_fit(F::PTYPE, T::PTYPE) || values_fit_in(array, T::PTYPE, ctx, false) {
return Ok(PrimitiveArray::new(cast::<F, T>(values), new_validity).into_array());
}

// TODO(joe): if the values source and target have the same bit-width we can
// mutate in place.

// Fallible: invalid lanes are pre-multiplied to zero so the checked cast always succeeds for
// them; valid lanes go through `NumCast::from` and the whole cast bails on the first overflow.
let mask = array.validity()?.execute_mask(array.len(), ctx)?;
let overflow = || {
vortex_err!(
Compute: "Cannot cast {} to {} — value exceeds target range",
F::PTYPE, T::PTYPE,
)
};
let buffer: Buffer<T> = match &mask {
Mask::AllTrue(_) => BufferMut::try_from_trusted_len_iter(
values
.iter()
.map(|&v| <T as NumCast>::from(v).ok_or_else(overflow)),
)?
.freeze(),
Mask::AllFalse(_) => BufferMut::<T>::zeroed(values.len()).freeze(),
Mask::Values(m) => BufferMut::try_from_trusted_len_iter(
values.iter().zip(m.bit_buffer().iter()).map(|(&v, valid)| {
let factor = if valid { F::one() } else { F::zero() };
<T as NumCast>::from(v * factor).ok_or_else(overflow)
}),
)?
.freeze(),
};

Ok(PrimitiveArray::new(buffer, new_validity).into_array())
}

/// Out-of-range values at invalid positions are truncated/wrapped by `as`, which is fine because
/// they are masked out by validity.
fn cast<F: NativePType + AsPrimitive<T>, T: NativePType>(array: &[F]) -> Buffer<T> {
BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze()
}

fn reinterpret(
array: ArrayView<'_, Primitive>,
new_ptype: PType,
new_validity: Validity,
) -> ArrayRef {
// SAFETY: caller has verified the bit representation is compatible and that validity length
// still matches the buffer length.
unsafe {
PrimitiveArray::new_unchecked_from_handle(
array.buffer_handle().clone(),
new_ptype,
new_validity,
)
}
.into_array()
}

/// Returns `true` if every value of `src` is guaranteed representable in `target` without
/// overflow. Precision may be lost (e.g. large integers cast to `f32`), but the cast can never
/// produce an out-of-range result.
fn values_always_fit(src: PType, target: PType) -> bool {
if src == target {
return true;
}
if src.is_int() && target.is_int() {
return target.byte_width() > src.byte_width()
&& (src.is_unsigned_int() || target.is_signed_int());
}
if src.is_float() && target.is_float() {
return target.byte_width() > src.byte_width();
}
src.is_int() && matches!(target, PType::F32 | PType::F64)
}

/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
///
/// Cached min/max statistics are consulted first. If either bound is missing, the function either
/// computes them with a single pass (when `compute` is `true`) or returns `false` so the caller
/// can fall back to a slower path (when `compute` is `false`).
fn values_fit_in(
array: ArrayView<'_, Primitive>,
target_ptype: PType,
ctx: &mut ExecutionCtx,
compute: bool,
) -> bool {
let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
if let Some(fits) = cached_values_fit_in(array, &target_dtype) {
return fits;
}
if !compute {
return false;
}
aggregate_fn::fns::min_max::min_max(array.array(), ctx)
.ok()
.flatten()
.is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
}

/// Caller must ensure all valid values are representable via `values_fit_in`.
/// Out-of-range values at invalid positions are truncated/wrapped by `as`,
/// which is fine because they are masked out by validity.
fn cast<F: NativePType + AsPrimitive<T>, T: NativePType>(array: &[F]) -> Buffer<T> {
BufferMut::from_trusted_len_iter(array.iter().map(|&src| src.as_())).freeze()
/// Cached-only check: returns `Some(fits)` if both `Min` and `Max` are present as `Exact` in the
/// stats cache, otherwise `None`.
fn cached_values_fit_in(array: ArrayView<'_, Primitive>, target_dtype: &DType) -> Option<bool> {
let stats = array.array().statistics();
let min = stats.get(Stat::Min).and_then(Precision::as_exact)?;
let max = stats.get(Stat::Max).and_then(Precision::as_exact)?;
Some(min.cast(target_dtype).is_ok() && max.cast(target_dtype).is_ok())
}

#[cfg(test)]
Expand Down
4 changes: 4 additions & 0 deletions vortex-buffer/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,10 @@ pub fn vortex_buffer::BufferMut<T>::extend_trusted<I: vortex_buffer::trusted_len

pub fn vortex_buffer::BufferMut<T>::from_trusted_len_iter<I>(I) -> Self where I: vortex_buffer::trusted_len::TrustedLen<Item = T>

pub fn vortex_buffer::BufferMut<T>::try_extend_trusted<E, I>(&mut self, I) -> core::result::Result<(), E> where I: vortex_buffer::trusted_len::TrustedLen<Item = core::result::Result<T, E>>

pub fn vortex_buffer::BufferMut<T>::try_from_trusted_len_iter<E, I>(I) -> core::result::Result<Self, E> where I: vortex_buffer::trusted_len::TrustedLen<Item = core::result::Result<T, E>>

impl<'a, T> core::iter::traits::collect::Extend<&'a T> for vortex_buffer::BufferMut<T> where T: core::marker::Copy + 'a

pub fn vortex_buffer::BufferMut<T>::extend<I: core::iter::traits::collect::IntoIterator<Item = &'a T>>(&mut self, I)
Expand Down
79 changes: 79 additions & 0 deletions vortex-buffer/src/buffer_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,66 @@ impl<T> BufferMut<T> {
buffer.extend_trusted(iter);
buffer
}

/// Like [`extend_trusted()`](Self::extend_trusted), but the iterator yields `Result<T, E>`
/// and the extension short-circuits on the first `Err`.
///
/// On error, items written before the failure remain in the buffer.
pub fn try_extend_trusted<E, I>(&mut self, iter: I) -> Result<(), E>
where
I: TrustedLen<Item = Result<T, E>>,
{
let (_, upper_bound) = iter.size_hint();
self.reserve(
upper_bound
.vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
);

let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
let mut dst: *mut T = begin.cast_mut();
let mut result: Result<(), E> = Ok(());

for item in iter {
match item {
Ok(value) => {
// SAFETY: We reserved enough capacity to hold this item, and `dst` is a
// pointer derived from a valid reference to byte data.
unsafe { dst.write(value) };
// SAFETY: The offset fits in `isize` because we reserved that much capacity.
unsafe { dst = dst.add(1) };
}
Err(e) => {
result = Err(e);
break;
}
}
}

// SAFETY: `dst` was derived from `begin`, both valid references to byte data, and
// `dst >= begin` since the only operation on `dst` is `add`.
let items_written = unsafe { dst.offset_from_unsigned(begin) };
let length = self.len() + items_written;
// SAFETY: We have written valid items between the old length and the new length.
unsafe { self.set_len(length) };

result
}

/// Like [`from_trusted_len_iter()`](Self::from_trusted_len_iter), but the iterator yields
/// `Result<T, E>` and construction short-circuits on the first `Err`.
pub fn try_from_trusted_len_iter<E, I>(iter: I) -> Result<Self, E>
where
I: TrustedLen<Item = Result<T, E>>,
{
let (_, upper_bound) = iter.size_hint();
let mut buffer = Self::with_capacity(
upper_bound
.vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
);

buffer.try_extend_trusted(iter)?;
Ok(buffer)
}
}

impl<T> Extend<T> for BufferMut<T> {
Expand Down Expand Up @@ -814,6 +874,25 @@ mod test {
assert_eq!(buf.as_slice(), &[0, 10, 20, 30]);
}

#[test]
fn try_from_trusted_len_iter_ok() {
let buf = BufferMut::<i32>::try_from_trusted_len_iter(
[0, 10, 20, 30].iter().map(|&v| Ok::<_, ()>(v)),
)
.unwrap();
assert_eq!(buf.as_slice(), &[0, 10, 20, 30]);
}

#[test]
fn try_from_trusted_len_iter_err() {
let result: Result<BufferMut<i32>, &'static str> = BufferMut::try_from_trusted_len_iter(
[0, 10, 20, 30]
.iter()
.map(|&v| if v == 20 { Err("bad") } else { Ok(v) }),
);
assert_eq!(result.err(), Some("bad"));
}

#[test]
fn extend() {
let mut buf = BufferMut::empty();
Expand Down
Loading