Skip to content

Commit

Permalink
expression: use assertionEvalContext to check some constraints in r…
Browse files Browse the repository at this point in the history
…untime (#52141)

close #52140
  • Loading branch information
lcwangchao committed Apr 1, 2024
1 parent c3fa9ac commit fdcecc4
Show file tree
Hide file tree
Showing 16 changed files with 213 additions and 103 deletions.
2 changes: 1 addition & 1 deletion pkg/expression/builtin_arithmetic_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func TestVectorizedDecimalErrOverflow(t *testing.T) {
baseFunc, err := funcs[tt.funcName].getFunction(ctx, cols)
require.NoError(t, err)
result := chunk.NewColumn(eType2FieldType(types.ETDecimal), 1)
err = baseFunc.vecEvalDecimal(ctx, input, result)
err = vecEvalType(ctx, baseFunc, types.ETDecimal, input, result)
require.EqualError(t, err, tt.errStr)
}
}
Expand Down
85 changes: 49 additions & 36 deletions pkg/expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,11 @@ func TestCastFuncSig(t *testing.T) {
case 5:
sig = &builtinCastDecimalAsDecimalSig{decFunc}
}
res, isNull, err := sig.evalDecimal(ctx, c.row.ToRow())
require.Equal(t, false, isNull)
res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.Equal(t, 0, res.Compare(c.after))
require.False(t, res.IsNull())
require.Equal(t, types.KindMysqlDecimal, res.Kind())
require.Equal(t, 0, res.GetMysqlDecimal().Compare(c.after))
}

durationColumn.RetType.SetDecimal(1)
Expand Down Expand Up @@ -439,10 +440,11 @@ func TestCastFuncSig(t *testing.T) {
case 5:
sig = &builtinCastDecimalAsDecimalSig{decFunc}
}
res, isNull, err := sig.evalDecimal(ctx, c.row.ToRow())
require.Equal(t, false, isNull)
res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.Equal(t, c.after.ToString(), res.ToString())
require.False(t, res.IsNull())
require.Equal(t, types.KindMysqlDecimal, res.Kind())
require.Equal(t, c.after.ToString(), res.GetMysqlDecimal().ToString())
}

durationColumn.RetType.SetDecimal(0)
Expand Down Expand Up @@ -522,10 +524,11 @@ func TestCastFuncSig(t *testing.T) {
case 6:
sig = &builtinCastJSONAsIntSig{intFunc}
}
res, isNull, err := sig.evalInt(ctx, c.row.ToRow())
require.False(t, isNull)
res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.Equal(t, c.after, res)
require.False(t, res.IsNull())
require.Equal(t, types.KindInt64, res.Kind())
require.Equal(t, c.after, res.GetInt64())
}

// Test cast as real.
Expand Down Expand Up @@ -590,10 +593,12 @@ func TestCastFuncSig(t *testing.T) {
case 5:
sig = &builtinCastJSONAsRealSig{realFunc}
}
res, isNull, err := sig.evalReal(ctx, c.row.ToRow())
require.False(t, isNull)

res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.Equal(t, c.after, res)
require.False(t, res.IsNull())
require.Equal(t, types.KindFloat64, res.Kind())
require.Equal(t, c.after, res.GetFloat64())
}

// Test cast as string.
Expand Down Expand Up @@ -667,10 +672,11 @@ func TestCastFuncSig(t *testing.T) {
case 6:
sig = &builtinCastStringAsStringSig{stringFunc}
}
res, isNull, err := sig.evalString(ctx, c.row.ToRow())
require.False(t, isNull)
res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.Equal(t, c.after, res)
require.False(t, res.IsNull())
require.Equal(t, types.KindString, res.Kind())
require.Equal(t, c.after, res.GetString())
}

// Test cast as string.
Expand Down Expand Up @@ -754,10 +760,11 @@ func TestCastFuncSig(t *testing.T) {
case 6:
sig = &builtinCastJSONAsStringSig{stringFunc}
}
res, isNull, err := sig.evalString(ctx, c.row.ToRow())
require.False(t, isNull)
res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.Equal(t, c.after, res)
require.False(t, res.IsNull())
require.Equal(t, types.KindString, res.Kind())
require.Equal(t, c.after, res.GetString())
}

castToTimeCases := []struct {
Expand Down Expand Up @@ -830,10 +837,11 @@ func TestCastFuncSig(t *testing.T) {
case 6:
sig = &builtinCastTimeAsTimeSig{timeFunc}
}
res, isNull, err := sig.evalTime(ctx, c.row.ToRow())
res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.False(t, isNull)
require.Equal(t, c.after.String(), res.String())
require.False(t, res.IsNull())
require.Equal(t, types.KindMysqlTime, res.Kind())
require.Equal(t, c.after.String(), res.GetMysqlTime().String())
}

castToTimeCases2 := []struct {
Expand Down Expand Up @@ -912,17 +920,18 @@ func TestCastFuncSig(t *testing.T) {
case 5:
sig = &builtinCastTimeAsTimeSig{timeFunc}
}
res, isNull, err := sig.evalTime(ctx, c.row.ToRow())
require.Equal(t, false, isNull)
res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.False(t, res.IsNull())
require.Equal(t, types.KindMysqlTime, res.Kind())
resAfter := c.after.String()
if c.fsp > 0 {
resAfter += "."
for i := 0; i < c.fsp; i++ {
resAfter += "0"
}
}
require.Equal(t, resAfter, res.String())
require.Equal(t, resAfter, res.GetMysqlTime().String())
}

castToDurationCases := []struct {
Expand Down Expand Up @@ -995,10 +1004,12 @@ func TestCastFuncSig(t *testing.T) {
case 6:
sig = &builtinCastDurationAsDurationSig{durationFunc}
}
res, isNull, err := sig.evalDuration(ctx, c.row.ToRow())
require.False(t, isNull)

res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.Equal(t, c.after.String(), res.String())
require.False(t, res.IsNull())
require.Equal(t, types.KindMysqlDuration, res.Kind())
require.Equal(t, c.after.String(), res.GetMysqlDuration().String())
}

castToDurationCases2 := []struct {
Expand Down Expand Up @@ -1070,17 +1081,18 @@ func TestCastFuncSig(t *testing.T) {
case 5:
sig = &builtinCastDurationAsDurationSig{durationFunc}
}
res, isNull, err := sig.evalDuration(ctx, c.row.ToRow())
require.False(t, isNull)
res, err := evalBuiltinFunc(sig, ctx, c.row.ToRow())
require.NoError(t, err)
require.False(t, res.IsNull())
require.Equal(t, types.KindMysqlDuration, res.Kind())
resAfter := c.after.String()
if c.fsp > 0 {
resAfter += "."
for j := 0; j < c.fsp; j++ {
resAfter += "0"
}
}
require.Equal(t, resAfter, res.String())
require.Equal(t, resAfter, res.GetMysqlDuration().String())
}

// null case
Expand All @@ -1089,20 +1101,21 @@ func TestCastFuncSig(t *testing.T) {
bf, err := newBaseBuiltinFunc(ctx, "", args, types.NewFieldType(mysql.TypeVarString))
require.NoError(t, err)
sig = &builtinCastRealAsStringSig{bf}
sRes, isNull, err := sig.evalString(ctx, row.ToRow())
require.Equal(t, "", sRes)
require.Equal(t, true, isNull)
sRes, err := evalBuiltinFunc(sig, ctx, row.ToRow())
require.NoError(t, err)
require.True(t, sRes.IsNull())
require.Equal(t, "", sRes.GetString())

// test hybridType case.
args = []Expression{&Constant{Value: types.NewDatum(types.Enum{Name: "a", Value: 0}), RetType: types.NewFieldType(mysql.TypeEnum)}}
b, err := newBaseBuiltinFunc(ctx, "", args, types.NewFieldType(mysql.TypeLonglong))
require.NoError(t, err)
sig = &builtinCastStringAsIntSig{newBaseBuiltinCastFunc(b, false)}
iRes, isNull, err := sig.evalInt(ctx, chunk.Row{})
require.Equal(t, false, isNull)
iRes, err := evalBuiltinFunc(sig, ctx, row.ToRow())
require.NoError(t, err)
require.Equal(t, int64(0), iRes)
require.False(t, iRes.IsNull())
require.Equal(t, types.KindInt64, iRes.Kind())
require.Equal(t, int64(0), iRes.GetInt64())
}

func TestCastJSONAsDecimalSig(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions pkg/expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ func TestCompare(t *testing.T) {
args := bf.getArgs()
require.Equal(t, test.tp, args[0].GetType().GetType())
require.Equal(t, test.tp, args[1].GetType().GetType())
res, isNil, err := bf.evalInt(ctx, chunk.Row{})
res, err := evalBuiltinFunc(bf, ctx, chunk.Row{})
require.NoError(t, err)
require.False(t, isNil)
require.Equal(t, test.expected, res)
require.False(t, res.IsNull())
require.Equal(t, types.KindInt64, res.Kind())
require.Equal(t, test.expected, res.GetInt64())
}

// test <non-const decimal expression> <cmp> <const string expression>
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ type builtinSetValSig struct {

func (b *builtinSetValSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet {
return b.SequenceOperatorPropReader.RequiredOptionalEvalProps() |
b.SequenceOperatorPropReader.RequiredOptionalEvalProps()
b.SessionVarsPropReader.RequiredOptionalEvalProps()
}

func (b *builtinSetValSig) Clone() builtinFunc {
Expand Down
16 changes: 8 additions & 8 deletions pkg/expression/builtin_miscellaneous_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,21 @@ func TestSleepVectorized(t *testing.T) {
levels[errctx.ErrGroupBadNull] = errctx.LevelWarn
sessVars.StmtCtx.SetErrLevels(levels)
input.AppendFloat64(0, 1)
err = f.vecEvalInt(ctx, input, result)
err = vecEvalType(ctx, f, types.ETInt, input, result)
require.NoError(t, err)
require.Equal(t, int64(0), result.GetInt64(0))
require.Equal(t, uint16(warnCnt.add(0)), sessVars.StmtCtx.WarningCount())

input.Reset()
input.AppendFloat64(0, -1)
err = f.vecEvalInt(ctx, input, result)
err = vecEvalType(ctx, f, types.ETInt, input, result)
require.NoError(t, err)
require.Equal(t, int64(0), result.GetInt64(0))
require.Equal(t, uint16(warnCnt.add(1)), sessVars.StmtCtx.WarningCount())

input.Reset()
input.AppendNull(0)
err = f.vecEvalInt(ctx, input, result)
err = vecEvalType(ctx, f, types.ETInt, input, result)
require.NoError(t, err)
require.Equal(t, int64(0), result.GetInt64(0))
require.Equal(t, uint16(warnCnt.add(1)), sessVars.StmtCtx.WarningCount())
Expand All @@ -179,7 +179,7 @@ func TestSleepVectorized(t *testing.T) {
input.AppendNull(0)
input.AppendFloat64(0, 1)
input.AppendFloat64(0, -1)
err = f.vecEvalInt(ctx, input, result)
err = vecEvalType(ctx, f, types.ETInt, input, result)
require.NoError(t, err)
require.Equal(t, int64(0), result.GetInt64(0))
require.Equal(t, int64(0), result.GetInt64(1))
Expand All @@ -191,22 +191,22 @@ func TestSleepVectorized(t *testing.T) {
sessVars.StmtCtx.SetErrLevels(levels)
input.Reset()
input.AppendNull(0)
err = f.vecEvalInt(ctx, input, result)
err = vecEvalType(ctx, f, types.ETInt, input, result)
require.Error(t, err)
require.Equal(t, int64(0), result.GetInt64(0))

sessVars.StmtCtx.SetWarnings(nil)
input.Reset()
input.AppendFloat64(0, -2.5)
err = f.vecEvalInt(ctx, input, result)
err = vecEvalType(ctx, f, types.ETInt, input, result)
require.Error(t, err)
require.Equal(t, int64(0), result.GetInt64(0))

// strict model
input.Reset()
input.AppendFloat64(0, 0.5)
start := time.Now()
err = f.vecEvalInt(ctx, input, result)
err = vecEvalType(ctx, f, types.ETInt, input, result)
require.NoError(t, err)
require.Equal(t, int64(0), result.GetInt64(0))
sub := time.Since(start)
Expand All @@ -221,7 +221,7 @@ func TestSleepVectorized(t *testing.T) {
time.Sleep(1 * time.Second)
ctx.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted)
}()
err = f.vecEvalInt(ctx, input, result)
err = vecEvalType(ctx, f, types.ETInt, input, result)
sub = time.Since(start)
require.NoError(t, err)
require.Equal(t, int64(0), result.GetInt64(0))
Expand Down
12 changes: 6 additions & 6 deletions pkg/expression/builtin_op_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,25 +174,25 @@ func TestBuiltinUnaryMinusIntSig(t *testing.T) {

require.False(t, mysql.HasUnsignedFlag(col0.GetType().GetFlag()))
input.AppendInt64(0, 233333)
require.Nil(t, f.vecEvalInt(ctx, input, result))
require.NoError(t, vecEvalType(ctx, f, types.ETInt, input, result))
require.Equal(t, int64(-233333), result.GetInt64(0))
input.Reset()
input.AppendInt64(0, math.MinInt64)
require.NotNil(t, f.vecEvalInt(ctx, input, result))
require.Error(t, vecEvalType(ctx, f, types.ETInt, input, result))
input.Column(0).SetNull(0, true)
require.NoError(t, f.vecEvalInt(ctx, input, result))
require.NoError(t, vecEvalType(ctx, f, types.ETInt, input, result))
require.True(t, result.IsNull(0))

col0.GetType().AddFlag(mysql.UnsignedFlag)
require.True(t, mysql.HasUnsignedFlag(col0.GetType().GetFlag()))
input.Reset()
input.AppendUint64(0, 233333)
require.NoError(t, f.vecEvalInt(ctx, input, result))
require.NoError(t, vecEvalType(ctx, f, types.ETInt, input, result))
require.Equal(t, int64(-233333), result.GetInt64(0))
input.Reset()
input.AppendUint64(0, -(math.MinInt64)+1)
require.NotNil(t, f.vecEvalInt(ctx, input, result))
require.Error(t, vecEvalType(ctx, f, types.ETInt, input, result))
input.Column(0).SetNull(0, true)
require.NoError(t, f.vecEvalInt(ctx, input, result))
require.NoError(t, vecEvalType(ctx, f, types.ETInt, input, result))
require.True(t, result.IsNull(0))
}
9 changes: 5 additions & 4 deletions pkg/expression/builtin_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,15 @@ func TestInFunc(t *testing.T) {
strD2 := types.NewCollationStringDatum("Á", "utf8_general_ci")
fn, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{strD1, strD2}))
require.NoError(t, err)
d, isNull, err := fn.evalInt(ctx, chunk.Row{})
require.False(t, isNull)
d, err := evalBuiltinFunc(fn, ctx, chunk.Row{})
require.NoError(t, err)
require.Equalf(t, int64(1), d, "%v, %v", strD1, strD2)
require.False(t, d.IsNull())
require.Equal(t, types.KindInt64, d.Kind())
require.Equalf(t, int64(1), d.GetInt64(), "%v, %v", strD1, strD2)
chk1 := chunk.NewChunkWithCapacity(nil, 1)
chk1.SetNumVirtualRows(1)
chk2 := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeTiny)}, 1)
err = fn.vecEvalInt(ctx, chk1, chk2.Column(0))
err = vecEvalType(ctx, fn, types.ETInt, chk1, chk2.Column(0))
require.NoError(t, err)
require.Equal(t, int64(1), chk2.Column(0).GetInt64(0))
}
2 changes: 1 addition & 1 deletion pkg/expression/builtin_other_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestInDecimal(t *testing.T) {
require.NotEqual(t, input.Column(0).GetDecimal(i).GetDigitsFrac(), input.Column(1).GetDecimal(i).GetDigitsFrac())
}
result := chunk.NewColumn(ft, 1024)
require.Nil(t, inFunc.vecEvalInt(ctx, input, result))
require.NoError(t, vecEvalType(ctx, inFunc, types.ETInt, input, result))
for i := 0; i < 1024; i++ {
require.Equal(t, int64(1), result.GetInt64(0))
}
Expand Down
11 changes: 6 additions & 5 deletions pkg/expression/builtin_regexp_vec_const_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func genVecBuiltinRegexpBenchCaseForConstants(ctx BuildContext) (baseFunc builti
func TestVectorizedBuiltinRegexpForConstants(t *testing.T) {
ctx := mock.NewContext()
bf, childrenFieldTypes, input, output := genVecBuiltinRegexpBenchCaseForConstants(ctx)
err := bf.vecEvalInt(ctx, input, output)
err := vecEvalType(ctx, bf, types.ETInt, input, output)
require.NoError(t, err)
i64s := output.Int64s()

Expand All @@ -73,11 +73,12 @@ func TestVectorizedBuiltinRegexpForConstants(t *testing.T) {
return fmt.Sprintf("func: builtinRegexpUTF8Sig, row: %v, rowData: %v", row, input.GetRow(row).GetDatumRow(childrenFieldTypes))
}
for row := it.Begin(); row != it.End(); row = it.Next() {
val, isNull, err := bf.evalInt(ctx, row)
val, err := evalBuiltinFunc(bf, ctx, row)
require.NoError(t, err)
require.Equal(t, output.IsNull(i), isNull, commentf(i))
if !isNull {
require.Equal(t, i64s[i], val, commentf(i))
require.Equal(t, output.IsNull(i), val.IsNull(), commentf(i))
if !val.IsNull() {
require.Equal(t, types.KindInt64, val.Kind(), commentf(i))
require.Equal(t, i64s[i], val.GetInt64(), commentf(i))
}
i++
}
Expand Down

0 comments on commit fdcecc4

Please sign in to comment.