Skip to content

Commit

Permalink
refactor[rust]: fix polars-sql compilation (#4947)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 23, 2022
1 parent 2d527ed commit 4421c6e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
23 changes: 11 additions & 12 deletions polars-sql/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
use std::collections::HashMap;

use crate::sql_expr::parse_sql_expr;
use polars::error::PolarsError;
use polars::prelude::{col, DataFrame, IntoLazy, LazyFrame};
use polars::error::PolarsResult;
use polars::prelude::*;
use sqlparser::ast::{
Expr as SqlExpr, Select, SelectItem, SetExpr, Statement, TableFactor, Value as SQLValue,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;

use crate::sql_expr::parse_sql_expr;

#[derive(Default)]
pub struct SQLContext {
table_map: HashMap<String, LazyFrame>,
table_map: PlHashMap<String, LazyFrame>,
dialect: GenericDialect,
}

impl SQLContext {
pub fn new() -> Self {
Self {
table_map: HashMap::new(),
table_map: PlHashMap::new(),
dialect: GenericDialect::default(),
}
}
Expand All @@ -27,11 +26,11 @@ impl SQLContext {
self.table_map.insert(name.to_owned(), df.clone().lazy());
}

fn execute_select(&self, select_stmt: &Select) -> PolarsResult<LazyFrame, PolarsError> {
fn execute_select(&self, select_stmt: &Select) -> PolarsResult<LazyFrame> {
// Determine involved dataframe
// Implicit join require some more work in query parsers, Explicit join are preferred for now.
let tbl = select_stmt.from.get(0).unwrap();
let mut alias_map = HashMap::new();
let mut alias_map = PlHashMap::new();
let tbl_name = match &tbl.relation {
TableFactor::Table { name, alias, .. } => {
let tbl_name = name.0.get(0).unwrap().value.as_str();
Expand All @@ -50,7 +49,7 @@ impl SQLContext {
_ => return Err(PolarsError::ComputeError("Not implemented".into())),
};
let df = &self.table_map[tbl_name];
let mut raw_projection_before_alias: HashMap<String, usize> = HashMap::new();
let mut raw_projection_before_alias: PlHashMap<String, usize> = PlHashMap::new();
let mut contain_wildcard = false;
// Filter Expression
let df = match select_stmt.selection.as_ref() {
Expand Down Expand Up @@ -146,14 +145,14 @@ impl SQLContext {
final_proj_pos.sort_by(|(proj_pa, _), (proj_pb, _)| proj_pa.cmp(proj_pb));
let final_proj = final_proj_pos
.into_iter()
.map(|(_, shm_p)| col(agg_df.schema().get_index(shm_p).unwrap().0))
.map(|(_, shm_p)| col(agg_df.schema().unwrap().get_index(shm_p).unwrap().0))
.collect::<Vec<_>>();
agg_df.select(final_proj)
};
Ok(df)
}

pub fn execute(&self, query: &str) -> PolarsResult<LazyFrame, PolarsError> {
pub fn execute(&self, query: &str) -> PolarsResult<LazyFrame> {
let ast = Parser::parse_sql(&self.dialect, query)
.map_err(|e| PolarsError::ComputeError(format!("{:?}", e).into()))?;
if ast.len() != 1 {
Expand Down
3 changes: 2 additions & 1 deletion polars-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ mod sql_expr;

#[cfg(test)]
mod test {
use super::*;
use polars::prelude::*;

use super::*;

fn create_sample_df() -> PolarsResult<DataFrame> {
let a = Series::new("a", (1..10000i64).map(|i| i / 100).collect::<Vec<_>>());
let b = Series::new("b", 1..10000i64);
Expand Down
7 changes: 4 additions & 3 deletions polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use polars::error::PolarsError;
use polars::prelude::{col, lit, DataType, Expr, LiteralValue, Result, TimeUnit};

use polars::prelude::*;
use sqlparser::ast::{
BinaryOperator as SQLBinaryOperator, DataType as SQLDataType, Expr as SqlExpr,
Function as SQLFunction, Value as SqlValue, WindowSpec,
Expand Down Expand Up @@ -120,6 +119,8 @@ pub(crate) fn parse_sql_expr(expr: &SqlExpr) -> PolarsResult<Expr> {
SqlExpr::Cast { expr, data_type } => cast_(parse_sql_expr(expr)?, data_type)?,
SqlExpr::Nested(expr) => parse_sql_expr(expr)?,
SqlExpr::Value(value) => literal_expr(value)?,
SqlExpr::IsNull(expr) => parse_sql_expr(expr)?.is_null(),
SqlExpr::IsNotNull(expr) => parse_sql_expr(expr)?.is_not_null(),
_ => {
return Err(PolarsError::ComputeError(
format!(
Expand All @@ -140,7 +141,7 @@ fn apply_window_spec(expr: Expr, window_spec: &Option<WindowSpec>) -> PolarsResu
.partition_by
.iter()
.map(parse_sql_expr)
.collect::<Result<Vec<_>>>()?;
.collect::<PolarsResult<Vec<_>>>()?;
expr.over(partition_by)
// Order by and Row range may not be supported at the moment
}
Expand Down

0 comments on commit 4421c6e

Please sign in to comment.