diff --git a/src/coprocessor/dag/aggr_fn/impl_bit_op.rs b/src/coprocessor/dag/aggr_fn/impl_bit_op.rs new file mode 100644 index 00000000000..aba48e1c747 --- /dev/null +++ b/src/coprocessor/dag/aggr_fn/impl_bit_op.rs @@ -0,0 +1,461 @@ +// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. + +use std::convert::TryFrom; + +use cop_codegen::AggrFunction; +use cop_datatype::{EvalType, FieldTypeAccessor}; +use tipb::expression::{Expr, ExprType, FieldType}; + +use crate::coprocessor::codec::data_type::*; +use crate::coprocessor::codec::mysql::Tz; +use crate::coprocessor::dag::expr::EvalContext; +use crate::coprocessor::dag::rpn_expr::{RpnExpression, RpnExpressionBuilder}; +use crate::coprocessor::{Error, Result}; + +/// A trait for all bit operations +pub trait BitOp: Clone + std::fmt::Debug + Send + Sync + 'static { + /// Returns the bit operation type + fn tp() -> ExprType; + + /// Returns the bit operation initial state + fn init_state() -> u64; + + /// Executes the special bit operation + fn op(lhs: &mut u64, rhs: u64); +} + +macro_rules! bit_op { + ($name:ident, $tp:path, $init:tt, $op:tt) => { + #[derive(Debug, Clone, Copy)] + pub struct $name; + impl BitOp for $name { + fn tp() -> ExprType { + $tp + } + + fn init_state() -> u64 { + $init + } + + fn op(lhs: &mut u64, rhs: u64) { + *lhs $op rhs + } + } + }; +} + +bit_op!(BitAnd, ExprType::Agg_BitAnd, 0xffff_ffff_ffff_ffff, &=); +bit_op!(BitOr, ExprType::Agg_BitOr, 0, |=); +bit_op!(BitXor, ExprType::Agg_BitXor, 0, ^=); + +/// The parser for bit operation aggregate functions. +pub struct AggrFnDefinitionParserBitOp(std::marker::PhantomData); + +impl AggrFnDefinitionParserBitOp { + pub fn new() -> Self { + AggrFnDefinitionParserBitOp(std::marker::PhantomData) + } +} + +impl super::AggrDefinitionParser for AggrFnDefinitionParserBitOp { + fn check_supported(&self, aggr_def: &Expr) -> Result<()> { + assert_eq!(aggr_def.get_tp(), T::tp()); + + super::util::check_aggr_exp_supported_one_child(aggr_def)?; + + // Check whether or not the children's field type is supported. Currently we only support + // Int and does not support other types (which need casting). + // TODO: remove this check after implementing `CAST as Int` + let child = &aggr_def.get_children()[0]; + let eval_type = EvalType::try_from(child.get_field_type().tp()) + .map_err(|e| Error::Other(box_err!(e)))?; + match eval_type { + EvalType::Int => {} + _ => return Err(box_err!("Cast from {:?} is not supported", eval_type)), + } + Ok(()) + } + + fn parse( + &self, + mut aggr_def: Expr, + time_zone: &Tz, + // We use the same structure for all data types, so this parameter is not needed. + src_schema: &[FieldType], + out_schema: &mut Vec, + out_exp: &mut Vec, + ) -> Result> { + assert_eq!(aggr_def.get_tp(), T::tp()); + + // bit operation outputs one column. + out_schema.push(aggr_def.take_field_type()); + + // Rewrite expression to insert CAST() if needed. + let child = aggr_def.take_children().into_iter().next().unwrap(); + let mut exp = + RpnExpressionBuilder::build_from_expr_tree(child, time_zone, src_schema.len())?; + super::util::rewrite_exp_for_bit_op(src_schema, &mut exp).unwrap(); + out_exp.push(exp); + + Ok(Box::new(AggrFnBitOp::(std::marker::PhantomData))) + } +} + +/// The bit operation aggregate functions. +#[derive(Debug, AggrFunction)] +#[aggr_function(state = AggrFnStateBitOp::::new())] +pub struct AggrFnBitOp(std::marker::PhantomData); + +/// The state of the BitAnd aggregate function. +#[derive(Debug)] +pub struct AggrFnStateBitOp { + c: u64, + _phantom: std::marker::PhantomData, +} + +impl AggrFnStateBitOp { + pub fn new() -> Self { + Self { + c: T::init_state(), + _phantom: std::marker::PhantomData, + } + } +} + +impl super::ConcreteAggrFunctionState for AggrFnStateBitOp { + type ParameterType = Int; + + #[inline] + fn update_concrete( + &mut self, + _ctx: &mut EvalContext, + value: &Option, + ) -> Result<()> { + match value { + None => Ok(()), + Some(value) => { + T::op(&mut self.c, *value as u64); + Ok(()) + } + } + } + + #[inline] + fn push_result(&self, _ctx: &mut EvalContext, target: &mut [VectorValue]) -> Result<()> { + target[0].push_int(Some(self.c as Int)); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use cop_datatype::EvalType; + use tipb_helper::ExprDefBuilder; + + use super::super::AggrFunction; + use super::*; + use crate::coprocessor::codec::batch::{LazyBatchColumn, LazyBatchColumnVec}; + use crate::coprocessor::dag::aggr_fn::parser::AggrDefinitionParser; + use cop_datatype::{FieldTypeAccessor, FieldTypeTp}; + + #[test] + fn test_bit_and() { + let mut ctx = EvalContext::default(); + let function = AggrFnBitOp::(std::marker::PhantomData); + let mut state = function.create_state(); + + let mut result = [VectorValue::with_capacity(0, EvalType::Int)]; + + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!( + result[0].as_int_slice(), + &[Some(0xffff_ffff_ffff_ffff_u64 as i64)] + ); + + state.update(&mut ctx, &Option::::None).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!( + result[0].as_int_slice(), + &[Some(0xffff_ffff_ffff_ffff_u64 as i64)] + ); + + // 7 & 4 == 4 + state.update(&mut ctx, &Some(7i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(7)]); + + state.update(&mut ctx, &Some(4i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(4)]); + + state.update_repeat(&mut ctx, &Some(4), 10).unwrap(); + state + .update_repeat(&mut ctx, &Option::::None, 7) + .unwrap(); + + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(4)]); + + // Reset the state + let mut state = function.create_state(); + // 7 & 1 == 1 + state.update(&mut ctx, &Some(7i64)).unwrap(); + state + .update_vector(&mut ctx, &[Some(1i64), None, Some(1i64)]) + .unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(1)]); + + // 7 & 1 & 2 == 0 + state.update(&mut ctx, &Some(2i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(0)]); + } + + #[test] + fn test_bit_or() { + let mut ctx = EvalContext::default(); + let function = AggrFnBitOp::(std::marker::PhantomData); + let mut state = function.create_state(); + + let mut result = [VectorValue::with_capacity(0, EvalType::Int)]; + + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(0)]); + + state.update(&mut ctx, &Option::::None).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(0)]); + + // 1 | 4 == 5 + state.update(&mut ctx, &Some(1i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(1)]); + + state.update(&mut ctx, &Some(4i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(5)]); + + state.update_repeat(&mut ctx, &Some(8), 10).unwrap(); + state + .update_repeat(&mut ctx, &Option::::None, 7) + .unwrap(); + + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(13)]); + + // 13 | 2 == 15 + state.update(&mut ctx, &Some(2i64)).unwrap(); + state + .update_vector(&mut ctx, &[Some(2i64), None, Some(1i64)]) + .unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(15)]); + + // 15 | 2 == 15 + state.update(&mut ctx, &Some(2i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(15)]); + + // 15 | 2 | -1 == 18446744073709551615 + state.update(&mut ctx, &Some(-1i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!( + result[0].as_int_slice(), + &[Some(18446744073709551615u64 as i64)] + ); + } + + #[test] + fn test_bit_xor() { + let mut ctx = EvalContext::default(); + let function = AggrFnBitOp::(std::marker::PhantomData); + let mut state = function.create_state(); + + let mut result = [VectorValue::with_capacity(0, EvalType::Int)]; + + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(0)]); + + state.update(&mut ctx, &Option::::None).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(0)]); + + // 1 ^ 5 == 4 + state.update(&mut ctx, &Some(1i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(1)]); + + state.update(&mut ctx, &Some(5i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(4)]); + + // 1 ^ 5 ^ 8 == 12 + state.update_repeat(&mut ctx, &Some(8), 9).unwrap(); + state + .update_repeat(&mut ctx, &Option::::None, 7) + .unwrap(); + + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(12)]); + + // Will not change due to xor even times + state.update_repeat(&mut ctx, &Some(9), 10).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(12)]); + + // 1 ^ 5 ^ 8 ^ ^ 2 ^ 2 ^ 1 == 13 + state.update(&mut ctx, &Some(2i64)).unwrap(); + state + .update_vector(&mut ctx, &[Some(2i64), None, Some(1i64)]) + .unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(13)]); + + // 13 ^ 2 == 15 + state.update(&mut ctx, &Some(2i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!(result[0].as_int_slice(), &[Some(15)]); + + // 15 ^ 2 ^ -1 == 18446744073709551602 + state.update(&mut ctx, &Some(2i64)).unwrap(); + state.update(&mut ctx, &Some(-1i64)).unwrap(); + result[0].clear(); + state.push_result(&mut ctx, &mut result).unwrap(); + assert_eq!( + result[0].as_int_slice(), + &[Some(18446744073709551602u64 as i64)] + ); + } + + #[test] + fn test_integration() { + let bit_and_parser = AggrFnDefinitionParserBitOp::::new(); + let bit_or_parser = AggrFnDefinitionParserBitOp::::new(); + let bit_xor_parser = AggrFnDefinitionParserBitOp::::new(); + + let bit_and = ExprDefBuilder::aggr_func(ExprType::Agg_BitAnd, FieldTypeTp::LongLong) + .push_child(ExprDefBuilder::column_ref(0, FieldTypeTp::LongLong)) + .build(); + bit_and_parser.check_supported(&bit_and).unwrap(); + + let bit_or = ExprDefBuilder::aggr_func(ExprType::Agg_BitOr, FieldTypeTp::LongLong) + .push_child(ExprDefBuilder::column_ref(0, FieldTypeTp::LongLong)) + .build(); + bit_or_parser.check_supported(&bit_or).unwrap(); + + let bit_xor = ExprDefBuilder::aggr_func(ExprType::Agg_BitXor, FieldTypeTp::LongLong) + .push_child(ExprDefBuilder::column_ref(0, FieldTypeTp::LongLong)) + .build(); + bit_xor_parser.check_supported(&bit_xor).unwrap(); + + let src_schema = [FieldTypeTp::LongLong.into()]; + let mut columns = LazyBatchColumnVec::from(vec![{ + let mut col = LazyBatchColumn::decoded_with_capacity_and_tp(0, EvalType::Int); + col.mut_decoded().push_int(Some(1)); + col.mut_decoded().push_int(Some(23)); + col.mut_decoded().push_int(Some(42)); + col.mut_decoded().push_int(None); + col.mut_decoded().push_int(Some(99)); + col.mut_decoded().push_int(Some(-1)); + col + }]); + + let mut schema = vec![]; + let mut exp = vec![]; + + let bit_and_fn = bit_and_parser + .parse(bit_and, &Tz::utc(), &src_schema, &mut schema, &mut exp) + .unwrap(); + assert_eq!(schema.len(), 1); + assert_eq!(schema[0].tp(), FieldTypeTp::LongLong); + assert_eq!(exp.len(), 1); + + let bit_or_fn = bit_or_parser + .parse(bit_or, &Tz::utc(), &src_schema, &mut schema, &mut exp) + .unwrap(); + assert_eq!(schema.len(), 2); + assert_eq!(schema[1].tp(), FieldTypeTp::LongLong); + assert_eq!(exp.len(), 2); + + let bit_xor_fn = bit_xor_parser + .parse(bit_xor, &Tz::utc(), &src_schema, &mut schema, &mut exp) + .unwrap(); + assert_eq!(schema.len(), 3); + assert_eq!(schema[2].tp(), FieldTypeTp::LongLong); + assert_eq!(exp.len(), 3); + + let mut ctx = EvalContext::default(); + let mut bit_and_state = bit_and_fn.create_state(); + let mut bit_or_state = bit_or_fn.create_state(); + let mut bit_xor_state = bit_xor_fn.create_state(); + + let mut aggr_result = [VectorValue::with_capacity(0, EvalType::Int)]; + + // bit and + { + let bit_and_result = exp[0].eval(&mut ctx, 6, &src_schema, &mut columns).unwrap(); + assert!(bit_and_result.is_vector()); + let bit_and_slice: &[Option] = bit_and_result.vector_value().unwrap().as_ref(); + bit_and_state + .update_vector(&mut ctx, bit_and_slice) + .unwrap(); + bit_and_state + .push_result(&mut ctx, &mut aggr_result) + .unwrap(); + } + + // bit or + { + let bit_or_result = exp[1].eval(&mut ctx, 6, &src_schema, &mut columns).unwrap(); + assert!(bit_or_result.is_vector()); + let bit_or_slice: &[Option] = bit_or_result.vector_value().unwrap().as_ref(); + bit_or_state.update_vector(&mut ctx, bit_or_slice).unwrap(); + bit_or_state + .push_result(&mut ctx, &mut aggr_result) + .unwrap(); + } + + // bit xor + { + let bit_xor_result = exp[2].eval(&mut ctx, 6, &src_schema, &mut columns).unwrap(); + assert!(bit_xor_result.is_vector()); + let bit_xor_slice: &[Option] = bit_xor_result.vector_value().unwrap().as_ref(); + bit_xor_state + .update_vector(&mut ctx, bit_xor_slice) + .unwrap(); + bit_xor_state + .push_result(&mut ctx, &mut aggr_result) + .unwrap(); + } + + assert_eq!( + aggr_result[0].as_int_slice(), + &[ + Some(0), + Some(18446744073709551615u64 as i64), + Some(18446744073709551520u64 as i64) + ] + ); + } +} diff --git a/src/coprocessor/dag/aggr_fn/mod.rs b/src/coprocessor/dag/aggr_fn/mod.rs index afe1b73db5e..db7625f780f 100644 --- a/src/coprocessor/dag/aggr_fn/mod.rs +++ b/src/coprocessor/dag/aggr_fn/mod.rs @@ -3,6 +3,7 @@ //! This module provides aggregate functions for batch executors. mod impl_avg; +mod impl_bit_op; mod impl_count; mod impl_first; mod impl_sum; diff --git a/src/coprocessor/dag/aggr_fn/parser.rs b/src/coprocessor/dag/aggr_fn/parser.rs index 0a7fcd1dcad..8b2c81cdd44 100644 --- a/src/coprocessor/dag/aggr_fn/parser.rs +++ b/src/coprocessor/dag/aggr_fn/parser.rs @@ -3,6 +3,7 @@ use tipb::expression::{Expr, ExprType, FieldType}; use crate::coprocessor::codec::mysql::Tz; +use crate::coprocessor::dag::aggr_fn::impl_bit_op::*; use crate::coprocessor::dag::aggr_fn::AggrFunction; use crate::coprocessor::dag::rpn_expr::RpnExpression; use crate::coprocessor::{Error, Result}; @@ -46,6 +47,9 @@ fn map_pb_sig_to_aggr_func_parser(value: ExprType) -> Result Ok(Box::new(super::impl_sum::AggrFnDefinitionParserSum)), ExprType::Avg => Ok(Box::new(super::impl_avg::AggrFnDefinitionParserAvg)), ExprType::First => Ok(Box::new(super::impl_first::AggrFnDefinitionParserFirst)), + ExprType::Agg_BitAnd => Ok(Box::new(AggrFnDefinitionParserBitOp::::new())), + ExprType::Agg_BitOr => Ok(Box::new(AggrFnDefinitionParserBitOp::::new())), + ExprType::Agg_BitXor => Ok(Box::new(AggrFnDefinitionParserBitOp::::new())), v => Err(box_err!( "Aggregation function expr type {:?} is not supported in batch mode", v diff --git a/src/coprocessor/dag/aggr_fn/util.rs b/src/coprocessor/dag/aggr_fn/util.rs index 46e4962a4ee..257868134f7 100644 --- a/src/coprocessor/dag/aggr_fn/util.rs +++ b/src/coprocessor/dag/aggr_fn/util.rs @@ -3,7 +3,7 @@ use std::convert::TryFrom; use cop_datatype::builder::FieldTypeBuilder; -use cop_datatype::{EvalType, FieldTypeAccessor, FieldTypeTp}; +use cop_datatype::{EvalType, FieldTypeAccessor, FieldTypeFlag, FieldTypeTp}; use tipb::expression::{Expr, FieldType}; use crate::coprocessor::dag::rpn_expr::impl_cast::get_cast_fn; @@ -57,3 +57,24 @@ pub fn rewrite_exp_for_sum_avg(schema: &[FieldType], exp: &mut RpnExpression) -> }); Ok(()) } + +/// Rewrites the expression to insert necessary cast functions for Bit operation family functions. +pub fn rewrite_exp_for_bit_op(schema: &[FieldType], exp: &mut RpnExpression) -> Result<()> { + let ret_field_type = exp.ret_field_type(schema); + let ret_eval_type = box_try!(EvalType::try_from(ret_field_type.tp())); + let new_ret_field_type = match ret_eval_type { + EvalType::Int => { + return Ok(()); + } + _ => FieldTypeBuilder::new() + .tp(FieldTypeTp::LongLong) + .flag(FieldTypeFlag::UNSIGNED) + .build(), + }; + let func = get_cast_fn(ret_field_type, &new_ret_field_type)?; + exp.push(RpnExpressionNode::FnCall { + func, + field_type: new_ret_field_type, + }); + Ok(()) +}