Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: handle max_allowed_packet warnings for pad functions (#7171) #7244

Merged
merged 3 commits into from Aug 2, 2018
Merged
Changes from 1 commit
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

Next

expression: handle max_allowed_packet warnings for pad functions (#7171)

  • Loading branch information...
zz-jason committed Jul 31, 2018
commit bfb563227eb2b1c4d697e3d1dc42664dd951feda
@@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/hack"
@@ -1759,24 +1760,33 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1])
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
SetBinFlagOrBinStr(args[2].GetType(), bf.tp)

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
sig := &builtinLpadBinarySig{bf}
sig := &builtinLpadBinarySig{bf, maxAllowedPacket}
return sig, nil
}
if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinLpadSig{bf}
sig := &builtinLpadSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinLpadBinarySig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinLpadBinarySig) Clone() builtinFunc {
newSig := &builtinLpadBinarySig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

@@ -1795,6 +1805,11 @@ func (b *builtinLpadBinarySig) evalString(row types.Row) (string, bool, error) {
}
targetLength := int(length)

if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
@@ -1814,11 +1829,13 @@ func (b *builtinLpadBinarySig) evalString(row types.Row) (string, bool, error) {

type builtinLpadSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinLpadSig) Clone() builtinFunc {
newSig := &builtinLpadSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

@@ -1837,6 +1854,11 @@ func (b *builtinLpadSig) evalString(row types.Row) (string, bool, error) {
}
targetLength := int(length)

if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
@@ -1866,24 +1888,33 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1])
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
SetBinFlagOrBinStr(args[2].GetType(), bf.tp)

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
sig := &builtinRpadBinarySig{bf}
sig := &builtinRpadBinarySig{bf, maxAllowedPacket}
return sig, nil
}
if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinRpadSig{bf}
sig := &builtinRpadSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinRpadBinarySig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinRpadBinarySig) Clone() builtinFunc {
newSig := &builtinRpadBinarySig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

@@ -1901,6 +1932,10 @@ func (b *builtinRpadBinarySig) evalString(row types.Row) (string, bool, error) {
return "", true, errors.Trace(err)
}
targetLength := int(length)
if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
@@ -1921,11 +1956,13 @@ func (b *builtinRpadBinarySig) evalString(row types.Row) (string, bool, error) {

type builtinRpadSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinRpadSig) Clone() builtinFunc {
newSig := &builtinRpadSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

@@ -1944,6 +1981,11 @@ func (b *builtinRpadSig) evalString(row types.Row) (string, bool, error) {
}
targetLength := int(length)

if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
@@ -23,6 +23,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/mock"
@@ -1265,6 +1266,47 @@ func (s *testEvaluatorSuite) TestRpad(c *C) {
}
}

func (s *testEvaluatorSuite) TestRpadSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeLonglong},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
&Column{Index: 2, RetType: colTypes[2]},
}

base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
rpad := &builtinRpadSig{base, 1000}

input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, "abc")
input.AppendString(0, "abc")
input.AppendInt64(1, 6)
input.AppendInt64(1, 10000)
input.AppendString(2, "123")
input.AppendString(2, "123")

res, isNull, err := rpad.evalString(input.GetRow(0))
c.Assert(res, Equals, "abc123")
c.Assert(isNull, IsFalse)
c.Assert(err, IsNil)

res, isNull, err = rpad.evalString(input.GetRow(1))
c.Assert(res, Equals, "")
c.Assert(isNull, IsTrue)
c.Assert(err, IsNil)

warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, 1)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}

func (s *testEvaluatorSuite) TestInstr(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
@@ -22,15 +22,18 @@ import (

// Error instances.
var (
// All the exported errors are defined here:
ErrIncorrectParameterCount = terror.ClassExpression.New(mysql.ErrWrongParamcountToNativeFct, mysql.MySQLErrName[mysql.ErrWrongParamcountToNativeFct])
ErrDivisionByZero = terror.ClassExpression.New(mysql.ErrDivisionByZero, mysql.MySQLErrName[mysql.ErrDivisionByZero])

// All the un-exported errors are defined here:
errFunctionNotExists = terror.ClassExpression.New(mysql.ErrSpDoesNotExist, mysql.MySQLErrName[mysql.ErrSpDoesNotExist])
errZlibZData = terror.ClassTypes.New(mysql.ErrZlibZData, mysql.MySQLErrName[mysql.ErrZlibZData])
errIncorrectArgs = terror.ClassExpression.New(mysql.ErrWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments])
errUnknownCharacterSet = terror.ClassExpression.New(mysql.ErrUnknownCharacterSet, mysql.MySQLErrName[mysql.ErrUnknownCharacterSet])
errDefaultValue = terror.ClassExpression.New(mysql.ErrInvalidDefault, "invalid default value")
errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement])
errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed])
)

func init() {
@@ -43,6 +46,7 @@ func init() {
mysql.ErrUnknownCharacterSet: mysql.ErrUnknownCharacterSet,
mysql.ErrInvalidDefault: mysql.ErrInvalidDefault,
mysql.ErrWarnDeprecatedSyntaxNoReplacement: mysql.ErrWarnDeprecatedSyntaxNoReplacement,
mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed,
}
terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes
}
@@ -45,6 +45,7 @@ func (s *testEvaluatorSuite) SetUpSuite(c *C) {
s.Parser = parser.New()
s.ctx = mock.NewContext()
s.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local
s.ctx.GetSessionVars().SetSystemVar("max_allowed_packet", "67108864")
}

func (s *testEvaluatorSuite) TearDownSuite(c *C) {
@@ -55,6 +56,7 @@ func (s *testEvaluatorSuite) SetUpTest(c *C) {
}

func (s *testEvaluatorSuite) TearDownTest(c *C) {
s.ctx.GetSessionVars().StmtCtx.SetWarnings(nil)
testleak.AfterTest(c)()
}

@@ -316,7 +316,7 @@ var MySQLErrName = map[uint16]string{
ErrUnknownTimeZone: "Unknown or incorrect time zone: '%-.64s'",
ErrWarnInvalidTimestamp: "Invalid TIMESTAMP value in column '%s' at row %d",
ErrInvalidCharacterString: "Invalid %s character string: '%.64s'",
ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than maxAllowedPacket (%d) - truncated",
ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than max_allowed_packet (%d) - truncated",
ErrConflictingDeclarations: "Conflicting declarations: '%s%s' and '%s%s'",
ErrSpNoRecursiveCreate: "Can't create a %s from within another stored routine",
ErrSpAlreadyExists: "%s %s already exists",
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.