Skip to content

Commit

Permalink
refactor(rust!): prepare for join coalescing argument (#15418)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 29, 2024
1 parent 2805eca commit f0dbb6a
Show file tree
Hide file tree
Showing 17 changed files with 128 additions and 73 deletions.
13 changes: 12 additions & 1 deletion crates/polars-lazy/src/frame/mod.rs
Expand Up @@ -31,6 +31,7 @@ pub use ndjson::*;
pub use parquet::*;
use polars_core::prelude::*;
use polars_io::RowIndex;
use polars_ops::frame::JoinCoalesce;
pub use polars_plan::frame::{AllowedOptimizations, OptState};
use polars_plan::global::FETCH_ROWS;
use smartstring::alias::String as SmartString;
Expand Down Expand Up @@ -1124,7 +1125,7 @@ impl LazyFrame {
other,
[left_on.into()],
[right_on.into()],
JoinArgs::new(JoinType::Outer { coalesce: false }),
JoinArgs::new(JoinType::Outer),
)
}

Expand Down Expand Up @@ -1195,6 +1196,7 @@ impl LazyFrame {
.right_on(right_on)
.how(args.how)
.validate(args.validation)
.coalesce(args.coalesce)
.join_nulls(args.join_nulls);

if let Some(suffix) = args.suffix {
Expand Down Expand Up @@ -1764,6 +1766,7 @@ pub struct JoinBuilder {
force_parallel: bool,
suffix: Option<String>,
validation: JoinValidation,
coalesce: JoinCoalesce,
join_nulls: bool,
}
impl JoinBuilder {
Expand All @@ -1780,6 +1783,7 @@ impl JoinBuilder {
join_nulls: false,
suffix: None,
validation: Default::default(),
coalesce: Default::default(),
}
}

Expand Down Expand Up @@ -1851,6 +1855,12 @@ impl JoinBuilder {
self
}

/// Whether to coalesce join columns.
pub fn coalesce(mut self, coalesce: JoinCoalesce) -> Self {
self.coalesce = coalesce;
self
}

/// Finish builder
pub fn finish(self) -> LazyFrame {
let mut opt_state = self.lf.opt_state;
Expand All @@ -1865,6 +1875,7 @@ impl JoinBuilder {
suffix: self.suffix,
slice: None,
join_nulls: self.join_nulls,
coalesce: self.coalesce,
};

let lp = self
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-lazy/src/tests/streaming.rs
@@ -1,3 +1,5 @@
use polars_ops::frame::JoinCoalesce;

use super::*;

fn get_csv_file() -> LazyFrame {
Expand Down Expand Up @@ -295,7 +297,8 @@ fn test_streaming_partial() -> PolarsResult<()> {
.left_on([col("a")])
.right_on([col("a")])
.suffix("_foo")
.how(JoinType::Outer { coalesce: true })
.how(JoinType::Outer)
.coalesce(JoinCoalesce::CoalesceColumns)
.finish();

let q = q.left_join(
Expand Down
66 changes: 51 additions & 15 deletions crates/polars-ops/src/frame/join/args.rs
Expand Up @@ -26,6 +26,36 @@ pub struct JoinArgs {
pub suffix: Option<String>,
pub slice: Option<(i64, usize)>,
pub join_nulls: bool,
pub coalesce: JoinCoalesce,
}

#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum JoinCoalesce {
#[default]
JoinSpecific,
CoalesceColumns,
KeepColumns,
}

impl JoinCoalesce {
pub fn coalesce(&self, join_type: &JoinType) -> bool {
use JoinCoalesce::*;
use JoinType::*;
match join_type {
Left | Inner => {
matches!(self, JoinSpecific | CoalesceColumns)
},
Outer { .. } => {
matches!(self, CoalesceColumns)
},
#[cfg(feature = "asof_join")]
AsOf(_) => false,
Cross => false,
#[cfg(feature = "semi_anti_join")]
Semi | Anti => false,
}
}
}

impl Default for JoinArgs {
Expand All @@ -36,6 +66,7 @@ impl Default for JoinArgs {
suffix: None,
slice: None,
join_nulls: false,
coalesce: Default::default(),
}
}
}
Expand All @@ -48,9 +79,15 @@ impl JoinArgs {
suffix: None,
slice: None,
join_nulls: false,
coalesce: Default::default(),
}
}

pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
self.coalesce = coalesce;
self
}

pub fn suffix(&self) -> &str {
self.suffix.as_deref().unwrap_or("_right")
}
Expand All @@ -61,9 +98,7 @@ impl JoinArgs {
pub enum JoinType {
Left,
Inner,
Outer {
coalesce: bool,
},
Outer,
#[cfg(feature = "asof_join")]
AsOf(AsOfOptions),
Cross,
Expand All @@ -73,18 +108,6 @@ pub enum JoinType {
Anti,
}

impl JoinType {
pub fn merges_join_keys(&self) -> bool {
match self {
Self::Outer { coalesce } => *coalesce,
// Merges them if they are equal
#[cfg(feature = "asof_join")]
Self::AsOf(_) => false,
_ => true,
}
}
}

impl From<JoinType> for JoinArgs {
fn from(value: JoinType) -> Self {
JoinArgs::new(value)
Expand Down Expand Up @@ -116,6 +139,19 @@ impl Debug for JoinType {
}
}

impl JoinType {
pub fn is_asof(&self) -> bool {
#[cfg(feature = "asof_join")]
{
matches!(self, JoinType::AsOf(_))
}
#[cfg(not(feature = "asof_join"))]
{
false
}
}
}

#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum JoinValidation {
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-ops/src/frame/join/hash_join/mod.rs
Expand Up @@ -271,9 +271,7 @@ pub trait JoinDispatch: IntoDf {
|| unsafe { other.take_unchecked(&idx_ca_r) },
);

let JoinType::Outer { coalesce } = args.how else {
unreachable!()
};
let coalesce = args.coalesce.coalesce(&JoinType::Outer);
let out = _finish_join(df_left, df_right, args.suffix.as_deref());
if coalesce {
Ok(_coalesce_outer_join(
Expand Down
18 changes: 6 additions & 12 deletions crates/polars-ops/src/frame/join/mod.rs
Expand Up @@ -209,9 +209,7 @@ pub trait DataFrameJoinOps: IntoDf {
JoinType::Left => {
left_df._left_join_from_series(other, s_left, s_right, args, _verbose, None)
},
JoinType::Outer { .. } => {
left_df._outer_join_from_series(other, s_left, s_right, args)
},
JoinType::Outer => left_df._outer_join_from_series(other, s_left, s_right, args),
#[cfg(feature = "semi_anti_join")]
JoinType::Anti => left_df._semi_anti_join_from_series(
s_left,
Expand Down Expand Up @@ -278,13 +276,14 @@ pub trait DataFrameJoinOps: IntoDf {
JoinType::Cross => {
unreachable!()
},
JoinType::Outer { coalesce } => {
JoinType::Outer => {
let names_left = selected_left.iter().map(|s| s.name()).collect::<Vec<_>>();
args.how = JoinType::Outer { coalesce: false };
let coalesce = args.coalesce;
args.coalesce = JoinCoalesce::KeepColumns;
let suffix = args.suffix.clone();
let out = left_df._outer_join_from_series(other, &lhs_keys, &rhs_keys, args);

if coalesce {
if coalesce.coalesce(&JoinType::Outer) {
Ok(_coalesce_outer_join(
out?,
&names_left,
Expand Down Expand Up @@ -411,12 +410,7 @@ pub trait DataFrameJoinOps: IntoDf {
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.join(
other,
left_on,
right_on,
JoinArgs::new(JoinType::Outer { coalesce: false }),
)
self.join(other, left_on, right_on, JoinArgs::new(JoinType::Outer))
}
}

Expand Down
16 changes: 9 additions & 7 deletions crates/polars-pipe/src/executors/sinks/joins/generic_build.rs
Expand Up @@ -5,6 +5,7 @@ use hashbrown::hash_map::RawEntryMut;
use polars_core::export::ahash::RandomState;
use polars_core::prelude::*;
use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked};
use polars_ops::prelude::JoinArgs;
use polars_utils::arena::Node;
use polars_utils::slice::GetSaferUnchecked;
use polars_utils::unitvec;
Expand Down Expand Up @@ -34,6 +35,7 @@ pub struct GenericBuild<K: ExtraPayload> {
materialized_join_cols: Vec<BinaryArray<i64>>,
suffix: Arc<str>,
hb: RandomState,
join_args: JoinArgs,
// partitioned tables that will be used for probing
// stores the key and the chunk_idx, df_idx of the left table
hash_tables: PartitionedMap<K>,
Expand All @@ -45,7 +47,6 @@ pub struct GenericBuild<K: ExtraPayload> {
// amortize allocations
join_columns: Vec<ArrayRef>,
hashes: Vec<u64>,
join_type: JoinType,
// the join order is swapped to ensure we hash the smaller table
swapped: bool,
join_nulls: bool,
Expand All @@ -59,7 +60,7 @@ impl<K: ExtraPayload> GenericBuild<K> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
suffix: Arc<str>,
join_type: JoinType,
join_args: JoinArgs,
swapped: bool,
join_columns_left: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
join_columns_right: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
Expand All @@ -76,7 +77,7 @@ impl<K: ExtraPayload> GenericBuild<K> {
}));
GenericBuild {
chunks: vec![],
join_type,
join_args,
suffix,
hb,
swapped,
Expand Down Expand Up @@ -278,7 +279,7 @@ impl<K: ExtraPayload> Sink for GenericBuild<K> {
fn split(&self, _thread_no: usize) -> Box<dyn Sink> {
let mut new = Self::new(
self.suffix.clone(),
self.join_type.clone(),
self.join_args.clone(),
self.swapped,
self.join_columns_left.clone(),
self.join_columns_right.clone(),
Expand Down Expand Up @@ -317,7 +318,7 @@ impl<K: ExtraPayload> Sink for GenericBuild<K> {
let mut hashes = std::mem::take(&mut self.hashes);
hashes.clear();

match self.join_type {
match self.join_args.how {
JoinType::Inner | JoinType::Left => {
let probe_operator = GenericJoinProbe::new(
left_df,
Expand All @@ -330,13 +331,14 @@ impl<K: ExtraPayload> Sink for GenericBuild<K> {
self.swapped,
hashes,
context,
self.join_type.clone(),
self.join_args.how.clone(),
self.join_nulls,
);
self.placeholder.replace(Box::new(probe_operator));
Ok(FinalizedSink::Operator)
},
JoinType::Outer { coalesce } => {
JoinType::Outer => {
let coalesce = self.join_args.coalesce.coalesce(&JoinType::Outer);
let probe_operator = GenericOuterJoinProbe::new(
left_df,
materialized_join_cols,
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-pipe/src/pipeline/convert.rs
Expand Up @@ -285,12 +285,12 @@ where
};

match jt {
join_type @ JoinType::Inner | join_type @ JoinType::Left => {
JoinType::Inner | JoinType::Left => {
let (join_columns_left, join_columns_right) = swap_eval();

Box::new(GenericBuild::<()>::new(
Arc::from(options.args.suffix()),
join_type.clone(),
options.args.clone(),
swapped,
join_columns_left,
join_columns_right,
Expand All @@ -317,7 +317,7 @@ where

Box::new(GenericBuild::<Tracker>::new(
Arc::from(options.args.suffix()),
jt.clone(),
options.args.clone(),
swapped,
join_columns_left,
join_columns_right,
Expand Down
Expand Up @@ -258,7 +258,8 @@ pub(super) fn process_join(
already_added_local_to_local_projected.insert(local_name);
}
// In outer joins both columns remain. So `add_local=true` also for the right table
let add_local = matches!(options.args.how, JoinType::Outer { coalesce: false });
let add_local = matches!(options.args.how, JoinType::Outer)
&& !options.args.coalesce.coalesce(&options.args.how);
for e in &right_on {
// In case of outer joins we also add the columns.
// But before we do that we must check if the column wasn't already added by the lhs.
Expand Down
13 changes: 8 additions & 5 deletions crates/polars-plan/src/logical_plan/schema.rs
Expand Up @@ -313,11 +313,11 @@ pub(crate) fn det_join_schema(
new_schema.with_column(field.name, field.dtype);
arena.clear();
}
// except in asof joins. Asof joins are not equi-joins
// Except in asof joins. Asof joins are not equi-joins
// so the columns that are joined on, may have different
// values so if the right has a different name, it is added to the schema
#[cfg(feature = "asof_join")]
if !options.args.how.merges_join_keys() {
if !options.args.coalesce.coalesce(&options.args.how) {
for (left_on, right_on) in left_on.iter().zip(right_on) {
let field_left =
left_on.to_field_amortized(schema_left, Context::Default, &mut arena)?;
Expand All @@ -342,10 +342,13 @@ pub(crate) fn det_join_schema(
join_on_right.insert(field.name);
}

let are_coalesced = options.args.coalesce.coalesce(&options.args.how);
let is_asof = options.args.how.is_asof();

// Asof joins are special, if the names are equal they will not be coalesced.
for (name, dtype) in schema_right.iter() {
if !join_on_right.contains(name.as_str()) // The names that are joined on are merged
|| matches!(&options.args.how, JoinType::Outer{coalesce: false})
// The names are not merged
if !join_on_right.contains(name.as_str()) || (!are_coalesced && !is_asof)
// The names that are joined on are merged
{
if schema_left.contains(name.as_str()) {
#[cfg(feature = "asof_join")]
Expand Down

0 comments on commit f0dbb6a

Please sign in to comment.