Skip to content

Commit

Permalink
feat[sql]: compounded names in join operation and is between expressi…
Browse files Browse the repository at this point in the history
…ons (#4971)
  • Loading branch information
ritchie46 committed Sep 25, 2022
1 parent 5768efa commit 548984b
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 32 deletions.
33 changes: 9 additions & 24 deletions polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use polars::error::PolarsResult;
use polars::prelude::*;
use polars_lazy::utils::expressions_to_schema;
use sqlparser::ast::{
Expr as SqlExpr, JoinConstraint, JoinOperator, OrderByExpr, Select, SelectItem, SetExpr,
Statement, TableFactor, TableWithJoins, Value as SQLValue,
Expr as SqlExpr, JoinOperator, OrderByExpr, Select, SelectItem, SetExpr, Statement,
TableFactor, TableWithJoins, Value as SQLValue,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;

use crate::sql_expr::parse_sql_expr;
use crate::sql_expr::{parse_sql_expr, process_join_constraint};

#[derive(Default, Clone)]
pub struct SQLContext {
Expand Down Expand Up @@ -78,15 +78,18 @@ impl SQLContext {
let join_tbl = self.get_table(join_tbl_name)?;
match &tbl.join_operator {
JoinOperator::Inner(constraint) => {
let (left_on, right_on) = process_join_constraint(&constraint)?;
let (left_on, right_on) =
process_join_constraint(&constraint, tbl_name, join_tbl_name)?;
lf = lf.inner_join(join_tbl, left_on, right_on)
}
JoinOperator::LeftOuter(constraint) => {
let (left_on, right_on) = process_join_constraint(&constraint)?;
let (left_on, right_on) =
process_join_constraint(&constraint, tbl_name, join_tbl_name)?;
lf = lf.left_join(join_tbl, left_on, right_on)
}
JoinOperator::FullOuter(constraint) => {
let (left_on, right_on) = process_join_constraint(&constraint)?;
let (left_on, right_on) =
process_join_constraint(&constraint, tbl_name, join_tbl_name)?;
lf = lf.outer_join(join_tbl, left_on, right_on)
}
JoinOperator::CrossJoin => lf = lf.cross_join(join_tbl),
Expand Down Expand Up @@ -282,21 +285,3 @@ impl SQLContext {
Ok(aggregated.select(final_projection))
}
}

fn process_join_constraint(constraint: &JoinConstraint) -> PolarsResult<(Expr, Expr)> {
if let JoinConstraint::On(expr) = constraint {
let expr = parse_sql_expr(expr)?;
if let Expr::BinaryExpr { left, right, op } = expr {
if let Operator::Eq = op {
return Ok((*left.clone(), *right.clone()));
}
}
}
Err(PolarsError::ComputeError(
format!(
"Join constraint {:?} not yet supported in polars-sql",
constraint
)
.into(),
))
}
85 changes: 83 additions & 2 deletions polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use polars::error::PolarsError;
use polars::prelude::*;
use sqlparser::ast::{
BinaryOperator as SQLBinaryOperator, DataType as SQLDataType, Expr as SqlExpr,
Function as SQLFunction, Value as SqlValue, WindowSpec,
BinaryOperator as SQLBinaryOperator, BinaryOperator, DataType as SQLDataType, Expr as SqlExpr,
Function as SQLFunction, JoinConstraint, Value as SqlValue, WindowSpec,
};

fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<DataType> {
Expand Down Expand Up @@ -121,6 +121,22 @@ pub(crate) fn parse_sql_expr(expr: &SqlExpr) -> PolarsResult<Expr> {
SqlExpr::Value(value) => literal_expr(value)?,
SqlExpr::IsNull(expr) => parse_sql_expr(expr)?.is_null(),
SqlExpr::IsNotNull(expr) => parse_sql_expr(expr)?.is_not_null(),
SqlExpr::Between {
expr,
negated,
low,
high,
} => {
let expr = parse_sql_expr(expr)?;
let low = parse_sql_expr(low)?;
let high = parse_sql_expr(high)?;

if *negated {
expr.clone().lt(low).or(expr.gt(high))
} else {
expr.clone().gt(low).and(expr.lt(high))
}
}
_ => {
return Err(PolarsError::ComputeError(
format!(
Expand Down Expand Up @@ -201,3 +217,68 @@ fn parse_sql_function(sql_function: &SQLFunction) -> PolarsResult<Expr> {
))
}
}

pub(super) fn process_join_constraint(
constraint: &JoinConstraint,
left_name: &str,
right_name: &str,
) -> PolarsResult<(Expr, Expr)> {
if let JoinConstraint::On(expr) = constraint {
if let SqlExpr::BinaryOp { left, op, right } = expr {
match (left.as_ref(), right.as_ref()) {
(SqlExpr::CompoundIdentifier(left), SqlExpr::CompoundIdentifier(right)) => {
if left.len() == 2 && right.len() == 2 {
let tbl_a = &left[0].value;
let col_a = &left[1].value;

let tbl_b = &right[0].value;
let col_b = &right[1].value;

if let BinaryOperator::Eq = op {
if left_name == tbl_a && right_name == tbl_b {
return Ok((col(col_a), col(col_b)));
} else if left_name == tbl_b && right_name == tbl_a {
return Ok((col(col_b), col(col_a)));
}
}
}
}
(SqlExpr::Identifier(left), SqlExpr::Identifier(right)) => {
return Ok((col(&left.value), col(&right.value)))
}
_ => {}
}
// if let (SqlExpr::CompoundIdentifier(left), SqlExpr::CompoundIdentifier(right)) = (left.as_ref(), right.as_ref()) {
// if left.len() == 2 && right.len() == 2 {
// let tbl_a = &left[0].value;
// let col_a = &left[1].value;
//
// let tbl_b = &right[0].value;
// let col_b = &right[1].value;
//
// if let BinaryOperator::Eq = op {
// if left_name == tbl_a && right_name == tbl_b {
// return Ok((col(col_a), col(col_b)))
// } else if left_name == tbl_b && right_name == tbl_a {
// return Ok((col(col_b), col(col_a)))
// }
// }
// }
// }
// };
//
// let expr = parse_sql_expr(expr)?;
// if let Expr::BinaryExpr { left, right, op } = expr {
// if let Operator::Eq = op {
// return Ok((*left.clone(), *right.clone()));
// }
}
}
Err(PolarsError::ComputeError(
format!(
"Join constraint {:?} not yet supported in polars-sql",
constraint
)
.into(),
))
}
10 changes: 5 additions & 5 deletions py-polars/docs/source/reference/sql.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

============
Utils
============
===
SQL
===
.. currentmodule:: polars

.. autosummary::
Expand All @@ -13,8 +13,8 @@ Utils

Run SQL query against a LazyFrame.

Attributes
----------
Methods
-------

.. autosummary::
:toctree: api/
Expand Down
39 changes: 38 additions & 1 deletion py-polars/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_sql_join(foods_ipc: str) -> None:
out = c.query(
"""
SELECT * FROM
foods1 INNER JOIN foods2 ON category = category
foods1 INNER JOIN foods2 ON foods1.category = foods2.category
LIMIT 2
"""
)
Expand All @@ -52,3 +52,40 @@ def test_sql_join(foods_ipc: str) -> None:
"fats_g_right": [0.5, 0.5],
"sugars_g_right": [2, 2],
}


def test_sql_is_between(foods_ipc: str) -> None:
c = pl.SQLContext()

lf = pl.scan_ipc(foods_ipc)
c.register("foods1", lf)

out = c.query(
"""
SELECT * FROM foods1
WHERE calories BETWEEN 20 AND 31
LIMIT 4
"""
)

assert out.to_dict(False) == {
"category": ["fruit", "vegetables", "fruit", "vegetables"],
"calories": [30, 25, 30, 22],
"fats_g": [0.0, 0.0, 0.0, 0.0],
"sugars_g": [5, 2, 3, 3],
}

out = c.query(
"""
SELECT * FROM foods1
WHERE calories NOT BETWEEN 20 AND 31
LIMIT 4
"""
)

assert out.to_dict(False) == {
"category": ["vegetables", "seafood", "meat", "fruit"],
"calories": [45, 150, 100, 60],
"fats_g": [0.5, 5.0, 5.0, 0.0],
"sugars_g": [2, 0, 0, 11],
}

0 comments on commit 548984b

Please sign in to comment.