New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coprocessor, dag: support decode Time from tipb.Expr #2199

Merged
merged 43 commits into from Aug 28, 2017
Commits
Jump to file or symbol
Failed to load files and symbols.
+240 −13
Diff settings

Always

Just for now

Viewing a subset of changes. View all

add tests.

  • Loading branch information...
hicqu committed Aug 18, 2017
commit e184bdbd4fa2f048d89257ca1d1c9c42cc27da1f
@@ -13,8 +13,8 @@
use std::{f64, i64, u64};
use std::borrow::Cow;
use coprocessor::codec::{datum, mysql, Datum};
use coprocessor::codec::mysql::{Decimal, MAX_FSP};
use coprocessor::codec::Datum;
use coprocessor::codec::mysql::Decimal;
use super::{Error, FnCall, Result, StatementContext};
impl FnCall {
@@ -26,7 +26,7 @@ impl FnCall {
if !res.is_finite() {
return Err(Error::Overflow);
}
Ok(r)
Ok(res)
})
}
@@ -74,7 +74,7 @@ impl FnCall {
if !res.is_finite() {
return Err(Error::Overflow);
}
Ok(r)
Ok(res)
})
}
@@ -122,7 +122,7 @@ impl FnCall {
if !res.is_finite() {
return Err(Error::Overflow);
}
Ok(r)
Ok(res)
})
}
@@ -175,3 +175,194 @@ where
(Some(lhs), Some(rhs)) => op(lhs, rhs).map(|t| Some(t)),
}
}
#[cfg(test)]
mod test {
use std::{f64, i64, u64};
use tipb::expression::ScalarFuncSig;
use coprocessor::codec::Datum;
use coprocessor::dag::expr::{Error, Expression, StatementContext};
use coprocessor::dag::expr::test::{fncall_expr, str2dec};
use coprocessor::select::xeval::evaluator::test::datum_expr;
fn check_overflow(e: Error) -> Result<(), ()> {
match e {
Error::Overflow => Ok(()),
_ => Err(()),
}
}
#[test]
fn test_arithmetic() {
let tests = vec![
(
ScalarFuncSig::PlusInt,
Datum::Null,
Datum::I64(1),
Datum::Null,
),
(
ScalarFuncSig::PlusInt,
Datum::I64(1),
Datum::Null,
Datum::Null,
),
(
ScalarFuncSig::PlusInt,
Datum::I64(12),
Datum::I64(1),
Datum::I64(13),
),
(
ScalarFuncSig::PlusIntUnsigned,
Datum::U64(12),
Datum::U64(1),
Datum::U64(13),
),
(
ScalarFuncSig::PlusReal,
Datum::F64(1.01001),
Datum::F64(-0.01),
Datum::F64(1.00001),
),
(
ScalarFuncSig::PlusDecimal,
str2dec("1.1"),
str2dec("2.2"),
str2dec("3.3"),
),
];
let ctx = StatementContext::default();
for tt in tests {
let lhs = datum_expr(tt.1);
let rhs = datum_expr(tt.2);
let expected = Expression::build(datum_expr(tt.3), 0).unwrap();
let op = Expression::build(fncall_expr(tt.0, &[lhs, rhs]), 0).unwrap();
match tt.0 {
ScalarFuncSig::PlusInt |
ScalarFuncSig::MinusInt |
ScalarFuncSig::MultiplyInt |
ScalarFuncSig::PlusIntUnsigned |
ScalarFuncSig::MinusIntUnsigned |
ScalarFuncSig::MultiplyIntUnsigned => {
let lhs = op.eval_int(&ctx, &[]).unwrap();
let rhs = expected.eval_int(&ctx, &[]).unwrap();
assert_eq!(lhs, rhs);
}
ScalarFuncSig::PlusReal | ScalarFuncSig::MinusReal => {
let lhs = op.eval_real(&ctx, &[]).unwrap();
let rhs = expected.eval_real(&ctx, &[]).unwrap();
assert_eq!(lhs, rhs);
}
ScalarFuncSig::MultiplyReal => {
let lhs = op.eval_real(&ctx, &[]).unwrap();
let rhs = expected.eval_real(&ctx, &[]).unwrap();
assert_eq!(lhs, rhs);
}
ScalarFuncSig::PlusDecimal |
ScalarFuncSig::MinusDecimal |
ScalarFuncSig::MultiplyDecimal => {
let lhs = op.eval_decimal(&ctx, &[]).unwrap();
let rhs = expected.eval_decimal(&ctx, &[]).unwrap();
assert_eq!(lhs, rhs);
}
_ => unreachable!(),
}
}
}
#[test]
fn test_arithmetic_overflow() {
let tests = vec![
(
ScalarFuncSig::PlusInt,
Datum::I64(i64::MAX),
Datum::I64(i64::MAX),
),
(
ScalarFuncSig::PlusInt,
Datum::I64(i64::MIN),
Datum::I64(i64::MIN),
),
(
ScalarFuncSig::PlusIntUnsigned,
Datum::U64(u64::MAX),
Datum::U64(u64::MAX),
),
(
ScalarFuncSig::PlusReal,
Datum::F64(f64::MAX),
Datum::F64(f64::MAX),
),
(
ScalarFuncSig::MinusInt,
Datum::I64(i64::MIN),
Datum::I64(i64::MAX),
),
(
ScalarFuncSig::MinusInt,
Datum::I64(i64::MAX),
Datum::I64(i64::MIN),
),
(
ScalarFuncSig::MinusIntUnsigned,
Datum::U64(1u64),
Datum::U64(2u64),
),
(
ScalarFuncSig::MinusReal,
Datum::F64(f64::MIN),
Datum::F64(f64::MAX),
),
(
ScalarFuncSig::MultiplyInt,
Datum::I64(i64::MIN),
Datum::I64(i64::MAX),
),
(
ScalarFuncSig::MultiplyIntUnsigned,
Datum::U64(u64::MAX),
Datum::U64(u64::MAX),
),
(
ScalarFuncSig::MultiplyReal,
Datum::F64(f64::MIN),
Datum::F64(f64::MAX),
),
];
let ctx = StatementContext::default();
for tt in tests {
let lhs = datum_expr(tt.1);
let rhs = datum_expr(tt.2);
let op = Expression::build(fncall_expr(tt.0, &[lhs, rhs]), 0).unwrap();
match tt.0 {
ScalarFuncSig::PlusInt |
ScalarFuncSig::MinusInt |
ScalarFuncSig::MultiplyInt |
ScalarFuncSig::PlusIntUnsigned |
ScalarFuncSig::MinusIntUnsigned |
ScalarFuncSig::MultiplyIntUnsigned => {
let lhs = op.eval_int(&ctx, &[]).unwrap_err();
assert!(check_overflow(lhs).is_ok());
}
ScalarFuncSig::PlusReal | ScalarFuncSig::MinusReal => {
let lhs = op.eval_real(&ctx, &[]).unwrap_err();
assert!(check_overflow(lhs).is_ok());
}
ScalarFuncSig::MultiplyReal => {
let lhs = op.eval_real(&ctx, &[]).unwrap_err();
assert!(check_overflow(lhs).is_ok());
}
ScalarFuncSig::PlusDecimal |
ScalarFuncSig::MinusDecimal |
ScalarFuncSig::MultiplyDecimal => {
let lhs = op.eval_decimal(&ctx, &[]).unwrap_err();
assert!(check_overflow(lhs).is_ok());
}
_ => unreachable!(),
}
}
}
}
@@ -18,6 +18,36 @@ impl FnCall {
pub fn check_args(sig: ScalarFuncSig, args: usize) -> Result<()> {
let (min_args, max_args) = match sig {
ScalarFuncSig::LTInt => (2, 2),
ScalarFuncSig::LEInt => (2, 2),
ScalarFuncSig::GTInt => (2, 2),
ScalarFuncSig::GEInt => (2, 2),
ScalarFuncSig::EQInt => (2, 2),
ScalarFuncSig::NEInt => (2, 2),
ScalarFuncSig::NullEQInt => (2, 2),
ScalarFuncSig::LTReal => (2, 2),
ScalarFuncSig::LEReal => (2, 2),
ScalarFuncSig::GTReal => (2, 2),
ScalarFuncSig::GEReal => (2, 2),
ScalarFuncSig::EQReal => (2, 2),
ScalarFuncSig::NEReal => (2, 2),
ScalarFuncSig::NullEQReal => (2, 2),
ScalarFuncSig::PlusReal => (2, 2),
ScalarFuncSig::PlusDecimal => (2, 2),
ScalarFuncSig::PlusIntUnsigned => (2, 2),
ScalarFuncSig::PlusInt => (2, 2),
ScalarFuncSig::MinusReal => (2, 2),
ScalarFuncSig::MinusDecimal => (2, 2),
ScalarFuncSig::MinusIntUnsigned => (2, 2),
ScalarFuncSig::MinusInt => (2, 2),
ScalarFuncSig::MultiplyReal => (2, 2),
ScalarFuncSig::MultiplyDecimal => (2, 2),
ScalarFuncSig::MultiplyIntUnsigned => (2, 2),
ScalarFuncSig::MultiplyInt => (2, 2),
ScalarFuncSig::CastIntAsInt => (1, 1),
_ => unimplemented!(),
};
@@ -201,6 +201,8 @@ impl Expression {
ScalarFuncSig::PlusIntUnsigned => f.plus_uint(ctx, row),
ScalarFuncSig::MinusInt => f.minus_int(ctx, row),
ScalarFuncSig::MinusIntUnsigned => f.minus_uint(ctx, row),
ScalarFuncSig::MultiplyInt => f.multiply_int(ctx, row),
ScalarFuncSig::MultiplyIntUnsigned => f.multiply_uint(ctx, row),
_ => Err(Error::Other("Unknown signature")),
},
@@ -214,6 +216,7 @@ impl Expression {
Expression::ScalarFn(ref f) => match f.sig {
ScalarFuncSig::PlusReal => f.plus_real(ctx, row),
ScalarFuncSig::MinusReal => f.minus_real(ctx, row),
ScalarFuncSig::MultiplyReal => f.multiply_real(ctx, row),
_ => unimplemented!(),
},
}
@@ -230,6 +233,7 @@ impl Expression {
Expression::ScalarFn(ref f) => match f.sig {
ScalarFuncSig::PlusDecimal => f.plus_decimal(ctx, row),
ScalarFuncSig::MinusDecimal => f.minus_decimal(ctx, row),
ScalarFuncSig::MultiplyDecimal => f.multiply_decimal(ctx, row),
_ => unimplemented!(),
},
}
@@ -371,15 +375,21 @@ impl Expression {
#[cfg(test)]
mod test {
use coprocessor::codec::Datum;
use coprocessor::codec::mysql::Decimal;
use coprocessor::select::xeval::evaluator::test::{col_expr, datum_expr};
use tipb::expression::{Expr, ExprType, FieldType, ScalarFuncSig};
use super::Expression;
pub fn fncall_expr(sig: ScalarFuncSig, ft: FieldType, children: &[Expr]) -> Expr {
#[inline]
pub fn str2dec(s: &str) -> Datum {
Datum::Dec(s.parse::<Decimal>().unwrap())
}
pub fn fncall_expr(sig: ScalarFuncSig, children: &[Expr]) -> Expr {
let mut expr = Expr::new();
expr.set_tp(ExprType::ScalarFunc);
expr.set_sig(sig);
expr.set_field_type(ft);
expr.set_field_type(FieldType::new());
for child in children {
expr.mut_children().push(child.clone());
}
@@ -396,16 +406,12 @@ mod test {
(colref.clone(), 2, true),
(constant.clone(), 0, true),
(
fncall_expr(
ScalarFuncSig::LTInt,
FieldType::new(),
&[colref.clone(), constant.clone()],
),
fncall_expr(ScalarFuncSig::LTInt, &[colref.clone(), constant.clone()]),
2,
true,
),
(
fncall_expr(ScalarFuncSig::LTInt, FieldType::new(), &[colref.clone()]),
fncall_expr(ScalarFuncSig::LTInt, &[colref.clone()]),
0,
false,
),
ProTip! Use n and p to navigate between commits in a pull request.