Skip to content

Commit

Permalink
Fix byte transmutes groups (#4290)
Browse files Browse the repository at this point in the history
* fix groupby byte transmute

* don't reinterpret bits in sorted joins
  • Loading branch information
ritchie46 committed Aug 6, 2022
1 parent 20c066e commit a077528
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 52 deletions.
1 change: 1 addition & 0 deletions polars/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ strings = []
compute = ["arrow/compute_cast"]
temporal = ["arrow/compute_temporal"]
bigidx = []
performant = []
1 change: 1 addition & 0 deletions polars/polars-arrow/src/kernels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod list;
pub mod rolling;
pub mod set;
pub mod sort_partition;
#[cfg(feature = "performant")]
pub mod sorted_join;
#[cfg(feature = "strings")]
pub mod string;
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ lazy = ["sort_multiple"]

# ~40% faster collect, needed until trustedlength iter stabilizes
# more fast paths, slower compilation
performant = []
performant = ["polars-arrow/performant"]

# extra utilities for Utf8Chunked
strings = ["regex", "polars-arrow/strings", "arrow/compute_substring"]
Expand Down
10 changes: 10 additions & 0 deletions polars/polars-core/src/chunked_array/ops/bit_repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ where

fn bit_repr_large(&self) -> UInt64Chunked {
if std::mem::size_of::<T::Native>() == 8 {
if matches!(self.dtype(), DataType::UInt64) {
let ca = self.clone();
// convince the compiler we are this type. This keeps flags
return unsafe { std::mem::transmute(ca) };
}
let chunks = self
.downcast_iter()
.map(|array| {
Expand Down Expand Up @@ -42,6 +47,11 @@ where

fn bit_repr_small(&self) -> UInt32Chunked {
if std::mem::size_of::<T::Native>() == 4 {
if matches!(self.dtype(), DataType::UInt32) {
let ca = self.clone();
// convince the compiler we are this type. This keeps flags
return unsafe { std::mem::transmute(ca) };
}
let chunks = self
.downcast_iter()
.map(|array| {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/frame/groupby/into_groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ where
num_groups_proxy(&ca, multithreaded, sorted)
}
_ => {
let ca = self.cast(&DataType::UInt32).unwrap();
let ca = self.cast_unchecked(&DataType::UInt32).unwrap();
let ca = ca.u32().unwrap();
num_groups_proxy(ca, multithreaded, sorted)
}
Expand Down
54 changes: 30 additions & 24 deletions polars/polars-core/src/frame/groupby/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,30 @@ pub use into_groups::*;
use polars_arrow::array::ValueSize;
pub use proxy::*;

// This will remove the sorted flag on signed integers
fn prepare_dataframe_unsorted(by: &[Series]) -> DataFrame {
DataFrame::new_no_checks(
by.iter()
.map(|s| match s.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => s.cast(&DataType::UInt32).unwrap(),
_ => {
if s.dtype().to_physical().is_numeric() {
let s = s.to_physical_repr();
if s.bit_repr_is_large() {
s.bit_repr_large().into_series()
} else {
s.bit_repr_small().into_series()
}
} else {
s.clone()
}
}
})
.collect(),
)
}

impl DataFrame {
pub fn groupby_with_series(
&self,
Expand Down Expand Up @@ -81,29 +105,6 @@ impl DataFrame {
));
};

use DataType::*;
// make sure that categorical and small integers are used as uint32 in value type
let keys_df = DataFrame::new_no_checks(
by.iter()
.map(|s| match s.dtype() {
Int8 | UInt8 | Int16 | UInt16 => s.cast(&DataType::UInt32).unwrap(),
#[cfg(feature = "dtype-categorical")]
Categorical(_) => s.cast(&DataType::UInt32).unwrap(),
Float32 => s.bit_repr_small().into_series(),
// otherwise we use the vec hash for float
Float64 => s.bit_repr_large().into_series(),
_ => {
// is date like
if !s.dtype().is_numeric() && s.is_numeric_physical() {
s.to_physical_repr().into_owned()
} else {
s.clone()
}
}
})
.collect(),
);

let n_partitions = set_partition_size();

let groups = match by.len() {
Expand All @@ -114,6 +115,8 @@ impl DataFrame {
2 => {
// multiple keys is always multi-threaded
// reduce code paths
let keys_df = prepare_dataframe_unsorted(&by);

let s0 = &keys_df.get_columns()[0];
let s1 = &keys_df.get_columns()[1];

Expand Down Expand Up @@ -159,7 +162,10 @@ impl DataFrame {
groupby_threaded_multiple_keys_flat(keys_df, n_partitions, sorted)
}
}
_ => groupby_threaded_multiple_keys_flat(keys_df, n_partitions, sorted),
_ => {
let keys_df = prepare_dataframe_unsorted(&by);
groupby_threaded_multiple_keys_flat(keys_df, n_partitions, sorted)
}
};
Ok(GroupBy::new(self, by, groups, None))
}
Expand Down
28 changes: 21 additions & 7 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,14 @@ impl DataFrame {
check_categorical_src(s_left.dtype(), s_right.dtype())?;

let (join_tuples_left, join_tuples_right) = if use_sort_merge(s_left, s_right) {
par_sorted_merge_inner(s_left, s_right)
#[cfg(feature = "performant")]
{
par_sorted_merge_inner(s_left, s_right)
}
#[cfg(not(feature = "performant"))]
{
s_left.hash_join_inner(s_right)
}
} else {
s_left.hash_join_inner(s_right)
};
Expand Down Expand Up @@ -824,15 +831,22 @@ impl DataFrame {
check_categorical_src(s_left.dtype(), s_right.dtype())?;

let ids = if use_sort_merge(s_left, s_right) {
let (left_idx, right_idx) = par_sorted_merge_left(s_left, s_right);
#[cfg(feature = "chunked_ids")]
#[cfg(feature = "performant")]
{
(Either::Left(left_idx), Either::Left(right_idx))
}
let (left_idx, right_idx) = par_sorted_merge_left(s_left, s_right);
#[cfg(feature = "chunked_ids")]
{
(Either::Left(left_idx), Either::Left(right_idx))
}

#[cfg(not(feature = "chunked_ids"))]
#[cfg(not(feature = "chunked_ids"))]
{
(left_idx, right_idx)
}
}
#[cfg(not(feature = "performant"))]
{
(left_idx, right_idx)
s_left.hash_join_left(s_right)
}
} else {
s_left.hash_join_left(s_right)
Expand Down
90 changes: 71 additions & 19 deletions polars/polars-core/src/frame/hash_join/sort_merge.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
use super::*;
#[cfg(feature = "performant")]
use crate::utils::split_offsets;
#[cfg(feature = "performant")]
use polars_arrow::kernels::sorted_join;
#[cfg(feature = "performant")]
use polars_utils::flatten;

pub(super) fn use_sort_merge(s_left: &Series, s_right: &Series) -> bool {
// only use for numeric data for now
use IsSorted::*;
let out = match (s_left.is_sorted(), s_right.is_sorted()) {
(Ascending, Ascending) => s_left.null_count() == 0 && s_right.null_count() == 0,
(Ascending, Ascending) => {
s_left.null_count() == 0
&& s_right.null_count() == 0
&& s_left.dtype().to_physical().is_numeric()
}
_ => false,
};
if out && std::env::var("POLARS_VERBOSE").is_ok() {
Expand All @@ -15,6 +23,7 @@ pub(super) fn use_sort_merge(s_left: &Series, s_right: &Series) -> bool {
out
}

#[cfg(feature = "performant")]
fn par_sorted_merge_left_impl<T>(
s_left: &ChunkedArray<T>,
s_right: &ChunkedArray<T>,
Expand Down Expand Up @@ -43,23 +52,45 @@ where
(flatten(&lefts, None), flatten(&rights, None))
}

#[cfg(feature = "performant")]
pub(super) fn par_sorted_merge_left(
s_left: &Series,
s_right: &Series,
) -> (Vec<IdxSize>, Vec<Option<IdxSize>>) {
// Don't use bit_repr here. It messes up sortedness.
debug_assert_eq!(s_left.dtype(), s_right.dtype());
if s_left.bit_repr_is_large() {
let left = s_left.bit_repr_large();
let right = s_right.bit_repr_large();
let s_left = s_left.to_physical_repr();
let s_right = s_right.to_physical_repr();

par_sorted_merge_left_impl(&left, &right)
} else {
let left = s_left.bit_repr_small();
let right = s_right.bit_repr_small();

par_sorted_merge_left_impl(&left, &right)
match s_left.dtype() {
#[cfg(feature = "dtype-i8")]
DataType::Int8 => par_sorted_merge_left_impl(s_left.i8().unwrap(), s_right.i8().unwrap()),
#[cfg(feature = "dtype-u8")]
DataType::UInt8 => par_sorted_merge_left_impl(s_left.u8().unwrap(), s_right.u8().unwrap()),
#[cfg(feature = "dtype-u16")]
DataType::UInt16 => {
par_sorted_merge_left_impl(s_left.u16().unwrap(), s_right.u16().unwrap())
}
#[cfg(feature = "dtype-i16")]
DataType::Int16 => {
par_sorted_merge_left_impl(s_left.i16().unwrap(), s_right.i16().unwrap())
}
DataType::UInt32 => {
par_sorted_merge_left_impl(s_left.u32().unwrap(), s_right.u32().unwrap())
}
DataType::Int32 => {
par_sorted_merge_left_impl(s_left.i32().unwrap(), s_right.i32().unwrap())
}
DataType::UInt64 => {
par_sorted_merge_left_impl(s_left.u64().unwrap(), s_right.u64().unwrap())
}
DataType::Int64 => {
par_sorted_merge_left_impl(s_left.i64().unwrap(), s_right.i64().unwrap())
}
_ => unreachable!(),
}
}
#[cfg(feature = "performant")]
fn par_sorted_merge_inner_impl<T>(
s_left: &ChunkedArray<T>,
s_right: &ChunkedArray<T>,
Expand Down Expand Up @@ -88,20 +119,41 @@ where
(flatten(&lefts, None), flatten(&rights, None))
}

#[cfg(feature = "performant")]
pub(super) fn par_sorted_merge_inner(
s_left: &Series,
s_right: &Series,
) -> (Vec<IdxSize>, Vec<IdxSize>) {
// Don't use bit_repr here. It messes up sortedness.
debug_assert_eq!(s_left.dtype(), s_right.dtype());
if s_left.bit_repr_is_large() {
let left = s_left.bit_repr_large();
let right = s_right.bit_repr_large();

par_sorted_merge_inner_impl(&left, &right)
} else {
let left = s_left.bit_repr_small();
let right = s_right.bit_repr_small();
let s_left = s_left.to_physical_repr();
let s_right = s_right.to_physical_repr();

par_sorted_merge_inner_impl(&left, &right)
match s_left.dtype() {
#[cfg(feature = "dtype-i8")]
DataType::Int8 => par_sorted_merge_inner_impl(s_left.i8().unwrap(), s_right.i8().unwrap()),
#[cfg(feature = "dtype-u8")]
DataType::UInt8 => par_sorted_merge_inner_impl(s_left.u8().unwrap(), s_right.u8().unwrap()),
#[cfg(feature = "dtype-u16")]
DataType::UInt16 => {
par_sorted_merge_inner_impl(s_left.u16().unwrap(), s_right.u16().unwrap())
}
#[cfg(feature = "dtype-i16")]
DataType::Int16 => {
par_sorted_merge_inner_impl(s_left.i16().unwrap(), s_right.i16().unwrap())
}
DataType::UInt32 => {
par_sorted_merge_inner_impl(s_left.u32().unwrap(), s_right.u32().unwrap())
}
DataType::Int32 => {
par_sorted_merge_inner_impl(s_left.i32().unwrap(), s_right.i32().unwrap())
}
DataType::UInt64 => {
par_sorted_merge_inner_impl(s_left.u64().unwrap(), s_right.u64().unwrap())
}
DataType::Int64 => {
par_sorted_merge_inner_impl(s_left.i64().unwrap(), s_right.i64().unwrap())
}
_ => unreachable!(),
}
}
16 changes: 16 additions & 0 deletions py-polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,19 @@ def test_groupby_rolling_negative_offset_3914() -> None:
[14, 15],
[15, 16],
]


def test_groupby_signed_transmutes() -> None:
df = pl.DataFrame({"foo": [-1, -2, -3, -4, -5], "bar": [500, 600, 700, 800, 900]})

for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:
df = (
df.with_columns([pl.col("foo").cast(dt), pl.col("bar")])
.groupby("foo", maintain_order=True)
.agg(pl.col("bar").median())
)

assert df.to_dict(False) == {
"foo": [-1, -2, -3, -4, -5],
"bar": [500.0, 600.0, 700.0, 800.0, 900.0],
}
22 changes: 22 additions & 0 deletions py-polars/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,25 @@ def test_sort_aggregation_fast_paths() -> None:
]
)
assert out.frame_equal(expected)


def test_sorted_join_and_dtypes() -> None:
for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int16]:
df_a = (
pl.DataFrame({"a": [-5, -2, 3, 3, 9, 10]})
.with_row_count()
.with_column(pl.col("a").cast(dt).set_sorted())
)

df_b = pl.DataFrame({"a": [-2, -3, 3, 10]}).with_column(
pl.col("a").cast(dt).set_sorted()
)

assert df_a.join(df_b, on="a", how="inner").to_dict(False) == {
"row_nr": [1, 2, 3, 5],
"a": [-2, 3, 3, 10],
}
assert df_a.join(df_b, on="a", how="left").to_dict(False) == {
"row_nr": [0, 1, 2, 3, 4, 5],
"a": [-5, -2, 3, 3, 9, 10],
}

0 comments on commit a077528

Please sign in to comment.