Skip to content

Commit

Permalink
expression: remove the usage of "TypeClass" in "builtin_arithmetic.go" (
Browse files Browse the repository at this point in the history
  • Loading branch information
zz-jason committed Sep 22, 2017
1 parent 3c02f6b commit bd01434
Showing 1 changed file with 40 additions and 39 deletions.
79 changes: 40 additions & 39 deletions expression/builtin_arithmetic.go
Expand Up @@ -58,19 +58,20 @@ var (
// performed with the / operator.
const precIncrement = 4

// numericContextResultType returns TypeClass for numeric function's parameters.
// the returned TypeClass should be one of: ClassInt, ClassDecimal, ClassReal
func numericContextResultType(ft *types.FieldType) types.TypeClass {
// numericContextResultType returns evalTp for numeric function's parameters.
// the returned evalTp should be one of: tpInt, tpDecimal, tpReal
func numericContextResultType(ft *types.FieldType) evalTp {
if types.IsTypeTemporal(ft.Tp) {
if ft.Decimal > 0 {
return types.ClassDecimal
return tpDecimal
}
return types.ClassInt
return tpInt
}
if ft.ToClass() == types.ClassString {
return types.ClassReal
evalTp4Ft := fieldTp2EvalTp(ft)
if evalTp4Ft != tpDecimal && evalTp4Ft != tpInt {
evalTp4Ft = tpReal
}
return ft.ToClass()
return evalTp4Ft
}

// setFlenDecimal4Int is called to set proper `Flen` and `Decimal` of return
Expand Down Expand Up @@ -137,15 +138,15 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx context.Context, args []Ex
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
tpA, tpB := args[0].GetType(), args[1].GetType()
tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB)
if tcA == types.ClassReal || tcB == types.ClassReal {
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == tpReal || rhsEvalTp == tpReal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true)
sig := &builtinArithmeticPlusRealSig{baseRealBuiltinFunc{bf}}
sig.setPbCode(tipb.ScalarFuncSig_PlusReal)
return sig.setSelf(sig), nil
} else if tcA == types.ClassDecimal || tcB == types.ClassDecimal {
} else if lhsEvalTp == tpDecimal || rhsEvalTp == tpDecimal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false)
sig := &builtinArithmeticPlusDecimalSig{baseDecimalBuiltinFunc{bf}}
Expand Down Expand Up @@ -261,15 +262,15 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx context.Context, args []E
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
tpA, tpB := args[0].GetType(), args[1].GetType()
tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB)
if tcA == types.ClassReal || tcB == types.ClassReal {
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == tpReal || rhsEvalTp == tpReal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true)
sig := &builtinArithmeticMinusRealSig{baseRealBuiltinFunc{bf}}
sig.setPbCode(tipb.ScalarFuncSig_MinusReal)
return sig.setSelf(sig), nil
} else if tcA == types.ClassDecimal || tcB == types.ClassDecimal {
} else if lhsEvalTp == tpDecimal || rhsEvalTp == tpDecimal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false)
sig := &builtinArithmeticMinusDecimalSig{baseDecimalBuiltinFunc{bf}}
Expand Down Expand Up @@ -382,23 +383,23 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx context.Context, args
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
tpA, tpB := args[0].GetType(), args[1].GetType()
tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB)
if tcA == types.ClassReal || tcB == types.ClassReal {
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == tpReal || rhsEvalTp == tpReal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true)
sig := &builtinArithmeticMultiplyRealSig{baseRealBuiltinFunc{bf}}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyReal)
return sig.setSelf(sig), nil
} else if tcA == types.ClassDecimal || tcB == types.ClassDecimal {
} else if lhsEvalTp == tpDecimal || rhsEvalTp == tpDecimal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal)
setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false)
sig := &builtinArithmeticMultiplyDecimalSig{baseDecimalBuiltinFunc{bf}}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyDecimal)
return sig.setSelf(sig), nil
} else {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpInt, tpInt)
if mysql.HasUnsignedFlag(tpA.Flag) || mysql.HasUnsignedFlag(tpB.Flag) {
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
setFlenDecimal4Int(bf.tp, args[0].GetType(), args[1].GetType())
sig := &builtinArithmeticMultiplyIntUnsignedSig{baseIntBuiltinFunc{bf}}
Expand Down Expand Up @@ -496,16 +497,16 @@ func (c *arithmeticDivideFunctionClass) getFunction(ctx context.Context, args []
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
tpA, tpB := args[0].GetType(), args[1].GetType()
tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB)
if tcA == types.ClassReal || tcB == types.ClassReal {
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == tpReal || rhsEvalTp == tpReal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal)
c.setType4DivReal(bf.tp)
sig := &builtinArithmeticDivideRealSig{baseRealBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}
bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal)
c.setType4DivDecimal(bf.tp, tpA, tpB)
c.setType4DivDecimal(bf.tp, lhsTp, rhsTp)
sig := &builtinArithmeticDivideDecimalSig{baseDecimalBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}
Expand Down Expand Up @@ -562,18 +563,18 @@ func (c *arithmeticIntDivideFunctionClass) getFunction(ctx context.Context, args
return nil, errors.Trace(err)
}

tpA, tpB := args[0].GetType(), args[1].GetType()
tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB)
if tcA == types.ClassInt && tcB == types.ClassInt {
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == tpInt && rhsEvalTp == tpInt {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpInt, tpInt)
if mysql.HasUnsignedFlag(tpA.Flag) || mysql.HasUnsignedFlag(tpB.Flag) {
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticIntDivideIntSig{baseIntBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}
bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpDecimal, tpDecimal)
if mysql.HasUnsignedFlag(tpA.Flag) || mysql.HasUnsignedFlag(tpB.Flag) {
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticIntDivideDecimalSig{baseIntBuiltinFunc{bf}}
Expand Down Expand Up @@ -681,27 +682,27 @@ func (c *arithmeticModFunctionClass) getFunction(ctx context.Context, args []Exp
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
tpA, tpB := args[0].GetType(), args[1].GetType()
tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB)
if tcA == types.ClassReal || tcB == types.ClassReal {
lhsTp, rhsTp := args[0].GetType(), args[1].GetType()
lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp)
if lhsEvalTp == tpReal || rhsEvalTp == tpReal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal)
c.setType4ModRealOrDecimal(bf.tp, tpA, tpB, false)
if mysql.HasUnsignedFlag(tpA.Flag) {
c.setType4ModRealOrDecimal(bf.tp, lhsTp, rhsTp, false)
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModRealSig{baseRealBuiltinFunc{bf}}
return sig.setSelf(sig), nil
} else if tcA == types.ClassDecimal || tcB == types.ClassDecimal {
} else if lhsEvalTp == tpDecimal || rhsEvalTp == tpDecimal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal)
c.setType4ModRealOrDecimal(bf.tp, tpA, tpB, true)
if mysql.HasUnsignedFlag(tpA.Flag) {
c.setType4ModRealOrDecimal(bf.tp, lhsTp, rhsTp, true)
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModDecimalSig{baseDecimalBuiltinFunc{bf}}
return sig.setSelf(sig), nil
} else {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpInt, tpInt)
if mysql.HasUnsignedFlag(tpA.Flag) {
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModIntSig{baseIntBuiltinFunc{bf}}
Expand Down

0 comments on commit bd01434

Please sign in to comment.