Skip to content

Commit

Permalink
fix(rust, python): fix dtypes in join_asof_by (#5746)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 8, 2022
1 parent 7e39f7f commit 48d415e
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 111 deletions.
194 changes: 85 additions & 109 deletions polars/polars-core/src/frame/asof_join/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,11 @@ where
let right_asof = right_asof.cont_slice().unwrap();

let n_threads = POOL.current_num_threads();
let splitted_left = split_ca(by_left, n_threads).unwrap();
let splitted_by_left = split_ca(by_left, n_threads).unwrap();
let splitted_right = split_ca(by_right, n_threads).unwrap();

let hb = RandomState::default();
let vals_left = prepare_strs(&splitted_left, &hb);
let vals_left = prepare_strs(&splitted_by_left, &hb);
let vals_right = prepare_strs(&splitted_right, &hb);

let hash_tbls = create_probe_table(vals_right);
Expand Down Expand Up @@ -571,131 +571,99 @@ where
})
}

#[allow(clippy::too_many_arguments)]
fn dispatch_join<T: PolarsNumericType>(
left_asof: &ChunkedArray<T>,
right_asof: &ChunkedArray<T>,
left_by_s: &Series,
right_by_s: &Series,
left_by: &mut DataFrame,
right_by: &mut DataFrame,
strategy: AsofStrategy,
tolerance: Option<AnyValue<'static>>,
) -> PolarsResult<Vec<Option<IdxSize>>> {
let out = if left_by.width() == 1 {
match left_by_s.dtype() {
DataType::Utf8 => asof_join_by_utf8(
left_by_s.utf8().unwrap(),
right_by_s.utf8().unwrap(),
left_asof,
right_asof,
tolerance,
strategy,
),
_ => {
if left_by_s.bit_repr_is_large() {
let left_by = left_by_s.bit_repr_large();
let right_by = right_by_s.bit_repr_large();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
} else {
let left_by = left_by_s.bit_repr_small();
let right_by = right_by_s.bit_repr_small();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
}
}
}
} else {
for (lhs, rhs) in left_by.get_columns().iter().zip(right_by.get_columns()) {
check_asof_columns(lhs, rhs)?;
#[cfg(feature = "dtype-categorical")]
check_categorical_src(lhs.dtype(), rhs.dtype())?;
}
asof_join_by_multiple(
left_by, right_by, left_asof, right_asof, tolerance, strategy,
)
};
Ok(out)
}

impl DataFrame {
#[cfg_attr(docsrs, doc(cfg(feature = "asof_join")))]
#[allow(clippy::too_many_arguments)]
#[doc(hidden)]
pub fn _join_asof_by<I, S>(
pub fn _join_asof_by(
&self,
other: &DataFrame,
left_on: &str,
right_on: &str,
left_by: I,
right_by: I,
left_by: Vec<String>,
right_by: Vec<String>,
strategy: AsofStrategy,
tolerance: Option<AnyValue<'static>>,
slice: Option<(i64, usize)>,
) -> PolarsResult<DataFrame>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
use DataType::*;
let left_asof = self.column(left_on)?;
let right_asof = other.column(right_on)?;
) -> PolarsResult<DataFrame> {
let left_asof = self.column(left_on)?.to_physical_repr();
let right_asof = other.column(right_on)?.to_physical_repr();
let right_asof_name = right_asof.name();
let left_asof_name = left_asof.name();

check_asof_columns(left_asof, right_asof)?;
check_asof_columns(&left_asof, &right_asof)?;

let mut left_by = self.select(left_by)?;
let mut right_by = other.select(right_by)?;

let left_by_s = &left_by.get_columns()[0];
let right_by_s = &right_by.get_columns()[0];

let right_join_tuples = if left_asof.bit_repr_is_large() {
// we cannot use bit repr as that loses ordering
let left_asof = left_asof.cast(&DataType::Int64)?;
let right_asof = right_asof.cast(&DataType::Int64)?;
let left_asof = left_asof.i64().unwrap();
let right_asof = right_asof.i64().unwrap();

if left_by.width() == 1 {
match left_by_s.dtype() {
Utf8 => asof_join_by_utf8(
left_by_s.utf8().unwrap(),
right_by_s.utf8().unwrap(),
left_asof,
right_asof,
tolerance,
strategy,
),
_ => {
if left_by_s.bit_repr_is_large() {
let left_by = left_by_s.bit_repr_large();
let right_by = right_by_s.bit_repr_large();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
} else {
let left_by = left_by_s.bit_repr_small();
let right_by = right_by_s.bit_repr_small();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
}
}
}
} else {
for (lhs, rhs) in left_by.get_columns().iter().zip(right_by.get_columns()) {
check_asof_columns(lhs, rhs)?;
#[cfg(feature = "dtype-categorical")]
check_categorical_src(lhs.dtype(), rhs.dtype())?;
}
asof_join_by_multiple(
&mut left_by,
&mut right_by,
left_asof,
right_asof,
tolerance,
strategy,
)
}
} else {
// we cannot use bit repr as that loses ordering
let left_asof = left_asof.cast(&DataType::Int32)?;
let right_asof = right_asof.cast(&DataType::Int32)?;
let left_asof = left_asof.i32().unwrap();
let right_asof = right_asof.i32().unwrap();

if left_by.width() == 1 {
match left_by_s.dtype() {
Utf8 => asof_join_by_utf8(
left_by_s.utf8().unwrap(),
right_by_s.utf8().unwrap(),
left_asof,
right_asof,
tolerance,
strategy,
),
_ => {
if left_by_s.bit_repr_is_large() {
let left_by = left_by_s.bit_repr_large();
let right_by = right_by_s.bit_repr_large();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
} else {
let left_by = left_by_s.bit_repr_small();
let right_by = right_by_s.bit_repr_small();
asof_join_by_numeric(
&left_by, &right_by, left_asof, right_asof, tolerance, strategy,
)?
}
}
}
} else {
asof_join_by_multiple(
&mut left_by,
&mut right_by,
left_asof,
right_asof,
tolerance,
strategy,
)
}
};
let left_by_s = left_by.get_columns()[0].to_physical_repr().into_owned();
let right_by_s = right_by.get_columns()[0].to_physical_repr().into_owned();

let right_join_tuples = with_match_physical_numeric_polars_type!(left_asof.dtype(), |$T| {
let left_asof: &ChunkedArray<$T> = left_asof.as_ref().as_ref().as_ref();
let right_asof: &ChunkedArray<$T> = right_asof.as_ref().as_ref().as_ref();

dispatch_join(
left_asof,
right_asof,
&left_by_s,
&right_by_s,
&mut left_by,
&mut right_by,
strategy,
tolerance
)
})?;

let mut drop_these = right_by.get_column_names();
if left_asof_name == right_asof_name {
Expand Down Expand Up @@ -755,6 +723,14 @@ impl DataFrame {
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let left_by = left_by
.into_iter()
.map(|s| s.as_ref().to_string())
.collect();
let right_by = right_by
.into_iter()
.map(|s| s.as_ref().to_string())
.collect();
self._join_asof_by(
other, left_on, right_on, left_by, right_by, strategy, tolerance, None,
)
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,11 @@ macro_rules! match_arrow_data_type_apply_macro_ca {

#[macro_export]
macro_rules! with_match_physical_numeric_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
$dtype:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use $crate::datatypes::DataType::*;
match $key_type {
match $dtype {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/src/vector_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ impl<'a> BytesHash<'a> {
}

impl<'a> PartialEq for BytesHash<'a> {
#[inline]
fn eq(&self, other: &Self) -> bool {
(self.hash == other.hash) && (self.payload == other.payload)
}
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,31 @@ def test_join_asof_floats() -> None:
"b_right": ["rrow1", "rrow2", "rrow3"],
}

# with by argument
# 5740
df1 = pl.DataFrame(
{"b": np.linspace(0, 5, 7), "c": ["x" if i < 4 else "y" for i in range(7)]}
)
df2 = pl.DataFrame(
{
"val": [0, 2.5, 2.6, 2.7, 3.4, 4, 5],
"c": ["x", "x", "x", "y", "y", "y", "y"],
}
).with_column(pl.col("val").alias("b"))
assert df1.join_asof(df2, on="b", by="c").to_dict(False) == {
"b": [
0.0,
0.8333333333333334,
1.6666666666666667,
2.5,
3.3333333333333335,
4.166666666666667,
5.0,
],
"c": ["x", "x", "x", "x", "y", "y", "y"],
"val": [0.0, 0.0, 0.0, 2.5, 2.7, 4.0, 5.0],
}


def test_join_asof_tolerance() -> None:
df_trades = pl.DataFrame(
Expand Down

0 comments on commit 48d415e

Please sign in to comment.