Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Incorrectly preserved sorted flag when concatenating sorted series containing nulls #15082

Merged
merged 7 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ impl<T: PolarsDataType> ChunkedArray<T> {
self.bit_settings.contains(Settings::SORTED_DSC)
}

/// Whether `self` is sorted in any direction.
pub(crate) fn is_sorted_any(&self) -> bool {
self.is_sorted_ascending_flag() || self.is_sorted_descending_flag()
}

pub fn unset_fast_explode_list(&mut self) {
self.bit_settings.remove(Settings::FAST_EXPLODE_LIST)
}
Expand Down Expand Up @@ -224,10 +229,7 @@ impl<T: PolarsDataType> ChunkedArray<T> {
None
}
// We now know there is at least 1 non-null item in the array, and self.len() > 0
else if matches!(
self.is_sorted_flag(),
IsSorted::Ascending | IsSorted::Descending
) {
else if self.is_sorted_any() {
let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } {
// nulls are all at the start
self.null_count()
Expand All @@ -236,6 +238,12 @@ impl<T: PolarsDataType> ChunkedArray<T> {
0
};

debug_assert!(
// If we are lucky this catches something.
unsafe { self.get_unchecked(out) }.is_some(),
"incorrect sorted flag"
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
);

Some(out)
} else {
first_non_null(self.iter_validities())
Expand All @@ -248,10 +256,7 @@ impl<T: PolarsDataType> ChunkedArray<T> {
None
}
// We now know there is at least 1 non-null item in the array, and self.len() > 0
else if matches!(
self.is_sorted_flag(),
IsSorted::Ascending | IsSorted::Descending
) {
else if self.is_sorted_any() {
let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } {
// nulls are all at the start
self.len() - 1
Expand All @@ -260,6 +265,12 @@ impl<T: PolarsDataType> ChunkedArray<T> {
self.len() - self.null_count() - 1
};

debug_assert!(
// If we are lucky this catches something.
unsafe { self.get_unchecked(out) }.is_some(),
"incorrect sorted flag"
);

Some(out)
} else {
last_non_null(self.iter_validities(), self.len())
Expand Down
95 changes: 57 additions & 38 deletions crates/polars-core/src/chunked_array/ops/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,70 @@ where
T: PolarsDataType,
for<'a> T::Physical<'a>: TotalOrd,
{
// TODO: attempt to maintain sortedness better in case of nulls.
// Note: Do not call (first|last)_non_null on an array here before checking
// it is sorted, otherwise it will lead to quadratic behavior.
let sorted_flag = match (
ca.null_count() != ca.len(),
other.null_count() != other.len(),
) {
(false, false) => IsSorted::Ascending, // all null
(false, true) => {
if other.is_sorted_any() && 1 + other.last_non_null().unwrap() == other.len() {
// nulls first
other.is_sorted_flag()
} else {
IsSorted::Not
}
},
(true, false) => {
if ca.is_sorted_any() && ca.first_non_null().unwrap() == 0 {
// nulls last
ca.is_sorted_flag()
} else {
IsSorted::Not
}
},
(true, true) => {
// both arrays have non-null values
if !ca.is_sorted_any()
|| !other.is_sorted_any()
|| ca.is_sorted_flag() != other.is_sorted_flag()
{
IsSorted::Not
} else {
let l_idx = ca.last_non_null().unwrap();
let r_idx = other.first_non_null().unwrap();

// If either is empty, copy the sorted flag from the other.
if ca.is_empty() {
ca.set_sorted_flag(other.is_sorted_flag());
return;
}
if other.is_empty() {
return;
}
let l_val = unsafe { ca.value_unchecked(l_idx) };
let r_val = unsafe { other.value_unchecked(r_idx) };

// Both need to be sorted, in the same order, if the order is maintained.
// TODO: rework sorted flags, ascending and descending are not mutually
// exclusive for all-equal/all-null arrays.
let ls = ca.is_sorted_flag();
let rs = other.is_sorted_flag();
if ls != rs || ls == IsSorted::Not || rs == IsSorted::Not {
ca.set_sorted_flag(IsSorted::Not);
return;
}
let keep_sorted =
// check null positions
// lhs does not end in nulls
(1 + l_idx == ca.len())
// rhs does not start with nulls
&& (r_idx == 0)
// if there are nulls, they are all on one end
&& !(ca.first_non_null().unwrap() != 0 && 1 + other.last_non_null().unwrap() != other.len());

let keep_sorted = keep_sorted
// compare values
&& if ca.is_sorted_ascending_flag() {
l_val.tot_le(&r_val)
} else {
l_val.tot_ge(&r_val)
};

// Check the order is maintained.
let still_sorted = {
// To prevent potential quadratic append behavior we do not find
// the last non-null element in ca.
if let Some(left) = ca.last() {
if let Some(right_idx) = other.first_non_null() {
let right = other.get(right_idx).unwrap();
if ca.is_sorted_ascending_flag() {
left.tot_le(&right)
if keep_sorted {
ca.is_sorted_flag()
} else {
left.tot_ge(&right)
IsSorted::Not
}
} else {
// Right is only nulls, trivially sorted.
true
}
} else {
// Last element in left is null, pessimistically assume not sorted.
false
}
},
};
if !still_sorted {
ca.set_sorted_flag(IsSorted::Not);
}

ca.set_sorted_flag(sorted_flag);
}

impl<T> ChunkedArray<T>
Expand Down
86 changes: 86 additions & 0 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,3 +796,89 @@ def test_sorted_flag_14552() -> None:

a = pl.concat([a, a], rechunk=False)
assert not a.join(a, on="a", how="left")["a"].flags["SORTED_ASC"]


def test_sorted_flag_concat_15072() -> None:
def is_sorted_any(s: pl.Series) -> bool:
return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"]

def is_not_sorted(s: pl.Series) -> bool:
return not is_sorted_any(s)

# Both all-null
a = pl.Series("x", [None, None], dtype=pl.Int8)
b = pl.Series("x", [None, None], dtype=pl.Int8)
assert pl.concat((a, b)).flags["SORTED_ASC"]

# left all-null, right 0 < null_count < len
a = pl.Series("x", [None, None], dtype=pl.Int8)
b = pl.Series("x", [1, 2, 1, None], dtype=pl.Int8)

out = pl.concat((a, b.sort()))
assert out.to_list() == [None, None, None, 1, 1, 2]
assert out.flags["SORTED_ASC"]

out = pl.concat((a, b.sort(descending=True)))
assert out.to_list() == [None, None, None, 2, 1, 1]
assert out.flags["SORTED_DESC"]

out = pl.concat((a, b.sort(nulls_last=True)))
assert out.to_list() == [None, None, 1, 1, 2, None]
assert is_not_sorted(out)

out = pl.concat((a, b.sort(nulls_last=True, descending=True)))
assert out.to_list() == [None, None, 2, 1, 1, None]
assert is_not_sorted(out)

# left 0 < null_count < len, right all-null
a = pl.Series("x", [1, 2, 1, None], dtype=pl.Int8)
b = pl.Series("x", [None, None], dtype=pl.Int8)

out = pl.concat((a.sort(), b))
assert out.to_list() == [None, 1, 1, 2, None, None]
assert is_not_sorted(out)

out = pl.concat((a.sort(descending=True), b))
assert out.to_list() == [None, 2, 1, 1, None, None]
assert is_not_sorted(out)

out = pl.concat((a.sort(nulls_last=True), b))
assert out.to_list() == [1, 1, 2, None, None, None]
assert out.flags["SORTED_ASC"]

out = pl.concat((a.sort(nulls_last=True, descending=True), b))
assert out.to_list() == [2, 1, 1, None, None, None]
assert out.flags["SORTED_DESC"]

# both 0 < null_count < len
assert pl.concat(
(
pl.Series([None, 1]).set_sorted(),
pl.Series([2]).set_sorted(),
)
).flags["SORTED_ASC"]

assert is_not_sorted(
pl.concat(
(
pl.Series([None, 1]).set_sorted(),
pl.Series([2, None]).set_sorted(),
)
)
)

assert pl.concat(
(
pl.Series([None, 2]).set_sorted(descending=True),
pl.Series([1]).set_sorted(descending=True),
)
).flags["SORTED_DESC"]

assert is_not_sorted(
pl.concat(
(
pl.Series([None, 2]).set_sorted(descending=True),
pl.Series([1, None]).set_sorted(descending=True),
)
)
)
Loading