Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coprocessor: Support all data types for AVG() #4777

Merged
merged 8 commits into from May 29, 2019
4 changes: 3 additions & 1 deletion components/cop_datatype/src/def/mod.rs
Expand Up @@ -12,4 +12,6 @@ pub use self::field_type::{Collation, FieldTypeAccessor, FieldTypeFlag, FieldTyp
pub const UNSPECIFIED_LENGTH: isize = -1;

/// MySQL type maximum length
pub const MAX_BLOB_WIDTH: i32 = 16_777_216;
pub const MAX_BLOB_WIDTH: i32 = 16_777_216; // FIXME: Should be isize
pub const MAX_DECIMAL_WIDTH: isize = 65;
pub const MAX_REAL_WIDTH: isize = 23;
128 changes: 80 additions & 48 deletions src/coprocessor/dag/aggr_fn/impl_avg.rs
@@ -1,6 +1,7 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.

use cop_codegen::AggrFunction;
use cop_datatype::builder::FieldTypeBuilder;
use cop_datatype::{EvalType, FieldTypeFlag, FieldTypeTp};
use tipb::expression::{Expr, ExprType, FieldType};

Expand All @@ -9,79 +10,53 @@ use crate::coprocessor::codec::data_type::*;
use crate::coprocessor::codec::mysql::Tz;
use crate::coprocessor::dag::expr::EvalContext;
use crate::coprocessor::dag::rpn_expr::{RpnExpression, RpnExpressionBuilder};
use crate::coprocessor::{Error, Result};
use crate::coprocessor::Result;

/// The parser for AVG aggregate function.
pub struct AggrFnDefinitionParserAvg;

impl super::AggrDefinitionParser for AggrFnDefinitionParserAvg {
fn check_supported(&self, aggr_def: &Expr) -> Result<()> {
use cop_datatype::FieldTypeAccessor;
use std::convert::TryFrom;

assert_eq!(aggr_def.get_tp(), ExprType::Avg);
if aggr_def.get_children().len() != 1 {
return Err(box_err!(
"Expect 1 parameter, but got {}",
aggr_def.get_children().len()
));
}

// Check whether or not the children's field type is supported. Currently we only support
// Double and Decimal and does not support other types (which need casting).
let child = &aggr_def.get_children()[0];
let eval_type = EvalType::try_from(child.get_field_type().tp())
.map_err(|e| Error::Other(box_err!(e)))?;
match eval_type {
EvalType::Real | EvalType::Decimal => {}
_ => return Err(box_err!("Cast from {:?} is not supported", eval_type)),
}

// Check whether parameter expression is supported.
RpnExpressionBuilder::check_expr_tree_supported(child)?;

Ok(())
super::util::check_aggr_exp_supported_one_child(aggr_def)
}

fn parse(
&self,
mut aggr_def: Expr,
time_zone: &Tz,
max_columns: usize,
src_schema: &[FieldType],
out_schema: &mut Vec<FieldType>,
out_exp: &mut Vec<RpnExpression>,
) -> Result<Box<dyn super::AggrFunction>> {
use cop_datatype::FieldTypeAccessor;
use std::convert::TryFrom;

assert_eq!(aggr_def.get_tp(), ExprType::Avg);
let child = aggr_def.take_children().into_iter().next().unwrap();
let eval_type = EvalType::try_from(child.get_field_type().tp()).unwrap();

// AVG outputs two columns.
out_schema.push({
let mut ft = FieldType::new();
ft.as_mut_accessor()
.set_tp(FieldTypeTp::LongLong)
.set_flag(FieldTypeFlag::UNSIGNED);
ft
});
out_schema.push(
FieldTypeBuilder::new()
.tp(FieldTypeTp::LongLong)
.flag(FieldTypeFlag::UNSIGNED)
.build(),
);
out_schema.push(aggr_def.take_field_type());

// Currently we don't support casting in `check_supported`, so we can directly use the
// built expression.
out_exp.push(RpnExpressionBuilder::build_from_expr_tree(
child,
time_zone,
max_columns,
)?);

// Choose a type-aware AVG implementation based on eval type.
match eval_type {
EvalType::Real => Ok(Box::new(AggrFnAvg::<Real>::new())),
EvalType::Decimal => Ok(Box::new(AggrFnAvg::<Decimal>::new())),
// Rewrite expression to insert CAST() if needed.
let child = aggr_def.take_children().into_iter().next().unwrap();
let mut exp =
RpnExpressionBuilder::build_from_expr_tree(child, time_zone, src_schema.len())?;
super::util::rewrite_exp_for_sum_avg(src_schema, &mut exp).unwrap();

let rewritten_eval_type = EvalType::try_from(exp.ret_field_type(src_schema).tp()).unwrap();
out_exp.push(exp);

Ok(match rewritten_eval_type {
EvalType::Decimal => Box::new(AggrFnAvg::<Decimal>::new()),
EvalType::Real => Box::new(AggrFnAvg::<Real>::new()),
_ => unreachable!(),
}
})
}
}

Expand Down Expand Up @@ -172,6 +147,12 @@ mod tests {
use super::super::AggrFunction;
use super::*;

use cop_datatype::FieldTypeAccessor;
use tipb_helper::ExprDefBuilder;

use crate::coprocessor::codec::batch::{LazyBatchColumn, LazyBatchColumnVec};
use crate::coprocessor::dag::aggr_fn::parser::AggrDefinitionParser;

#[test]
fn test_update() {
let mut ctx = EvalContext::default();
Expand Down Expand Up @@ -217,4 +198,55 @@ mod tests {
&[None, None, Real::new(15.0).ok(), Real::new(10.5).ok()]
);
}

/// AVG(IntColumn) should produce (Int, Decimal).
#[test]
fn test_integration() {
let expr = ExprDefBuilder::aggr_func(ExprType::Avg, FieldTypeTp::NewDecimal)
.push_child(ExprDefBuilder::column_ref(0, FieldTypeTp::LongLong))
.build();
AggrFnDefinitionParserAvg.check_supported(&expr).unwrap();

let src_schema = [FieldTypeTp::LongLong.into()];
let mut columns = LazyBatchColumnVec::from(vec![{
let mut col = LazyBatchColumn::decoded_with_capacity_and_tp(0, EvalType::Int);
col.mut_decoded().push_int(Some(1));
col.mut_decoded().push_int(None);
col.mut_decoded().push_int(Some(42));
col.mut_decoded().push_int(None);
col
}]);

let mut schema = vec![];
let mut exp = vec![];

let aggr_fn = AggrFnDefinitionParserAvg
.parse(expr, &Tz::utc(), &src_schema, &mut schema, &mut exp)
.unwrap();
assert_eq!(schema.len(), 2);
assert_eq!(schema[0].tp(), FieldTypeTp::LongLong);
assert_eq!(schema[1].tp(), FieldTypeTp::NewDecimal);

assert_eq!(exp.len(), 1);

let mut state = aggr_fn.create_state();
let mut ctx = EvalContext::default();

let exp_result = exp[0].eval(&mut ctx, 4, &src_schema, &mut columns).unwrap();
assert!(exp_result.is_vector());
let slice: &[Option<Decimal>] = exp_result.vector_value().unwrap().as_ref();
state.update_vector(&mut ctx, slice).unwrap();

let mut aggr_result = [
VectorValue::with_capacity(0, EvalType::Int),
VectorValue::with_capacity(0, EvalType::Decimal),
];
state.push_result(&mut ctx, &mut aggr_result).unwrap();

assert_eq!(aggr_result[0].as_int_slice(), &[Some(2)]);
assert_eq!(
aggr_result[1].as_decimal_slice(),
&[Some(Decimal::from(43u64))]
);
}
}
34 changes: 11 additions & 23 deletions src/coprocessor/dag/aggr_fn/impl_count.rs
@@ -1,6 +1,7 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.

use cop_codegen::AggrFunction;
use cop_datatype::builder::FieldTypeBuilder;
use cop_datatype::{FieldTypeFlag, FieldTypeTp};
use tipb::expression::{Expr, ExprType, FieldType};

Expand All @@ -16,47 +17,34 @@ pub struct AggrFnDefinitionParserCount;
impl super::AggrDefinitionParser for AggrFnDefinitionParserCount {
fn check_supported(&self, aggr_def: &Expr) -> Result<()> {
assert_eq!(aggr_def.get_tp(), ExprType::Count);
if aggr_def.get_children().len() != 1 {
return Err(box_err!(
"Expect 1 parameter, but got {}",
aggr_def.get_children().len()
));
}

// Only check whether or not the children expr is supported.
let child = &aggr_def.get_children()[0];
RpnExpressionBuilder::check_expr_tree_supported(child)?;

Ok(())
super::util::check_aggr_exp_supported_one_child(aggr_def)
}

fn parse(
&self,
mut aggr_def: Expr,
time_zone: &Tz,
max_columns: usize,
// We use the same structure for all data types, so this parameter is not needed.
src_schema: &[FieldType],
out_schema: &mut Vec<FieldType>,
out_exp: &mut Vec<RpnExpression>,
) -> Result<Box<dyn super::AggrFunction>> {
use cop_datatype::FieldTypeAccessor;

assert_eq!(aggr_def.get_tp(), ExprType::Count);
let child = aggr_def.take_children().into_iter().next().unwrap();

// COUNT outputs one column.
out_schema.push({
let mut ft = FieldType::new();
ft.as_mut_accessor()
.set_tp(FieldTypeTp::LongLong)
.set_flag(FieldTypeFlag::UNSIGNED);
ft
});
out_schema.push(
FieldTypeBuilder::new()
.tp(FieldTypeTp::LongLong)
.flag(FieldTypeFlag::UNSIGNED)
.build(),
);

// COUNT doesn't need to cast, so using the expression directly.
out_exp.push(RpnExpressionBuilder::build_from_expr_tree(
child,
time_zone,
max_columns,
src_schema.len(),
)?);

Ok(Box::new(AggrFnCount))
Expand Down
1 change: 1 addition & 0 deletions src/coprocessor/dag/aggr_fn/mod.rs
Expand Up @@ -6,6 +6,7 @@ mod impl_avg;
mod impl_count;
mod parser;
mod summable;
mod util;

pub use self::parser::{AggrDefinitionParser, AllAggrDefinitionParser};

Expand Down
9 changes: 6 additions & 3 deletions src/coprocessor/dag/aggr_fn/parser.rs
Expand Up @@ -23,14 +23,17 @@ pub trait AggrDefinitionParser {
/// RPN expression (maybe wrapped by some casting according to types) will be appended in
/// `out_exp`.
///
/// The parser may choose particular aggregate function implementation based on the data
/// type, so `schema` is also needed in case of data type depending on the column.
///
/// # Panic
///
/// May panic if the aggregate function definition is not supported by this parser.
fn parse(
&self,
aggr_def: Expr,
time_zone: &Tz,
max_columns: usize,
src_schema: &[FieldType],
out_schema: &mut Vec<FieldType>,
out_exp: &mut Vec<RpnExpression>,
) -> Result<Box<dyn AggrFunction>>;
Expand Down Expand Up @@ -76,11 +79,11 @@ impl AggrDefinitionParser for AllAggrDefinitionParser {
&self,
aggr_def: Expr,
time_zone: &Tz,
max_columns: usize,
src_schema: &[FieldType],
out_schema: &mut Vec<FieldType>,
out_exp: &mut Vec<RpnExpression>,
) -> Result<Box<dyn AggrFunction>> {
let parser = map_pb_sig_to_aggr_func_parser(aggr_def.get_tp()).unwrap();
parser.parse(aggr_def, time_zone, max_columns, out_schema, out_exp)
parser.parse(aggr_def, time_zone, src_schema, out_schema, out_exp)
}
}
59 changes: 59 additions & 0 deletions src/coprocessor/dag/aggr_fn/util.rs
@@ -0,0 +1,59 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.

use std::convert::TryFrom;

use cop_datatype::builder::FieldTypeBuilder;
use cop_datatype::{EvalType, FieldTypeAccessor, FieldTypeTp};
use tipb::expression::{Expr, FieldType};

use crate::coprocessor::dag::rpn_expr::impl_cast::get_cast_fn;
use crate::coprocessor::dag::rpn_expr::types::RpnExpressionNode;
use crate::coprocessor::dag::rpn_expr::{RpnExpression, RpnExpressionBuilder};
use crate::coprocessor::Result;

/// Checks whether or not there is only one child and the child expression is supported.
pub fn check_aggr_exp_supported_one_child(aggr_def: &Expr) -> Result<()> {
if aggr_def.get_children().len() != 1 {
return Err(box_err!(
"Expect 1 parameter, but got {}",
aggr_def.get_children().len()
));
}

// Check whether parameter expression is supported.
let child = &aggr_def.get_children()[0];
RpnExpressionBuilder::check_expr_tree_supported(child)?;

Ok(())
}

/// Rewrites the expression to insert necessary cast functions for SUM and AVG aggregate functions.
///
/// See `typeInfer4Sum` and `typeInfer4Avg` in TiDB.
///
/// TODO: This logic should be performed by TiDB.
pub fn rewrite_exp_for_sum_avg(schema: &[FieldType], exp: &mut RpnExpression) -> Result<()> {
let ret_field_type = exp.ret_field_type(schema);
let ret_eval_type = box_try!(EvalType::try_from(ret_field_type.tp()));
let new_ret_field_type = match ret_eval_type {
EvalType::Decimal | EvalType::Real => {
// No need to cast. Return directly without changing anything.
return Ok(());
}
EvalType::Int => FieldTypeBuilder::new()
.tp(FieldTypeTp::NewDecimal)
.flen(cop_datatype::MAX_DECIMAL_WIDTH)
.build(),
_ => FieldTypeBuilder::new()
.tp(FieldTypeTp::Double)
.flen(cop_datatype::MAX_REAL_WIDTH)
.decimal(cop_datatype::UNSPECIFIED_LENGTH)
.build(),
};
let func = get_cast_fn(ret_field_type, &new_ret_field_type)?;
exp.push(RpnExpressionNode::FnCall {
func,
field_type: new_ret_field_type,
});
Ok(())
}
Expand Up @@ -475,7 +475,7 @@ mod tests {
&self,
_aggr_def: Expr,
_time_zone: &Tz,
_max_columns: usize,
_src_schema: &[FieldType],
out_schema: &mut Vec<FieldType>,
out_exp: &mut Vec<RpnExpression>,
) -> Result<Box<dyn AggrFunction>> {
Expand Down
4 changes: 2 additions & 2 deletions src/coprocessor/dag/batch/executors/simple_aggr_executor.rs
Expand Up @@ -327,7 +327,7 @@ mod tests {
&self,
aggr_def: Expr,
_time_zone: &Tz,
_max_columns: usize,
_src_schema: &[FieldType],
out_schema: &mut Vec<FieldType>,
out_exp: &mut Vec<RpnExpression>,
) -> Result<Box<dyn AggrFunction>> {
Expand Down Expand Up @@ -564,7 +564,7 @@ mod tests {
&self,
_aggr_def: Expr,
_time_zone: &Tz,
_max_columns: usize,
_src_schema: &[FieldType],
out_schema: &mut Vec<FieldType>,
out_exp: &mut Vec<RpnExpression>,
) -> Result<Box<dyn AggrFunction>> {
Expand Down
3 changes: 1 addition & 2 deletions src/coprocessor/dag/batch/executors/util/aggr_executor.rs
Expand Up @@ -117,7 +117,6 @@ impl<Src: BatchExecutor, I: AggregationExecutorImpl<Src>> AggregationExecutor<Sr
) -> Result<Self> {
let aggr_fn_len = aggr_defs.len();
let src_schema = src.schema();
let src_schema_len = src_schema.len();

let mut schema = Vec::with_capacity(aggr_fn_len * 2);
let mut each_aggr_fn = Vec::with_capacity(aggr_fn_len);
Expand All @@ -131,7 +130,7 @@ impl<Src: BatchExecutor, I: AggregationExecutorImpl<Src>> AggregationExecutor<Sr
let aggr_fn = aggr_def_parser.parse(
aggr_def,
&config.tz,
src_schema_len,
src_schema,
&mut schema,
&mut each_aggr_exprs,
)?;
Expand Down