From a943d3e60649f4267e40376c0bb1ff30ae024436 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 23:26:58 -0700 Subject: [PATCH] [SPARK-9154] implicit input cast, added tests for null, support for null primitives --- .../expressions/stringOperations.scala | 24 +++++++++++++++---- .../expressions/StringExpressionsSuite.scala | 18 +++++++------- .../spark/sql/StringFunctionsSuite.scala | 10 ++++++++ 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 81979ab5d2dce..08b17420d6cbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -476,7 +476,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression { +case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, "printf() should take at least 1 argument") @@ -486,6 +486,10 @@ case class StringFormat(children: Expression*) extends Expression { private def format: Expression = children(0) private def args: Seq[Expression] = children.tail + override def inputTypes: Seq[AbstractDataType] = + children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType) + + override def eval(input: InternalRow): Any = { val pattern = format.eval(input) if (pattern == null) { @@ -504,15 +508,25 @@ case class StringFormat(children: Expression*) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val pattern = children.head.gen(ctx) - val argListGen = children.tail.map(_.gen(ctx)) - val argListCode = argListGen.map(_.code + "\n") - val argListString = argListGen.foldLeft("")((s, v) => s + s", ${v.primitive}") + val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) + val argListCode = argListGen.map(_._2.code + "\n") + + val argListString = argListGen.foldLeft("")((s, v) => { + val nullSafeString = + if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + // Java primitives get boxed in order to allow null values. + s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + + s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" + } else { + s"(${v._2.isNull}) ? null : ${v._2.primitive}" + } + s + "," + nullSafeString + }) val form = ctx.freshName("formatter") val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 93bb538663cec..63d09fd6375cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - val f = 'f.string.at(0) - val d1 = 'd.int.at(1) - val s1 = 's.string.at(2) - - val row1 = create_row("aa%d%s", 12, "cc") - val row2 = create_row(null, 12, "cc") - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") - checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) - checkEvaluation(StringFormat(f, d1, s1), null, row2) + checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation( + StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + checkEvaluation( + StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index fe4de8d8b855f..274ec8f4675e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -120,6 +120,16 @@ class StringFunctionsSuite extends QueryTest { checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) + + val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df2.selectExpr("printf(a, b, c)"), + Row("aa123cc")) } test("string instr function") {