diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 9c38dd2ee4e53..fd2ac78b25dbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -80,7 +80,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions - case q => q.transformExpressions( + case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 66d9b4c8e351f..f98d5c0e52eaa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -92,8 +92,14 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { checkType(Average(d1), DecimalType(6, 5)) checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(d1, d1), d1), DecimalType(4, 1)) + checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1)) checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2)) + checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4)) + checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6)) + checkType(Sum(Add(d1, d1)), DecimalType(13, 1)) } test("Comparison operations") {