Skip to content

Commit

Permalink
feat: support TPC-H Q9 (#761)
Browse files Browse the repository at this point in the history
* add function `extract`

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* fix and evaluate LIKE

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* evaluate EXTRACT YEAR

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* add tpch q9

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* fix clippy

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* fix decimal scale

Signed-off-by: Runji Wang <wangrunji0408@163.com>

* special tpch test for v1

Signed-off-by: Runji Wang <wangrunji0408@163.com>

Signed-off-by: Runji Wang <wangrunji0408@163.com>
  • Loading branch information
wangrunji0408 committed Jan 10, 2023
1 parent 8418cf3 commit 857159e
Show file tree
Hide file tree
Showing 27 changed files with 571 additions and 59 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ paste = "1"
pgwire = "0.6.3"
prost = "0.11.0"
ref-cast = "1.0"
regex = "1"
risinglight_proto = "0.2"
rust_decimal = "1"
rustyline = "10"
serde = { version = "1", features = ["derive", "rc"] }
serde_json = "1"
smallvec = { version = "1", features = ["serde"] }
sqllogictest = "0.9"
sqlparser = { version = "0.27", features = ["serde"] }
sqlparser = { version = "0.30", features = ["serde"] }
thiserror = "1"
tikv-jemallocator = { version = "0.5", optional = true, features=["disable_initial_exec_tls"] }
tokio = { version = "1", features = ["full"] }
Expand Down
5 changes: 5 additions & 0 deletions src/array/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ impl DataChunk {
&self.arrays[idx]
}

/// Get the mutable reference of array by index.
pub fn array_mut_at(&mut self, idx: usize) -> &mut ArrayImpl {
&mut Arc::get_mut(&mut self.arrays).unwrap()[idx]
}

/// Get all arrays.
pub fn arrays(&self) -> &[ArrayImpl] {
&self.arrays
Expand Down
74 changes: 53 additions & 21 deletions src/array/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
use std::borrow::Borrow;

use num_traits::ToPrimitive;
use regex::Regex;
use rust_decimal::prelude::FromStr;
use rust_decimal::Decimal;

use super::*;
use crate::for_all_variants;
use crate::parser::{BinaryOperator, UnaryOperator};
use crate::types::{Blob, ConvertError, DataTypeKind, DataValue, Date, Interval, F64};
use crate::types::{
Blob, ConvertError, DataTypeKind, DataValue, Date, DateTimeField, Interval, F64,
};

type A = ArrayImpl;

Expand Down Expand Up @@ -155,6 +158,46 @@ impl ArrayImpl {
Ok(A::new_bool(clear_null(unary_op(a.as_ref(), |b| !b))))
}

pub fn like(&self, pattern: &str) -> Result<Self, ConvertError> {
/// Converts a SQL LIKE pattern to a regex pattern.
fn like_to_regex(pattern: &str) -> String {
let mut regex = String::with_capacity(pattern.len());
for c in pattern.chars() {
match c {
'%' => regex.push_str(".*"),
'_' => regex.push('.'),
c => regex.push(c),
}
}
regex
}
let A::Utf8(a) = self else {
return Err(ConvertError::NoUnaryOp("like".into(), self.type_string()));
};
let regex = Regex::new(&like_to_regex(pattern)).unwrap();
Ok(A::new_bool(clear_null(unary_op(a.as_ref(), |s| {
regex.is_match(s)
}))))
}

pub fn extract(&self, field: DateTimeField) -> Result<Self, ConvertError> {
Ok(match self {
A::Date(a) => match field.0 {
sqlparser::ast::DateTimeField::Year => {
A::new_int32(unary_op(a.as_ref(), |d| d.year()))
}
f => todo!("extract {f} from date"),
},
A::Interval(_) => todo!("extract {field} from interval"),
_ => {
return Err(ConvertError::NoUnaryOp(
"extract".into(),
self.type_string(),
))
}
})
}

/// Perform binary operation.
pub fn binary_op(
&self,
Expand Down Expand Up @@ -244,30 +287,19 @@ impl ArrayImpl {
},
Self::Float64(a) => match data_type {
Type::Bool => Self::new_bool(unary_op(a.as_ref(), |&f| f != 0.0)),
Type::Int32 => Self::new_int32(try_unary_op(a.as_ref(), |&b| match b.to_i32() {
Some(d) => Ok(d),
None => Err(ConvertError::Overflow(DataValue::Float64(b), Type::Int32)),
Type::Int32 => Self::new_int32(try_unary_op(a.as_ref(), |&b| {
b.to_i32()
.ok_or(ConvertError::Overflow(DataValue::Float64(b), Type::Int32))
})?),
Type::Int64 => Self::new_int64(try_unary_op(a.as_ref(), |&b| match b.to_i64() {
Some(d) => Ok(d),
None => Err(ConvertError::Overflow(DataValue::Float64(b), Type::Int64)),
Type::Int64 => Self::new_int64(try_unary_op(a.as_ref(), |&b| {
b.to_i64()
.ok_or(ConvertError::Overflow(DataValue::Float64(b), Type::Int64))
})?),
Type::Float64 => Self::Float64(a.clone()),
Type::String => Self::new_utf8(Utf8Array::from_iter_display(a.iter())),
Type::Decimal(_, scale) => {
Self::new_decimal(try_unary_op(
a.as_ref(),
|&f| match Decimal::from_f64_retain(f.0) {
Some(mut d) => {
if let Some(s) = scale {
d.rescale(*s as u32);
}
Ok(d)
}
None => Err(ConvertError::ToDecimalError(DataValue::Float64(f))),
},
)?)
}
Type::Decimal(_, _) => Self::new_decimal(unary_op(a.as_ref(), |&f| {
Decimal::from_f64_retain(f.0).unwrap()
})),
Type::Null | Type::Date | Type::Interval | Type::Blob | Type::Struct(_) => {
return Err(ConvertError::NoCast("DOUBLE", data_type.clone()));
}
Expand Down
10 changes: 10 additions & 0 deletions src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::iter::FromIterator;
use std::mem;

use bitvec::vec::BitVec;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};

use super::ops::BitVecExt;
Expand Down Expand Up @@ -188,6 +189,15 @@ impl PrimitiveArray<bool> {
}
}

impl PrimitiveArray<Decimal> {
/// Rescale the decimals.
pub fn rescale(&mut self, scale: u8) {
for v in &mut self.data {
v.rescale(scale as u32);
}
}
}

pub fn clear_null(mut array: BoolArray) -> BoolArray {
use std::simd::ToBitMask;
let mut valid = Vec::with_capacity(array.valid.as_raw_slice().len() * 64);
Expand Down
7 changes: 7 additions & 0 deletions src/binder_v2/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl Binder {
leading_field,
..
} => self.bind_interval(*value, leading_field),
Expr::Extract { field, expr } => self.bind_extract(field, *expr),
_ => todo!("bind expression: {:?}", expr),
}?;
self.check_type(id)?;
Expand Down Expand Up @@ -182,6 +183,12 @@ impl Binder {
Ok(self.egraph.add(Node::Constant(value)))
}

fn bind_extract(&mut self, field: DateTimeField, expr: Expr) -> Result {
let expr = self.bind_expr(expr)?;
let field = self.egraph.add(Node::Field(field.into()));
Ok(self.egraph.add(Node::Extract([field, expr])))
}

fn bind_function(&mut self, func: Function) -> Result {
// TODO: Support scalar function
let mut args = vec![];
Expand Down
5 changes: 3 additions & 2 deletions src/binder_v2/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl Binder {
.insert(alias.value, ref_id);
select_list.push(id);
}
SelectItem::Wildcard => {
SelectItem::Wildcard(_) => {
select_list.append(&mut self.schema(from));
}
_ => todo!("bind select list"),
Expand Down Expand Up @@ -135,7 +135,8 @@ impl Binder {
}

/// Binds the VALUES clause. Returns a [`Values`](Node::Values) plan.
fn bind_values(&mut self, Values(values): Values) -> Result {
fn bind_values(&mut self, values: Values) -> Result {
let values = values.rows;
let mut bound_values = Vec::with_capacity(values.len());
if values.is_empty() {
return Ok(self.egraph.add(Node::Values([].into())));
Expand Down
30 changes: 18 additions & 12 deletions src/executor_v2/copy_from_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ use std::fs::File;
use std::io::BufReader;

use indicatif::{ProgressBar, ProgressStyle};
use itertools::izip;
use tokio::sync::mpsc::Sender;

use super::*;
use crate::array::DataChunkBuilder;
use crate::array::{ArrayImpl, DataChunkBuilder};
use crate::binder_v2::copy::{ExtSource, FileFormat};
use crate::types::DataTypeKind;

/// The executor of loading file data.
pub struct CopyFromFileExecutor {
Expand All @@ -23,12 +23,21 @@ const IMPORT_PROGRESS_BAR_LIMIT: u64 = 1024 * 1024;
impl CopyFromFileExecutor {
#[try_stream(boxed, ok = DataChunk, error = ExecutorError)]
pub async fn execute(self) {
let types = self.types.clone();
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
// # Cancellation
// When this stream is dropped, the `rx` is dropped, the spawned task will fail to send to
// `tx`, then the task will finish.
let handle = tokio::task::spawn_blocking(|| self.read_file_blocking(tx));
while let Some(chunk) = rx.recv().await {
while let Some(mut chunk) = rx.recv().await {
// rescale decimals
for (i, ty) in types.iter().enumerate() {
if let (ArrayImpl::Decimal(a), DataTypeKind::Decimal(_, Some(scale))) =
(chunk.array_mut_at(i), ty.kind())
{
Arc::get_mut(a).unwrap().rescale(scale);
}
}
yield chunk;
}
handle.await.unwrap()?;
Expand Down Expand Up @@ -88,18 +97,15 @@ impl CopyFromFileExecutor {
});
}

let str_row_data: Result<Vec<&str>, _> = izip!(record.iter(), &self.types)
.map(|(v, ty)| {
if !ty.nullable && v.is_empty() {
return Err(ExecutorError::NotNullable);
}
Ok(v)
})
.collect();
for (v, ty) in record.iter().zip(&self.types) {
if !ty.nullable && v.is_empty() {
return Err(ExecutorError::NotNullable);
}
}
size_count += record.as_slice().as_bytes().len();

// push a raw str row and send it if necessary
if let Some(chunk) = chunk_builder.push_str_row(str_row_data?)? {
if let Some(chunk) = chunk_builder.push_str_row(record.iter())? {
bar.set_position(size_count as u64);
tx.blocking_send(chunk).map_err(|_| ExecutorError::Abort)?;
}
Expand Down
12 changes: 12 additions & 0 deletions src/executor_v2/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ impl<'a> Evaluator<'a> {
array.get_valid_bitmap().iter().map(|v| !v).collect(),
))
}
Like([a, b]) => match self.next(*b).node() {
Expr::Constant(DataValue::String(pattern)) => {
let a = self.next(*a).eval(chunk)?;
a.like(pattern)
}
_ => panic!("like pattern must be a string constant"),
},
Extract([field, a]) => {
let a = self.next(*a).eval(chunk)?;
let Expr::Field(field) = self.expr[*field] else { panic!("not a field") };
a.extract(field)
}
Asc(a) | Desc(a) | Ref(a) => self.next(*a).eval(chunk),
// for aggs, evaluate its children
RowCount => Ok(ArrayImpl::new_null(
Expand Down
5 changes: 5 additions & 0 deletions src/planner/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ impl Display for Explain<'_> {
self.expr(else_)
),

// functions
Extract([field, e]) => write!(f, "extract({} from {})", self.expr(field), self.expr(e)),
Field(field) => write!(f, "{field}"),

// aggregations
RowCount => write!(f, "rowcount"),
Max(a) | Min(a) | Sum(a) | Avg(a) | Count(a) | First(a) | Last(a) => {
write!(f, "{}({})", enode, self.expr(a))
Expand Down
12 changes: 8 additions & 4 deletions src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::binder_v2::copy::ExtSource;
use crate::binder_v2::{BoundDrop, CreateTable};
use crate::catalog::{ColumnRefId, TableRefId};
use crate::parser::{BinaryOperator, UnaryOperator};
use crate::types::{ColumnIndex, DataTypeKind, DataValue};
use crate::types::{ColumnIndex, DataTypeKind, DataValue, DateTimeField};

mod cost;
mod explain;
Expand Down Expand Up @@ -60,7 +60,11 @@ define_language! {

"if" = If([Id; 3]), // (if cond then else)

// aggregates
// functions
"extract" = Extract([Id; 2]), // (extract field expr)
Field(DateTimeField),

// aggregations
"max" = Max(Id),
"min" = Min(Id),
"sum" = Sum(Id),
Expand Down Expand Up @@ -227,8 +231,8 @@ pub fn optimize(expr: &RecExpr) -> RecExpr {
}
best_cost = cost;
// println!(
// "{i}:\n{}",
// crate::planner::Explain::with_costs(&expr, &costs(&expr))
// "{}",
// crate::planner::Explain::of(&expr).with_costs(&costs(&expr))
// );
}

Expand Down
21 changes: 20 additions & 1 deletion src/planner/rules/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ pub fn analyze_type(enode: &Expr, x: impl Fn(&Id) -> Type, catalog: &RootCatalog
merge(enode, [x(a)?, x(b)?], |[a, b]| {
match if a > b { (b, a) } else { (a, b) } {
(Kind::Null, _) => Some(Kind::Null),
(Kind::Decimal(Some(p1), Some(s1)), Kind::Decimal(Some(p2), Some(s2))) => {
match enode {
Add(_) | Sub(_) => Some(Kind::Decimal(
Some((p1 - s1).max(p2 - s2) + s1.max(s2) + 1),
Some(s1.max(s2)),
)),
Mul(_) => Some(Kind::Decimal(Some(p1 + p2), Some(s1 + s2))),
Div(_) | Mod(_) => Some(Kind::Decimal(None, None)),
_ => unreachable!(),
}
}
(a, b) if a.is_number() && b.is_number() => Some(b),
(Kind::Date, Kind::Interval) => Some(Kind::Date),
_ => None,
Expand All @@ -51,9 +62,12 @@ pub fn analyze_type(enode: &Expr, x: impl Fn(&Id) -> Type, catalog: &RootCatalog
}

// string ops
StringConcat([a, b]) | Like([a, b]) => merge(enode, [x(a)?, x(b)?], |[a, b]| {
StringConcat([a, b]) => merge(enode, [x(a)?, x(b)?], |[a, b]| {
(a == Kind::String && b == Kind::String).then_some(Kind::String)
}),
Like([a, b]) => merge(enode, [x(a)?, x(b)?], |[a, b]| {
(a == Kind::String && b == Kind::String).then_some(Kind::Bool)
}),

// bool ops
Not(a) => check(enode, x(a)?, |a| a == Kind::Bool),
Expand All @@ -79,6 +93,11 @@ pub fn analyze_type(enode: &Expr, x: impl Fn(&Id) -> Type, catalog: &RootCatalog
// null ops
IsNull(_) => Ok(Kind::Bool.not_null()),

// functions
Extract([_, a]) => merge(enode, [x(a)?], |[a]| {
matches!(a, Kind::Date | Kind::Interval).then_some(Kind::Int32)
}),

// number agg
Max(a) | Min(a) => x(a),
Sum(a) => check(enode, x(a)?, |a| a.is_number()),
Expand Down
Loading

0 comments on commit 857159e

Please sign in to comment.