Skip to content

Commit

Permalink
[SPARK-9154] implicit input cast, added tests for null, support for n…
Browse files Browse the repository at this point in the history
…ull primitives
  • Loading branch information
tarekbecker committed Jul 21, 2015
1 parent 10b4de8 commit a943d3e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
case class StringFormat(children: Expression*) extends Expression {
case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {

require(children.nonEmpty, "printf() should take at least 1 argument")

Expand All @@ -486,6 +486,10 @@ case class StringFormat(children: Expression*) extends Expression {
private def format: Expression = children(0)
private def args: Seq[Expression] = children.tail

override def inputTypes: Seq[AbstractDataType] =
children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType)


override def eval(input: InternalRow): Any = {
val pattern = format.eval(input)
if (pattern == null) {
Expand All @@ -504,15 +508,25 @@ case class StringFormat(children: Expression*) extends Expression {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val pattern = children.head.gen(ctx)

val argListGen = children.tail.map(_.gen(ctx))
val argListCode = argListGen.map(_.code + "\n")
val argListString = argListGen.foldLeft("")((s, v) => s + s", ${v.primitive}")
val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
val argListCode = argListGen.map(_._2.code + "\n")

val argListString = argListGen.foldLeft("")((s, v) => {
val nullSafeString =
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
// Java primitives get boxed in order to allow null values.
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
} else {
s"(${v._2.isNull}) ? null : ${v._2.primitive}"
}
s + "," + nullSafeString
})

val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName

s"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("FORMAT") {
val f = 'f.string.at(0)
val d1 = 'd.int.at(1)
val s1 = 's.string.at(2)

val row1 = create_row("aa%d%s", 12, "cc")
val row2 = create_row(null, 12, "cc")
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")

checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
checkEvaluation(StringFormat(f, d1, s1), null, row2)
checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
checkEvaluation(
StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
checkEvaluation(
StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
}

test("INSTR") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))

val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")

checkAnswer(
df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
Row("aa123cc", "aa123cc"))

checkAnswer(
df2.selectExpr("printf(a, b, c)"),
Row("aa123cc"))
}

test("string instr function") {
Expand Down

0 comments on commit a943d3e

Please sign in to comment.