From ac863e9048e1aedf587aaf5e90d09b91bbf8ec25 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 22 Jul 2015 15:52:56 +0800 Subject: [PATCH] reduce the calling of numChars --- .../expressions/stringOperations.scala | 78 ++-------- .../apache/spark/unsafe/types/UTF8String.java | 140 +++++++++++++++++- .../spark/unsafe/types/UTF8StringSuite.java | 1 + 3 files changed, 148 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 35ec2c991a94b..2d921124c38ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -21,10 +21,7 @@ import java.text.DecimalFormat import java.util.Locale import java.util.regex.{MatchResult, Pattern} -import org.apache.commons.lang.StringUtils - import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -371,75 +368,22 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr) override def prettyName: String = "substring_index" - override def toString: String = s"substring_index($strExpr, $delimExpr, $countExpr)" override def eval(input: InternalRow): Any = { val str = strExpr.eval(input) - val delim = delimExpr.eval(input) - val count = countExpr.eval(input) - if (str == null || delim == null || count == null) { - null - } else { - subStrIndex( - str.asInstanceOf[UTF8String], - delim.asInstanceOf[UTF8String], - count.asInstanceOf[Int]) - } - } - - private def lastOrdinalIndexOf( - str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = { - ordinalIndexOf(str, searchStr, ordinal, true) - } - - private def ordinalIndexOf( - str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = { - if (str == null || searchStr == null || ordinal <= 0) { - return -1 - } - val strNumChars = str.numChars() - if (searchStr.numBytes() == 0) { - return if (lastIndex) {strNumChars} else {0} - } - var found = 0 - var index = if (lastIndex) {strNumChars} else {0} - do { - if (lastIndex) { - index = str.lastIndexOf(searchStr, index - 1) - } else { - index = str.indexOf(searchStr, index + 1) - } - if (index < 0) { - return index - } - found += 1 - } while (found < ordinal) - index - } - - private def subStrIndex(strUtf8: UTF8String, delimUtf8: UTF8String, count: Int): UTF8String = { - if (strUtf8 == null || delimUtf8 == null || count == null) { - return null - } - if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) { - return UTF8String.fromString("") - } - val res = if (count > 0) { - val idx = ordinalIndexOf(strUtf8, delimUtf8, count) - if (idx != -1) { - strUtf8.substring(0, idx) - } else { - strUtf8 - } - } else { - val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count) - if (idx != -1) { - strUtf8.substring(idx + delimUtf8.numChars(), strUtf8.numChars()) - } else { - strUtf8 + if (str != null) { + val delim = delimExpr.eval(input) + if (delim != null) { + val count = countExpr.eval(input) + if (count != null) { + return UTF8String.subStringIndex( + str.asInstanceOf[UTF8String], + delim.asInstanceOf[UTF8String], + count.asInstanceOf[Int]) + } } } - res + null } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 78d767ee4de12..11e34f95bce8a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -165,6 +165,27 @@ public UTF8String substring(final int start, final int until) { return fromBytes(bytes); } + /** + * Returns a substring of this from start to end. + * @param start the position of first code point + */ + public UTF8String substring(final int start) { + if (start >= numBytes) { + return fromBytes(new byte[0]); + } + + int i = 0; + int c = 0; + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + byte[] bytes = new byte[numBytes - i]; + copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i); + return fromBytes(bytes); + } + public UTF8String substringSQL(int pos, int length) { // Information regarding the pos calculation: // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and @@ -391,7 +412,19 @@ private int indexEnd(int startCodePoint) { return i; } + /** + * Returns the index within this string of the last occurrence of the + * specified substring, searching backward starting at the specified index. + * @param v the substring to search for. + * @param startCodePoint the index to start search from + * @return the index of the last occurrence of the specified substring, + * searching backward from the specified index, + * or {@code -1} if there is no such occurrence. + */ public int lastIndexOf(UTF8String v, int startCodePoint) { + return lastIndexOf(v, v.numChars(), startCodePoint); + } + public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) { if (v.numBytes == 0) { return 0; } @@ -399,22 +432,121 @@ public int lastIndexOf(UTF8String v, int startCodePoint) { return -1; } int fromIndexEnd = indexEnd(startCodePoint); - int count = startCodePoint; - int vNumChars = v.numChars(); do { if (fromIndexEnd - v.numBytes + 1 < 0 ) { return -1; } if (ByteArrayMethods.arrayEquals( base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) { - return count - vNumChars + 1; + int count = 0; // count from right most to the match end in byte. + while (fromIndexEnd >= 0) { + count++; + fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1; + } + return count - vNumChars; } fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1; - count--; } while (fromIndexEnd >= 0); return -1; } + /** + * Finds the n-th last index within a String. + * This method uses {@link String#lastIndexOf(String)}.

+ * + * @param str the String to check, may be null + * @param searchStr the String to find, may be null + * @param searchStrNumChars num of code ponts of the searchStr + * @param ordinal the n-th last searchStr to find + * @return the n-th last index of the search String, + * -1 if no match or null string input + */ + public static int lastOrdinalIndexOf( + UTF8String str, + UTF8String searchStr, + int searchStrNumChars, + int ordinal) { + return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, true); + } + /** + * Finds the n-th index within a String, handling null. + * A null String will return -1 + * + * @param str the String to check, may be null + * @param searchStr the String to find, may be null + * @param searchStrNumChars num of code points of searchStr + * @param ordinal the n-th searchStr to find + * @return the n-th index of the search String, + * -1 if no match or null string input + */ + public static int ordinalIndexOf( + UTF8String str, + UTF8String searchStr, + int searchStrNumChars, + int ordinal) { + return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, false); + } + + private static int doOrdinalIndexOf( + UTF8String str, + UTF8String searchStr, + int searchStrNumChars, + int ordinal, + boolean lastIndex) { + if (str == null || searchStr == null || ordinal <= 0) { + return -1; + } + // Only calc numChars when lastIndex == true sicnc the calculation is expensive + int strNumChars = 0; + if (lastIndex) { + strNumChars = str.numChars(); + } + if (searchStr.numBytes == 0) { + return lastIndex ? strNumChars : 0; + } + int found = 0; + int index = lastIndex ? strNumChars : 0; + do { + if (lastIndex) { + index = str.lastIndexOf(searchStr, searchStrNumChars, index - 1); + } else { + index = str.indexOf(searchStr, index + 1); + } + if (index < 0) { + return index; + } + found += 1; + } while (found < ordinal); + return index; + } + /** + * Returns the substring from string str before count occurrences of the delimiter delim. + * If count is positive, everything the left of the final delimiter (counting from left) is + * returned. If count is negative, every to the right of the final delimiter (counting from the + * right) is returned. substring_index performs a case-sensitive match when searching for delim. + */ + public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) { + if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) { + return UTF8String.EMPTY_UTF8; + } + int delimNumChars = delim.numChars(); + if (count > 0) { + int idx = ordinalIndexOf(str, delim, delimNumChars, count); + if (idx != -1) { + return str.substring(0, idx); + } else { + return str; + } + } else { + int idx = lastOrdinalIndexOf(str, delim, delimNumChars, -count); + if (idx != -1) { + return str.substring(idx + delimNumChars); + } else { + return str; + } + } + } + /** * Returns str, right-padded with pad to a length of len * For example: diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index bee232b11a11e..df69d8e655ead 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -226,6 +226,7 @@ public void lastIndexOf() { assertEquals(0, fromString("").lastIndexOf(fromString(""), 0)); assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0)); assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0)); + assertEquals(0, fromString("hello").lastIndexOf(fromString("h"), 4)); assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0)); assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3)); assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4));