diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 269eb81bdbce3..6c20677c2b0d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -387,16 +387,33 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr } } - private def ordinalIndexOf(str: UTF8String, delim: UTF8String, count: Int): Int = { + private def lastOrdinalIndexOf( + str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = { + ordinalIndexOf(str, searchStr, ordinal, true) + } + + private def ordinalIndexOf( + str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = { + if (str == null || searchStr == null || ordinal <= 0) { + return -1 + } + val strNumChars = str.numChars() + if (searchStr.numBytes() == 0) { + return if (lastIndex) {strNumChars} else {0} + } var found = 0 - var index = -1 + var index = if (lastIndex) {strNumChars} else {0} do { - index = str.indexOf(delim, index + 1) + if (lastIndex) { + index = str.lastIndexOf(searchStr, index - 1) + } else { + index = str.indexOf(searchStr, index + 1) + } if (index < 0) { return index } found += 1 - } while (found < count) + } while (found < ordinal) index } @@ -407,24 +424,21 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) { return UTF8String.fromString("") } - val res: UTF8String = - if (count > 0) { - val idx = ordinalIndexOf(strUtf8, delimUtf8, count) - if (idx != -1) { - strUtf8.substring(0, idx) - } else { - strUtf8 - } + val res = if (count > 0) { + val idx = ordinalIndexOf(strUtf8, delimUtf8, count) + if (idx != -1) { + strUtf8.substring(0, idx) } else { - val str = strUtf8.toString - val delim = delimUtf8.toString - val idx = StringUtils.lastOrdinalIndexOf(str, delim, -count) - if (idx != -1) { - UTF8String.fromString(str.substring(idx + 1)) - } else { - UTF8String.fromString(str) - } + strUtf8 } + } else { + val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count) + if (idx != -1) { + strUtf8.substring(idx + 1, strUtf8.numChars()) + } else { + strUtf8 + } + } res } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index bc5ce2c49a7ad..9a0eb16ec700b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -170,9 +170,11 @@ class StringFunctionsSuite extends QueryTest { checkAnswer( df.select(substring_index(lit(null), "ab", 2)), Row(null)) + // scalastyle:off checkAnswer( df.select(substring_index(lit("大千世界大千世界"), "千", 2)), Row("大千世界大")) + // scalastyle:on checkAnswer( df.selectExpr("""substring_index(a, ",", 2)"""), Row("ac,ab")) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 946d355f1fc28..f3d14178ce3cc 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -352,6 +352,69 @@ public int indexOf(UTF8String v, int start) { return -1; } + private enum ByteType {FIRSTBYTE, MIDBYTE, SINGLEBYTECHAR}; + + private ByteType checkByteType(Byte b) { + int firstTwoBits = (b >>> 6) & 0x03; + if (firstTwoBits == 3) { + return ByteType.FIRSTBYTE; + } else if (firstTwoBits == 2) { + return ByteType.MIDBYTE; + } else { + return ByteType.SINGLEBYTECHAR; + } + } + + /** + * Return the first byte position for a given byte which shared the same code point. + * @param bytePos any byte within the code point + * @return the first byte position of a given code point, throw exception if not a valid UTF8 str + */ + private int firstOfCurrentCodePoint(int bytePos) { + while (bytePos >= 0) { + if (ByteType.FIRSTBYTE == checkByteType(getByte(bytePos)) + || ByteType.SINGLEBYTECHAR == checkByteType(getByte(bytePos))) { + return bytePos; + } + bytePos--; + } + throw new RuntimeException("Invalid utf8 string"); + } + + private int endByte(int startCodePoint) { + int i = numBytes -1; // position in byte + int c = numChars() - 1; // position in character + while (i >=0 && c > startCodePoint) { + i = firstOfCurrentCodePoint(i) - 1; + c -= 1; + } + return i; + } + + public int lastIndexOf(UTF8String v, int startCodePoint) { + if (v.numBytes == 0) { + return 0; + } + if (numBytes == 0) { + return -1; + } + int fromIndexEnd = endByte(startCodePoint); + int count = startCodePoint; + int vNumChars = v.numChars(); + do { + if (fromIndexEnd - v.numBytes + 1 < 0 ) { + return -1; + } + if (ByteArrayMethods.arrayEquals( + base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) { + return count - vNumChars + 1; + } + fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1; + count--; + } while (fromIndexEnd >= 0); + return -1; + } + /** * Returns str, right-padded with pad to a length of len * For example: diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index e2a5628ff4d93..bee232b11a11e 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -221,6 +221,22 @@ public void indexOf() { assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); } + @Test + public void lastIndexOf() { + assertEquals(0, fromString("").lastIndexOf(fromString(""), 0)); + assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0)); + assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0)); + assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3)); + assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4)); + assertEquals(2, fromString("hello").lastIndexOf(fromString("ll"), 4)); + assertEquals(-1, fromString("hello").lastIndexOf(fromString("ll"), 0)); + assertEquals(5, fromString("数据砖头数据砖头").lastIndexOf(fromString("据砖"), 7)); + assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 3)); + assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 0)); + assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3)); + } + @Test public void reverse() { assertEquals(fromString("olleh"), fromString("hello").reverse());