Skip to content

Commit

Permalink
*: check sc.IgnoreZeroInDate when parsing string or number to date/da…
Browse files Browse the repository at this point in the history
…tetime/timestamp (#4732)
  • Loading branch information
XuHuaiyu committed Oct 11, 2017
1 parent 2eab78f commit 7d2804e
Show file tree
Hide file tree
Showing 21 changed files with 245 additions and 175 deletions.
5 changes: 5 additions & 0 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,20 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) {
sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.InUpdateOrDeleteStmt = true
sc.DividedByZeroAsWarning = stmt.IgnoreErr
sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr
case *ast.DeleteStmt:
sc.IgnoreTruncate = false
sc.OverflowAsWarning = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.InUpdateOrDeleteStmt = true
sc.DividedByZeroAsWarning = stmt.IgnoreErr
sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr
case *ast.InsertStmt:
sc.IgnoreTruncate = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.InInsertStmt = true
sc.DividedByZeroAsWarning = stmt.IgnoreErr
sc.IgnoreZeroInDate = !sessVars.StrictSQLMode || stmt.IgnoreErr
case *ast.CreateTableStmt, *ast.AlterTableStmt:
// Make sure the sql_mode is strict when checking column default value.
sc.IgnoreTruncate = false
Expand All @@ -374,6 +377,7 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) {
// Return warning for truncate error in selection.
sc.IgnoreTruncate = false
sc.TruncateAsWarning = true
sc.IgnoreZeroInDate = true
if opts := stmt.SelectStmtOpts; opts != nil {
sc.Priority = opts.Priority
sc.NotFillCache = !opts.SQLCache
Expand All @@ -387,6 +391,7 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) {
sc.SetWarnings(sessVars.StmtCtx.GetWarnings())
}
}
sc.IgnoreZeroInDate = true
}
if sessVars.LastInsertID > 0 {
sessVars.PrevLastInsertID = sessVars.LastInsertID
Expand Down
13 changes: 13 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ func (s *testSuite) TestInsert(c *C) {
tk.MustExec("INSERT INTO t VALUES (1.000000);")
r = tk.MustQuery("SHOW WARNINGS;")
r.Check(testkit.Rows())

// issue 4653
tk.MustExec("DROP TABLE IF EXISTS t;")
tk.MustExec("CREATE TABLE t(a datetime);")
_, err = tk.Exec("INSERT INTO t VALUES('2017-00-00')")
c.Assert(err, NotNil)
tk.MustExec("set sql_mode = ''")
tk.MustExec("INSERT INTO t VALUES('2017-00-00')")
r = tk.MustQuery("SELECT * FROM t;")
r.Check(testkit.Rows("2017-00-00 00:00:00"))
tk.MustExec("set sql_mode = 'strict_all_tables';")
r = tk.MustQuery("SELECT * FROM t;")
r.Check(testkit.Rows("2017-00-00 00:00:00"))
}

func (s *testSuite) TestInsertAutoInc(c *C) {
Expand Down
32 changes: 23 additions & 9 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,16 @@ func (b *builtinCastIntAsTimeSig) evalTime(row []types.Datum) (res types.Time, i
if isNull || err != nil {
return res, isNull, errors.Trace(err)
}
res, err = types.ParseTimeFromNum(val, b.tp.Tp, b.tp.Decimal)
res, err = types.ParseTimeFromNum(sc, val, b.tp.Tp, b.tp.Decimal)
if err != nil {
return res, true, errors.Trace(err)
}
if b.tp.Tp == mysql.TypeDate {
// Truncate hh:mm:ss part if the type is Date.
res.Time = types.FromDate(res.Time.Year(), res.Time.Month(), res.Time.Day(), 0, 0, 0, 0)
}
res.TimeZone = sc.TimeZone
return res, false, errors.Trace(err)
return res, false, nil
}

type builtinCastIntAsDurationSig struct {
Expand Down Expand Up @@ -706,7 +709,7 @@ func (b *builtinCastRealAsTimeSig) evalTime(row []types.Datum) (types.Time, bool
if isNull || err != nil {
return types.Time{}, true, errors.Trace(err)
}
res, err := types.ParseTime(strconv.FormatFloat(val, 'f', -1, 64), b.tp.Tp, b.tp.Decimal)
res, err := types.ParseTime(sc, strconv.FormatFloat(val, 'f', -1, 64), b.tp.Tp, b.tp.Decimal)
if err != nil {
return types.Time{}, true, errors.Trace(err)
}
Expand All @@ -723,7 +726,8 @@ type builtinCastRealAsDurationSig struct {
}

func (b *builtinCastRealAsDurationSig) evalDuration(row []types.Datum) (res types.Duration, isNull bool, err error) {
val, isNull, err := b.args[0].EvalReal(row, b.getCtx().GetSessionVars().StmtCtx)
sc := b.getCtx().GetSessionVars().StmtCtx
val, isNull, err := b.args[0].EvalReal(row, sc)
if isNull || err != nil {
return res, isNull, errors.Trace(err)
}
Expand Down Expand Up @@ -816,7 +820,10 @@ func (b *builtinCastDecimalAsTimeSig) evalTime(row []types.Datum) (res types.Tim
if isNull || err != nil {
return res, isNull, errors.Trace(err)
}
res, err = types.ParseTime(string(val.ToString()), b.tp.Tp, b.tp.Decimal)
res, err = types.ParseTime(sc, string(val.ToString()), b.tp.Tp, b.tp.Decimal)
if err != nil {
return res, false, errors.Trace(err)
}
if b.tp.Tp == mysql.TypeDate {
// Truncate hh:mm:ss part if the type is Date.
res.Time = types.FromDate(res.Time.Year(), res.Time.Month(), res.Time.Day(), 0, 0, 0, 0)
Expand All @@ -830,7 +837,8 @@ type builtinCastDecimalAsDurationSig struct {
}

func (b *builtinCastDecimalAsDurationSig) evalDuration(row []types.Datum) (res types.Duration, isNull bool, err error) {
val, isNull, err := b.args[0].EvalDecimal(row, b.getCtx().GetSessionVars().StmtCtx)
sc := b.getCtx().GetSessionVars().StmtCtx
val, isNull, err := b.args[0].EvalDecimal(row, sc)
if isNull || err != nil {
return res, false, errors.Trace(err)
}
Expand Down Expand Up @@ -969,7 +977,10 @@ func (b *builtinCastStringAsTimeSig) evalTime(row []types.Datum) (res types.Time
if isNull || err != nil {
return res, isNull, errors.Trace(err)
}
res, err = types.ParseTime(val, b.tp.Tp, b.tp.Decimal)
res, err = types.ParseTime(sc, val, b.tp.Tp, b.tp.Decimal)
if err != nil {
return res, false, errors.Trace(err)
}
if b.tp.Tp == mysql.TypeDate {
// Truncate hh:mm:ss part if the type is Date.
res.Time = types.FromDate(res.Time.Year(), res.Time.Month(), res.Time.Day(), 0, 0, 0, 0)
Expand Down Expand Up @@ -1006,7 +1017,7 @@ func (b *builtinCastTimeAsTimeSig) evalTime(row []types.Datum) (res types.Time,
return res, isNull, errors.Trace(err)
}

if res, err = res.Convert(b.tp.Tp); err != nil {
if res, err = res.Convert(sc, b.tp.Tp); err != nil {
return res, true, errors.Trace(err)
}
res, err = res.RoundFrac(b.tp.Decimal)
Expand Down Expand Up @@ -1265,7 +1276,10 @@ func (b *builtinCastJSONAsTimeSig) evalTime(row []types.Datum) (res types.Time,
if err != nil {
return res, false, errors.Trace(err)
}
res, err = types.ParseTime(s, b.tp.Tp, b.tp.Decimal)
res, err = types.ParseTime(sc, s, b.tp.Tp, b.tp.Decimal)
if err != nil {
return res, false, errors.Trace(err)
}
if b.tp.Tp == mysql.TypeDate {
// Truncate hh:mm:ss part if the type is Date.
res.Time = types.FromDate(res.Time.Year(), res.Time.Month(), res.Time.Day(), 0, 0, 0, 0)
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ func (b *builtinGreatestTimeSig) evalString(row []types.Datum) (_ string, isNull
if isNull || err != nil {
return "", true, errors.Trace(err)
}
t, err = types.ParseDatetime(v)
t, err = types.ParseDatetime(sc, v)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return v, true, errors.Trace(err)
Expand Down Expand Up @@ -670,7 +670,7 @@ func (b *builtinLeastTimeSig) evalString(row []types.Datum) (res string, isNull
if isNull || err != nil {
return "", true, errors.Trace(err)
}
t, err = types.ParseDatetime(v)
t, err = types.ParseDatetime(sc, v)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return v, true, errors.Trace(err)
Expand Down
Loading

0 comments on commit 7d2804e

Please sign in to comment.