diff --git a/Cargo.lock b/Cargo.lock index 804cfc87cdc..70e4ab1f4e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ [root] name = "tikv" -version = "0.0.1" +version = "0.9.0" dependencies = [ "backtrace 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -957,7 +957,7 @@ dependencies = [ [[package]] name = "tipb" version = "0.0.1" -source = "git+https://github.com/pingcap/tipb.git#fae5d734815d6024d8264a008d5736e362efcc77" +source = "git+https://github.com/pingcap/tipb.git#ffbc1bb02a2425782286825eec276be817b7f93a" dependencies = [ "protobuf 1.4.1 (registry+https://github.com/rust-lang/crates.io-index)", ] diff --git a/src/coprocessor/dag/expr/builtin_control.rs b/src/coprocessor/dag/expr/builtin_control.rs index 910f903aade..67ead246bab 100644 --- a/src/coprocessor/dag/expr/builtin_control.rs +++ b/src/coprocessor/dag/expr/builtin_control.rs @@ -14,7 +14,8 @@ use std::borrow::Cow; use super::{FnCall, Result, StatementContext}; use coprocessor::codec::Datum; -use coprocessor::codec::mysql::{Decimal, Duration, Time}; +use coprocessor::codec::mysql::{Decimal, Duration, Json, Time}; +use coprocessor::dag::expr::Expression; fn if_null(f: F) -> Result> where @@ -44,6 +45,30 @@ where } } +/// See https://dev.mysql.com/doc/refman/5.7/en/case.html +fn case_when<'a, F, T>( + expr: &'a FnCall, + ctx: &StatementContext, + row: &'a [Datum], + f: F, +) -> Result> +where + F: Fn(&'a Expression) -> Result>, +{ + for chunk in expr.children.chunks(2) { + if chunk.len() == 1 { + // else statement + return f(&chunk[0]); + } + let cond = try!(chunk[0].eval_int(ctx, row)); + if cond.unwrap_or(0) == 0 { + continue; + } + return f(&chunk[1]); + } + Ok(None) +} + impl FnCall { pub fn if_null_int(&self, ctx: &StatementContext, row: &[Datum]) -> Result> { if_null(|i| self.children[i].eval_int(ctx, row)) @@ -124,16 +149,68 @@ impl FnCall { ) -> Result>> { if_condition(self, ctx, row, |i| self.children[i].eval_duration(ctx, row)) } + + pub fn case_when_int(&self, ctx: &StatementContext, row: &[Datum]) -> Result> { + case_when(self, ctx, row, |v| v.eval_int(ctx, row)) + } + + pub fn case_when_real(&self, ctx: &StatementContext, row: &[Datum]) -> Result> { + case_when(self, ctx, row, |v| v.eval_real(ctx, row)) + } + + pub fn case_when_decimal<'a, 'b: 'a>( + &'b self, + ctx: &StatementContext, + row: &'a [Datum], + ) -> Result>> { + case_when(self, ctx, row, |v| v.eval_decimal(ctx, row)) + } + + pub fn case_when_string<'a, 'b: 'a>( + &'b self, + ctx: &StatementContext, + row: &'a [Datum], + ) -> Result>>> { + case_when(self, ctx, row, |v| v.eval_string(ctx, row)) + } + + pub fn case_when_time<'a, 'b: 'a>( + &'b self, + ctx: &StatementContext, + row: &'a [Datum], + ) -> Result>> { + case_when(self, ctx, row, |v| v.eval_time(ctx, row)) + } + + pub fn case_when_duration<'a, 'b: 'a>( + &'b self, + ctx: &StatementContext, + row: &'a [Datum], + ) -> Result>> { + case_when(self, ctx, row, |v| v.eval_duration(ctx, row)) + } + + pub fn case_when_json<'a, 'b: 'a>( + &'b self, + ctx: &StatementContext, + row: &'a [Datum], + ) -> Result>> { + case_when(self, ctx, row, |v| v.eval_json(ctx, row)) + } } #[cfg(test)] mod test { - use tipb::expression::ScalarFuncSig; + use protobuf::RepeatedField; + use tipb::expression::{Expr, ExprType, ScalarFuncSig}; + use coprocessor::codec::Datum; - use coprocessor::codec::mysql::Duration; + use coprocessor::codec::mysql::{Duration, Json, Time}; use coprocessor::dag::expr::{Expression, StatementContext}; use coprocessor::dag::expr::test::{fncall_expr, str2dec}; use coprocessor::select::xeval::evaluator::test::datum_expr; + use coprocessor::select::xeval::evaluator::test::col_expr; + #[test] fn test_if_null() { @@ -344,4 +421,77 @@ mod test { assert_eq!(lhs, rhs); } } + + fn cond(ok: bool) -> Datum { + if ok { + Datum::I64(1) + } else { + Datum::I64(0) + } + } + + #[test] + fn test_case_when() { + let dec1 = Datum::Dec("1.1".parse().unwrap()); + let dec2 = Datum::Dec("2.2".parse().unwrap()); + let dur1 = Datum::Dur(Duration::parse(b"01:00:00", 0).unwrap()); + let dur2 = Datum::Dur(Duration::parse(b"12:00:12", 0).unwrap()); + let time1 = Datum::Time(Time::parse_utc_datetime("2012-12-12 12:00:23", 0).unwrap()); + let s = "你好".as_bytes().to_owned(); + + let cases = vec![ + ( + ScalarFuncSig::CaseWhenInt, + vec![cond(true), Datum::I64(3), cond(true), Datum::I64(5)], + Datum::I64(3), + ), + ( + ScalarFuncSig::CaseWhenDecimal, + vec![cond(false), dec1, cond(true), dec2.clone()], + dec2, + ), + ( + ScalarFuncSig::CaseWhenDuration, + vec![Datum::Null, dur1, cond(true), dur2.clone()], + dur2, + ), + (ScalarFuncSig::CaseWhenTime, vec![time1.clone()], time1), + ( + ScalarFuncSig::CaseWhenReal, + vec![cond(false), Datum::Null], + Datum::Null, + ), + ( + ScalarFuncSig::CaseWhenString, + vec![cond(true), Datum::Bytes(s.clone())], + Datum::Bytes(s), + ), + ( + ScalarFuncSig::CaseWhenJson, + vec![ + cond(false), + Datum::Null, + Datum::Null, + Datum::Null, + Datum::Json(Json::I64(23)), + ], + Datum::Json(Json::I64(23)), + ), + ]; + + let ctx = StatementContext::default(); + + for (sig, row, exp) in cases { + let children: Vec = (0..row.len()).map(|id| col_expr(id as i64)).collect(); + let mut expr = Expr::new(); + expr.set_tp(ExprType::ScalarFunc); + expr.set_sig(sig); + + expr.set_children(RepeatedField::from_vec(children)); + let e = Expression::build(expr, &ctx).unwrap(); + let res = e.eval(&ctx, &row).unwrap(); + assert_eq!(res, exp); + } + } + } diff --git a/src/coprocessor/dag/expr/fncall.rs b/src/coprocessor/dag/expr/fncall.rs index f6d1a30a133..af7ddf6c59b 100644 --- a/src/coprocessor/dag/expr/fncall.rs +++ b/src/coprocessor/dag/expr/fncall.rs @@ -12,6 +12,7 @@ // limitations under the License. use std::borrow::Cow; +use std::usize; use tipb::expression::ScalarFuncSig; @@ -174,6 +175,16 @@ impl FnCall { ScalarFuncSig::IfDecimal | ScalarFuncSig::IfTime | ScalarFuncSig::IfDuration => (3, 3), + + ScalarFuncSig::CaseWhenDecimal | + ScalarFuncSig::CaseWhenDuration | + ScalarFuncSig::CaseWhenInt | + ScalarFuncSig::CaseWhenJson | + ScalarFuncSig::CaseWhenReal | + ScalarFuncSig::CaseWhenString | + ScalarFuncSig::CaseWhenTime => (1, usize::MAX), + _ => return Err(Error::UnknownSignature(sig)), + }; if args < min_args || args > max_args { return Err(box_err!("unexpected arguments")); @@ -294,6 +305,7 @@ macro_rules! dispatch_call { $(ScalarFuncSig::$j_sig => { self.$j_func(ctx, row, $($j_arg)*).map(Datum::from) })* + _=> Err(Error::UnknownSignature(self.sig)), } } } @@ -395,6 +407,8 @@ dispatch_call! { IfNullInt => if_null_int, IfInt => if_int, + + CaseWhenInt => case_when_int, } REAL_CALLS { CastIntAsReal => cast_int_as_real, @@ -416,6 +430,8 @@ dispatch_call! { IfNullReal => if_null_real, IfReal => if_real, + + CaseWhenReal => case_when_real, } DEC_CALLS { CastIntAsDecimal => cast_int_as_decimal, @@ -439,6 +455,8 @@ dispatch_call! { IfNullDecimal => if_null_decimal, IfDecimal => if_decimal, + + CaseWhenDecimal => case_when_decimal, } BYTES_CALLS { CastIntAsString => cast_int_as_str, @@ -451,6 +469,8 @@ dispatch_call! { IfNullString => if_null_string, IfString => if_string, + + CaseWhenString => case_when_string, } TIME_CALLS { CastIntAsTime => cast_int_as_time, @@ -463,6 +483,8 @@ dispatch_call! { IfNullTime => if_null_time, IfTime => if_time, + + CaseWhenTime => case_when_time, } DUR_CALLS { CastIntAsDuration => cast_int_as_duration, @@ -475,6 +497,8 @@ dispatch_call! { IfNullDuration => if_null_duration, IfDuration => if_duration, + + CaseWhenDuration => case_when_duration, } JSON_CALLS { CastIntAsJson => cast_int_as_json, @@ -484,5 +508,7 @@ dispatch_call! { CastTimeAsJson => cast_time_as_json, CastDurationAsJson => cast_duration_as_json, CastJsonAsJson => cast_json_as_json, + + CaseWhenJson => case_when_json, } }