From ac863e9048e1aedf587aaf5e90d09b91bbf8ec25 Mon Sep 17 00:00:00 2001
From: "zhichao.li"
Date: Wed, 22 Jul 2015 15:52:56 +0800
Subject: [PATCH] reduce the calling of numChars
---
.../expressions/stringOperations.scala | 78 ++--------
.../apache/spark/unsafe/types/UTF8String.java | 140 +++++++++++++++++-
.../spark/unsafe/types/UTF8StringSuite.java | 1 +
3 files changed, 148 insertions(+), 71 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 35ec2c991a94b..2d921124c38ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -21,10 +21,7 @@ import java.text.DecimalFormat
import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
-import org.apache.commons.lang.StringUtils
-
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -371,75 +368,22 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable
override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
override def prettyName: String = "substring_index"
- override def toString: String = s"substring_index($strExpr, $delimExpr, $countExpr)"
override def eval(input: InternalRow): Any = {
val str = strExpr.eval(input)
- val delim = delimExpr.eval(input)
- val count = countExpr.eval(input)
- if (str == null || delim == null || count == null) {
- null
- } else {
- subStrIndex(
- str.asInstanceOf[UTF8String],
- delim.asInstanceOf[UTF8String],
- count.asInstanceOf[Int])
- }
- }
-
- private def lastOrdinalIndexOf(
- str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
- ordinalIndexOf(str, searchStr, ordinal, true)
- }
-
- private def ordinalIndexOf(
- str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
- if (str == null || searchStr == null || ordinal <= 0) {
- return -1
- }
- val strNumChars = str.numChars()
- if (searchStr.numBytes() == 0) {
- return if (lastIndex) {strNumChars} else {0}
- }
- var found = 0
- var index = if (lastIndex) {strNumChars} else {0}
- do {
- if (lastIndex) {
- index = str.lastIndexOf(searchStr, index - 1)
- } else {
- index = str.indexOf(searchStr, index + 1)
- }
- if (index < 0) {
- return index
- }
- found += 1
- } while (found < ordinal)
- index
- }
-
- private def subStrIndex(strUtf8: UTF8String, delimUtf8: UTF8String, count: Int): UTF8String = {
- if (strUtf8 == null || delimUtf8 == null || count == null) {
- return null
- }
- if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) {
- return UTF8String.fromString("")
- }
- val res = if (count > 0) {
- val idx = ordinalIndexOf(strUtf8, delimUtf8, count)
- if (idx != -1) {
- strUtf8.substring(0, idx)
- } else {
- strUtf8
- }
- } else {
- val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count)
- if (idx != -1) {
- strUtf8.substring(idx + delimUtf8.numChars(), strUtf8.numChars())
- } else {
- strUtf8
+ if (str != null) {
+ val delim = delimExpr.eval(input)
+ if (delim != null) {
+ val count = countExpr.eval(input)
+ if (count != null) {
+ return UTF8String.subStringIndex(
+ str.asInstanceOf[UTF8String],
+ delim.asInstanceOf[UTF8String],
+ count.asInstanceOf[Int])
+ }
}
}
- res
+ null
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 78d767ee4de12..11e34f95bce8a 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -165,6 +165,27 @@ public UTF8String substring(final int start, final int until) {
return fromBytes(bytes);
}
+ /**
+ * Returns a substring of this from start to end.
+ * @param start the position of first code point
+ */
+ public UTF8String substring(final int start) {
+ if (start >= numBytes) {
+ return fromBytes(new byte[0]);
+ }
+
+ int i = 0;
+ int c = 0;
+ while (i < numBytes && c < start) {
+ i += numBytesForFirstByte(getByte(i));
+ c += 1;
+ }
+
+ byte[] bytes = new byte[numBytes - i];
+ copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i);
+ return fromBytes(bytes);
+ }
+
public UTF8String substringSQL(int pos, int length) {
// Information regarding the pos calculation:
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
@@ -391,7 +412,19 @@ private int indexEnd(int startCodePoint) {
return i;
}
+ /**
+ * Returns the index within this string of the last occurrence of the
+ * specified substring, searching backward starting at the specified index.
+ * @param v the substring to search for.
+ * @param startCodePoint the index to start search from
+ * @return the index of the last occurrence of the specified substring,
+ * searching backward from the specified index,
+ * or {@code -1} if there is no such occurrence.
+ */
public int lastIndexOf(UTF8String v, int startCodePoint) {
+ return lastIndexOf(v, v.numChars(), startCodePoint);
+ }
+ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
if (v.numBytes == 0) {
return 0;
}
@@ -399,22 +432,121 @@ public int lastIndexOf(UTF8String v, int startCodePoint) {
return -1;
}
int fromIndexEnd = indexEnd(startCodePoint);
- int count = startCodePoint;
- int vNumChars = v.numChars();
do {
if (fromIndexEnd - v.numBytes + 1 < 0 ) {
return -1;
}
if (ByteArrayMethods.arrayEquals(
base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) {
- return count - vNumChars + 1;
+ int count = 0; // count from right most to the match end in byte.
+ while (fromIndexEnd >= 0) {
+ count++;
+ fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
+ }
+ return count - vNumChars;
}
fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
- count--;
} while (fromIndexEnd >= 0);
return -1;
}
+ /**
+ * Finds the n-th last index within a String.
+ * This method uses {@link String#lastIndexOf(String)}.
+ *
+ * @param str the String to check, may be null
+ * @param searchStr the String to find, may be null
+ * @param searchStrNumChars num of code ponts of the searchStr
+ * @param ordinal the n-th last searchStr
to find
+ * @return the n-th last index of the search String,
+ * -1
if no match or null
string input
+ */
+ public static int lastOrdinalIndexOf(
+ UTF8String str,
+ UTF8String searchStr,
+ int searchStrNumChars,
+ int ordinal) {
+ return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, true);
+ }
+ /**
+ * Finds the n-th index within a String, handling null
.
+ * A null
String will return -1
+ *
+ * @param str the String to check, may be null
+ * @param searchStr the String to find, may be null
+ * @param searchStrNumChars num of code points of searchStr
+ * @param ordinal the n-th searchStr
to find
+ * @return the n-th index of the search String,
+ * -1
if no match or null
string input
+ */
+ public static int ordinalIndexOf(
+ UTF8String str,
+ UTF8String searchStr,
+ int searchStrNumChars,
+ int ordinal) {
+ return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, false);
+ }
+
+ private static int doOrdinalIndexOf(
+ UTF8String str,
+ UTF8String searchStr,
+ int searchStrNumChars,
+ int ordinal,
+ boolean lastIndex) {
+ if (str == null || searchStr == null || ordinal <= 0) {
+ return -1;
+ }
+ // Only calc numChars when lastIndex == true sicnc the calculation is expensive
+ int strNumChars = 0;
+ if (lastIndex) {
+ strNumChars = str.numChars();
+ }
+ if (searchStr.numBytes == 0) {
+ return lastIndex ? strNumChars : 0;
+ }
+ int found = 0;
+ int index = lastIndex ? strNumChars : 0;
+ do {
+ if (lastIndex) {
+ index = str.lastIndexOf(searchStr, searchStrNumChars, index - 1);
+ } else {
+ index = str.indexOf(searchStr, index + 1);
+ }
+ if (index < 0) {
+ return index;
+ }
+ found += 1;
+ } while (found < ordinal);
+ return index;
+ }
+ /**
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
+ */
+ public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) {
+ if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) {
+ return UTF8String.EMPTY_UTF8;
+ }
+ int delimNumChars = delim.numChars();
+ if (count > 0) {
+ int idx = ordinalIndexOf(str, delim, delimNumChars, count);
+ if (idx != -1) {
+ return str.substring(0, idx);
+ } else {
+ return str;
+ }
+ } else {
+ int idx = lastOrdinalIndexOf(str, delim, delimNumChars, -count);
+ if (idx != -1) {
+ return str.substring(idx + delimNumChars);
+ } else {
+ return str;
+ }
+ }
+ }
+
/**
* Returns str, right-padded with pad to a length of len
* For example:
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index bee232b11a11e..df69d8e655ead 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -226,6 +226,7 @@ public void lastIndexOf() {
assertEquals(0, fromString("").lastIndexOf(fromString(""), 0));
assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0));
assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0));
+ assertEquals(0, fromString("hello").lastIndexOf(fromString("h"), 4));
assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0));
assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3));
assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4));