Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coprocessor: Support SUM() #4797

Merged
merged 13 commits into from May 30, 2019
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
190 changes: 190 additions & 0 deletions src/coprocessor/dag/aggr_fn/impl_sum.rs
@@ -0,0 +1,190 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.

use cop_codegen::AggrFunction;
use cop_datatype::EvalType;
use tipb::expression::{Expr, ExprType, FieldType};

use super::summable::Summable;
use crate::coprocessor::codec::data_type::*;
use crate::coprocessor::codec::mysql::Tz;
use crate::coprocessor::dag::expr::EvalContext;
use crate::coprocessor::dag::rpn_expr::{RpnExpression, RpnExpressionBuilder};
use crate::coprocessor::Result;

/// The parser for SUM aggregate function.
pub struct AggrFnDefinitionParserSum;

impl super::parser::AggrDefinitionParser for AggrFnDefinitionParserSum {
fn check_supported(&self, aggr_def: &Expr) -> Result<()> {
assert_eq!(aggr_def.get_tp(), ExprType::Sum);
super::util::check_aggr_exp_supported_one_child(aggr_def)
}

fn parse(
&self,
mut aggr_def: Expr,
time_zone: &Tz,
src_schema: &[FieldType],
out_schema: &mut Vec<FieldType>,
out_exp: &mut Vec<RpnExpression>,
) -> Result<Box<dyn super::AggrFunction>> {
use cop_datatype::FieldTypeAccessor;
use std::convert::TryFrom;

assert_eq!(aggr_def.get_tp(), ExprType::Sum);

// SUM outputs one column.
out_schema.push(aggr_def.take_field_type());

// Rewrite expression, inserting CAST if necessary. See `typeInfer4Sum` in TiDB.
let child = aggr_def.take_children().into_iter().next().unwrap();
let mut exp =
RpnExpressionBuilder::build_from_expr_tree(child, time_zone, src_schema.len())?;
// The rewrite should always success.
super::util::rewrite_exp_for_sum_avg(src_schema, &mut exp).unwrap();

let rewritten_eval_type = EvalType::try_from(exp.ret_field_type(src_schema).tp()).unwrap();
out_exp.push(exp);

// Choose a type-aware SUM implementation based on the eval type after rewriting exp.
Ok(match rewritten_eval_type {
EvalType::Decimal => Box::new(AggrFnSum::<Decimal>::new()),
EvalType::Real => Box::new(AggrFnSum::<Real>::new()),
// If we meet unexpected types after rewriting, it is an implementation fault.
_ => unreachable!(),
})
}
}

/// The SUM aggregate function.
///
/// Note that there are `SUM(Decimal) -> Decimal` and `SUM(Double) -> Double`.
#[derive(Debug, AggrFunction)]
#[aggr_function(state = AggrFnStateSum::<T>::new())]
pub struct AggrFnSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
_phantom: std::marker::PhantomData<T>,
}

impl<T> AggrFnSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}

/// The state of the SUM aggregate function.
#[derive(Debug)]
pub struct AggrFnStateSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
sum: T,
has_value: bool,
}

impl<T> AggrFnStateSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
pub fn new() -> Self {
Self {
sum: T::zero(),
has_value: false,
}
}
}

impl<T> super::ConcreteAggrFunctionState for AggrFnStateSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
type ParameterType = T;

#[inline]
fn update_concrete(&mut self, ctx: &mut EvalContext, value: &Option<T>) -> Result<()> {
match value {
None => Ok(()),
Some(value) => {
self.sum.add_assign(ctx, value)?;
self.has_value = true;
Ok(())
}
}
}

#[inline]
fn push_result(&self, _ctx: &mut EvalContext, target: &mut [VectorValue]) -> Result<()> {
if !self.has_value {
target[0].push(None);
} else {
target[0].push(Some(self.sum.clone()));
}
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

use cop_datatype::{FieldTypeAccessor, FieldTypeTp};
use tipb_helper::ExprDefBuilder;

use crate::coprocessor::codec::batch::{LazyBatchColumn, LazyBatchColumnVec};
use crate::coprocessor::dag::aggr_fn::parser::AggrDefinitionParser;

/// SUM(Bytes) should produce (Real).
#[test]
fn test_integration() {
let expr = ExprDefBuilder::aggr_func(ExprType::Sum, FieldTypeTp::Double)
.push_child(ExprDefBuilder::column_ref(0, FieldTypeTp::VarString))
.build();
AggrFnDefinitionParserSum.check_supported(&expr).unwrap();

let src_schema = [FieldTypeTp::VarString.into()];
let mut columns = LazyBatchColumnVec::from(vec![{
let mut col = LazyBatchColumn::decoded_with_capacity_and_tp(0, EvalType::Bytes);
col.mut_decoded().push_bytes(Some(b"12.5".to_vec()));
col.mut_decoded().push_bytes(None);
col.mut_decoded().push_bytes(Some(b"42.0".to_vec()));
col.mut_decoded().push_bytes(None);
col
}]);

let mut schema = vec![];
let mut exp = vec![];

let aggr_fn = AggrFnDefinitionParserSum
.parse(expr, &Tz::utc(), &src_schema, &mut schema, &mut exp)
.unwrap();
assert_eq!(schema.len(), 1);
assert_eq!(schema[0].tp(), FieldTypeTp::Double);

assert_eq!(exp.len(), 1);

let mut state = aggr_fn.create_state();
let mut ctx = EvalContext::default();

let exp_result = exp[0].eval(&mut ctx, 4, &src_schema, &mut columns).unwrap();
assert!(exp_result.is_vector());
let slice: &[Option<Real>] = exp_result.vector_value().unwrap().as_ref();
state.update_vector(&mut ctx, slice).unwrap();

let mut aggr_result = [VectorValue::with_capacity(0, EvalType::Real)];
state.push_result(&mut ctx, &mut aggr_result).unwrap();

assert_eq!(aggr_result[0].as_real_slice(), &[Real::new(54.5).ok()]);
}
}
1 change: 1 addition & 0 deletions src/coprocessor/dag/aggr_fn/mod.rs
Expand Up @@ -5,6 +5,7 @@
mod impl_avg;
mod impl_count;
mod impl_first;
mod impl_sum;
mod parser;
mod summable;
mod util;
Expand Down
1 change: 1 addition & 0 deletions src/coprocessor/dag/aggr_fn/parser.rs
Expand Up @@ -43,6 +43,7 @@ pub trait AggrDefinitionParser {
fn map_pb_sig_to_aggr_func_parser(value: ExprType) -> Result<Box<dyn AggrDefinitionParser>> {
match value {
ExprType::Count => Ok(Box::new(super::impl_count::AggrFnDefinitionParserCount)),
ExprType::Sum => Ok(Box::new(super::impl_sum::AggrFnDefinitionParserSum)),
ExprType::Avg => Ok(Box::new(super::impl_avg::AggrFnDefinitionParserAvg)),
ExprType::First => Ok(Box::new(super::impl_first::AggrFnDefinitionParserFirst)),
v => Err(box_err!(
Expand Down