Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust!): prepare for join coalescing argument #15418

Merged
merged 4 commits into from Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 12 additions & 1 deletion crates/polars-lazy/src/frame/mod.rs
Expand Up @@ -31,6 +31,7 @@ pub use ndjson::*;
pub use parquet::*;
use polars_core::prelude::*;
use polars_io::RowIndex;
use polars_ops::frame::JoinCoalesce;
pub use polars_plan::frame::{AllowedOptimizations, OptState};
use polars_plan::global::FETCH_ROWS;
use smartstring::alias::String as SmartString;
Expand Down Expand Up @@ -1124,7 +1125,7 @@ impl LazyFrame {
other,
[left_on.into()],
[right_on.into()],
JoinArgs::new(JoinType::Outer { coalesce: false }),
JoinArgs::new(JoinType::Outer),
)
}

Expand Down Expand Up @@ -1195,6 +1196,7 @@ impl LazyFrame {
.right_on(right_on)
.how(args.how)
.validate(args.validation)
.coalesce(args.coalesce)
.join_nulls(args.join_nulls);

if let Some(suffix) = args.suffix {
Expand Down Expand Up @@ -1764,6 +1766,7 @@ pub struct JoinBuilder {
force_parallel: bool,
suffix: Option<String>,
validation: JoinValidation,
coalesce: JoinCoalesce,
join_nulls: bool,
}
impl JoinBuilder {
Expand All @@ -1780,6 +1783,7 @@ impl JoinBuilder {
join_nulls: false,
suffix: None,
validation: Default::default(),
coalesce: Default::default(),
}
}

Expand Down Expand Up @@ -1851,6 +1855,12 @@ impl JoinBuilder {
self
}

/// Whether to coalesce join columns.
pub fn coalesce(mut self, coalesce: JoinCoalesce) -> Self {
self.coalesce = coalesce;
self
}

/// Finish builder
pub fn finish(self) -> LazyFrame {
let mut opt_state = self.lf.opt_state;
Expand All @@ -1865,6 +1875,7 @@ impl JoinBuilder {
suffix: self.suffix,
slice: None,
join_nulls: self.join_nulls,
coalesce: self.coalesce,
};

let lp = self
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-lazy/src/tests/streaming.rs
@@ -1,3 +1,5 @@
use polars_ops::frame::JoinCoalesce;

use super::*;

fn get_csv_file() -> LazyFrame {
Expand Down Expand Up @@ -295,7 +297,8 @@ fn test_streaming_partial() -> PolarsResult<()> {
.left_on([col("a")])
.right_on([col("a")])
.suffix("_foo")
.how(JoinType::Outer { coalesce: true })
.how(JoinType::Outer)
.coalesce(JoinCoalesce::CoalesceColumns)
.finish();

let q = q.left_join(
Expand Down
66 changes: 51 additions & 15 deletions crates/polars-ops/src/frame/join/args.rs
Expand Up @@ -26,6 +26,36 @@ pub struct JoinArgs {
pub suffix: Option<String>,
pub slice: Option<(i64, usize)>,
pub join_nulls: bool,
pub coalesce: JoinCoalesce,
}

#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum JoinCoalesce {
#[default]
JoinSpecific,
CoalesceColumns,
KeepColumns,
}

impl JoinCoalesce {
pub fn coalesce(&self, join_type: &JoinType) -> bool {
use JoinCoalesce::*;
use JoinType::*;
match join_type {
Left | Inner => {
matches!(self, JoinSpecific | CoalesceColumns)
},
Outer { .. } => {
matches!(self, CoalesceColumns)
},
#[cfg(feature = "asof_join")]
AsOf(_) => false,
Cross => false,
#[cfg(feature = "semi_anti_join")]
Semi | Anti => false,
}
}
}

impl Default for JoinArgs {
Expand All @@ -36,6 +66,7 @@ impl Default for JoinArgs {
suffix: None,
slice: None,
join_nulls: false,
coalesce: Default::default(),
}
}
}
Expand All @@ -48,9 +79,15 @@ impl JoinArgs {
suffix: None,
slice: None,
join_nulls: false,
coalesce: Default::default(),
}
}

pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
self.coalesce = coalesce;
self
}

pub fn suffix(&self) -> &str {
self.suffix.as_deref().unwrap_or("_right")
}
Expand All @@ -61,9 +98,7 @@ impl JoinArgs {
pub enum JoinType {
Left,
Inner,
Outer {
coalesce: bool,
},
Outer,
#[cfg(feature = "asof_join")]
AsOf(AsOfOptions),
Cross,
Expand All @@ -73,18 +108,6 @@ pub enum JoinType {
Anti,
}

impl JoinType {
pub fn merges_join_keys(&self) -> bool {
match self {
Self::Outer { coalesce } => *coalesce,
// Merges them if they are equal
#[cfg(feature = "asof_join")]
Self::AsOf(_) => false,
_ => true,
}
}
}

impl From<JoinType> for JoinArgs {
fn from(value: JoinType) -> Self {
JoinArgs::new(value)
Expand Down Expand Up @@ -116,6 +139,19 @@ impl Debug for JoinType {
}
}

impl JoinType {
pub fn is_asof(&self) -> bool {
#[cfg(feature = "asof_join")]
{
matches!(self, JoinType::AsOf(_))
}
#[cfg(not(feature = "asof_join"))]
{
false
}
}
}

#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum JoinValidation {
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-ops/src/frame/join/hash_join/mod.rs
Expand Up @@ -271,9 +271,7 @@ pub trait JoinDispatch: IntoDf {
|| unsafe { other.take_unchecked(&idx_ca_r) },
);

let JoinType::Outer { coalesce } = args.how else {
unreachable!()
};
let coalesce = args.coalesce.coalesce(&JoinType::Outer);
let out = _finish_join(df_left, df_right, args.suffix.as_deref());
if coalesce {
Ok(_coalesce_outer_join(
Expand Down
18 changes: 6 additions & 12 deletions crates/polars-ops/src/frame/join/mod.rs
Expand Up @@ -209,9 +209,7 @@ pub trait DataFrameJoinOps: IntoDf {
JoinType::Left => {
left_df._left_join_from_series(other, s_left, s_right, args, _verbose, None)
},
JoinType::Outer { .. } => {
left_df._outer_join_from_series(other, s_left, s_right, args)
},
JoinType::Outer => left_df._outer_join_from_series(other, s_left, s_right, args),
#[cfg(feature = "semi_anti_join")]
JoinType::Anti => left_df._semi_anti_join_from_series(
s_left,
Expand Down Expand Up @@ -278,13 +276,14 @@ pub trait DataFrameJoinOps: IntoDf {
JoinType::Cross => {
unreachable!()
},
JoinType::Outer { coalesce } => {
JoinType::Outer => {
let names_left = selected_left.iter().map(|s| s.name()).collect::<Vec<_>>();
args.how = JoinType::Outer { coalesce: false };
let coalesce = args.coalesce;
args.coalesce = JoinCoalesce::KeepColumns;
let suffix = args.suffix.clone();
let out = left_df._outer_join_from_series(other, &lhs_keys, &rhs_keys, args);

if coalesce {
if coalesce.coalesce(&JoinType::Outer) {
Ok(_coalesce_outer_join(
out?,
&names_left,
Expand Down Expand Up @@ -411,12 +410,7 @@ pub trait DataFrameJoinOps: IntoDf {
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.join(
other,
left_on,
right_on,
JoinArgs::new(JoinType::Outer { coalesce: false }),
)
self.join(other, left_on, right_on, JoinArgs::new(JoinType::Outer))
}
}

Expand Down
16 changes: 9 additions & 7 deletions crates/polars-pipe/src/executors/sinks/joins/generic_build.rs
Expand Up @@ -5,6 +5,7 @@ use hashbrown::hash_map::RawEntryMut;
use polars_core::export::ahash::RandomState;
use polars_core::prelude::*;
use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked};
use polars_ops::prelude::JoinArgs;
use polars_utils::arena::Node;
use polars_utils::slice::GetSaferUnchecked;
use polars_utils::unitvec;
Expand Down Expand Up @@ -34,6 +35,7 @@ pub struct GenericBuild<K: ExtraPayload> {
materialized_join_cols: Vec<BinaryArray<i64>>,
suffix: Arc<str>,
hb: RandomState,
join_args: JoinArgs,
// partitioned tables that will be used for probing
// stores the key and the chunk_idx, df_idx of the left table
hash_tables: PartitionedMap<K>,
Expand All @@ -45,7 +47,6 @@ pub struct GenericBuild<K: ExtraPayload> {
// amortize allocations
join_columns: Vec<ArrayRef>,
hashes: Vec<u64>,
join_type: JoinType,
// the join order is swapped to ensure we hash the smaller table
swapped: bool,
join_nulls: bool,
Expand All @@ -59,7 +60,7 @@ impl<K: ExtraPayload> GenericBuild<K> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
suffix: Arc<str>,
join_type: JoinType,
join_args: JoinArgs,
swapped: bool,
join_columns_left: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
join_columns_right: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
Expand All @@ -76,7 +77,7 @@ impl<K: ExtraPayload> GenericBuild<K> {
}));
GenericBuild {
chunks: vec![],
join_type,
join_args,
suffix,
hb,
swapped,
Expand Down Expand Up @@ -278,7 +279,7 @@ impl<K: ExtraPayload> Sink for GenericBuild<K> {
fn split(&self, _thread_no: usize) -> Box<dyn Sink> {
let mut new = Self::new(
self.suffix.clone(),
self.join_type.clone(),
self.join_args.clone(),
self.swapped,
self.join_columns_left.clone(),
self.join_columns_right.clone(),
Expand Down Expand Up @@ -317,7 +318,7 @@ impl<K: ExtraPayload> Sink for GenericBuild<K> {
let mut hashes = std::mem::take(&mut self.hashes);
hashes.clear();

match self.join_type {
match self.join_args.how {
JoinType::Inner | JoinType::Left => {
let probe_operator = GenericJoinProbe::new(
left_df,
Expand All @@ -330,13 +331,14 @@ impl<K: ExtraPayload> Sink for GenericBuild<K> {
self.swapped,
hashes,
context,
self.join_type.clone(),
self.join_args.how.clone(),
self.join_nulls,
);
self.placeholder.replace(Box::new(probe_operator));
Ok(FinalizedSink::Operator)
},
JoinType::Outer { coalesce } => {
JoinType::Outer => {
let coalesce = self.join_args.coalesce.coalesce(&JoinType::Outer);
let probe_operator = GenericOuterJoinProbe::new(
left_df,
materialized_join_cols,
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-pipe/src/pipeline/convert.rs
Expand Up @@ -285,12 +285,12 @@ where
};

match jt {
join_type @ JoinType::Inner | join_type @ JoinType::Left => {
JoinType::Inner | JoinType::Left => {
let (join_columns_left, join_columns_right) = swap_eval();

Box::new(GenericBuild::<()>::new(
Arc::from(options.args.suffix()),
join_type.clone(),
options.args.clone(),
swapped,
join_columns_left,
join_columns_right,
Expand All @@ -317,7 +317,7 @@ where

Box::new(GenericBuild::<Tracker>::new(
Arc::from(options.args.suffix()),
jt.clone(),
options.args.clone(),
swapped,
join_columns_left,
join_columns_right,
Expand Down
Expand Up @@ -258,7 +258,8 @@ pub(super) fn process_join(
already_added_local_to_local_projected.insert(local_name);
}
// In outer joins both columns remain. So `add_local=true` also for the right table
let add_local = matches!(options.args.how, JoinType::Outer { coalesce: false });
let add_local = matches!(options.args.how, JoinType::Outer)
&& !options.args.coalesce.coalesce(&options.args.how);
for e in &right_on {
// In case of outer joins we also add the columns.
// But before we do that we must check if the column wasn't already added by the lhs.
Expand Down
13 changes: 8 additions & 5 deletions crates/polars-plan/src/logical_plan/schema.rs
Expand Up @@ -313,11 +313,11 @@ pub(crate) fn det_join_schema(
new_schema.with_column(field.name, field.dtype);
arena.clear();
}
// except in asof joins. Asof joins are not equi-joins
// Except in asof joins. Asof joins are not equi-joins
// so the columns that are joined on, may have different
// values so if the right has a different name, it is added to the schema
#[cfg(feature = "asof_join")]
if !options.args.how.merges_join_keys() {
if !options.args.coalesce.coalesce(&options.args.how) {
for (left_on, right_on) in left_on.iter().zip(right_on) {
let field_left =
left_on.to_field_amortized(schema_left, Context::Default, &mut arena)?;
Expand All @@ -342,10 +342,13 @@ pub(crate) fn det_join_schema(
join_on_right.insert(field.name);
}

let are_coalesced = options.args.coalesce.coalesce(&options.args.how);
let is_asof = options.args.how.is_asof();

// Asof joins are special, if the names are equal they will not be coalesced.
for (name, dtype) in schema_right.iter() {
if !join_on_right.contains(name.as_str()) // The names that are joined on are merged
|| matches!(&options.args.how, JoinType::Outer{coalesce: false})
// The names are not merged
if !join_on_right.contains(name.as_str()) || (!are_coalesced && !is_asof)
// The names that are joined on are merged
{
if schema_left.contains(name.as_str()) {
#[cfg(feature = "asof_join")]
Expand Down