Skip to content

Commit

Permalink
Finish all todos in suite
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 31, 2015
1 parent 52f51a0 commit c0800e6
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 44 deletions.
6 changes: 6 additions & 0 deletions unsafe/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
Expand Down
19 changes: 9 additions & 10 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,9 @@ public UTF8String trim() {
int s = 0;
int e = this.numBytes - 1;
// skip all of the space (0x20) in the left side
while (s < this.numBytes && getByte(s) == 0x20) s++;
while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++;
// skip all of the space (0x20) in the right side
while (e >= 0 && getByte(e) == 0x20) e--;

while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--;
if (s > e) {
// empty string
return UTF8String.fromBytes(new byte[0]);
Expand All @@ -316,7 +315,7 @@ public UTF8String trim() {
public UTF8String trimLeft() {
int s = 0;
// skip all of the space (0x20) in the left side
while (s < this.numBytes && getByte(s) == 0x20) s++;
while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++;
if (s == this.numBytes) {
// empty string
return UTF8String.fromBytes(new byte[0]);
Expand All @@ -328,7 +327,7 @@ public UTF8String trimLeft() {
public UTF8String trimRight() {
int e = numBytes - 1;
// skip all of the space (0x20) in the right side
while (e >= 0 && getByte(e) == 0x20) e--;
while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--;

if (e < 0) {
// empty string
Expand All @@ -354,7 +353,7 @@ public UTF8String reverse() {
}

public UTF8String repeat(int times) {
if (times <=0) {
if (times <= 0) {
return EMPTY_UTF8;
}

Expand Down Expand Up @@ -414,7 +413,7 @@ public int indexOf(UTF8String v, int start) {
*/
public UTF8String rpad(int len, UTF8String pad) {
int spaces = len - this.numChars(); // number of char need to pad
if (spaces <= 0) {
if (spaces <= 0 || pad.numChars() == 0) {
// no padding at all, return the substring of the current string
return substring(0, len);
} else {
Expand All @@ -429,7 +428,7 @@ public UTF8String rpad(int len, UTF8String pad) {
int idx = 0;
while (idx < count) {
copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
++idx;
++ idx;
offset += pad.numBytes;
}
copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
Expand All @@ -446,7 +445,7 @@ public UTF8String rpad(int len, UTF8String pad) {
*/
public UTF8String lpad(int len, UTF8String pad) {
int spaces = len - this.numChars(); // number of char need to pad
if (spaces <= 0) {
if (spaces <= 0 || pad.numChars() == 0) {
// no padding at all, return the substring of the current string
return substring(0, len);
} else {
Expand All @@ -461,7 +460,7 @@ public UTF8String lpad(int len, UTF8String pad) {
int idx = 0;
while (idx < count) {
copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
++idx;
++ idx;
offset += pad.numBytes;
}
copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ public void pad() {
assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????")));
assertEquals(fromString("???????"), EMPTY_UTF8.rpad(7, fromString("?????")));


assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????")));
assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????")));
assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????")));
Expand All @@ -289,6 +288,18 @@ public void pad() {
assertEquals(
fromString("数据砖头孙行者孙行者孙行"),
fromString("数据砖头").rpad(12, fromString("孙行者")));

assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, fromString("孙行者")));
assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, EMPTY_UTF8));
assertEquals(fromString("数据砖头"), fromString("数据砖头").lpad(5, EMPTY_UTF8));
assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, EMPTY_UTF8));
assertEquals(EMPTY_UTF8, EMPTY_UTF8.lpad(3, EMPTY_UTF8));

assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, fromString("孙行者")));
assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, EMPTY_UTF8));
assertEquals(fromString("数据砖头"), fromString("数据砖头").rpad(5, EMPTY_UTF8));
assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, EMPTY_UTF8));
assertEquals(EMPTY_UTF8, EMPTY_UTF8.rpad(3, EMPTY_UTF8));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.unsafe.types

import org.apache.commons.lang3.StringUtils

import org.scalacheck.{Arbitrary, Gen}
import org.scalatest.prop.GeneratorDrivenPropertyChecks
// scalastyle:off
import org.scalatest.{FunSuite, Matchers}

import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8}

class UTF8StringPropertyChecks extends FunSuite with GeneratorDrivenPropertyChecks with Matchers {
// scalastyle:on

test("toString") {
forAll { (s: String) =>
assert(s === toUTF8(s).toString())
assert(toUTF8(s).toString() === s)
}
}

Expand Down Expand Up @@ -42,41 +63,35 @@ class UTF8StringPropertyChecks extends FunSuite with GeneratorDrivenPropertyChec

test("toUpperCase") {
forAll { (s: String) =>
assert(s.toUpperCase === toUTF8(s).toUpperCase.toString)
assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase))
}
}

test("toLowerCase") {
forAll { (s: String) =>
assert(s.toLowerCase === toUTF8(s).toLowerCase.toString)
assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase))
}
}

test("compare") {
forAll { (s1: String, s2: String) =>
assert(Math.signum(s1.compareTo(s2)) === Math.signum(toUTF8(s1).compareTo(toUTF8(s2))))
assert(Math.signum(toUTF8(s1).compareTo(toUTF8(s2))) === Math.signum(s1.compareTo(s2)))
}
}

test("substring") {
forAll { (s: String) =>
for (start <- 0 to s.length; end <- 0 to s.length) {
withClue(s"start=$start, end=$end") {
assert(s.substring(start, end) === toUTF8(s).substring(start, end).toString)
}
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
assert(toUTF8(s).substring(start, end).toString === s.substring(start, end))
}
}
}

// TODO: substringSQL

test("contains") {
forAll { (s: String) =>
for (start <- 0 to s.length; end <- 0 to s.length) {
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
val substring = s.substring(start, end)
withClue(s"substring=$substring") {
assert(s.contains(substring) === toUTF8(s).contains(toUTF8(substring)))
}
assert(toUTF8(s).contains(toUTF8(substring)) === s.contains(substring))
}
}
}
Expand All @@ -86,48 +101,146 @@ class UTF8StringPropertyChecks extends FunSuite with GeneratorDrivenPropertyChec
val randomString: Gen[String] = Arbitrary.arbString.arbitrary

test("trim, trimLeft, trimRight") {
// lTrim and rTrim are both modified from java.lang.String.trim
def lTrim(s: String): String = {
var st = 0
val array: Array[Char] = s.toCharArray
while ((st < s.length) && (array(st) <= ' ')) {
st += 1
}
if (st > 0) s.substring(st, s.length) else s
}
def rTrim(s: String): String = {
var len = s.length
val array: Array[Char] = s.toCharArray
while ((len > 0) && (array(len - 1) <= ' ')) {
len -= 1
}
if (len < s.length) s.substring(0, len) else s
}

forAll(
whitespaceString,
randomString,
whitespaceString
) { (start: String, middle: String, end: String) =>
val s = start + middle + end
assert(s.trim() === toUTF8(s).trim().toString)
assert(s.stripMargin === toUTF8(s).trimLeft().toString)
assert(s.reverse.stripMargin.reverse === toUTF8(s).trimRight().toString)
assert(toUTF8(s).trim() === toUTF8(s.trim()))
assert(toUTF8(s).trimLeft() === toUTF8(lTrim(s)))
assert(toUTF8(s).trimRight() === toUTF8(rTrim(s)))
}
}

test("reverse") {
forAll() { (s: String) =>
assert(s.reverse === toUTF8(s).reverse.toString)
forAll { (s: String) =>
assert(toUTF8(s).reverse === toUTF8(s.reverse))
}
}

// TODO: repeat
// TODO: indexOf
// TODO: lpad
// TODO: rpad
test("indexOf") {
forAll { (s: String) =>
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
val substring = s.substring(start, end)
assert(toUTF8(s).indexOf(toUTF8(substring), 0) === s.indexOf(substring))
}
}
}

val randomInt = Gen.choose(-100, 100)

test("repeat") {
def repeat(str: String, times: Int): String = {
if (times > 0) str * times else ""
}
// ScalaCheck always generating too large repeat times which might hang the test forever.
forAll(randomString, randomInt) { (s: String, times: Int) =>
assert(toUTF8(s).repeat(times) === toUTF8(repeat(s, times)))
}
}

test("lpad, rpad") {
def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = {
if (length <= 0) return ""
if (length <= origin.length) {
if (length <= 0) "" else origin.substring(0, length)
} else {
if (pad.length == 0) return origin
val toPad = length - origin.length
val partPad = if (toPad % pad.length == 0) "" else pad.substring(0, toPad % pad.length)
if (isLPad) {
pad * (toPad / pad.length) + partPad + origin
} else {
origin + pad * (toPad / pad.length) + partPad
}
}
}

forAll (
randomString,
randomString,
randomInt
) { (s: String, pad: String, length: Int) =>
assert(toUTF8(s).lpad(length, toUTF8(pad)) ===
toUTF8(padding(s, pad, length, true)))
assert(toUTF8(s).rpad(length, toUTF8(pad)) ===
toUTF8(padding(s, pad, length, false)))
}
}

val nullalbeSeq = Gen.listOf(Gen.oneOf[String](null: String, randomString))

test("concat") {
forAll() { (inputs: Seq[String]) =>
// TODO: test case where at least one of the inputs is null
assert(inputs.mkString === UTF8String.concat(inputs.map(toUTF8): _*).toString)
def concat(orgin: Seq[String]): String =
if (orgin.exists(_ == null)) null else orgin.mkString

forAll { (inputs: Seq[String]) =>
assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString))
}
forAll (nullalbeSeq) { (inputs: Seq[String]) =>
assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(concat(inputs)))
}
}

test("concatWs") {
forAll() { (sep: String, inputs: Seq[String]) =>
// TODO: handle case where at least one of the inputs is null
assert(
inputs.mkString(sep) === UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*).toString)
def concatWs(sep: String, inputs: Seq[String]): String = {
if (sep == null) return null
inputs.filter(_ != null).mkString(sep)
}

forAll { (sep: String, inputs: Seq[String]) =>
assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) ===
toUTF8(inputs.mkString(sep)))
}
forAll(randomString, nullalbeSeq) {(sep: String, inputs: Seq[String]) =>
assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) ===
toUTF8(concatWs(sep, inputs)))
}
}

// TODO: split
// TODO: enable this when we find a proper way to generate valid patterns
ignore("split") {
forAll { (s: String, pattern: String, limit: Int) =>
assert(toUTF8(s).split(toUTF8(pattern), limit) ===
s.split(pattern, limit).map(toUTF8(_)))
}
}

// TODO: levenshteinDistance that tests against StringUtils' implementation
test("levenshteinDistance") {
forAll { (one: String, another: String) =>
assert(toUTF8(one).levenshteinDistance(toUTF8(another)) ===
StringUtils.getLevenshteinDistance(one, another))
}
}

// TODO: equals(), hashCode(), and compare()
test("hashCode") {
forAll { (s: String) =>
assert(toUTF8(s).hashCode() === toUTF8(s).hashCode())
}
}

test("equals") {
forAll { (one: String, another: String) =>
assert(toUTF8(one).equals(toUTF8(another)) === one.equals(another))
}
}
}

0 comments on commit c0800e6

Please sign in to comment.