Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: broadcasting of unit LHS in string operations #12737

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions crates/polars-arrow/src/array/static_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ pub trait StaticArray:
fn from_zeroable_vec(v: Vec<Self::ZeroableValueT<'_>>, dtype: ArrowDataType) -> Self {
Self::arr_from_iter_with_dtype(dtype, v)
}

fn full_null(length: usize, dtype: ArrowDataType) -> Self;
}

pub trait ParameterFreeDtypeStaticArray: StaticArray {
Expand Down Expand Up @@ -118,6 +120,10 @@ impl<T: NativeType> StaticArray for PrimitiveArray<T> {
fn from_zeroable_vec(v: Vec<Self::ZeroableValueT<'_>>, _dtype: ArrowDataType) -> Self {
PrimitiveArray::from_vec(v)
}

fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}
}

impl<T: NativeType> ParameterFreeDtypeStaticArray for PrimitiveArray<T> {
Expand Down Expand Up @@ -155,6 +161,10 @@ impl StaticArray for BooleanArray {
fn from_zeroable_vec(v: Vec<Self::ValueT<'_>>, _dtype: ArrowDataType) -> Self {
BooleanArray::from_slice(v)
}

fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}
}

impl ParameterFreeDtypeStaticArray for BooleanArray {
Expand Down Expand Up @@ -184,6 +194,10 @@ impl StaticArray for Utf8Array<i64> {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}
}

impl ParameterFreeDtypeStaticArray for Utf8Array<i64> {
Expand Down Expand Up @@ -213,6 +227,10 @@ impl StaticArray for BinaryArray<i64> {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}
}

impl ParameterFreeDtypeStaticArray for BinaryArray<i64> {
Expand Down Expand Up @@ -242,6 +260,10 @@ impl StaticArray for ListArray<i64> {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}
}

impl StaticArray for FixedSizeListArray {
Expand All @@ -265,4 +287,8 @@ impl StaticArray for FixedSizeListArray {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}
}
228 changes: 228 additions & 0 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray};
use crate::prelude::{ChunkedArray, PolarsDataType};
use crate::utils::{align_chunks_binary, align_chunks_ternary};

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait UnaryFnMut<A1>: FnMut(A1) -> Self::Ret {
type Ret;
}

impl<A1, R, T: FnMut(A1) -> R> UnaryFnMut<A1> for T {
type Ret = R;
}

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait TernaryFnMut<A1, A2, A3>: FnMut(A1, A2, A3) -> Self::Ret {
Expand All @@ -27,6 +37,82 @@ impl<A1, A2, R, T: FnMut(A1, A2) -> R> BinaryFnMut<A1, A2> for T {
type Ret = R;
}

#[inline]
pub fn unary_elementwise<'a, T, V, F>(ca: &'a ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: UnaryFnMut<Option<T::Physical<'a>>>,
V::Array: ArrayFromIter<<F as UnaryFnMut<Option<T::Physical<'a>>>>::Ret>,
{
nameexhaustion marked this conversation as resolved.
Show resolved Hide resolved
let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(&mut op).collect_arr());
ChunkedArray::from_chunk_iter(ca.name(), iter)
}

#[inline]
pub fn try_unary_elementwise<'a, T, V, F, K, E>(
ca: &'a ChunkedArray<T>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
V: PolarsDataType,
F: FnMut(Option<T::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Collaborator Author

@nameexhaustion nameexhaustion Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm - this would mean requiring unary functions that take Option - i.e. F: FnMut(Option<T::Physical<'a>>) -> Result<Option<K>, E> to always behave as F(None) == None - as otherwise the fast-path would be incorrect.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although, we could also cheaply elide this requirement by checking the output of F(input[0]) if input.len() == input.null_count()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. strictly speaking this would also require the function to be Fn instead of FnMut

let iter = ca
.downcast_iter()
.map(|arr| arr.iter().map(&mut op).try_collect_arr());
ChunkedArray::try_from_chunk_iter(ca.name(), iter)
}

#[inline]
pub fn unary_elementwise_values<'a, T, V, F>(ca: &'a ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
where
T: PolarsDataType,
V: PolarsDataType,
F: UnaryFnMut<T::Physical<'a>>,
V::Array: ArrayFromIter<<F as UnaryFnMut<T::Physical<'a>>>::Ret>,
{
nameexhaustion marked this conversation as resolved.
Show resolved Hide resolved
if ca.null_count() == ca.len() {
let arr = V::Array::full_null(ca.len(), V::get_dtype().to_arrow());
return ChunkedArray::with_chunk(ca.name(), arr);
}

let iter = ca.downcast_iter().map(|arr| {
let validity = arr.validity().cloned();
let arr: V::Array = arr.values_iter().map(&mut op).collect_arr();
arr.with_validity_typed(validity)
});
ChunkedArray::from_chunk_iter(ca.name(), iter)
}

#[inline]
pub fn try_unary_elementwise_values<'a, T, V, F, K, E>(
ca: &'a ChunkedArray<T>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
V: PolarsDataType,
F: FnMut(T::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
nameexhaustion marked this conversation as resolved.
Show resolved Hide resolved
if ca.null_count() == ca.len() {
let arr = V::Array::full_null(ca.len(), V::get_dtype().to_arrow());
return Ok(ChunkedArray::with_chunk(ca.name(), arr));
}

let iter = ca.downcast_iter().map(|arr| {
let validity = arr.validity().cloned();
let arr: V::Array = arr.values_iter().map(&mut op).try_collect_arr()?;
Ok(arr.with_validity_typed(validity))
});
ChunkedArray::try_from_chunk_iter(ca.name(), iter)
}

/// Applies a kernel that produces `Array` types.
///
/// Intended for kernels that apply on values, this function will apply the
Expand Down Expand Up @@ -176,6 +262,13 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
V::Array: ArrayFromIter<K>,
{
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let len = lhs.len().min(rhs.len());
let arr = V::Array::full_null(len, V::get_dtype().to_arrow());

return ChunkedArray::with_chunk(lhs.name(), arr);
}

let (lhs, rhs) = align_chunks_binary(lhs, rhs);

let iter = lhs
Expand Down Expand Up @@ -208,6 +301,13 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let len = lhs.len().min(rhs.len());
let arr = V::Array::full_null(len, V::get_dtype().to_arrow());

return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let iter = lhs
.downcast_iter()
Expand Down Expand Up @@ -446,3 +546,131 @@ where
});
ChunkedArray::from_chunk_iter(ca1.name(), iter)
}

pub fn broadcast_binary_elementwise<T, U, V, F>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
) -> ChunkedArray<V>
where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType,
F: for<'a> BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>,
V::Array: for<'a> ArrayFromIter<
<F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
>,
{
Copy link
Collaborator

@orlp orlp Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we can only do a fast-path here if both are full-null... Makes me wonder if this is a good idea after all :( We could still handle that wherever we call them though. Nevertheless, we should do a fast-path here for that case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm - this would mean we require F: for<'a> BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>> to always behave as F(None, None) == None

match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
let mut out = unary_elementwise(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
unary_elementwise(lhs, |a| op(a, b.clone()))
},
_ => binary_elementwise(lhs, rhs, op),
}
}

pub fn broadcast_try_binary_elementwise<T, U, V, F, K, E>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType,
F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
nameexhaustion marked this conversation as resolved.
Show resolved Hide resolved
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
let mut out = try_unary_elementwise(rhs, |b| op(a.clone(), b))?;
out.rename(lhs.name());
Ok(out)
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
try_unary_elementwise(lhs, |a| op(a, b.clone()))
},
_ => try_binary_elementwise(lhs, rhs, op),
}
}

pub fn broadcast_binary_elementwise_values<T, U, V, F, K>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
) -> ChunkedArray<V>
where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType,
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
V::Array: ArrayFromIter<K>,
{
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let min = lhs.len().min(rhs.len());
let max = lhs.len().max(rhs.len());
let len = if min == 1 { max } else { min };
let arr = V::Array::full_null(len, V::get_dtype().to_arrow());

return ChunkedArray::with_chunk(lhs.name(), arr);
}

match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.value_unchecked(0) };
let mut out = unary_elementwise_values(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
},
(_, 1) => {
let b = unsafe { rhs.value_unchecked(0) };
unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => binary_elementwise_values(lhs, rhs, op),
}
}

pub fn broadcast_try_binary_elementwise_values<T, U, V, F, K, E>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType,
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
let min = lhs.len().min(rhs.len());
let max = lhs.len().max(rhs.len());
let len = if min == 1 { max } else { min };
let arr = V::Array::full_null(len, V::get_dtype().to_arrow());

return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.value_unchecked(0) };
let mut out = try_unary_elementwise_values(rhs, |b| op(a.clone(), b))?;
out.rename(lhs.name());
Ok(out)
},
(_, 1) => {
let b = unsafe { rhs.value_unchecked(0) };
try_unary_elementwise_values(lhs, |a| op(a, b.clone()))
},
_ => try_binary_elementwise_values(lhs, rhs, op),
}
}
2 changes: 1 addition & 1 deletion crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct Flat;
///
/// The StaticArray and dtype return must be correct.
pub unsafe trait PolarsDataType: Send + Sync + Sized {
type Physical<'a>: std::fmt::Debug;
type Physical<'a>: std::fmt::Debug + Clone;
type ZeroablePhysical<'a>: Zeroable + From<Self::Physical<'a>>;
type Array: for<'a> StaticArray<
ValueT<'a> = Self::Physical<'a>,
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/datatypes/static_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,8 @@ impl<T: PolarsObject> StaticArray for ObjectArray<T> {
fn with_validity_typed(self, validity: Option<Bitmap>) -> Self {
self.with_validity(validity)
}

fn full_null(_length: usize, _dtype: ArrowDataType) -> Self {
panic!("ObjectArray does not support full_null");
}
}
14 changes: 11 additions & 3 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,8 @@ where
}
}

/// # Panics
/// This will panic if `a.len() != b.len() || b.len() != c.len()` and array is chunked.
#[allow(clippy::type_complexity)]
pub fn align_chunks_ternary<'a, A, B, C>(
a: &'a ChunkedArray<A>,
Expand All @@ -706,10 +708,16 @@ where
B: PolarsDataType,
C: PolarsDataType,
{
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(b.len(), c.len());
if a.chunks.len() == 1 && b.chunks.len() == 1 && c.chunks.len() == 1 {
return (Cow::Borrowed(a), Cow::Borrowed(b), Cow::Borrowed(c));
}

assert!(
a.len() == b.len() && b.len() == c.len(),
"expected arrays of the same length"
);

match (a.chunks.len(), b.chunks.len(), c.chunks.len()) {
(1, 1, 1) => (Cow::Borrowed(a), Cow::Borrowed(b), Cow::Borrowed(c)),
(_, 1, 1) => (
Cow::Borrowed(a),
Cow::Owned(b.match_chunks(a.chunk_id())),
Expand Down
Loading