Skip to content

Commit

Permalink
*: implement tidb_bounded_staleness built-in function (#24328)
Browse files Browse the repository at this point in the history
  • Loading branch information
JmPotato committed May 18, 2021
1 parent 9148ff9 commit e9488ce
Show file tree
Hide file tree
Showing 17 changed files with 421 additions and 21 deletions.
5 changes: 3 additions & 2 deletions executor/show_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1102,9 +1102,10 @@ func (s *testSuite5) TestShowBuiltin(c *C) {
res := tk.MustQuery("show builtins;")
c.Assert(res, NotNil)
rows := res.Rows()
c.Assert(268, Equals, len(rows))
const builtinFuncNum = 269
c.Assert(builtinFuncNum, Equals, len(rows))
c.Assert("abs", Equals, rows[0][0].(string))
c.Assert("yearweek", Equals, rows[267][0].(string))
c.Assert("yearweek", Equals, rows[builtinFuncNum-1][0].(string))
}

func (s *testSuite5) TestShowClusterConfig(c *C) {
Expand Down
4 changes: 3 additions & 1 deletion expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,9 @@ var funcs = map[string]functionClass{
ast.Year: &yearFunctionClass{baseFunctionClass{ast.Year, 1, 1}},
ast.YearWeek: &yearWeekFunctionClass{baseFunctionClass{ast.YearWeek, 1, 2}},
ast.LastDay: &lastDayFunctionClass{baseFunctionClass{ast.LastDay, 1, 1}},
// TSO functions
ast.TiDBBoundedStaleness: &tidbBoundedStalenessFunctionClass{baseFunctionClass{ast.TiDBBoundedStaleness, 2, 2}},
ast.TiDBParseTso: &tidbParseTsoFunctionClass{baseFunctionClass{ast.TiDBParseTso, 1, 1}},

// string functions
ast.ASCII: &asciiFunctionClass{baseFunctionClass{ast.ASCII, 1, 1}},
Expand Down Expand Up @@ -881,7 +884,6 @@ var funcs = map[string]functionClass{
// This function is used to show tidb-server version info.
ast.TiDBVersion: &tidbVersionFunctionClass{baseFunctionClass{ast.TiDBVersion, 0, 0}},
ast.TiDBIsDDLOwner: &tidbIsDDLOwnerFunctionClass{baseFunctionClass{ast.TiDBIsDDLOwner, 0, 0}},
ast.TiDBParseTso: &tidbParseTsoFunctionClass{baseFunctionClass{ast.TiDBParseTso, 1, 1}},
ast.TiDBDecodePlan: &tidbDecodePlanFunctionClass{baseFunctionClass{ast.TiDBDecodePlan, 1, 1}},

// TiDB Sequence function.
Expand Down
95 changes: 95 additions & 0 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/cznic/mathutil"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
Expand Down Expand Up @@ -7113,3 +7114,97 @@ func handleInvalidZeroTime(ctx sessionctx.Context, t types.Time) (bool, error) {
}
return true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String()))
}

// tidbBoundedStalenessFunctionClass reads a time window [a, b] and compares it with the latest SafeTS
// to determine which TS to use in a read only transaction.
type tidbBoundedStalenessFunctionClass struct {
baseFunctionClass
}

func (c *tidbBoundedStalenessFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime, types.ETDatetime)
if err != nil {
return nil, err
}
sig := &builtinTiDBBoundedStalenessSig{bf}
return sig, nil
}

type builtinTiDBBoundedStalenessSig struct {
baseBuiltinFunc
}

func (b *builtinTiDBBoundedStalenessSig) Clone() builtinFunc {
newSig := &builtinTidbParseTsoSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinTiDBBoundedStalenessSig) evalTime(row chunk.Row) (types.Time, bool, error) {
leftTime, isNull, err := b.args[0].EvalTime(b.ctx, row)
if isNull || err != nil {
return types.ZeroTime, true, handleInvalidTimeError(b.ctx, err)
}
rightTime, isNull, err := b.args[1].EvalTime(b.ctx, row)
if isNull || err != nil {
return types.ZeroTime, true, handleInvalidTimeError(b.ctx, err)
}
if invalidLeftTime, invalidRightTime := leftTime.InvalidZero(), rightTime.InvalidZero(); invalidLeftTime || invalidRightTime {
if invalidLeftTime {
err = handleInvalidTimeError(b.ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, leftTime.String()))
}
if invalidRightTime {
err = handleInvalidTimeError(b.ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, rightTime.String()))
}
return types.ZeroTime, true, err
}
timeZone := getTimeZone(b.ctx)
minTime, err := leftTime.GoTime(timeZone)
if err != nil {
return types.ZeroTime, true, err
}
maxTime, err := rightTime.GoTime(timeZone)
if err != nil {
return types.ZeroTime, true, err
}
if minTime.After(maxTime) {
return types.ZeroTime, true, nil
}
// Because the minimum unit of a TSO is millisecond, so we only need fsp to be 3.
return types.NewTime(types.FromGoTime(calAppropriateTime(minTime, maxTime, getMinSafeTime(b.ctx, timeZone))), mysql.TypeDatetime, 3), false, nil
}

func getMinSafeTime(sessionCtx sessionctx.Context, timeZone *time.Location) time.Time {
var minSafeTS uint64
if store := sessionCtx.GetStore(); store != nil {
minSafeTS = store.GetMinSafeTS(sessionCtx.GetSessionVars().CheckAndGetTxnScope())
}
// Inject mocked SafeTS for test.
failpoint.Inject("injectSafeTS", func(val failpoint.Value) {
injectTS := val.(int)
minSafeTS = uint64(injectTS)
})
// Try to get from the stmt cache to make sure this function is deterministic.
stmtCtx := sessionCtx.GetSessionVars().StmtCtx
minSafeTS = stmtCtx.GetOrStoreStmtCache(stmtctx.StmtSafeTSCacheKey, minSafeTS).(uint64)
return oracle.GetTimeFromTS(minSafeTS).In(timeZone)
}

// For a SafeTS t and a time range [t1, t2]:
// 1. If t < t1, we will use t1 as the result,
// and with it, a read request may fail because it's an unreached SafeTS.
// 2. If t1 <= t <= t2, we will use t as the result, and with it,
// a read request won't fail.
// 2. If t2 < t, we will use t2 as the result,
// and with it, a read request won't fail because it's bigger than the latest SafeTS.
func calAppropriateTime(minTime, maxTime, minSafeTime time.Time) time.Time {
if minSafeTime.Before(minTime) {
return minTime
} else if minSafeTime.After(maxTime) {
return maxTime
}
return minSafeTime
}
102 changes: 101 additions & 1 deletion expression/builtin_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@
package expression

import (
"fmt"
"math"
"strings"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/store/tikv/oracle"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/mock"
Expand Down Expand Up @@ -804,7 +807,7 @@ func (s *testEvaluatorSuite) TestTime(c *C) {
}

func resetStmtContext(ctx sessionctx.Context) {
ctx.GetSessionVars().StmtCtx.ResetNowTs()
ctx.GetSessionVars().StmtCtx.ResetStmtCache()
}

func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) {
Expand Down Expand Up @@ -2854,6 +2857,103 @@ func (s *testEvaluatorSuite) TestTidbParseTso(c *C) {
}
}

func (s *testEvaluatorSuite) TestTiDBBoundedStaleness(c *C) {
t1, err := time.Parse(types.TimeFormat, "2015-09-21 09:53:04")
c.Assert(err, IsNil)
// time.Parse uses UTC time zone by default, we need to change it to Local manually.
t1 = t1.Local()
t1Str := t1.Format(types.TimeFormat)
t2 := time.Now()
t2Str := t2.Format(types.TimeFormat)
timeZone := time.Local
s.ctx.GetSessionVars().TimeZone = timeZone
tests := []struct {
leftTime interface{}
rightTime interface{}
injectSafeTS uint64
isNull bool
expect time.Time
}{
// SafeTS is in the range.
{
leftTime: t1Str,
rightTime: t2Str,
injectSafeTS: oracle.GoTimeToTS(t2.Add(-1 * time.Second)),
isNull: false,
expect: t2.Add(-1 * time.Second),
},
// SafeTS is less than the left time.
{
leftTime: t1Str,
rightTime: t2Str,
injectSafeTS: oracle.GoTimeToTS(t1.Add(-1 * time.Second)),
isNull: false,
expect: t1,
},
// SafeTS is bigger than the right time.
{
leftTime: t1Str,
rightTime: t2Str,
injectSafeTS: oracle.GoTimeToTS(t2.Add(time.Second)),
isNull: false,
expect: t2,
},
// Wrong time order.
{
leftTime: t2Str,
rightTime: t1Str,
injectSafeTS: 0,
isNull: true,
expect: time.Time{},
},
}

fc := funcs[ast.TiDBBoundedStaleness]
for _, test := range tests {
c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/injectSafeTS",
fmt.Sprintf("return(%v)", test.injectSafeTS)), IsNil)
f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewDatum(test.leftTime), types.NewDatum(test.rightTime)}))
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, IsNil)
if test.isNull {
c.Assert(d.IsNull(), IsTrue)
} else {
goTime, err := d.GetMysqlTime().GoTime(timeZone)
c.Assert(err, IsNil)
c.Assert(goTime.Format(types.TimeFormat), Equals, test.expect.Format(types.TimeFormat))
}
resetStmtContext(s.ctx)
}

// Test whether it's deterministic.
safeTime1 := t2.Add(-1 * time.Second)
safeTS1 := oracle.ComposeTS(safeTime1.Unix()*1000, 0)
c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/injectSafeTS",
fmt.Sprintf("return(%v)", safeTS1)), IsNil)
f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewDatum(t1Str), types.NewDatum(t2Str)}))
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, IsNil)
goTime, err := d.GetMysqlTime().GoTime(timeZone)
c.Assert(err, IsNil)
resultTime := goTime.Format(types.TimeFormat)
c.Assert(resultTime, Equals, safeTime1.Format(types.TimeFormat))
// SafeTS updated.
safeTime2 := t2.Add(1 * time.Second)
safeTS2 := oracle.ComposeTS(safeTime2.Unix()*1000, 0)
c.Assert(failpoint.Enable("github.com/pingcap/tidb/expression/injectSafeTS",
fmt.Sprintf("return(%v)", safeTS2)), IsNil)
f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewDatum(t1Str), types.NewDatum(t2Str)}))
c.Assert(err, IsNil)
d, err = evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, IsNil)
// Still safeTime1
c.Assert(resultTime, Equals, safeTime1.Format(types.TimeFormat))
resetStmtContext(s.ctx)
failpoint.Disable("github.com/pingcap/tidb/expression/injectSafeTS")
}

func (s *testEvaluatorSuite) TestGetIntervalFromDecimal(c *C) {
du := baseDateArithmitical{}

Expand Down
64 changes: 64 additions & 0 deletions expression/builtin_time_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,70 @@ func (b *builtinTidbParseTsoSig) vecEvalTime(input *chunk.Chunk, result *chunk.C
return nil
}

func (b *builtinTiDBBoundedStalenessSig) vectorized() bool {
return true
}

func (b *builtinTiDBBoundedStalenessSig) vecEvalTime(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
buf0, err := b.bufAllocator.get(types.ETDatetime, n)
if err != nil {
return err
}
defer b.bufAllocator.put(buf0)
if err = b.args[0].VecEvalTime(b.ctx, input, buf0); err != nil {
return err
}
buf1, err := b.bufAllocator.get(types.ETDatetime, n)
if err != nil {
return err
}
defer b.bufAllocator.put(buf1)
if err = b.args[1].VecEvalTime(b.ctx, input, buf1); err != nil {
return err
}
args0 := buf0.Times()
args1 := buf1.Times()
timeZone := getTimeZone(b.ctx)
minSafeTime := getMinSafeTime(b.ctx, timeZone)
result.ResizeTime(n, false)
result.MergeNulls(buf0, buf1)
times := result.Times()
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
}
if invalidArg0, invalidArg1 := args0[i].InvalidZero(), args1[i].InvalidZero(); invalidArg0 || invalidArg1 {
if invalidArg0 {
err = handleInvalidTimeError(b.ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, args0[i].String()))
}
if invalidArg1 {
err = handleInvalidTimeError(b.ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, args1[i].String()))
}
if err != nil {
return err
}
result.SetNull(i, true)
continue
}
minTime, err := args0[i].GoTime(timeZone)
if err != nil {
return err
}
maxTime, err := args1[i].GoTime(timeZone)
if err != nil {
return err
}
if minTime.After(maxTime) {
result.SetNull(i, true)
continue
}
// Because the minimum unit of a TSO is millisecond, so we only need fsp to be 3.
times[i] = types.NewTime(types.FromGoTime(calAppropriateTime(minTime, maxTime, minSafeTime)), mysql.TypeDatetime, 3)
}
return nil
}

func (b *builtinFromDaysSig) vectorized() bool {
return true
}
Expand Down
7 changes: 7 additions & 0 deletions expression/builtin_time_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,13 @@ var vecBuiltinTimeCases = map[string][]vecExprBenchCase{
geners: []dataGenerator{newRangeInt64Gener(0, math.MaxInt64)},
},
},
// Todo: how to inject the safeTS for better testing.
ast.TiDBBoundedStaleness: {
{
retEvalType: types.ETDatetime,
childrenTypes: []types.EvalType{types.ETDatetime, types.ETDatetime},
},
},
ast.LastDay: {
{retEvalType: types.ETDatetime, childrenTypes: []types.EvalType{types.ETDatetime}},
},
Expand Down
3 changes: 2 additions & 1 deletion expression/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
driver "github.com/pingcap/tidb/types/parser_driver"
Expand Down Expand Up @@ -155,5 +156,5 @@ func getStmtTimestamp(ctx sessionctx.Context) (time.Time, error) {
return time.Unix(timestamp, 0), nil
}
stmtCtx := ctx.GetSessionVars().StmtCtx
return stmtCtx.GetNowTsCached(), nil
return stmtCtx.GetOrStoreStmtCache(stmtctx.StmtNowTsCacheKey, time.Now()).(time.Time), nil
}
Loading

0 comments on commit e9488ce

Please sign in to comment.