Skip to content

Commit

Permalink
[SPARK-9154][SQL] Rename formatString to format_string.
Browse files Browse the repository at this point in the history
Also make format_string the canonical form, rather than printf.
  • Loading branch information
rxin committed Jul 21, 2015
1 parent d4c7a7a commit b40a42a
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ object FunctionRegistry {
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
expression[StringFormat]("printf"),
expression[FormatString]("printf"),
expression[StringRPad]("rpad"),
expression[StringRepeat]("repeat"),
expression[StringReverse]("reverse"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,29 +526,26 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {
case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes {

require(children.nonEmpty, "printf() should take at least 1 argument")
require(children.nonEmpty, "format_string() should take at least 1 argument")

override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = children(0).nullable
override def dataType: DataType = StringType
private def format: Expression = children(0)
private def args: Seq[Expression] = children.tail

override def inputTypes: Seq[AbstractDataType] =
StringType :: List.fill(children.size - 1)(AnyDataType)


override def eval(input: InternalRow): Any = {
val pattern = format.eval(input)
val pattern = children(0).eval(input)
if (pattern == null) {
null
} else {
val sb = new StringBuffer()
val formatter = new java.util.Formatter(sb, Locale.US)

val arglist = args.map(_.eval(input).asInstanceOf[AnyRef])
val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef])
formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*)

UTF8String.fromString(sb.toString)
Expand Down Expand Up @@ -591,7 +588,7 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC
"""
}

override def prettyName: String = "printf"
override def prettyName: String = "format_string"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("FORMAT") {
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")
checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(FormatString(Literal("aa")), "aa", create_row(null))
checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(FormatString(Literal("aa%d%s"), 12, "cc"), "aa12cc")

checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
checkEvaluation(FormatString(Literal.create(null, StringType), 12, "cc"), null)
checkEvaluation(
StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
FormatString(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
checkEvaluation(
StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
}

test("INSTR") {
Expand Down
18 changes: 3 additions & 15 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1742,26 +1742,14 @@ object functions {
def rtrim(e: Column): Column = StringTrimRight(e.expr)

/**
* Format strings in printf-style.
* Formats the arguments in printf-style and returns the result as a string column.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def formatString(format: Column, arguments: Column*): Column = {
StringFormat((format +: arguments).map(_.expr): _*)
}

/**
* Format strings in printf-style.
* NOTE: `format` is the string value of the formatter, not column name.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def formatString(format: String, arguNames: String*): Column = {
StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*)
def format_string(format: String, arguments: Column*): Column = {
FormatString((lit(format) +: arguments).map(_.expr): _*)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,14 @@ class StringFunctionsSuite extends QueryTest {
test("string formatString function") {
val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c")

checkAnswer(
df.select(formatString("aa%d%s", "b", "c")),
Row("aa123cc"))

checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))

val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")

checkAnswer(
df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
df2.select(format_string("a", $"b", $"c"), format_string("aa%d%s", $"b", $"c")),
Row("aa123cc", "aa123cc"))

checkAnswer(
Expand Down

0 comments on commit b40a42a

Please sign in to comment.