Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Dec 28, 2023
1 parent 67f920d commit 26e109c
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 25 deletions.
26 changes: 26 additions & 0 deletions crates/polars-arrow/src/array/static_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ pub trait StaticArray:
fn from_zeroable_vec(v: Vec<Self::ZeroableValueT<'_>>, dtype: ArrowDataType) -> Self {
Self::arr_from_iter_with_dtype(dtype, v)
}

fn full_null(length: usize, dtype: DataType) -> Self;
}

pub trait ParameterFreeDtypeStaticArray: StaticArray {
Expand Down Expand Up @@ -118,6 +120,10 @@ impl<T: NativeType> StaticArray for PrimitiveArray<T> {
fn from_zeroable_vec(v: Vec<Self::ZeroableValueT<'_>>, _dtype: ArrowDataType) -> Self {
PrimitiveArray::from_vec(v)
}

fn full_null(length: usize, dtype: DataType) -> Self {
Self::new_null(dtype.to_arrow(), length)
}
}

impl<T: NativeType> ParameterFreeDtypeStaticArray for PrimitiveArray<T> {
Expand Down Expand Up @@ -155,6 +161,10 @@ impl StaticArray for BooleanArray {
fn from_zeroable_vec(v: Vec<Self::ValueT<'_>>, _dtype: ArrowDataType) -> Self {
BooleanArray::from_slice(v)
}

fn full_null(length: usize, dtype: DataType) -> Self {
Self::new_null(dtype.to_arrow(), length)
}
}

impl ParameterFreeDtypeStaticArray for BooleanArray {
Expand Down Expand Up @@ -184,6 +194,10 @@ impl StaticArray for Utf8Array<i64> {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(length: usize, dtype: DataType) -> Self {
Self::new_null(dtype.to_arrow(), length)
}
}

impl ParameterFreeDtypeStaticArray for Utf8Array<i64> {
Expand Down Expand Up @@ -213,6 +227,10 @@ impl StaticArray for BinaryArray<i64> {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(length: usize, dtype: DataType) -> Self {
Self::new_null(dtype.to_arrow(), length)
}
}

impl ParameterFreeDtypeStaticArray for BinaryArray<i64> {
Expand Down Expand Up @@ -242,6 +260,10 @@ impl StaticArray for ListArray<i64> {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(length: usize, dtype: DataType) -> Self {
Self::new_null(dtype.to_arrow(), length)
}
}

impl StaticArray for FixedSizeListArray {
Expand All @@ -265,4 +287,8 @@ impl StaticArray for FixedSizeListArray {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(length: usize, dtype: DataType) -> Self {
Self::new_null(dtype.to_arrow(), length)
}
}
228 changes: 228 additions & 0 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray};
use crate::prelude::{ChunkedArray, PolarsDataType};
use crate::utils::{align_chunks_binary, align_chunks_ternary};

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait UnaryFnMut<A1>: FnMut(A1) -> Self::Ret {
type Ret;
}

impl<A1, R, T: FnMut(A1) -> R> UnaryFnMut<A1> for T {
type Ret = R;
}

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait TernaryFnMut<A1, A2, A3>: FnMut(A1, A2, A3) -> Self::Ret {
Expand All @@ -27,6 +37,82 @@ impl<A1, A2, R, T: FnMut(A1, A2) -> R> BinaryFnMut<A1, A2> for T {
type Ret = R;
}

#[inline]
pub fn unary_elementwise<'a, T, V, F>(ca: &'a ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: UnaryFnMut<Option<T::Physical<'a>>>,
V::Array: ArrayFromIter<<F as UnaryFnMut<Option<T::Physical<'a>>>>::Ret>,
{
let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(&mut op).collect_arr());
ChunkedArray::from_chunk_iter(ca.name(), iter)
}

#[inline]
pub fn try_unary_elementwise<'a, T, V, F, K, E>(
ca: &'a ChunkedArray<T>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
V: PolarsDataType,
F: FnMut(Option<T::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(&mut op).try_collect_arr());
ChunkedArray::try_from_chunk_iter(ca.name(), iter)
}

#[inline]
pub fn unary_elementwise_values<'a, T, V, F>(ca: &'a ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: UnaryFnMut<T::Physical<'a>>,
V::Array: ArrayFromIter<<F as UnaryFnMut<T::Physical<'a>>>::Ret>,
{
if ca.null_count() == ca.len() {
let arr = V::Array::full_null(ca.len(), V::get_dtype());
return ChunkedArray::with_chunk(ca.name(), arr);
}

let iter = ca.downcast_iter().map(|arr| {
let validity = arr.validity().cloned();
let arr: V::Array = arr.values_iter().map(&mut op).collect_arr();
arr.with_validity_typed(validity)
});
ChunkedArray::from_chunk_iter(ca.name(), iter)
}

#[inline]
pub fn try_unary_elementwise_values<'a, T, V, F, K, E>(
ca: &'a ChunkedArray<T>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
V: PolarsDataType,
F: FnMut(T::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
if ca.null_count() == ca.len() {
let arr = V::Array::full_null(ca.len(), V::get_dtype());
return Ok(ChunkedArray::with_chunk(ca.name(), arr));
}

let iter = ca.downcast_iter().map(|arr| {
let validity = arr.validity().cloned();
let arr: V::Array = arr.values_iter().map(&mut op).try_collect_arr()?;
Ok(arr.with_validity_typed(validity))
});
ChunkedArray::try_from_chunk_iter(ca.name(), iter)
}

/// Applies a kernel that produces `Array` types.
///
/// Intended for kernels that apply on values, this function will apply the
Expand Down Expand Up @@ -176,6 +262,13 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
V::Array: ArrayFromIter<K>,
{
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let len = lhs.len().min(rhs.len());
let arr = V::Array::full_null(len, V::get_dtype());

return ChunkedArray::with_chunk(lhs.name(), arr);
}

let (lhs, rhs) = align_chunks_binary(lhs, rhs);

let iter = lhs
Expand Down Expand Up @@ -208,6 +301,13 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let len = lhs.len().min(rhs.len());
let arr = V::Array::full_null(len, V::get_dtype());

return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let iter = lhs
.downcast_iter()
Expand Down Expand Up @@ -446,3 +546,131 @@ where
});
ChunkedArray::from_chunk_iter(ca1.name(), iter)
}

pub fn broadcast_binary_elementwise<T, U, V, F>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
) -> ChunkedArray<V>
where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType,
F: for<'a> BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>,
V::Array: for<'a> ArrayFromIter<
<F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
>,
{
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
let mut out = unary_elementwise(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
unary_elementwise(lhs, |a| op(a, b.clone()))
},
_ => binary_elementwise(lhs, rhs, op),
}
}

pub fn broadcast_try_binary_elementwise<T, U, V, F, K, E>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType,
F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
let mut out = try_unary_elementwise(rhs, |b| op(a.clone(), b))?;
out.rename(lhs.name());
Ok(out)
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
try_unary_elementwise(lhs, |a| op(a, b.clone()))
},
_ => try_binary_elementwise(lhs, rhs, op),
}
}

pub fn broadcast_binary_elementwise_values<T, U, V, F, K>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
) -> ChunkedArray<V>
where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType,
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
V::Array: ArrayFromIter<K>,
{
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let min = lhs.len().min(rhs.len());
let max = lhs.len().max(rhs.len());
let len = if min == 1 { max } else { min };
let arr = V::Array::full_null(len, V::get_dtype());

return ChunkedArray::with_chunk(lhs.name(), arr);
}

match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.value_unchecked(0) };
let mut out = unary_elementwise_values(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
},
(_, 1) => {
let b = unsafe { rhs.value_unchecked(0) };
unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => binary_elementwise_values(lhs, rhs, op),
}
}

pub fn broadcast_try_binary_elementwise_values<T, U, V, F, K, E>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType,
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let min = lhs.len().min(rhs.len());
let max = lhs.len().max(rhs.len());
let len = if min == 1 { max } else { min };
let arr = V::Array::full_null(len, V::get_dtype());

return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.value_unchecked(0) };
let mut out = try_unary_elementwise_values(rhs, |b| op(a.clone(), b))?;
out.rename(lhs.name());
Ok(out)
},
(_, 1) => {
let b = unsafe { rhs.value_unchecked(0) };
try_unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => try_binary_elementwise_values(lhs, rhs, op),
}
}
2 changes: 1 addition & 1 deletion crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct Flat;
///
/// The StaticArray and dtype return must be correct.
pub unsafe trait PolarsDataType: Send + Sync + Sized {
type Physical<'a>: std::fmt::Debug;
type Physical<'a>: std::fmt::Debug + Clone;
type ZeroablePhysical<'a>: Zeroable + From<Self::Physical<'a>>;
type Array: for<'a> StaticArray<
ValueT<'a> = Self::Physical<'a>,
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-core/src/datatypes/static_array.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use arrow::array::Array;
use arrow::bitmap::utils::{BitmapIter, ZipValidity};
use arrow::bitmap::Bitmap;

Expand Down Expand Up @@ -25,4 +26,8 @@ impl<T: PolarsObject> StaticArray for ObjectArray<T> {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(_length: usize, _dtype: DataType) -> Self {
panic!("ObjectArray does not support full_null");
}
}
14 changes: 11 additions & 3 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,8 @@ where
}
}

/// # Panics
/// This will panic if `a.len() != b.len() || b.len() != c.len()` and array is chunked.
#[allow(clippy::type_complexity)]
pub fn align_chunks_ternary<'a, A, B, C>(
a: &'a ChunkedArray<A>,
Expand All @@ -706,10 +708,16 @@ where
B: PolarsDataType,
C: PolarsDataType,
{
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(b.len(), c.len());
if a.chunks.len() == 1 && b.chunks.len() == 1 && c.chunks.len() == 1 {
return (Cow::Borrowed(a), Cow::Borrowed(b), Cow::Borrowed(c));
}

assert!(
a.len() == b.len() && b.len() == c.len(),
"expected arrays of the same length"
);

match (a.chunks.len(), b.chunks.len(), c.chunks.len()) {
(1, 1, 1) => (Cow::Borrowed(a), Cow::Borrowed(b), Cow::Borrowed(c)),
(_, 1, 1) => (
Cow::Borrowed(a),
Cow::Owned(b.match_chunks(a.chunk_id())),
Expand Down
Loading

0 comments on commit 26e109c

Please sign in to comment.