Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: fix wrong `DATE/DATETIME` comparison in `BETWEEN` function #10313

Merged
merged 6 commits into from May 6, 2019
Merged
Changes from 1 commit
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

use GetCmpTp4MinMax

  • Loading branch information...
erjiaqing committed May 5, 2019
commit d44fe55a7658b9bc89cf3f1d19c7d857896bed75
@@ -430,12 +430,3 @@ Projection_7 10000.00 root 6_aux_0
│ └─TableScan_9 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo
└─TableReader_12 10000.00 root data:TableScan_11
└─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo
drop table if exists t1;
create table t1 (f2 datetime);
insert into t1 values('2001-01-01 01:01:01');
select f2 from t1 where '2001-04-10 12:34:56' between f2 and '01-05-01';
f2
2001-01-01 01:01:01
select 1 from t1 where 20010410123456 between cast('2001-01-01 12:34:56' as datetime) and 010501;
1
drop table t1;
@@ -201,13 +201,3 @@ explain select a in (select a+b from t t2 where t2.b = t1.b) from t t1;
drop table t;
create table t(a int not null, b int);
explain select a in (select a from t t2 where t2.b = t1.b) from t t1;

# issue 9764
drop table if exists t1;
create table t1 (f2 datetime);
insert into t1 values('2001-01-01 01:01:01');
# convert string to DATETIME if one of fields in between is DATETIME
select f2 from t1 where '2001-04-10 12:34:56' between f2 and '01-05-01';
# do not do that if one of them is int, convert them to int instead
select 1 from t1 where 20010410123456 between cast('2001-01-01 12:34:56' as datetime) and 010501;
drop table t1;
@@ -379,8 +379,8 @@ func temporalWithDateAsNumEvalType(argTp *types.FieldType) (argEvalType types.Ev
return
}

// getCmpTp4MinMax gets compare type for GREATEST and LEAST.
func getCmpTp4MinMax(args []Expression) (argTp types.EvalType) {
// GetCmpTp4MinMax gets compare type for GREATEST and LEAST and BETWEEN (mainly for datetime).
func GetCmpTp4MinMax(args []Expression) (argTp types.EvalType) {
datetimeFound, isAllStr := false, true
cmpEvalType, isStr, isTemporalWithDate := temporalWithDateAsNumEvalType(args[0].GetType())
if !isStr {
@@ -421,7 +421,7 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp, cmpAsDatetime := getCmpTp4MinMax(args), false
tp, cmpAsDatetime := GetCmpTp4MinMax(args), false
if tp == types.ETDatetime {
cmpAsDatetime = true
tp = types.ETString
@@ -615,7 +615,7 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp, cmpAsDatetime := getCmpTp4MinMax(args), false
tp, cmpAsDatetime := GetCmpTp4MinMax(args), false
if tp == types.ETDatetime {
cmpAsDatetime = true
tp = types.ETString
@@ -954,18 +954,9 @@ type compareFunctionClass struct {
op opcode.Op
}

// AggCmpType aggregates extends getBaseCmpType to get type when comparing three or more fields.
func AggCmpType(fields ...*types.FieldType) types.EvalType {
ret := fields[0].EvalType()
for i := 1; i < len(fields); i++ {
ret = getBaseCmpType(ret, fields[i].EvalType(), nil, fields[i])
}
return ret
}

// getBaseCmpType gets the EvalType that the two args will be treated as when comparing.
func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.EvalType {
if lft != nil && rft != nil && (lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified) {
if lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified {
if lft.Tp == rft.Tp {
return types.ETString
}
@@ -977,10 +968,10 @@ func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.Ev
}
if lhs.IsStringKind() && rhs.IsStringKind() {
return types.ETString
} else if (lhs == types.ETInt || (lft != nil && lft.Hybrid())) && (rhs == types.ETInt || (rft != nil && rft.Hybrid())) {
} else if (lhs == types.ETInt || lft.Hybrid()) && (rhs == types.ETInt || rft.Hybrid()) {
return types.ETInt
} else if ((lhs == types.ETInt || (lft != nil && lft.Hybrid())) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || (rft != nil && rft.Hybrid())) || rhs == types.ETDecimal) {
} else if ((lhs == types.ETInt || lft.Hybrid()) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || rft.Hybrid()) || rhs == types.ETDecimal) {
return types.ETDecimal
}
return types.ETReal
@@ -1262,19 +1262,10 @@ func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) {

expr, lexp, rexp := er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1]

if expression.AggCmpType(expr.GetType(), lexp.GetType(), rexp.GetType()) == types.ETString {
containsDateTime := false
for _, v := range []expression.Expression{expr, lexp, rexp} {
if v.GetType().EvalType() == types.ETDatetime {
containsDateTime = true
break
}
}
if containsDateTime {
expr = expression.WrapWithCastAsTime(er.ctx, expr, types.NewFieldType(mysql.TypeDatetime))
lexp = expression.WrapWithCastAsTime(er.ctx, lexp, types.NewFieldType(mysql.TypeDatetime))
rexp = expression.WrapWithCastAsTime(er.ctx, rexp, types.NewFieldType(mysql.TypeDatetime))
}
if expression.GetCmpTp4MinMax([]expression.Expression{expr, lexp, rexp}) == types.ETDatetime {
expr = expression.WrapWithCastAsTime(er.ctx, expr, types.NewFieldType(mysql.TypeDatetime))
lexp = expression.WrapWithCastAsTime(er.ctx, lexp, types.NewFieldType(mysql.TypeDatetime))
rexp = expression.WrapWithCastAsTime(er.ctx, rexp, types.NewFieldType(mysql.TypeDatetime))
}

var op string
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.