Skip to content

Commit

Permalink
allow joining on expressions (#4029)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 15, 2022
1 parent 9545dd5 commit fd77f1f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 26 deletions.
50 changes: 26 additions & 24 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1076,32 +1076,34 @@ impl DataFrame {

/// Add a new column to this `DataFrame` or replace an existing one.
pub fn with_column<S: IntoSeries>(&mut self, column: S) -> Result<&mut Self> {
let mut series = column.into_series();

let height = self.height();
if series.len() == 1 && height > 1 {
series = series.expand_at_index(0, height);
}
fn inner(df: &mut DataFrame, mut series: Series) -> Result<&mut DataFrame> {
let height = df.height();
if series.len() == 1 && height > 1 {
series = series.expand_at_index(0, height);
}

if series.len() == height || self.is_empty() {
self.add_column_by_search(series)?;
Ok(self)
}
// special case for literals
else if height == 0 && series.len() == 1 {
let s = series.slice(0, 0);
self.add_column_by_search(s)?;
Ok(self)
} else {
Err(PolarsError::ShapeMisMatch(
format!(
"Could not add column. The Series length {} differs from the DataFrame height: {}",
series.len(),
self.height()
)
.into(),
))
if series.len() == height || df.is_empty() {
df.add_column_by_search(series)?;
Ok(df)
}
// special case for literals
else if height == 0 && series.len() == 1 {
let s = series.slice(0, 0);
df.add_column_by_search(s)?;
Ok(df)
} else {
Err(PolarsError::ShapeMisMatch(
format!(
"Could not add column. The Series length {} differs from the DataFrame height: {}",
series.len(),
df.height()
)
.into(),
))
}
}
let series = column.into_series();
inner(self, series)
}

fn add_column_by_schema(&mut self, s: Series, schema: &Schema) -> Result<()> {
Expand Down
12 changes: 10 additions & 2 deletions polars/polars-lazy/src/physical_plan/executors/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ impl Executor for JoinExec {
(input_left.execute(state), input_right.execute(state))
};

let df_left = df_left?;
let df_right = df_right?;
let mut df_left = df_left?;
let mut df_right = df_right?;

let left_on_series = self
.left_on
Expand All @@ -85,6 +85,14 @@ impl Executor for JoinExec {
.map(|e| e.evaluate(&df_right, state))
.collect::<Result<Vec<_>>>()?;

// make sure that we can join on evaluated expressions
for s in &left_on_series {
df_left.with_column(s.clone())?;
}
for s in &right_on_series {
df_right.with_column(s.clone())?;
}

// prepare the tolerance
// we must ensure that we use the right units
#[cfg(feature = "asof_join")]
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,13 @@ def test_deprecated() -> None:
df.lazy().join(other=other.lazy(), on="a").collect().to_numpy(),
result.to_numpy(),
)


def test_join_on_expressions() -> None:
df_a = pl.DataFrame({"a": [1, 2, 3]})

df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]})

assert df_a.join(df_b, left_on=(pl.col("a") ** 2).cast(int), right_on=pl.col("b"))[
"a"
].to_list() == [1, 4, 9, 9]

0 comments on commit fd77f1f

Please sign in to comment.