Skip to content

Commit

Permalink
fix join negative keys (#3730)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 18, 2022
1 parent 7a9d7ac commit 2b7d403
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 29 deletions.
62 changes: 55 additions & 7 deletions polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,38 @@
#[cfg(feature = "dtype-categorical")]
use crate::chunked_array::categorical::CategoricalChunkedBuilder;
use crate::prelude::*;
use arrow::compute::cast::CastOptions;
use polars_arrow::compute::cast;
use std::convert::TryFrom;

pub(crate) fn cast_chunks(chunks: &[ArrayRef], dtype: &DataType) -> Result<Vec<ArrayRef>> {
pub(crate) fn cast_chunks(
chunks: &[ArrayRef],
dtype: &DataType,
checked: bool,
) -> Result<Vec<ArrayRef>> {
let options = if checked {
Default::default()
} else {
CastOptions {
wrapped: true,
partial: false,
}
};

let chunks = chunks
.iter()
.map(|arr| cast::cast(arr.as_ref(), &dtype.to_arrow()))
.map(|arr| arrow::compute::cast::cast(arr.as_ref(), &dtype.to_arrow(), options))
.collect::<arrow::error::Result<Vec<_>>>()?;
Ok(chunks)
}

fn cast_impl(name: &str, chunks: &[ArrayRef], dtype: &DataType) -> Result<Series> {
let chunks = cast_chunks(chunks, &dtype.to_physical())?;
fn cast_impl_inner(
name: &str,
chunks: &[ArrayRef],
dtype: &DataType,
checked: bool,
) -> Result<Series> {
let chunks = cast_chunks(chunks, &dtype.to_physical(), checked)?;
let out = Series::try_from((name, chunks))?;
use DataType::*;
let out = match dtype {
Expand All @@ -29,17 +48,21 @@ fn cast_impl(name: &str, chunks: &[ArrayRef], dtype: &DataType) -> Result<Series
Ok(out)
}

impl<T> ChunkCast for ChunkedArray<T>
fn cast_impl(name: &str, chunks: &[ArrayRef], dtype: &DataType) -> Result<Series> {
cast_impl_inner(name, chunks, dtype, true)
}

impl<T> ChunkedArray<T>
where
T: PolarsNumericType,
{
fn cast(&self, data_type: &DataType) -> Result<Series> {
fn cast_impl(&self, data_type: &DataType, checked: bool) -> Result<Series> {
match data_type {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => {
Ok(CategoricalChunked::full_null(self.name(), self.len()).into_series())
}
_ => cast_impl(self.name(), &self.chunks, data_type).map(|mut s| {
_ => cast_impl_inner(self.name(), &self.chunks, data_type, checked).map(|mut s| {
// maintain sorted if data types remain signed
if self.is_sorted()
|| self.is_sorted_reverse() && (s.null_count() == self.null_count())
Expand All @@ -60,6 +83,19 @@ where
}
}

impl<T> ChunkCast for ChunkedArray<T>
where
T: PolarsNumericType,
{
fn cast(&self, data_type: &DataType) -> Result<Series> {
self.cast_impl(data_type, true)
}

fn cast_unchecked(&self, data_type: &DataType) -> Result<Series> {
self.cast_impl(data_type, false)
}
}

impl ChunkCast for Utf8Chunked {
fn cast(&self, data_type: &DataType) -> Result<Series> {
match data_type {
Expand All @@ -74,6 +110,10 @@ impl ChunkCast for Utf8Chunked {
_ => cast_impl(self.name(), &self.chunks, data_type),
}
}

fn cast_unchecked(&self, data_type: &DataType) -> Result<Series> {
self.cast(data_type)
}
}

fn boolean_to_utf8(ca: &BooleanChunked) -> Utf8Chunked {
Expand All @@ -96,6 +136,10 @@ impl ChunkCast for BooleanChunked {
cast_impl(self.name(), &self.chunks, data_type)
}
}

fn cast_unchecked(&self, data_type: &DataType) -> Result<Series> {
self.cast(data_type)
}
}

fn cast_inner_list_type(list: &ListArray<i64>, child_type: &DataType) -> Result<ArrayRef> {
Expand Down Expand Up @@ -133,6 +177,10 @@ impl ChunkCast for ListChunked {
_ => Err(PolarsError::ComputeError("Cannot cast list type".into())),
}
}

fn cast_unchecked(&self, data_type: &DataType) -> Result<Series> {
self.cast(data_type)
}
}

#[cfg(test)]
Expand Down
6 changes: 5 additions & 1 deletion polars/polars-core/src/chunked_array/ops/bit_repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ where
.collect::<Vec<_>>();
UInt32Chunked::from_chunks(self.name(), chunks)
} else {
self.cast(&DataType::UInt32).unwrap().u32().unwrap().clone()
self.cast_unchecked(&DataType::UInt32)
.unwrap()
.u32()
.unwrap()
.clone()
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ pub trait ChunkSet<'a, A, B> {
pub trait ChunkCast {
/// Cast a `[ChunkedArray]` to `[DataType]`
fn cast(&self, data_type: &DataType) -> Result<Series>;

/// Does not check if the cast is a valid one and may over/underflow
fn cast_unchecked(&self, data_type: &DataType) -> Result<Series>;
}

/// Fastest way to do elementwise operations on a ChunkedArray<T> when the operation is cheaper than
Expand Down
18 changes: 6 additions & 12 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,12 @@ pub(crate) unsafe fn get_hash_tbl_threaded_join_partitioned<Item>(
hash_tables: &[Item],
len: u64,
) -> &Item {
let mut idx = 0;
for i in 0..len {
// can only be done for powers of two.
// n % 2^i = n & (2^i - 1)
if (h + i) & (len - 1) == 0 {
idx = i as usize;
if this_partition(h, i, len) {
return hash_tables.get_unchecked(i as usize);
}
}
hash_tables.get_unchecked(idx)
unreachable!()
}

#[allow(clippy::type_complexity)]
Expand All @@ -155,15 +152,12 @@ unsafe fn get_hash_tbl_threaded_join_mut_partitioned<T, H>(
hash_tables: &mut [HashMap<T, (bool, Vec<IdxSize>), H>],
len: u64,
) -> &mut HashMap<T, (bool, Vec<IdxSize>), H> {
let mut idx = 0;
for i in 0..len {
// can only be done for powers of two.
// n % 2^i = n & (2^i - 1)
if (h + i) & (len - 1) == 0 {
idx = i as usize;
if this_partition(h, i, len) {
return hash_tables.get_unchecked_mut(i as usize);
}
}
hash_tables.get_unchecked_mut(idx)
unreachable!()
}

pub trait ZipOuterJoinColumn {
Expand Down
14 changes: 7 additions & 7 deletions polars/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl Series {
match dtype {
ArrowDataType::LargeUtf8 => Ok(Utf8Chunked::from_chunks(name, chunks).into_series()),
ArrowDataType::Utf8 => {
let chunks = cast_chunks(&chunks, &DataType::Utf8).unwrap();
let chunks = cast_chunks(&chunks, &DataType::Utf8, false).unwrap();
Ok(Utf8Chunked::from_chunks(name, chunks).into_series())
}
ArrowDataType::List(_) => {
Expand All @@ -114,14 +114,14 @@ impl Series {
ArrowDataType::Float64 => Ok(Float64Chunked::from_chunks(name, chunks).into_series()),
#[cfg(feature = "dtype-date")]
ArrowDataType::Date32 => {
let chunks = cast_chunks(&chunks, &DataType::Int32).unwrap();
let chunks = cast_chunks(&chunks, &DataType::Int32, false).unwrap();
Ok(Int32Chunked::from_chunks(name, chunks)
.into_date()
.into_series())
}
#[cfg(feature = "dtype-datetime")]
ArrowDataType::Date64 => {
let chunks = cast_chunks(&chunks, &DataType::Int64).unwrap();
let chunks = cast_chunks(&chunks, &DataType::Int64, false).unwrap();
let ca = Int64Chunked::from_chunks(name, chunks);
Ok(ca.into_datetime(TimeUnit::Milliseconds, None).into_series())
}
Expand All @@ -132,7 +132,7 @@ impl Series {
tz = None;
}
// we still drop timezone for now
let chunks = cast_chunks(&chunks, &DataType::Int64).unwrap();
let chunks = cast_chunks(&chunks, &DataType::Int64, false).unwrap();
let s = Int64Chunked::from_chunks(name, chunks)
.into_datetime(tu.into(), tz)
.into_series();
Expand All @@ -145,7 +145,7 @@ impl Series {
}
#[cfg(feature = "dtype-duration")]
ArrowDataType::Duration(tu) => {
let chunks = cast_chunks(&chunks, &DataType::Int64).unwrap();
let chunks = cast_chunks(&chunks, &DataType::Int64, false).unwrap();
let s = Int64Chunked::from_chunks(name, chunks)
.into_duration(tu.into())
.into_series();
Expand All @@ -160,9 +160,9 @@ impl Series {
ArrowDataType::Time64(tu) | ArrowDataType::Time32(tu) => {
let mut chunks = chunks;
if matches!(dtype, ArrowDataType::Time32(_)) {
chunks = cast_chunks(&chunks, &DataType::Int32).unwrap();
chunks = cast_chunks(&chunks, &DataType::Int32, false).unwrap();
}
let chunks = cast_chunks(&chunks, &DataType::Int64).unwrap();
let chunks = cast_chunks(&chunks, &DataType::Int64, false).unwrap();
let s = Int64Chunked::from_chunks(name, chunks)
.into_time()
.into_series();
Expand Down
5 changes: 3 additions & 2 deletions polars/polars-core/src/vector_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,14 @@ impl AsU64 for i32 {
#[inline]
fn as_u64(self) -> u64 {
let asu32: u32 = unsafe { std::mem::transmute(self) };
asu32 as u64
dbg!(asu32 as u64)
}
}

impl AsU64 for i64 {
#[inline]
fn as_u64(self) -> u64 {
unsafe { std::mem::transmute(self) }
unsafe { dbg!(std::mem::transmute(self)) }
}
}

Expand Down Expand Up @@ -397,6 +397,7 @@ impl<'a> AsU64 for StrHash<'a> {
#[inline]
/// For partitions that are a power of 2 we can use a bitshift instead of a modulo.
pub(crate) fn this_partition(h: u64, thread_no: u64, n_partitions: u64) -> bool {
debug_assert!(n_partitions.is_power_of_two());
// n % 2^i = n & (2^i - 1)
(h.wrapping_add(thread_no)) & n_partitions.wrapping_sub(1) == 0
}
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,28 @@ def test_sorted_merge_joins() -> None:
).join(df_b.with_column(pl.col("a").set_sorted(reverse)), on="a", how=how)

assert out_hash_join.frame_equal(out_sorted_merge_join)


def test_join_negative_integers() -> None:
expected = {"a": [-6, -1, 0], "b": [-6, -1, 0]}

df1 = pl.DataFrame(
{
"a": [-1, -6, -3, 0],
}
)

df2 = pl.DataFrame(
{
"a": [-6, -1, -4, -2, 0],
"b": [-6, -1, -4, -2, 0],
}
)

for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:
assert (
df1.with_column(pl.all().cast(dt))
.join(df2.with_column(pl.all().cast(dt)), on="a", how="inner")
.to_dict(False)
== expected
)

0 comments on commit 2b7d403

Please sign in to comment.