-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[polars-sql] Adding SQL Context, SELECT and GROUP BY (#3024)
Initializes a SQL frontend for Polars
- Loading branch information
Showing
5 changed files
with
548 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
[package] | ||
name = "polars-sql" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
serde = "1" | ||
serde_json = { version = "1" } | ||
sqlparser = { version = "0.15.0" } | ||
|
||
[dependencies.polars] | ||
path = "../polars" | ||
default-features = false | ||
features = [ | ||
"dynamic_groupby", | ||
"zip_with", | ||
"simd", | ||
"lazy", | ||
"strings", | ||
"temporal", | ||
"random", | ||
"object", | ||
"csv-file", | ||
"fmt", | ||
"performant", | ||
"dtype-full", | ||
"rows", | ||
"private", | ||
"round_series", | ||
"is_first", | ||
"asof_join", | ||
"cross_join", | ||
"dot_product", | ||
"concat_str", | ||
"row_hash", | ||
"reinterpret", | ||
"decompress-fast", | ||
"mode", | ||
"extract_jsonpath", | ||
"lazy_regex", | ||
"cum_agg", | ||
"rolling_window", | ||
"interpolate", | ||
"list", | ||
"rank", | ||
"diff", | ||
"pct_change", | ||
"moment", | ||
"arange", | ||
"true_div", | ||
"dtype-categorical", | ||
"diagonal_concat", | ||
"horizontal_concat", | ||
"abs", | ||
"ewma", | ||
"dot_diagram", | ||
"dataframe_arithmetic", | ||
"json", | ||
"string_encoding", | ||
"product", | ||
"ndarray", | ||
"series_from_anyvalue", | ||
"avro", | ||
"parquet", | ||
"ipc", | ||
"is_in", | ||
"serde", | ||
] | ||
|
||
[workspace] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
nightly |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
use std::collections::HashMap; | ||
|
||
use crate::sql_expr::parse_sql_expr; | ||
use polars::error::PolarsError; | ||
use polars::prelude::{col, DataFrame, IntoLazy, LazyFrame}; | ||
use sqlparser::ast::{ | ||
Expr as SqlExpr, Select, SelectItem, SetExpr, Statement, TableFactor, Value as SQLValue, | ||
}; | ||
use sqlparser::dialect::GenericDialect; | ||
use sqlparser::parser::Parser; | ||
|
||
#[derive(Default)] | ||
pub struct SQLContext { | ||
table_map: HashMap<String, LazyFrame>, | ||
dialect: GenericDialect, | ||
} | ||
|
||
impl SQLContext { | ||
pub fn new() -> Self { | ||
Self { | ||
table_map: HashMap::new(), | ||
dialect: GenericDialect::default(), | ||
} | ||
} | ||
|
||
pub fn register(&mut self, name: &str, df: &DataFrame) { | ||
self.table_map.insert(name.to_owned(), df.clone().lazy()); | ||
} | ||
|
||
fn execute_select(&self, select_stmt: &Select) -> Result<LazyFrame, PolarsError> { | ||
// Determine involved dataframe | ||
// Implicit join require some more work in query parsers, Explicit join are preferred for now. | ||
let tbl = select_stmt.from.get(0).unwrap(); | ||
let mut alias_map = HashMap::new(); | ||
let tbl_name = match &tbl.relation { | ||
TableFactor::Table { name, alias, .. } => { | ||
let tbl_name = name.0.get(0).unwrap().value.as_str(); | ||
if self.table_map.contains_key(tbl_name) { | ||
if let Some(alias) = alias { | ||
alias_map.insert(alias.name.value.clone(), tbl_name.to_owned()); | ||
}; | ||
tbl_name | ||
} else { | ||
return Err(PolarsError::ComputeError( | ||
format!("Table name {tbl_name} was not found").into(), | ||
)); | ||
} | ||
} | ||
// Support bare table, optional with alias for now | ||
_ => return Err(PolarsError::ComputeError("Not implemented".into())), | ||
}; | ||
let df = &self.table_map[tbl_name]; | ||
let mut raw_projection_before_alias: HashMap<String, usize> = HashMap::new(); | ||
let mut contain_wildcard = false; | ||
// Filter Expression | ||
let df = match select_stmt.selection.as_ref() { | ||
Some(expr) => { | ||
let filter_expression = parse_sql_expr(expr)?; | ||
df.clone().filter(filter_expression) | ||
} | ||
None => df.clone(), | ||
}; | ||
// Column Projections | ||
let projection = select_stmt | ||
.projection | ||
.iter() | ||
.enumerate() | ||
.map(|(i, select_item)| { | ||
Ok(match select_item { | ||
SelectItem::UnnamedExpr(expr) => { | ||
let expr = parse_sql_expr(expr)?; | ||
raw_projection_before_alias.insert(format!("{:?}", expr), i); | ||
expr | ||
} | ||
SelectItem::ExprWithAlias { expr, alias } => { | ||
let expr = parse_sql_expr(expr)?; | ||
raw_projection_before_alias.insert(format!("{:?}", expr), i); | ||
expr.alias(&alias.value) | ||
} | ||
SelectItem::QualifiedWildcard(_) | SelectItem::Wildcard => { | ||
contain_wildcard = true; | ||
col("*") | ||
} | ||
}) | ||
}) | ||
.collect::<Result<Vec<_>, PolarsError>>()?; | ||
// Check for group by | ||
// After projection since there might be number. | ||
let group_by = select_stmt | ||
.group_by | ||
.iter() | ||
.map( | ||
|e|match e { | ||
SqlExpr::Value(SQLValue::Number(idx, _)) => { | ||
let idx = match idx.parse::<usize>() { | ||
Ok(0)| Err(_) => Err( | ||
PolarsError::ComputeError( | ||
format!("Group By Error: Only positive number or expression are supported, got {idx}").into() | ||
)), | ||
Ok(idx) => Ok(idx) | ||
}?; | ||
Ok(projection[idx].clone()) | ||
} | ||
SqlExpr::Value(_) => Err( | ||
PolarsError::ComputeError("Group By Error: Only positive number or expression are supported".into()) | ||
), | ||
_ => parse_sql_expr(e) | ||
} | ||
) | ||
.collect::<Result<Vec<_>, PolarsError>>()?; | ||
|
||
let df = if group_by.is_empty() { | ||
df.select(projection) | ||
} else { | ||
// check groupby and projection due to difference between SQL and polars | ||
// Return error on wild card, shouldn't process this | ||
if contain_wildcard { | ||
return Err(PolarsError::ComputeError( | ||
"Group By Error: Can't processed wildcard in groupby".into(), | ||
)); | ||
} | ||
// Default polars group by will have group by columns at the front | ||
// need some container to contain position of group by columns and its position | ||
// at the final agg projection, check the schema for the existant of group by column | ||
// and its projections columns, keeping the original index | ||
let (exclude_expr, groupby_pos): (Vec<_>, Vec<_>) = group_by | ||
.iter() | ||
.map(|expr| raw_projection_before_alias.get(&format!("{:?}", expr))) | ||
.enumerate() | ||
.filter(|(_, proj_p)| proj_p.is_some()) | ||
.map(|(gb_p, proj_p)| (*proj_p.unwrap(), (*proj_p.unwrap(), gb_p))) | ||
.unzip(); | ||
let (agg_projection, agg_proj_pos): (Vec<_>, Vec<_>) = projection | ||
.iter() | ||
.enumerate() | ||
.filter(|(i, _)| !exclude_expr.contains(i)) | ||
.enumerate() | ||
.map(|(agg_pj, (proj_p, expr))| (expr.clone(), (proj_p, agg_pj + group_by.len()))) | ||
.unzip(); | ||
let agg_df = df.groupby(group_by).agg(agg_projection); | ||
let mut final_proj_pos = groupby_pos | ||
.into_iter() | ||
.chain(agg_proj_pos.into_iter()) | ||
.collect::<Vec<_>>(); | ||
|
||
final_proj_pos.sort_by(|(proj_pa, _), (proj_pb, _)| proj_pa.cmp(proj_pb)); | ||
let final_proj = final_proj_pos | ||
.into_iter() | ||
.map(|(_, shm_p)| col(agg_df.schema().get_index(shm_p).unwrap().0)) | ||
.collect::<Vec<_>>(); | ||
agg_df.select(final_proj) | ||
}; | ||
Ok(df) | ||
} | ||
|
||
pub fn execute(&self, query: &str) -> Result<LazyFrame, PolarsError> { | ||
let ast = Parser::parse_sql(&self.dialect, query) | ||
.map_err(|e| PolarsError::ComputeError(format!("{:?}", e).into()))?; | ||
if ast.len() != 1 { | ||
Err(PolarsError::ComputeError( | ||
"One and only one statement at a time please".into(), | ||
)) | ||
} else { | ||
let ast = ast.get(0).unwrap(); | ||
Ok(match ast { | ||
Statement::Query(query) => { | ||
let rs = match &query.body { | ||
SetExpr::Select(select_stmt) => self.execute_select(&*select_stmt)?, | ||
_ => { | ||
return Err(PolarsError::ComputeError( | ||
"INSERT, UPDATE is not supported for polars".into(), | ||
)) | ||
} | ||
}; | ||
match &query.limit { | ||
Some(SqlExpr::Value(SQLValue::Number(nrow, _))) => { | ||
let nrow = nrow.parse().map_err(|err| { | ||
PolarsError::ComputeError( | ||
format!("Conversion Error: {:?}", err).into(), | ||
) | ||
})?; | ||
rs.limit(nrow) | ||
} | ||
None => rs, | ||
_ => { | ||
return Err(PolarsError::ComputeError( | ||
"Only support number argument to LIMIT clause".into(), | ||
)) | ||
} | ||
} | ||
} | ||
_ => { | ||
return Err(PolarsError::ComputeError( | ||
format!("Statement type {:?} is not supported", ast).into(), | ||
)) | ||
} | ||
}) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
pub use context::SQLContext; | ||
|
||
mod context; | ||
mod sql_expr; | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use super::*; | ||
use polars::prelude::*; | ||
|
||
fn create_sample_df() -> Result<DataFrame> { | ||
let a = Series::new("a", (1..10000i64).map(|i| i / 100).collect::<Vec<_>>()); | ||
let b = Series::new("b", 1..10000i64); | ||
DataFrame::new(vec![a, b]) | ||
} | ||
|
||
#[test] | ||
fn test_simple_select() -> Result<()> { | ||
let df = create_sample_df()?; | ||
let mut context = SQLContext::new(); | ||
context.register("df", &df); | ||
let df_sql = context | ||
.execute( | ||
r#" | ||
SELECT a, b, a + b as c | ||
FROM df | ||
where a > 10 and a < 20 | ||
LIMIT 100 | ||
"#, | ||
)? | ||
.collect()?; | ||
let df_pl = df | ||
.lazy() | ||
.filter(col("a").gt(lit(10)).and(col("a").lt(lit(20)))) | ||
.select(&[col("a"), col("b"), (col("a") + col("b")).alias("c")]) | ||
.limit(100) | ||
.collect()?; | ||
assert_eq!(df_sql, df_pl); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_groupby_simple() -> Result<()> { | ||
let df = create_sample_df()?; | ||
let mut context = SQLContext::new(); | ||
context.register("df", &df); | ||
let df_sql = context | ||
.execute( | ||
r#" | ||
SELECT a, sum(b) as b , sum(a + b) as c, count(a) as total_count | ||
FROM df | ||
GROUP BY a | ||
LIMIT 100 | ||
"#, | ||
)? | ||
.sort( | ||
"a", | ||
SortOptions { | ||
descending: false, | ||
nulls_last: false, | ||
}, | ||
) | ||
.collect()?; | ||
let df_pl = df | ||
.lazy() | ||
.groupby(&[col("a")]) | ||
.agg(&[ | ||
col("b").sum().alias("b"), | ||
(col("a") + col("b")).sum().alias("c"), | ||
col("a").count().alias("total_count"), | ||
]) | ||
.limit(100) | ||
.sort( | ||
"a", | ||
SortOptions { | ||
descending: false, | ||
nulls_last: false, | ||
}, | ||
) | ||
.collect()?; | ||
assert_eq!(df_sql, df_pl); | ||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.