Skip to content

Commit

Permalink
fix combine validities
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Nov 29, 2023
1 parent 65ca198 commit ba498ba
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 31 deletions.
17 changes: 17 additions & 0 deletions crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use arrow::legacy::kernels::concatenate::concatenate_owned_unchecked;
use arrow::legacy::prelude::*;
use bitflags::bitflags;

use self::upstream_traits::finish_validities;
use crate::series::IsSorted;
use crate::utils::{first_non_null, last_non_null, CustomIterTools};

Expand Down Expand Up @@ -235,6 +236,22 @@ impl<T: PolarsDataType> ChunkedArray<T> {
self.chunks.iter().map(to_validity)
}

/// Construct a Bitmap representing the validity of all chunks without
/// rechunking the values.
pub fn rechunked_validity(&self) -> Option<Bitmap> {
if self.has_validity() {
let validities = self
.iter_validities()
.map(|x| x.cloned())
.zip(self.chunk_id())
.collect::<Vec<(Option<Bitmap>, usize)>>();

finish_validities(validities, self.len())
} else {
None
}
}

#[inline]
/// Return if any the chunks in this [`ChunkedArray`] have a validity bitmap.
/// no bitmap means no null values.
Expand Down
46 changes: 42 additions & 4 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,12 +535,14 @@ where
}

if unsafe {
(lhs.len() == 1 && !lhs.downcast_get_unchecked(0).is_valid(0))
|| (rhs.len() == 1 && !rhs.downcast_get_unchecked(0).is_valid(0))
(lhs.len() == 1 && lhs.downcast_get_unchecked(0).is_null_unchecked(0))
|| (rhs.len() == 1 && rhs.downcast_get_unchecked(0).is_null_unchecked(0))
} {
let broadcast_to = lhs.len().max(rhs.len());
let arr = &*new_null_array(V::get_dtype().to_arrow(), broadcast_to);
debug_assert!(arr.is_null(0), "[dyn] expected null");
let arr = unsafe { std::ptr::read(arr as *const dyn Array as *const V::Array) };
debug_assert!(arr.is_null(0), "[static] expected null");
return ChunkedArray::with_chunk(lhs.name(), arr);
}

Expand All @@ -564,6 +566,23 @@ where
let mut b = ca_to_iter!(rhs);
let out: V::Array = broadcast_apply!(collect_func, a, b);

let out_len = out.len();

let out = out.with_validity_typed(combine_validities_and(
if lhs.len() >= out_len {
lhs.slice(0, out_len).rechunked_validity()
} else {
None
}
.as_ref(),
if rhs.len() >= out_len {
rhs.slice(0, out_len).rechunked_validity()
} else {
None
}
.as_ref(),
));

ChunkedArray::with_chunk(lhs.name(), out)
}

Expand All @@ -584,12 +603,14 @@ where
}

if unsafe {
(lhs.len() == 1 && !lhs.downcast_get_unchecked(0).is_valid(0))
|| (rhs.len() == 1 && !rhs.downcast_get_unchecked(0).is_valid(0))
(lhs.len() == 1 && lhs.downcast_get_unchecked(0).is_null_unchecked(0))
|| (rhs.len() == 1 && rhs.downcast_get_unchecked(0).is_null_unchecked(0))
} {
let broadcast_to = lhs.len().max(rhs.len());
let arr = &*new_null_array(V::get_dtype().to_arrow(), broadcast_to);
debug_assert!(arr.is_null(0), "[dyn] expected null");
let arr = unsafe { std::ptr::read(arr as *const dyn Array as *const V::Array) };
debug_assert!(arr.is_null(0), "[static] expected null");
return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

Expand All @@ -613,5 +634,22 @@ where
let mut b = ca_to_iter!(rhs);
let out: V::Array = broadcast_apply!(collect_func, a, b)?;

let out_len = out.len();

let out = out.with_validity_typed(combine_validities_and(
if lhs.len() >= out_len {
lhs.slice(0, out_len).rechunked_validity()
} else {
None
}
.as_ref(),
if rhs.len() >= out_len {
rhs.slice(0, out_len).rechunked_validity()
} else {
None
}
.as_ref(),
));

Ok(ChunkedArray::with_chunk(lhs.name(), out))
}
5 changes: 4 additions & 1 deletion crates/polars-core/src/chunked_array/upstream_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ where
}
}

fn finish_validities(validities: Vec<(Option<Bitmap>, usize)>, capacity: usize) -> Option<Bitmap> {
pub fn finish_validities(
validities: Vec<(Option<Bitmap>, usize)>,
capacity: usize,
) -> Option<Bitmap> {
if validities.iter().any(|(v, _)| v.is_some()) {
let mut bitmap = MutableBitmap::with_capacity(capacity);
for (valids, len) in validities {
Expand Down
73 changes: 47 additions & 26 deletions py-polars/tests/unit/test_arity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

import pytest

import polars as pl
from polars.testing import assert_frame_equal

Expand Down Expand Up @@ -28,33 +32,50 @@ def test_when_then_broadcast_nulls_12665() -> None:
).to_dict(as_series=False) == {"when": [0, 0, 0, 1]}


def test_broadcast_string_ops_12632() -> None:
df = pl.DataFrame(
[
{"name": "COMPANY A", "id": 1},
{"name": "COMPANY B", "id": 2},
{"name": "COMPANY C", "id": 3},
]
@pytest.mark.parametrize(
("needs_broadcast", "expect_contains"),
[
(pl.lit("a"), [True, False, False]),
(pl.col("name").head(1), [True, False, False]),
(pl.lit(None, dtype=pl.Utf8), [None, None, None]),
(pl.col("null_utf8").head(1), [None, None, None]),
],
)
@pytest.mark.parametrize("literal", [True, False])
@pytest.mark.parametrize(
"df",
[
pl.DataFrame(
{
"name": ["a", "b", "c"],
"null_utf8": pl.Series([None, None, None], dtype=pl.Utf8),
}
)
],
)
def test_broadcast_string_ops_12632(
df: pl.DataFrame,
needs_broadcast: pl.Expr,
expect_contains: list[bool],
literal: bool,
) -> None:
assert (
df.select(needs_broadcast.str.contains(pl.col("name"), literal=literal))
.to_series()
.to_list()
== expect_contains
)

for needs_broadcast in (pl.lit("COMPANY A"), pl.col("name").head(1)):
for literal in (True, False):
assert df.select(
needs_broadcast.str.contains(pl.col("name"), literal=literal)
).to_series().to_list() == [True, False, False]

assert df.select(
needs_broadcast.str.starts_with(pl.col("name"))
).to_series().to_list() == [True, False, False]
assert (
df.select(needs_broadcast.str.starts_with(pl.col("name"))).to_series().to_list()
== expect_contains
)

assert df.select(
needs_broadcast.str.ends_with(pl.col("name"))
).to_series().to_list() == [True, False, False]
assert (
df.select(needs_broadcast.str.ends_with(pl.col("name"))).to_series().to_list()
== expect_contains
)

assert df.select(needs_broadcast.str.strip_chars(pl.col("name"))).height == 3
assert (
df.select(needs_broadcast.str.strip_chars_start(pl.col("name"))).height == 3
)
assert (
df.select(needs_broadcast.str.strip_chars_end(pl.col("name"))).height == 3
)
assert df.select(needs_broadcast.str.strip_chars(pl.col("name"))).height == 3
assert df.select(needs_broadcast.str.strip_chars_start(pl.col("name"))).height == 3
assert df.select(needs_broadcast.str.strip_chars_end(pl.col("name"))).height == 3

0 comments on commit ba498ba

Please sign in to comment.