Skip to content

Commit

Permalink
fix[rust]: unset sorted flag on mutation (#4943)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 23, 2022
1 parent 52824d9 commit 43c9fb1
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 1 deletion.
2 changes: 2 additions & 0 deletions polars/polars-core/src/chunked_array/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use arrow::compute::arity_assign;
use num::{Num, NumCast, ToPrimitive};

use crate::prelude::*;
use crate::series::IsSorted;
use crate::utils::{align_chunks_binary, align_chunks_binary_owned};

macro_rules! apply_operand_on_chunkedarray_by_iter {
Expand Down Expand Up @@ -117,6 +118,7 @@ where
.zip(rhs.downcast_iter_mut())
.for_each(|(lhs, rhs)| kernel(lhs, rhs));
}
lhs.set_sorted2(IsSorted::Not);
lhs
}
// broadcast right path
Expand Down
3 changes: 3 additions & 0 deletions polars/polars-core/src/chunked_array/ops/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use polars_arrow::array::PolarsArray;
use polars_arrow::trusted_len::PushUnchecked;

use crate::prelude::*;
use crate::series::IsSorted;
use crate::utils::{CustomIterTools, NoNull};

macro_rules! try_apply {
Expand Down Expand Up @@ -118,6 +119,8 @@ impl<T: PolarsNumericType> ChunkedArray<T> {
self.downcast_iter_mut()
.for_each(|arr| arrow::compute::arity_assign::unary(arr, f))
};
// can be in any order now
self.set_sorted2(IsSorted::Not);
}
}

Expand Down
4 changes: 3 additions & 1 deletion polars/polars-core/src/chunked_array/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ where
}

/// # Safety
/// The caller must ensure the length remains correct.
/// The caller must ensure:
/// * the length remains correct.
/// * the flags (sorted, etc) are correct.
pub unsafe fn downcast_iter_mut(
&mut self,
) -> impl Iterator<Item = &mut PrimitiveArray<T::Native>> + DoubleEndedIterator {
Expand Down
3 changes: 3 additions & 0 deletions polars/polars-ops/src/chunked_array/set.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars_arrow::export::arrow::array::PrimitiveArray;
use polars_core::export::arrow::array::Array;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_core::utils::arrow::bitmap::MutableBitmap;
use polars_core::utils::arrow::types::NativeType;

Expand Down Expand Up @@ -127,6 +128,8 @@ where

// safety:
// we will not modify the length
// and we unset the sorted flag.
ca.set_sorted2(IsSorted::Not);
let arr = unsafe { ca.downcast_iter_mut() }.next().unwrap();
let len = arr.len();

Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/unit/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,24 @@ def test_top_k() -> None:

assert s.top_k(3).to_list() == [8, 5, 3]
assert s.top_k(4, reverse=True).to_list() == [1, 2, 3, 5]


def test_sorted_flag_unset_by_arithmetic_4937() -> None:
df = pl.DataFrame(
{
"ts": [1, 1, 1, 0, 1],
"price": [3.3, 3.0, 3.5, 3.6, 3.7],
"mask": [1, 1, 1, 1, 0],
}
)

assert df.sort("price").groupby("ts").agg(
[
(pl.col("price") * pl.col("mask")).max().alias("pmax"),
(pl.col("price") * pl.col("mask")).min().alias("pmin"),
]
).sort("ts").to_dict(False) == {
"ts": [0, 1],
"pmax": [3.6, 3.5],
"pmin": [3.6, 0.0],
}

0 comments on commit 43c9fb1

Please sign in to comment.