Skip to content

Commit

Permalink
fix substringIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jul 31, 2015
1 parent f2d29a1 commit 3ce7802
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 284 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ object FunctionRegistry {
expression[StringSplit]("split"),
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[Substring_index]("substring_index"),
expression[SubstringIndex]("substring_index"),
expression[StringTrim]("trim"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,52 +427,22 @@ case class StringInstr(str: Expression, substr: Expression)
* returned. If count is negative, every to the right of the final delimiter (counting from the
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
extends Expression with ImplicitCastInputTypes {
case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = StringType
override def foldable: Boolean = strExpr.foldable && delimExpr.foldable && countExpr.foldable
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable
override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
override def prettyName: String = "substring_index"

override def eval(input: InternalRow): Any = {
val str = strExpr.eval(input)
if (str != null) {
val delim = delimExpr.eval(input)
if (delim != null) {
val count = countExpr.eval(input)
if (count != null) {
return str.asInstanceOf[UTF8String].subStringIndex(
delim.asInstanceOf[UTF8String],
count.asInstanceOf[Int])
}
}
}
null
override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
str.asInstanceOf[UTF8String].subStringIndex(
delim.asInstanceOf[UTF8String],
count.asInstanceOf[Int])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val str = strExpr.gen(ctx)
val delim = delimExpr.gen(ctx)
val count = countExpr.gen(ctx)
val resultCode = s"${str.primitive}.subStringIndex(${delim.primitive}, ${count.primitive})"
s"""
${str.code}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${str.isNull}) {
${delim.code}
if (!${delim.isNull}) {
${count.code}
if (!${count.isNull}) {
${ev.isNull} = false;
${ev.primitive} = $resultCode;
}
}
}
"""
defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,32 +190,32 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("string substring_index function") {
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(0)), "")
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
checkEvaluation(
Substring_index(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
checkEvaluation(
Substring_index(Literal(""), Literal("."), Literal(-2)), "")
SubstringIndex(Literal(""), Literal("."), Literal(-2)), "")
checkEvaluation(
Substring_index(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
checkEvaluation(Substring_index(
SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
checkEvaluation(SubstringIndex(
Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
// non ascii chars
// scalastyle:off
checkEvaluation(
Substring_index(Literal("大千世界大千世界"), Literal( ""), Literal(2)), "大千世界大")
SubstringIndex(Literal("大千世界大千世界"), Literal( ""), Literal(2)), "大千世界大")
// scalastyle:on
checkEvaluation(
Substring_index(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
}

test("LIKE literal Regular Expression") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,7 @@ object functions {
* @group string_funcs
*/
def substring_index(str: Column, delim: String, count: Int): Column =
Substring_index(str.expr, lit(delim).expr, lit(count).expr)
SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)

/**
* Locate the position of the first occurrence of substr.
Expand Down
Loading

0 comments on commit 3ce7802

Please sign in to comment.