From d4c7a7a3642a74ad40093c96c4bf45a62a470605 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Tue, 21 Jul 2015 15:47:40 -0700 Subject: [PATCH 1/2] [SPARK-9154] [SQL] codegen StringFormat Jira: https://issues.apache.org/jira/browse/SPARK-9154 fixes bug of #7546 marmbrus I can't reopen the other PR, because I didn't closed it. Can you trigger Jenkins? Author: Tarek Auel Closes #7571 from tarekauel/SPARK-9154 and squashes the following commits: dcae272 [Tarek Auel] [SPARK-9154][SQL] build fix 1487602 [Tarek Auel] Merge remote-tracking branch 'upstream/master' into SPARK-9154 f512c5f [Tarek Auel] [SPARK-9154][SQL] build fix a943d3e [Tarek Auel] [SPARK-9154] implicit input cast, added tests for null, support for null primitives 10b4de8 [Tarek Auel] [SPARK-9154][SQL] codegen removed fallback trait cd8322b [Tarek Auel] [SPARK-9154][SQL] codegen string format 086caba [Tarek Auel] [SPARK-9154][SQL] codegen string format --- .../expressions/stringOperations.scala | 42 ++++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 18 ++++---- .../org/apache/spark/sql/functions.scala | 11 +++++ .../spark/sql/StringFunctionsSuite.scala | 10 +++++ 4 files changed, 70 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index fe57d17f1ec14..1f18a6e9ff8a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with CodegenFallback { +case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, "printf() should take at least 1 argument") @@ -536,6 +536,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa private def format: Expression = children(0) private def args: Seq[Expression] = children.tail + override def inputTypes: Seq[AbstractDataType] = + StringType :: List.fill(children.size - 1)(AnyDataType) + + override def eval(input: InternalRow): Any = { val pattern = format.eval(input) if (pattern == null) { @@ -551,6 +555,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val pattern = children.head.gen(ctx) + + val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) + val argListCode = argListGen.map(_._2.code + "\n") + + val argListString = argListGen.foldLeft("")((s, v) => { + val nullSafeString = + if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + // Java primitives get boxed in order to allow null values. + s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + + s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" + } else { + s"(${v._2.isNull}) ? null : ${v._2.primitive}" + } + s + "," + nullSafeString + }) + + val form = ctx.freshName("formatter") + val formatter = classOf[java.util.Formatter].getName + val sb = ctx.freshName("sb") + val stringBuffer = classOf[StringBuffer].getName + s""" + ${pattern.code} + boolean ${ev.isNull} = ${pattern.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${argListCode.mkString} + $stringBuffer $sb = new $stringBuffer(); + $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); + $form.format(${pattern.primitive}.toString() $argListString); + ${ev.primitive} = UTF8String.fromString($sb.toString()); + } + """ + } + override def prettyName: String = "printf" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 96c540ab36f08..3c2d88731beb4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - val f = 'f.string.at(0) - val d1 = 'd.int.at(1) - val s1 = 's.int.at(2) - - val row1 = create_row("aa%d%s", 12, "cc") - val row2 = create_row(null, 12, "cc") - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") - checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) - checkEvaluation(StringFormat(f, d1, s1), null, row2) + checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation( + StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + checkEvaluation( + StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d94d7335828c5..e5ff8ae7e3179 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1741,6 +1741,17 @@ object functions { */ def rtrim(e: Column): Column = StringTrimRight(e.expr) + /** + * Format strings in printf-style. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: Column, arguments: Column*): Column = { + StringFormat((format +: arguments).map(_.expr): _*) + } + /** * Format strings in printf-style. * NOTE: `format` is the string value of the formatter, not column name. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index d1f855903ca4b..3702e73b4e74f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -132,6 +132,16 @@ class StringFunctionsSuite extends QueryTest { checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) + + val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df2.selectExpr("printf(a, b, c)"), + Row("aa123cc")) } test("string instr function") { From a4c83cb1e4b066cd60264b6572fd3e51d160d26a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 21 Jul 2015 19:14:07 -0700 Subject: [PATCH 2/2] [SPARK-9154][SQL] Rename formatString to format_string. Also make format_string the canonical form, rather than printf. Author: Reynold Xin Closes #7579 from rxin/format_strings and squashes the following commits: 53ee54f [Reynold Xin] Fixed unit tests. 52357e1 [Reynold Xin] Add format_string alias. b40a42a [Reynold Xin] [SPARK-9154][SQL] Rename formatString to format_string. --- .../catalyst/analysis/FunctionRegistry.scala | 3 ++- .../expressions/stringOperations.scala | 13 +++++-------- .../expressions/StringExpressionsSuite.scala | 14 +++++++------- .../scala/org/apache/spark/sql/functions.scala | 18 +++--------------- .../spark/sql/StringFunctionsSuite.scala | 12 +----------- 5 files changed, 18 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e3d8d2adf2135..9c349838c28a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,7 +168,8 @@ object FunctionRegistry { expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), - expression[StringFormat]("printf"), + expression[FormatString]("format_string"), + expression[FormatString]("printf"), expression[StringRPad]("rpad"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 1f18a6e9ff8a5..cf187ad5a0a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -526,29 +526,26 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { +case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { - require(children.nonEmpty, "printf() should take at least 1 argument") + require(children.nonEmpty, "format_string() should take at least 1 argument") override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children(0).nullable override def dataType: DataType = StringType - private def format: Expression = children(0) - private def args: Seq[Expression] = children.tail override def inputTypes: Seq[AbstractDataType] = StringType :: List.fill(children.size - 1)(AnyDataType) - override def eval(input: InternalRow): Any = { - val pattern = format.eval(input) + val pattern = children(0).eval(input) if (pattern == null) { null } else { val sb = new StringBuffer() val formatter = new java.util.Formatter(sb, Locale.US) - val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) + val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef]) formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*) UTF8String.fromString(sb.toString) @@ -591,7 +588,7 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC """ } - override def prettyName: String = "printf" + override def prettyName: String = "format_string" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3c2d88731beb4..3d294fda5d103 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,16 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa")), "aa", create_row(null)) + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa%d%s"), 12, "cc"), "aa12cc") - checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation(FormatString(Literal.create(null, StringType), 12, "cc"), null) checkEvaluation( - StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + FormatString(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") checkEvaluation( - StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") + FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e5ff8ae7e3179..28159cbd5ab96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1742,26 +1742,14 @@ object functions { def rtrim(e: Column): Column = StringTrimRight(e.expr) /** - * Format strings in printf-style. + * Formats the arguments in printf-style and returns the result as a string column. * * @group string_funcs * @since 1.5.0 */ @scala.annotation.varargs - def formatString(format: Column, arguments: Column*): Column = { - StringFormat((format +: arguments).map(_.expr): _*) - } - - /** - * Format strings in printf-style. - * NOTE: `format` is the string value of the formatter, not column name. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def formatString(format: String, arguNames: String*): Column = { - StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) + def format_string(format: String, arguments: Column*): Column = { + FormatString((lit(format) +: arguments).map(_.expr): _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3702e73b4e74f..0f9c986f649a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -126,22 +126,12 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") checkAnswer( - df.select(formatString("aa%d%s", "b", "c")), + df.select(format_string("aa%d%s", $"b", $"c")), Row("aa123cc")) checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) - - val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") - - checkAnswer( - df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) - - checkAnswer( - df2.selectExpr("printf(a, b, c)"), - Row("aa123cc")) } test("string instr function") {