diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e3d8d2adf2135..5aa3b4e46b688 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,7 +168,7 @@ object FunctionRegistry { expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), - expression[StringFormat]("printf"), + expression[FormatString]("printf"), expression[StringRPad]("rpad"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 1f18a6e9ff8a5..cf187ad5a0a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -526,29 +526,26 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { +case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { - require(children.nonEmpty, "printf() should take at least 1 argument") + require(children.nonEmpty, "format_string() should take at least 1 argument") override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children(0).nullable override def dataType: DataType = StringType - private def format: Expression = children(0) - private def args: Seq[Expression] = children.tail override def inputTypes: Seq[AbstractDataType] = StringType :: List.fill(children.size - 1)(AnyDataType) - override def eval(input: InternalRow): Any = { - val pattern = format.eval(input) + val pattern = children(0).eval(input) if (pattern == null) { null } else { val sb = new StringBuffer() val formatter = new java.util.Formatter(sb, Locale.US) - val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) + val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef]) formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*) UTF8String.fromString(sb.toString) @@ -591,7 +588,7 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC """ } - override def prettyName: String = "printf" + override def prettyName: String = "format_string" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3c2d88731beb4..3d294fda5d103 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,16 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa")), "aa", create_row(null)) + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa%d%s"), 12, "cc"), "aa12cc") - checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation(FormatString(Literal.create(null, StringType), 12, "cc"), null) checkEvaluation( - StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + FormatString(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") checkEvaluation( - StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") + FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e5ff8ae7e3179..28159cbd5ab96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1742,26 +1742,14 @@ object functions { def rtrim(e: Column): Column = StringTrimRight(e.expr) /** - * Format strings in printf-style. + * Formats the arguments in printf-style and returns the result as a string column. * * @group string_funcs * @since 1.5.0 */ @scala.annotation.varargs - def formatString(format: Column, arguments: Column*): Column = { - StringFormat((format +: arguments).map(_.expr): _*) - } - - /** - * Format strings in printf-style. - * NOTE: `format` is the string value of the formatter, not column name. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def formatString(format: String, arguNames: String*): Column = { - StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) + def format_string(format: String, arguments: Column*): Column = { + FormatString((lit(format) +: arguments).map(_.expr): _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3702e73b4e74f..83ca766e1548c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -125,10 +125,6 @@ class StringFunctionsSuite extends QueryTest { test("string formatString function") { val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") - checkAnswer( - df.select(formatString("aa%d%s", "b", "c")), - Row("aa123cc")) - checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) @@ -136,7 +132,7 @@ class StringFunctionsSuite extends QueryTest { val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") checkAnswer( - df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + df2.select(format_string("a", $"b", $"c"), format_string("aa%d%s", $"b", $"c")), Row("aa123cc", "aa123cc")) checkAnswer(