Skip to content

Commit

Permalink
Specialize groupby (#4292)
Browse files Browse the repository at this point in the history
* specialize groupby of integers < 32 bits

* expression to join on argument
  • Loading branch information
ritchie46 committed Aug 6, 2022
1 parent a077528 commit 517e945
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 6 deletions.
2 changes: 1 addition & 1 deletion polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ avro = ["polars-io", "polars-io/avro"]
csv-file = ["polars-io", "polars-io/csv-file", "polars-lazy/csv-file"]

# slower builds
performant = ["polars-core/performant", "chunked_ids"]
performant = ["polars-core/performant", "chunked_ids", "dtype-u8", "dtype-u16"]

# Dataframe formatting.
fmt = ["polars-core/fmt"]
Expand Down
60 changes: 60 additions & 0 deletions polars/polars-core/src/chunked_array/ops/bit_repr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,66 @@
use crate::prelude::*;
use arrow::buffer::Buffer;

#[cfg(feature = "performant")]
impl Int16Chunked {
pub(crate) fn reinterpret_unsigned(&self) -> UInt16Chunked {
let chunks = self
.downcast_iter()
.map(|array| {
let buf = array.values().clone();
// Safety
// same bit length i16 <-> u16
// The fields can still be reordered between generic types
// so we do some extra assertions
let len = buf.len();
let offset = buf.offset();
let ptr = buf.as_slice().as_ptr() as usize;
#[allow(clippy::transmute_undefined_repr)]
let reinterpreted_buf = unsafe { std::mem::transmute::<_, Buffer<u16>>(buf) };
debug_assert_eq!(reinterpreted_buf.len(), len);
debug_assert_eq!(reinterpreted_buf.offset(), offset);
debug_assert_eq!(reinterpreted_buf.as_slice().as_ptr() as usize, ptr);
Box::new(PrimitiveArray::new(
ArrowDataType::UInt16,
reinterpreted_buf,
array.validity().cloned(),
)) as ArrayRef
})
.collect::<Vec<_>>();
UInt16Chunked::from_chunks(self.name(), chunks)
}
}

#[cfg(feature = "performant")]
impl Int8Chunked {
pub(crate) fn reinterpret_unsigned(&self) -> UInt8Chunked {
let chunks = self
.downcast_iter()
.map(|array| {
let buf = array.values().clone();
// Safety
// same bit length i8 <-> u8
// The fields can still be reordered between generic types
// so we do some extra assertions
let len = buf.len();
let offset = buf.offset();
let ptr = buf.as_slice().as_ptr() as usize;
#[allow(clippy::transmute_undefined_repr)]
let reinterpreted_buf = unsafe { std::mem::transmute::<_, Buffer<u8>>(buf) };
debug_assert_eq!(reinterpreted_buf.len(), len);
debug_assert_eq!(reinterpreted_buf.offset(), offset);
debug_assert_eq!(reinterpreted_buf.as_slice().as_ptr() as usize, ptr);
Box::new(PrimitiveArray::new(
ArrowDataType::UInt8,
reinterpreted_buf,
array.validity().cloned(),
)) as ArrayRef
})
.collect::<Vec<_>>();
UInt8Chunked::from_chunks(self.name(), chunks)
}
}

impl<T> ToBitRepr for ChunkedArray<T>
where
T: PolarsNumericType,
Expand Down
46 changes: 43 additions & 3 deletions polars/polars-core/src/frame/groupby/into_groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,37 @@ where
let ca = self.bit_repr_small();
num_groups_proxy(&ca, multithreaded, sorted)
}
#[cfg(feature = "performant")]
DataType::Int8 => {
// convince the compiler that we are this type.
let ca: &Int8Chunked =
unsafe { &*(self as *const ChunkedArray<T> as *const ChunkedArray<Int8Type>) };
let ca = ca.reinterpret_unsigned();
num_groups_proxy(&ca, multithreaded, sorted)
}
#[cfg(feature = "performant")]
DataType::UInt8 => {
// convince the compiler that we are this type.
let ca: &UInt8Chunked =
unsafe { &*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt8Type>) };
num_groups_proxy(ca, multithreaded, sorted)
}
#[cfg(feature = "performant")]
DataType::Int16 => {
// convince the compiler that we are this type.
let ca: &Int16Chunked =
unsafe { &*(self as *const ChunkedArray<T> as *const ChunkedArray<Int16Type>) };
let ca = ca.reinterpret_unsigned();
num_groups_proxy(&ca, multithreaded, sorted)
}
#[cfg(feature = "performant")]
DataType::UInt16 => {
// convince the compiler that we are this type.
let ca: &UInt16Chunked = unsafe {
&*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt16Type>)
};
num_groups_proxy(ca, multithreaded, sorted)
}
_ => {
let ca = self.cast_unchecked(&DataType::UInt32).unwrap();
let ca = ca.u32().unwrap();
Expand All @@ -183,9 +214,18 @@ where
}
impl IntoGroupsProxy for BooleanChunked {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
let ca = self.cast(&DataType::UInt32).unwrap();
let ca = ca.u32().unwrap();
ca.group_tuples(multithreaded, sorted)
#[cfg(feature = "performant")]
{
let ca = self.cast(&DataType::UInt8).unwrap();
let ca = ca.u8().unwrap();
ca.group_tuples(multithreaded, sorted)
}
#[cfg(not(feature = "performant"))]
{
let ca = self.cast(&DataType::UInt32).unwrap();
let ca = ca.u32().unwrap();
ca.group_tuples(multithreaded, sorted)
}
}
}

Expand Down
40 changes: 40 additions & 0 deletions polars/polars-core/src/vector_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,22 @@ pub(crate) trait AsU64 {
fn as_u64(self) -> u64;
}

#[cfg(feature = "performant")]
impl AsU64 for u8 {
#[inline]
fn as_u64(self) -> u64 {
self as u64
}
}

#[cfg(feature = "performant")]
impl AsU64 for u16 {
#[inline]
fn as_u64(self) -> u64 {
self as u64
}
}

impl AsU64 for u32 {
#[inline]
fn as_u64(self) -> u64 {
Expand Down Expand Up @@ -270,6 +286,30 @@ impl AsU64 for Option<u32> {
}
}

#[cfg(feature = "performant")]
impl AsU64 for Option<u8> {
#[inline]
fn as_u64(self) -> u64 {
match self {
Some(v) => v as u64,
// just a number safe from overflow
None => u64::MAX >> 2,
}
}
}

#[cfg(feature = "performant")]
impl AsU64 for Option<u16> {
#[inline]
fn as_u64(self) -> u64 {
match self {
Some(v) => v as u64,
// just a number safe from overflow
None => u64::MAX >> 2,
}
}
}

impl AsU64 for Option<u64> {
#[inline]
fn as_u64(self) -> u64 {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3960,7 +3960,7 @@ def join(
else:
right_on_ = right_on

if isinstance(on, str):
if isinstance(on, (str, pli.Expr)):
left_on_ = [on]
right_on_ = [on]
elif isinstance(on, list):
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,7 @@ def join(
else:
right_on_ = right_on

if isinstance(on, str):
if isinstance(on, (str, pli.Expr)):
left_on_ = [on]
right_on_ = [on]
elif isinstance(on, list):
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,18 @@ def test_joins_dispatch() -> None:
dfa.join(dfa, on=["date", "a"], how=how)
dfa.join(dfa, on=["a", "datetime"], how=how)
dfa.join(dfa, on=["date"], how=how)


def test_join_on_cast() -> None:
df_a = (
pl.DataFrame({"a": [-5, -2, 3, 3, 9, 10]})
.with_row_count()
.with_column(pl.col("a").cast(pl.Int32))
)

df_b = pl.DataFrame({"a": [-2, -3, 3, 10]})

assert df_a.join(df_b, on=pl.col("a").cast(pl.Int64)).to_dict(False) == {
"row_nr": [1, 2, 3, 5],
"a": [-2, 3, 3, 10],
}

0 comments on commit 517e945

Please sign in to comment.