Skip to content

Commit

Permalink
add and use unary_* functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Dec 1, 2023
1 parent fa037a1 commit 3177d8c
Showing 1 changed file with 96 additions and 62 deletions.
158 changes: 96 additions & 62 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray};
use crate::prelude::{ChunkedArray, PolarsDataType};
use crate::utils::{align_chunks_binary, align_chunks_ternary};

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait UnaryFnMut<A1>: FnMut(A1) -> Self::Ret {
type Ret;
}

impl<A1, R, T: FnMut(A1) -> R> UnaryFnMut<A1> for T {
type Ret = R;
}

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait TernaryFnMut<A1, A2, A3>: FnMut(A1, A2, A3) -> Self::Ret {
Expand All @@ -27,16 +37,66 @@ impl<A1, A2, R, T: FnMut(A1, A2) -> R> BinaryFnMut<A1, A2> for T {
type Ret = R;
}

impl<T: PolarsDataType> ChunkedArray<T> {
/// Get an iterator over the values of all chunks.
fn downcast_iter_values(&self) -> impl Iterator<Item = Option<T::Physical<'_>>> {
self.downcast_iter().flat_map(|arr| arr.iter())
}
#[inline]
pub fn unary_elementwise<T, V, F>(ca: &ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: for<'a> UnaryFnMut<Option<T::Physical<'a>>>,
V::Array: for<'a> ArrayFromIter<<F as UnaryFnMut<Option<T::Physical<'a>>>>::Ret>,
{
let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(op).collect_arr());
ChunkedArray::from_chunk_iter(lhs.name(), iter)
}

/// Get an iterator over the values of all chunks ignoring validity.
fn downcast_iter_values_unchecked(&self) -> impl Iterator<Item = T::Physical<'_>> {
self.downcast_iter().flat_map(|arr| arr.values_iter())
}
#[inline]
pub fn try_unary_elementwise<T, V, F>(ca: &ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: for<'a> UnaryFnMut<Option<T::Physical<'a>>>,
V::Array: for<'a> ArrayFromIter<<F as UnaryFnMut<Option<T::Physical<'a>>>>::Ret>,
{
let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(op).try_collect_arr());
ChunkedArray::try_from_chunk_iter(lhs.name(), iter)
}

#[inline]
pub fn unary_elementwise_values<T, V, F>(ca: &ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: for<'a> UnaryFnMut<T::Physical<'a>>,
V::Array: for<'a> ArrayFromIter<<F as UnaryFnMut<T::Physical<'a>>>::Ret>,
{
let iter = ca.downcast_iter().map(|arr| {
arr.iter()
.map(op)
.collect_arr()
.with_validity_typed(arr.validity().clone())
});
ChunkedArray::from_chunk_iter(lhs.name(), iter)
}

#[inline]
pub fn try_unary_elementwise_values<T, V, F>(ca: &ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: for<'a> UnaryFnMut<T::Physical<'a>>,
V::Array: for<'a> ArrayFromIter<<F as UnaryFnMut<T::Physical<'a>>>::Ret>,
{
let iter = ca.downcast_iter().map(|arr| {
arr.iter()
.map(op)
.try_collect_arr()
.with_validity_typed(arr.validity().clone())
});
ChunkedArray::try_from_chunk_iter(lhs.name(), iter)
}

/// Applies a kernel that produces `Array` types.
Expand Down Expand Up @@ -473,23 +533,16 @@ where
<F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
>,
{
if lhs.len() == rhs.len() {
return binary_elementwise(lhs, rhs, op);
}

let mut lhs_iter = lhs.downcast_iter_values();
let mut rhs_iter = rhs.downcast_iter_values();

let chunk: V::Array = match (lhs.len(), rhs.len()) {
let chunk = match (lhs.len(), rhs.len()) {
(1, _) => {
let a = lhs_iter.next().unwrap();
rhs_iter.map(|b| op(a.clone(), b)).collect_arr()
let a = unsafe { lhs.downcast_get_unchecked(0).get_unchecked(0) };
unary_elementwise(rhs, |b| op(a.clone(), b))
},
(_, 1) => {
let b = rhs_iter.next().unwrap();
lhs_iter.map(|a| op(a, b.clone())).collect_arr()
let b = unsafe { rhs.downcast_get_unchecked(0).get_unchecked(0) };
unary_elementwise(lhs, |a| op(a, b.clone()))
},
_ => lhs_iter.zip(rhs_iter).map(|(a, b)| op(a, b)).collect_arr(),
_ => binary_elementwise(lhs, rhs, op),
};

ChunkedArray::with_chunk(lhs.name(), chunk)
Expand All @@ -507,29 +560,19 @@ where
F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
if lhs.len() == rhs.len() {
return try_binary_elementwise(lhs, rhs, op);
}

let mut lhs_iter = lhs.downcast_iter_values();
let mut rhs_iter = rhs.downcast_iter_values();

let chunk: V::Array = match (lhs.len(), rhs.len()) {
let chunk = match (lhs.len(), rhs.len()) {
(1, _) => {
let a = lhs_iter.next().unwrap();
rhs_iter.map(|b| op(a.clone(), b)).try_collect_arr()
let a = unsafe { lhs.downcast_get_unchecked(0).get_unchecked(0) };
try_unary_elementwise(rhs, |b| op(a.clone(), b))
},
(_, 1) => {
let b = rhs_iter.next().unwrap();
lhs_iter.map(|a| op(a, b.clone())).try_collect_arr()
let b = unsafe { rhs.downcast_get_unchecked(0).get_unchecked(0) };
try_unary_elementwise(lhs, |a| op(a, b.clone()))
},
_ => lhs_iter
.zip(rhs_iter)
.map(|(a, b)| op(a, b))
.try_collect_arr(),
}?;
_ => try_binary_elementwise(lhs, rhs, op),
};

Ok(ChunkedArray::with_chunk(lhs.name(), chunk))
ChunkedArray::try_with_chunk(lhs.name(), chunk)
}

pub fn broadcast_binary_elementwise_values<T, U, V, F, K>(
Expand Down Expand Up @@ -559,19 +602,16 @@ where
return ChunkedArray::with_chunk(lhs.name(), arr);
}

let mut lhs_iter = lhs.downcast_iter_values_unchecked();
let mut rhs_iter = rhs.downcast_iter_values_unchecked();

let chunk: V::Array = match (lhs.len(), rhs.len()) {
let chunk = match (lhs.len(), rhs.len()) {
(1, _) => {
let a = lhs_iter.next().unwrap();
rhs_iter.map(|b| op(a.clone(), b)).collect_arr()
let a = unsafe { lhs.downcast_get_unchecked(0).get_unchecked(0) };
unary_elementwise_values(rhs, |b| op(a.clone(), b))
},
(_, 1) => {
let b = rhs_iter.next().unwrap();
lhs_iter.map(|a| op(a, b.clone())).collect_arr()
let b = unsafe { rhs.downcast_get_unchecked(0).get_unchecked(0) };
unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => lhs_iter.zip(rhs_iter).map(|(a, b)| op(a, b)).collect_arr(),
_ => binary_elementwise_values(lhs, rhs, op),
};

ChunkedArray::with_chunk(lhs.name(), chunk)
Expand Down Expand Up @@ -604,23 +644,17 @@ where
return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

let mut lhs_iter = lhs.downcast_iter_values_unchecked();
let mut rhs_iter = rhs.downcast_iter_values_unchecked();

let chunk: V::Array = match (lhs.len(), rhs.len()) {
let chunk = match (lhs.len(), rhs.len()) {
(1, _) => {
let a = lhs_iter.next().unwrap();
rhs_iter.map(|b| op(a.clone(), b)).try_collect_arr()
let a = unsafe { lhs.downcast_get_unchecked(0).get_unchecked(0) };
try_unary_elementwise_values(rhs, |b| op(a.clone(), b))
},
(_, 1) => {
let b = rhs_iter.next().unwrap();
lhs_iter.map(|a| op(a, b.clone())).try_collect_arr()
let b = unsafe { rhs.downcast_get_unchecked(0).get_unchecked(0) };
try_unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => lhs_iter
.zip(rhs_iter)
.map(|(a, b)| op(a, b))
.try_collect_arr(),
}?;
_ => try_binary_elementwise_values(lhs, rhs, op),
};

Ok(ChunkedArray::with_chunk(lhs.name(), chunk))
}

0 comments on commit 3177d8c

Please sign in to comment.