Skip to content

Commit

Permalink
add lastIndexOf
Browse files Browse the repository at this point in the history
  • Loading branch information
zhichao-li committed Jul 22, 2015
1 parent 52d7b03 commit d92951b
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,33 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
}
}

private def ordinalIndexOf(str: UTF8String, delim: UTF8String, count: Int): Int = {
private def lastOrdinalIndexOf(
str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
ordinalIndexOf(str, searchStr, ordinal, true)
}

private def ordinalIndexOf(
str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
if (str == null || searchStr == null || ordinal <= 0) {
return -1
}
val strNumChars = str.numChars()
if (searchStr.numBytes() == 0) {
return if (lastIndex) {strNumChars} else {0}
}
var found = 0
var index = -1
var index = if (lastIndex) {strNumChars} else {0}
do {
index = str.indexOf(delim, index + 1)
if (lastIndex) {
index = str.lastIndexOf(searchStr, index - 1)
} else {
index = str.indexOf(searchStr, index + 1)
}
if (index < 0) {
return index
}
found += 1
} while (found < count)
} while (found < ordinal)
index
}

Expand All @@ -407,24 +424,21 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) {
return UTF8String.fromString("")
}
val res: UTF8String =
if (count > 0) {
val idx = ordinalIndexOf(strUtf8, delimUtf8, count)
if (idx != -1) {
strUtf8.substring(0, idx)
} else {
strUtf8
}
val res = if (count > 0) {
val idx = ordinalIndexOf(strUtf8, delimUtf8, count)
if (idx != -1) {
strUtf8.substring(0, idx)
} else {
val str = strUtf8.toString
val delim = delimUtf8.toString
val idx = StringUtils.lastOrdinalIndexOf(str, delim, -count)
if (idx != -1) {
UTF8String.fromString(str.substring(idx + 1))
} else {
UTF8String.fromString(str)
}
strUtf8
}
} else {
val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count)
if (idx != -1) {
strUtf8.substring(idx + 1, strUtf8.numChars())
} else {
strUtf8
}
}
res
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,11 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(
df.select(substring_index(lit(null), "ab", 2)),
Row(null))
// scalastyle:off
checkAnswer(
df.select(substring_index(lit("大千世界大千世界"), "", 2)),
Row("大千世界大"))
// scalastyle:on
checkAnswer(
df.selectExpr("""substring_index(a, ",", 2)"""),
Row("ac,ab"))
Expand Down
63 changes: 63 additions & 0 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,69 @@ public int indexOf(UTF8String v, int start) {
return -1;
}

private enum ByteType {FIRSTBYTE, MIDBYTE, SINGLEBYTECHAR};

private ByteType checkByteType(Byte b) {
int firstTwoBits = (b >>> 6) & 0x03;
if (firstTwoBits == 3) {
return ByteType.FIRSTBYTE;
} else if (firstTwoBits == 2) {
return ByteType.MIDBYTE;
} else {
return ByteType.SINGLEBYTECHAR;
}
}

/**
* Return the first byte position for a given byte which shared the same code point.
* @param bytePos any byte within the code point
* @return the first byte position of a given code point, throw exception if not a valid UTF8 str
*/
private int firstOfCurrentCodePoint(int bytePos) {
while (bytePos >= 0) {
if (ByteType.FIRSTBYTE == checkByteType(getByte(bytePos))
|| ByteType.SINGLEBYTECHAR == checkByteType(getByte(bytePos))) {
return bytePos;
}
bytePos--;
}
throw new RuntimeException("Invalid utf8 string");
}

private int endByte(int startCodePoint) {
int i = numBytes -1; // position in byte
int c = numChars() - 1; // position in character
while (i >=0 && c > startCodePoint) {
i = firstOfCurrentCodePoint(i) - 1;
c -= 1;
}
return i;
}

public int lastIndexOf(UTF8String v, int startCodePoint) {
if (v.numBytes == 0) {
return 0;
}
if (numBytes == 0) {
return -1;
}
int fromIndexEnd = endByte(startCodePoint);
int count = startCodePoint;
int vNumChars = v.numChars();
do {
if (fromIndexEnd - v.numBytes + 1 < 0 ) {
return -1;
}
if (ByteArrayMethods.arrayEquals(
base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) {
return count - vNumChars + 1;
}
fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
count--;
} while (fromIndexEnd >= 0);
return -1;
}

/**
* Returns str, right-padded with pad to a length of len
* For example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ public void indexOf() {
assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0));
}

@Test
public void lastIndexOf() {
assertEquals(0, fromString("").lastIndexOf(fromString(""), 0));
assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0));
assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0));
assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0));
assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3));
assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4));
assertEquals(2, fromString("hello").lastIndexOf(fromString("ll"), 4));
assertEquals(-1, fromString("hello").lastIndexOf(fromString("ll"), 0));
assertEquals(5, fromString("数据砖头数据砖头").lastIndexOf(fromString("据砖"), 7));
assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 3));
assertEquals(0, fromString("数据砖头").lastIndexOf(fromString("数"), 0));
assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3));
}

@Test
public void reverse() {
assertEquals(fromString("olleh"), fromString("hello").reverse());
Expand Down

0 comments on commit d92951b

Please sign in to comment.