Skip to content

Commit

Permalink
add all-null fastpaths, fix output name
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Dec 8, 2023
1 parent 6c540b2 commit cef38ab
Showing 1 changed file with 63 additions and 19 deletions.
82 changes: 63 additions & 19 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ where
F: UnaryFnMut<Option<T::Physical<'a>>>,
V::Array: ArrayFromIter<<F as UnaryFnMut<Option<T::Physical<'a>>>>::Ret>,
{
if ca.null_count() == ca.len() {
let arr = StaticArray::full_null(ca.len(), V::get_dtype());
return ChunkedArray::with_chunk(ca.name(), arr);
}

let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(&mut op).collect_arr());
Expand All @@ -62,6 +67,11 @@ where
F: FnMut(Option<T::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
if ca.null_count() == ca.len() {
let arr = StaticArray::full_null(ca.len(), V::get_dtype());
return Ok(ChunkedArray::with_chunk(ca.name(), arr));
}

let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(&mut op).try_collect_arr());
Expand All @@ -76,6 +86,11 @@ where
F: UnaryFnMut<T::Physical<'a>>,
V::Array: ArrayFromIter<<F as UnaryFnMut<T::Physical<'a>>>::Ret>,
{
if ca.null_count() == ca.len() {
let arr = StaticArray::full_null(ca.len(), V::get_dtype());
return ChunkedArray::with_chunk(ca.name(), arr);
}

let iter = ca.downcast_iter().map(|arr| {
let validity = arr.validity().cloned();
let arr: V::Array = arr.values_iter().map(&mut op).collect_arr();
Expand All @@ -95,6 +110,11 @@ where
F: FnMut(T::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
if ca.null_count() == ca.len() {
let arr = StaticArray::full_null(ca.len(), V::get_dtype());
return Ok(ChunkedArray::with_chunk(ca.name(), arr));
}

let iter = ca.downcast_iter().map(|arr| {
let validity = arr.validity().cloned();
let arr: V::Array = arr.values_iter().map(&mut op).try_collect_arr()?;
Expand Down Expand Up @@ -537,10 +557,21 @@ where
<F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
>,
{
if lhs.null_count() == lhs.len() && rhs.null_count() == rhs.len() {
let min = lhs.len().min(rhs.len());
let max = lhs.len().max(rhs.len());
let len = if min == 1 { max } else { min };
let arr = StaticArray::full_null(len, V::get_dtype());

return ChunkedArray::with_chunk(lhs.name(), arr);
}

match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
unary_elementwise(rhs, |b| op(a.clone(), b))
let mut out = unary_elementwise(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
Expand All @@ -562,10 +593,21 @@ where
F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
if lhs.null_count() == lhs.len() && rhs.null_count() == rhs.len() {
let min = lhs.len().min(rhs.len());
let max = lhs.len().max(rhs.len());
let len = if min == 1 { max } else { min };
let arr = StaticArray::full_null(len, V::get_dtype());

return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
try_unary_elementwise(rhs, |b| op(a.clone(), b))
let mut out = try_unary_elementwise(rhs, |b| op(a.clone(), b))?;
out.rename(lhs.name());
Ok(out)
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
Expand All @@ -587,22 +629,24 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
V::Array: ArrayFromIter<K>,
{
if unsafe {
(lhs.len() == 1 && lhs.downcast_get_unchecked(0).is_null_unchecked(0))
|| (rhs.len() == 1 && rhs.downcast_get_unchecked(0).is_null_unchecked(0))
} {
let broadcast_to = lhs.len().max(rhs.len());
let arr = StaticArray::full_null(broadcast_to, V::get_dtype());
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let min = lhs.len().min(rhs.len());
let max = lhs.len().max(rhs.len());
let len = if min == 1 { max } else { min };
let arr = StaticArray::full_null(len, V::get_dtype());

return ChunkedArray::with_chunk(lhs.name(), arr);
}

match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.downcast_get_unchecked(0).value_unchecked(0) };
unary_elementwise_values(rhs, |b| op(a.clone(), b))
let a = unsafe { lhs.value_unchecked(0) };
let mut out = unary_elementwise_values(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
},
(_, 1) => {
let b = unsafe { rhs.downcast_get_unchecked(0).value_unchecked(0) };
let b = unsafe { rhs.value_unchecked(0) };
unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => binary_elementwise_values(lhs, rhs, op),
Expand All @@ -621,22 +665,22 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
if unsafe {
(lhs.len() == 1 && lhs.downcast_get_unchecked(0).is_null_unchecked(0))
|| (rhs.len() == 1 && rhs.downcast_get_unchecked(0).is_null_unchecked(0))
} {
let broadcast_to = lhs.len().max(rhs.len());
let arr = StaticArray::full_null(broadcast_to, V::get_dtype());
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let min = lhs.len().min(rhs.len());
let max = lhs.len().max(rhs.len());
let len = if min == 1 { max } else { min };
let arr = StaticArray::full_null(len, V::get_dtype());

return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.downcast_get_unchecked(0).value_unchecked(0) };
let a = unsafe { lhs.value_unchecked(0) };
try_unary_elementwise_values(rhs, |b| op(a.clone(), b))
},
(_, 1) => {
let b = unsafe { rhs.downcast_get_unchecked(0).value_unchecked(0) };
let b = unsafe { rhs.value_unchecked(0) };
try_unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => try_binary_elementwise_values(lhs, rhs, op),
Expand Down

0 comments on commit cef38ab

Please sign in to comment.