From bd01434de02815e1322c09dbefc75947df8fff23 Mon Sep 17 00:00:00 2001 From: Jian Zhang Date: Fri, 22 Sep 2017 13:52:10 +0800 Subject: [PATCH] expression: remove the usage of "TypeClass" in "builtin_arithmetic.go" (#4575) --- expression/builtin_arithmetic.go | 79 ++++++++++++++++---------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/expression/builtin_arithmetic.go b/expression/builtin_arithmetic.go index b9ca84b547ead..f68e3bce9e0bb 100644 --- a/expression/builtin_arithmetic.go +++ b/expression/builtin_arithmetic.go @@ -58,19 +58,20 @@ var ( // performed with the / operator. const precIncrement = 4 -// numericContextResultType returns TypeClass for numeric function's parameters. -// the returned TypeClass should be one of: ClassInt, ClassDecimal, ClassReal -func numericContextResultType(ft *types.FieldType) types.TypeClass { +// numericContextResultType returns evalTp for numeric function's parameters. +// the returned evalTp should be one of: tpInt, tpDecimal, tpReal +func numericContextResultType(ft *types.FieldType) evalTp { if types.IsTypeTemporal(ft.Tp) { if ft.Decimal > 0 { - return types.ClassDecimal + return tpDecimal } - return types.ClassInt + return tpInt } - if ft.ToClass() == types.ClassString { - return types.ClassReal + evalTp4Ft := fieldTp2EvalTp(ft) + if evalTp4Ft != tpDecimal && evalTp4Ft != tpInt { + evalTp4Ft = tpReal } - return ft.ToClass() + return evalTp4Ft } // setFlenDecimal4Int is called to set proper `Flen` and `Decimal` of return @@ -137,15 +138,15 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx context.Context, args []Ex if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - tpA, tpB := args[0].GetType(), args[1].GetType() - tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB) - if tcA == types.ClassReal || tcB == types.ClassReal { + lhsTp, rhsTp := args[0].GetType(), args[1].GetType() + lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) + if lhsEvalTp == tpReal || rhsEvalTp == tpReal { bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal) setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true) sig := &builtinArithmeticPlusRealSig{baseRealBuiltinFunc{bf}} sig.setPbCode(tipb.ScalarFuncSig_PlusReal) return sig.setSelf(sig), nil - } else if tcA == types.ClassDecimal || tcB == types.ClassDecimal { + } else if lhsEvalTp == tpDecimal || rhsEvalTp == tpDecimal { bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal) setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false) sig := &builtinArithmeticPlusDecimalSig{baseDecimalBuiltinFunc{bf}} @@ -261,15 +262,15 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx context.Context, args []E if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - tpA, tpB := args[0].GetType(), args[1].GetType() - tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB) - if tcA == types.ClassReal || tcB == types.ClassReal { + lhsTp, rhsTp := args[0].GetType(), args[1].GetType() + lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) + if lhsEvalTp == tpReal || rhsEvalTp == tpReal { bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal) setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true) sig := &builtinArithmeticMinusRealSig{baseRealBuiltinFunc{bf}} sig.setPbCode(tipb.ScalarFuncSig_MinusReal) return sig.setSelf(sig), nil - } else if tcA == types.ClassDecimal || tcB == types.ClassDecimal { + } else if lhsEvalTp == tpDecimal || rhsEvalTp == tpDecimal { bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal) setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false) sig := &builtinArithmeticMinusDecimalSig{baseDecimalBuiltinFunc{bf}} @@ -382,15 +383,15 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx context.Context, args if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - tpA, tpB := args[0].GetType(), args[1].GetType() - tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB) - if tcA == types.ClassReal || tcB == types.ClassReal { + lhsTp, rhsTp := args[0].GetType(), args[1].GetType() + lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) + if lhsEvalTp == tpReal || rhsEvalTp == tpReal { bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal) setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true) sig := &builtinArithmeticMultiplyRealSig{baseRealBuiltinFunc{bf}} sig.setPbCode(tipb.ScalarFuncSig_MultiplyReal) return sig.setSelf(sig), nil - } else if tcA == types.ClassDecimal || tcB == types.ClassDecimal { + } else if lhsEvalTp == tpDecimal || rhsEvalTp == tpDecimal { bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal) setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false) sig := &builtinArithmeticMultiplyDecimalSig{baseDecimalBuiltinFunc{bf}} @@ -398,7 +399,7 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx context.Context, args return sig.setSelf(sig), nil } else { bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpInt, tpInt) - if mysql.HasUnsignedFlag(tpA.Flag) || mysql.HasUnsignedFlag(tpB.Flag) { + if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) { bf.tp.Flag |= mysql.UnsignedFlag setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType()) sig := &builtinArithmeticMultiplyIntUnsignedSig{baseIntBuiltinFunc{bf}} @@ -496,16 +497,16 @@ func (c *arithmeticDivideFunctionClass) getFunction(ctx context.Context, args [] if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - tpA, tpB := args[0].GetType(), args[1].GetType() - tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB) - if tcA == types.ClassReal || tcB == types.ClassReal { + lhsTp, rhsTp := args[0].GetType(), args[1].GetType() + lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) + if lhsEvalTp == tpReal || rhsEvalTp == tpReal { bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal) c.setType4DivReal(bf.tp) sig := &builtinArithmeticDivideRealSig{baseRealBuiltinFunc{bf}} return sig.setSelf(sig), nil } bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal) - c.setType4DivDecimal(bf.tp, tpA, tpB) + c.setType4DivDecimal(bf.tp, lhsTp, rhsTp) sig := &builtinArithmeticDivideDecimalSig{baseDecimalBuiltinFunc{bf}} return sig.setSelf(sig), nil } @@ -562,18 +563,18 @@ func (c *arithmeticIntDivideFunctionClass) getFunction(ctx context.Context, args return nil, errors.Trace(err) } - tpA, tpB := args[0].GetType(), args[1].GetType() - tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB) - if tcA == types.ClassInt && tcB == types.ClassInt { + lhsTp, rhsTp := args[0].GetType(), args[1].GetType() + lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) + if lhsEvalTp == tpInt && rhsEvalTp == tpInt { bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpInt, tpInt) - if mysql.HasUnsignedFlag(tpA.Flag) || mysql.HasUnsignedFlag(tpB.Flag) { + if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) { bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticIntDivideIntSig{baseIntBuiltinFunc{bf}} return sig.setSelf(sig), nil } bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpDecimal, tpDecimal) - if mysql.HasUnsignedFlag(tpA.Flag) || mysql.HasUnsignedFlag(tpB.Flag) { + if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) { bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticIntDivideDecimalSig{baseIntBuiltinFunc{bf}} @@ -681,27 +682,27 @@ func (c *arithmeticModFunctionClass) getFunction(ctx context.Context, args []Exp if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - tpA, tpB := args[0].GetType(), args[1].GetType() - tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB) - if tcA == types.ClassReal || tcB == types.ClassReal { + lhsTp, rhsTp := args[0].GetType(), args[1].GetType() + lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) + if lhsEvalTp == tpReal || rhsEvalTp == tpReal { bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal) - c.setType4ModRealOrDecimal(bf.tp, tpA, tpB, false) - if mysql.HasUnsignedFlag(tpA.Flag) { + c.setType4ModRealOrDecimal(bf.tp, lhsTp, rhsTp, false) + if mysql.HasUnsignedFlag(lhsTp.Flag) { bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticModRealSig{baseRealBuiltinFunc{bf}} return sig.setSelf(sig), nil - } else if tcA == types.ClassDecimal || tcB == types.ClassDecimal { + } else if lhsEvalTp == tpDecimal || rhsEvalTp == tpDecimal { bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal) - c.setType4ModRealOrDecimal(bf.tp, tpA, tpB, true) - if mysql.HasUnsignedFlag(tpA.Flag) { + c.setType4ModRealOrDecimal(bf.tp, lhsTp, rhsTp, true) + if mysql.HasUnsignedFlag(lhsTp.Flag) { bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticModDecimalSig{baseDecimalBuiltinFunc{bf}} return sig.setSelf(sig), nil } else { bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpInt, tpInt) - if mysql.HasUnsignedFlag(tpA.Flag) { + if mysql.HasUnsignedFlag(lhsTp.Flag) { bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticModIntSig{baseIntBuiltinFunc{bf}}