Skip to content

Commit

Permalink
try add and use unary functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Dec 1, 2023
1 parent 471beaf commit 9d4a246
Showing 1 changed file with 101 additions and 79 deletions.
180 changes: 101 additions & 79 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray};
use crate::prelude::{ChunkedArray, PolarsDataType};
use crate::utils::{align_chunks_binary, align_chunks_ternary};

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait UnaryFnMut<A1>: FnMut(A1) -> Self::Ret {
type Ret;
}

impl<A1, R, T: FnMut(A1) -> R> UnaryFnMut<A1> for T {
type Ret = R;
}

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait TernaryFnMut<A1, A2, A3>: FnMut(A1, A2, A3) -> Self::Ret {
Expand All @@ -27,16 +37,70 @@ impl<A1, A2, R, T: FnMut(A1, A2) -> R> BinaryFnMut<A1, A2> for T {
type Ret = R;
}

impl<T: PolarsDataType> ChunkedArray<T> {
/// Get an iterator over the values of all chunks.
fn downcast_iter_values(&self) -> impl Iterator<Item = Option<T::Physical<'_>>> {
self.downcast_iter().flat_map(|arr| arr.iter())
}
#[inline]
pub fn unary_elementwise<T, V, F>(ca: &ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: for<'a> UnaryFnMut<Option<T::Physical<'a>>>,
V::Array: for<'a> ArrayFromIter<<F as UnaryFnMut<Option<T::Physical<'a>>>>::Ret>,
{
let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(|x| op(x)).collect_arr());
ChunkedArray::from_chunk_iter(ca.name(), iter)
}

/// Get an iterator over the values of all chunks ignoring validity.
fn downcast_iter_values_unchecked(&self) -> impl Iterator<Item = T::Physical<'_>> {
self.downcast_iter().flat_map(|arr| arr.values_iter())
}
#[inline]
pub fn try_unary_elementwise<T, V, F, K, E>(
ca: &ChunkedArray<T>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
V: PolarsDataType,
F: for<'a> FnMut(Option<T::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(|x| op(x)).try_collect_arr());
ChunkedArray::try_from_chunk_iter(ca.name(), iter)
}

#[inline]
pub fn unary_elementwise_values<T, V, F>(ca: &ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: for<'a> UnaryFnMut<T::Physical<'a>>,
V::Array: for<'a> ArrayFromIter<<F as UnaryFnMut<T::Physical<'a>>>::Ret>,
{
let iter = ca.downcast_iter().map(|arr| {
let validity = arr.validity().map(|x| x.clone());
let arr: V::Array = arr.values_iter().map(|x| op(x)).collect_arr();
arr.with_validity_typed(validity)
});
ChunkedArray::from_chunk_iter(ca.name(), iter)
}

#[inline]
pub fn try_unary_elementwise_values<T, V, F, K, E>(
ca: &ChunkedArray<T>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
V: PolarsDataType,
F: for<'a> FnMut(T::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
let iter = ca.downcast_iter().map(|arr| {
let validity = arr.validity().map(|x| x.clone());
let arr: V::Array = arr.values_iter().map(|x| op(x)).try_collect_arr()?;
Ok(arr.with_validity_typed(validity))
});
ChunkedArray::try_from_chunk_iter(ca.name(), iter)
}

/// Applies a kernel that produces `Array` types.
Expand Down Expand Up @@ -473,26 +537,17 @@ where
<F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
>,
{
if lhs.len() == rhs.len() {
return binary_elementwise(lhs, rhs, op);
}

let mut lhs_iter = lhs.downcast_iter_values();
let mut rhs_iter = rhs.downcast_iter_values();

let chunk: V::Array = match (lhs.len(), rhs.len()) {
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = lhs_iter.next().unwrap();
rhs_iter.map(|b| op(a.clone(), b)).collect_arr()
let a = unsafe { lhs.downcast_get_unchecked(0).get_unchecked(0).clone() };
unary_elementwise(rhs, |b| op(a.clone(), b))
},
(_, 1) => {
let b = rhs_iter.next().unwrap();
lhs_iter.map(|a| op(a, b.clone())).collect_arr()
let b = unsafe { rhs.downcast_get_unchecked(0).get_unchecked(0) };
unary_elementwise(lhs, |a| op(a, b.clone()))
},
_ => lhs_iter.zip(rhs_iter).map(|(a, b)| op(a, b)).collect_arr(),
};

ChunkedArray::with_chunk(lhs.name(), chunk)
_ => binary_elementwise(lhs, rhs, op),
}
}

pub fn broadcast_try_binary_elementwise<T, U, V, F, K, E>(
Expand All @@ -507,29 +562,17 @@ where
F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
if lhs.len() == rhs.len() {
return try_binary_elementwise(lhs, rhs, op);
}

let mut lhs_iter = lhs.downcast_iter_values();
let mut rhs_iter = rhs.downcast_iter_values();

let chunk: V::Array = match (lhs.len(), rhs.len()) {
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = lhs_iter.next().unwrap();
rhs_iter.map(|b| op(a.clone(), b)).try_collect_arr()
let a = unsafe { lhs.downcast_get_unchecked(0).get_unchecked(0) };
try_unary_elementwise(rhs, |b| op(a.clone(), b))
},
(_, 1) => {
let b = rhs_iter.next().unwrap();
lhs_iter.map(|a| op(a, b.clone())).try_collect_arr()
let b = unsafe { rhs.downcast_get_unchecked(0).get_unchecked(0) };
try_unary_elementwise(lhs, |a| op(a, b.clone()))
},
_ => lhs_iter
.zip(rhs_iter)
.map(|(a, b)| op(a, b))
.try_collect_arr(),
}?;

Ok(ChunkedArray::with_chunk(lhs.name(), chunk))
_ => try_binary_elementwise(lhs, rhs, op),
}
}

pub fn broadcast_binary_elementwise_values<T, U, V, F, K>(
Expand All @@ -544,10 +587,6 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
V::Array: ArrayFromIter<K>,
{
if lhs.len() == rhs.len() {
return binary_elementwise_values(lhs, rhs, op);
}

if unsafe {
(lhs.len() == 1 && lhs.downcast_get_unchecked(0).is_null_unchecked(0))
|| (rhs.len() == 1 && rhs.downcast_get_unchecked(0).is_null_unchecked(0))
Expand All @@ -559,22 +598,17 @@ where
return ChunkedArray::with_chunk(lhs.name(), arr);
}

let mut lhs_iter = lhs.downcast_iter_values_unchecked();
let mut rhs_iter = rhs.downcast_iter_values_unchecked();

let chunk: V::Array = match (lhs.len(), rhs.len()) {
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = lhs_iter.next().unwrap();
rhs_iter.map(|b| op(a.clone(), b)).collect_arr()
let a = unsafe { lhs.downcast_get_unchecked(0).value_unchecked(0) };
unary_elementwise_values(rhs, |b| op(a.clone(), b))
},
(_, 1) => {
let b = rhs_iter.next().unwrap();
lhs_iter.map(|a| op(a, b.clone())).collect_arr()
let b = unsafe { rhs.downcast_get_unchecked(0).value_unchecked(0) };
unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => lhs_iter.zip(rhs_iter).map(|(a, b)| op(a, b)).collect_arr(),
};

ChunkedArray::with_chunk(lhs.name(), chunk)
_ => binary_elementwise_values(lhs, rhs, op),
}
}

pub fn broadcast_try_binary_elementwise_values<T, U, V, F, K, E>(
Expand All @@ -589,10 +623,6 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
if lhs.len() == rhs.len() {
return try_binary_elementwise_values(lhs, rhs, op);
}

if unsafe {
(lhs.len() == 1 && lhs.downcast_get_unchecked(0).is_null_unchecked(0))
|| (rhs.len() == 1 && rhs.downcast_get_unchecked(0).is_null_unchecked(0))
Expand All @@ -604,23 +634,15 @@ where
return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

let mut lhs_iter = lhs.downcast_iter_values_unchecked();
let mut rhs_iter = rhs.downcast_iter_values_unchecked();

let chunk: V::Array = match (lhs.len(), rhs.len()) {
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = lhs_iter.next().unwrap();
rhs_iter.map(|b| op(a.clone(), b)).try_collect_arr()
let a = unsafe { lhs.downcast_get_unchecked(0).value_unchecked(0) };
try_unary_elementwise_values(rhs, |b| op(a.clone(), b))
},
(_, 1) => {
let b = rhs_iter.next().unwrap();
lhs_iter.map(|a| op(a, b.clone())).try_collect_arr()
let b = unsafe { rhs.downcast_get_unchecked(0).value_unchecked(0) };
try_unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => lhs_iter
.zip(rhs_iter)
.map(|(a, b)| op(a, b))
.try_collect_arr(),
}?;

Ok(ChunkedArray::with_chunk(lhs.name(), chunk))
_ => try_binary_elementwise_values(lhs, rhs, op),
}
}

0 comments on commit 9d4a246

Please sign in to comment.