Skip to content

Commit

Permalink
[polars-sql] Adding SQL Context, SELECT and GROUP BY (#3024)
Browse files Browse the repository at this point in the history
Initializes a SQL frontend for Polars
  • Loading branch information
potter420 committed Apr 7, 2022
1 parent 198e88f commit 91d7a0c
Show file tree
Hide file tree
Showing 5 changed files with 548 additions and 0 deletions.
72 changes: 72 additions & 0 deletions polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
[package]
name = "polars-sql"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
serde = "1"
serde_json = { version = "1" }
sqlparser = { version = "0.15.0" }

[dependencies.polars]
path = "../polars"
default-features = false
features = [
"dynamic_groupby",
"zip_with",
"simd",
"lazy",
"strings",
"temporal",
"random",
"object",
"csv-file",
"fmt",
"performant",
"dtype-full",
"rows",
"private",
"round_series",
"is_first",
"asof_join",
"cross_join",
"dot_product",
"concat_str",
"row_hash",
"reinterpret",
"decompress-fast",
"mode",
"extract_jsonpath",
"lazy_regex",
"cum_agg",
"rolling_window",
"interpolate",
"list",
"rank",
"diff",
"pct_change",
"moment",
"arange",
"true_div",
"dtype-categorical",
"diagonal_concat",
"horizontal_concat",
"abs",
"ewma",
"dot_diagram",
"dataframe_arithmetic",
"json",
"string_encoding",
"product",
"ndarray",
"series_from_anyvalue",
"avro",
"parquet",
"ipc",
"is_in",
"serde",
]

[workspace]
1 change: 1 addition & 0 deletions polars-sql/rust-toolchain
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
nightly
200 changes: 200 additions & 0 deletions polars-sql/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
use std::collections::HashMap;

use crate::sql_expr::parse_sql_expr;
use polars::error::PolarsError;
use polars::prelude::{col, DataFrame, IntoLazy, LazyFrame};
use sqlparser::ast::{
Expr as SqlExpr, Select, SelectItem, SetExpr, Statement, TableFactor, Value as SQLValue,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;

#[derive(Default)]
pub struct SQLContext {
table_map: HashMap<String, LazyFrame>,
dialect: GenericDialect,
}

impl SQLContext {
pub fn new() -> Self {
Self {
table_map: HashMap::new(),
dialect: GenericDialect::default(),
}
}

pub fn register(&mut self, name: &str, df: &DataFrame) {
self.table_map.insert(name.to_owned(), df.clone().lazy());
}

fn execute_select(&self, select_stmt: &Select) -> Result<LazyFrame, PolarsError> {
// Determine involved dataframe
// Implicit join require some more work in query parsers, Explicit join are preferred for now.
let tbl = select_stmt.from.get(0).unwrap();
let mut alias_map = HashMap::new();
let tbl_name = match &tbl.relation {
TableFactor::Table { name, alias, .. } => {
let tbl_name = name.0.get(0).unwrap().value.as_str();
if self.table_map.contains_key(tbl_name) {
if let Some(alias) = alias {
alias_map.insert(alias.name.value.clone(), tbl_name.to_owned());
};
tbl_name
} else {
return Err(PolarsError::ComputeError(
format!("Table name {tbl_name} was not found").into(),
));
}
}
// Support bare table, optional with alias for now
_ => return Err(PolarsError::ComputeError("Not implemented".into())),
};
let df = &self.table_map[tbl_name];
let mut raw_projection_before_alias: HashMap<String, usize> = HashMap::new();
let mut contain_wildcard = false;
// Filter Expression
let df = match select_stmt.selection.as_ref() {
Some(expr) => {
let filter_expression = parse_sql_expr(expr)?;
df.clone().filter(filter_expression)
}
None => df.clone(),
};
// Column Projections
let projection = select_stmt
.projection
.iter()
.enumerate()
.map(|(i, select_item)| {
Ok(match select_item {
SelectItem::UnnamedExpr(expr) => {
let expr = parse_sql_expr(expr)?;
raw_projection_before_alias.insert(format!("{:?}", expr), i);
expr
}
SelectItem::ExprWithAlias { expr, alias } => {
let expr = parse_sql_expr(expr)?;
raw_projection_before_alias.insert(format!("{:?}", expr), i);
expr.alias(&alias.value)
}
SelectItem::QualifiedWildcard(_) | SelectItem::Wildcard => {
contain_wildcard = true;
col("*")
}
})
})
.collect::<Result<Vec<_>, PolarsError>>()?;
// Check for group by
// After projection since there might be number.
let group_by = select_stmt
.group_by
.iter()
.map(
|e|match e {
SqlExpr::Value(SQLValue::Number(idx, _)) => {
let idx = match idx.parse::<usize>() {
Ok(0)| Err(_) => Err(
PolarsError::ComputeError(
format!("Group By Error: Only positive number or expression are supported, got {idx}").into()
)),
Ok(idx) => Ok(idx)
}?;
Ok(projection[idx].clone())
}
SqlExpr::Value(_) => Err(
PolarsError::ComputeError("Group By Error: Only positive number or expression are supported".into())
),
_ => parse_sql_expr(e)
}
)
.collect::<Result<Vec<_>, PolarsError>>()?;

let df = if group_by.is_empty() {
df.select(projection)
} else {
// check groupby and projection due to difference between SQL and polars
// Return error on wild card, shouldn't process this
if contain_wildcard {
return Err(PolarsError::ComputeError(
"Group By Error: Can't processed wildcard in groupby".into(),
));
}
// Default polars group by will have group by columns at the front
// need some container to contain position of group by columns and its position
// at the final agg projection, check the schema for the existant of group by column
// and its projections columns, keeping the original index
let (exclude_expr, groupby_pos): (Vec<_>, Vec<_>) = group_by
.iter()
.map(|expr| raw_projection_before_alias.get(&format!("{:?}", expr)))
.enumerate()
.filter(|(_, proj_p)| proj_p.is_some())
.map(|(gb_p, proj_p)| (*proj_p.unwrap(), (*proj_p.unwrap(), gb_p)))
.unzip();
let (agg_projection, agg_proj_pos): (Vec<_>, Vec<_>) = projection
.iter()
.enumerate()
.filter(|(i, _)| !exclude_expr.contains(i))
.enumerate()
.map(|(agg_pj, (proj_p, expr))| (expr.clone(), (proj_p, agg_pj + group_by.len())))
.unzip();
let agg_df = df.groupby(group_by).agg(agg_projection);
let mut final_proj_pos = groupby_pos
.into_iter()
.chain(agg_proj_pos.into_iter())
.collect::<Vec<_>>();

final_proj_pos.sort_by(|(proj_pa, _), (proj_pb, _)| proj_pa.cmp(proj_pb));
let final_proj = final_proj_pos
.into_iter()
.map(|(_, shm_p)| col(agg_df.schema().get_index(shm_p).unwrap().0))
.collect::<Vec<_>>();
agg_df.select(final_proj)
};
Ok(df)
}

pub fn execute(&self, query: &str) -> Result<LazyFrame, PolarsError> {
let ast = Parser::parse_sql(&self.dialect, query)
.map_err(|e| PolarsError::ComputeError(format!("{:?}", e).into()))?;
if ast.len() != 1 {
Err(PolarsError::ComputeError(
"One and only one statement at a time please".into(),
))
} else {
let ast = ast.get(0).unwrap();
Ok(match ast {
Statement::Query(query) => {
let rs = match &query.body {
SetExpr::Select(select_stmt) => self.execute_select(&*select_stmt)?,
_ => {
return Err(PolarsError::ComputeError(
"INSERT, UPDATE is not supported for polars".into(),
))
}
};
match &query.limit {
Some(SqlExpr::Value(SQLValue::Number(nrow, _))) => {
let nrow = nrow.parse().map_err(|err| {
PolarsError::ComputeError(
format!("Conversion Error: {:?}", err).into(),
)
})?;
rs.limit(nrow)
}
None => rs,
_ => {
return Err(PolarsError::ComputeError(
"Only support number argument to LIMIT clause".into(),
))
}
}
}
_ => {
return Err(PolarsError::ComputeError(
format!("Statement type {:?} is not supported", ast).into(),
))
}
})
}
}
}
84 changes: 84 additions & 0 deletions polars-sql/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
pub use context::SQLContext;

mod context;
mod sql_expr;

#[cfg(test)]
mod test {
use super::*;
use polars::prelude::*;

fn create_sample_df() -> Result<DataFrame> {
let a = Series::new("a", (1..10000i64).map(|i| i / 100).collect::<Vec<_>>());
let b = Series::new("b", 1..10000i64);
DataFrame::new(vec![a, b])
}

#[test]
fn test_simple_select() -> Result<()> {
let df = create_sample_df()?;
let mut context = SQLContext::new();
context.register("df", &df);
let df_sql = context
.execute(
r#"
SELECT a, b, a + b as c
FROM df
where a > 10 and a < 20
LIMIT 100
"#,
)?
.collect()?;
let df_pl = df
.lazy()
.filter(col("a").gt(lit(10)).and(col("a").lt(lit(20))))
.select(&[col("a"), col("b"), (col("a") + col("b")).alias("c")])
.limit(100)
.collect()?;
assert_eq!(df_sql, df_pl);
Ok(())
}

#[test]
fn test_groupby_simple() -> Result<()> {
let df = create_sample_df()?;
let mut context = SQLContext::new();
context.register("df", &df);
let df_sql = context
.execute(
r#"
SELECT a, sum(b) as b , sum(a + b) as c, count(a) as total_count
FROM df
GROUP BY a
LIMIT 100
"#,
)?
.sort(
"a",
SortOptions {
descending: false,
nulls_last: false,
},
)
.collect()?;
let df_pl = df
.lazy()
.groupby(&[col("a")])
.agg(&[
col("b").sum().alias("b"),
(col("a") + col("b")).sum().alias("c"),
col("a").count().alias("total_count"),
])
.limit(100)
.sort(
"a",
SortOptions {
descending: false,
nulls_last: false,
},
)
.collect()?;
assert_eq!(df_sql, df_pl);
Ok(())
}
}

0 comments on commit 91d7a0c

Please sign in to comment.