Skip to content

Commit

Permalink
sql module
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Apr 8, 2015
1 parent 04ec7ac commit 7e0db5e
Show file tree
Hide file tree
Showing 20 changed files with 120 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ class CachedTableSuite extends QueryTest {

test("too big for memory") {
val data = "*" * 10000
sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF().registerTempTable("bigData")
sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF()
.registerTempTable("bigData")
table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
assert(table("bigData").count() === 200000L)
table("bigData").unpersist(blocking = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,9 @@ class DataFrameSuite extends QueryTest {
checkAnswer(
decimalData.agg(avg('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0)))
// non-partial
checkAnswer(
decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class QueryTest extends PlanTest {
checkAnswer(df, Seq(expectedAnswer))
}

def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) {
test(sqlString) {
checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
}
Expand Down
30 changes: 22 additions & 8 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))

checkAnswer(sql(
"SELECT time FROM timestamps WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002')"),
""""
|SELECT time FROM timestamps
|WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002')
""".stripMargin),
Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")),
Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))))

Expand Down Expand Up @@ -248,7 +251,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row("1"))
}

def sortTest() = {
def sortTest(): Unit = {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"),
Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
Expand Down Expand Up @@ -327,7 +330,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {

test("from follow multiple brackets") {
checkAnswer(sql(
"select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"),
"""
|select key from ((select * from testData limit 1)
| union all (select * from testData limit 1)) x limit 1
""".stripMargin),
Row(1)
)

Expand All @@ -337,7 +343,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
)

checkAnswer(sql(
"select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"),
"""
|select key from
| (select * from testData limit 1 union all select * from testData limit 1) x
| limit 1
""".stripMargin),
Row(1)
)
}
Expand Down Expand Up @@ -384,7 +394,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Seq(Row(1, 0), Row(2, 1)))

checkAnswer(
sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"),
sql(
"""
|SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3
""".stripMargin),
Row(2, 1, 2, 2, 1))
}

Expand Down Expand Up @@ -997,7 +1010,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}

test("SPARK-3483 Special chars in column names") {
val data = sparkContext.parallelize(Seq("""{"key?number1": "value1", "key.number2": "value2"}"""))
val data = sparkContext.parallelize(
Seq("""{"key?number1": "value1", "key.number2": "value2"}"""))
jsonRDD(data).registerTempTable("records")
sql("SELECT `key?number1` FROM records")
}
Expand Down Expand Up @@ -1082,8 +1096,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}

test("SPARK-6145: ORDER BY test for nested fields") {
jsonRDD(sparkContext.makeRDD(
"""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder")
jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil))
.registerTempTable("nestedOrder")

checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1))
checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ScalaReflectionRelationSuite extends FunSuite {

test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
val rdd = sparkContext.parallelize(data :: Nil)
rdd.toDF().registerTempTable("reflectData")

Expand All @@ -103,7 +103,8 @@ class ScalaReflectionRelationSuite extends FunSuite {
val rdd = sparkContext.parallelize(data :: Nil)
rdd.toDF().registerTempTable("reflectOptionalData")

assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
assert(sql("SELECT * FROM reflectOptionalData").collect().head ===
Row.fromSeq(Seq.fill(7)(null)))
}

// Equality is broken for Arrays, so we test that separately.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
}
}

override def userClass = classOf[MyDenseVector]
override def userClass: Class[MyDenseVector] = classOf[MyDenseVector]

private[spark] override def asNullable: MyDenseVectorUDT = this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types.{Decimal, DataType, NativeType}

object ColumnarTestUtils {
def makeNullRow(length: Int) = {
def makeNullRow(length: Int): GenericMutableRow = {
val row = new GenericMutableRow(length)
(0 until length).foreach(row.setNullAt)
row
Expand Down Expand Up @@ -93,7 +93,7 @@ object ColumnarTestUtils {

def makeUniqueValuesAndSingleValueRows[T <: NativeType](
columnType: NativeColumnType[T],
count: Int) = {
count: Int): (Seq[T#JvmType], Seq[GenericMutableRow]) = {

val values = makeUniqueRandomValues(columnType, count)
val rows = values.map { value =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class TestNullableColumnAccessor[T <: DataType, JvmType](
with NullableColumnAccessor

object TestNullableColumnAccessor {
def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = {
def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType])
: TestNullableColumnAccessor[T, JvmType] = {
// Skips the column type ID
buffer.getInt()
new TestNullableColumnAccessor(buffer, columnType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T
with NullableColumnBuilder

object TestNullableColumnBuilder {
def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = {
def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0)
: TestNullableColumnBuilder[T, JvmType] = {
val builder = new TestNullableColumnBuilder(columnType)
builder.initialize(initialSize)
builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object TestCompressibleColumnBuilder {
def apply[T <: NativeType](
columnStats: ColumnStats,
columnType: NativeColumnType[T],
scheme: CompressionScheme) = {
scheme: CompressionScheme): TestCompressibleColumnBuilder[T] = {

val builder = new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme))
builder.initialize(0, "", useCompression = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ class DebuggingSuite extends FunSuite {
test("DataFrame.typeCheck()") {
testData.typeCheck()
}
}
}
98 changes: 50 additions & 48 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {

conn = DriverManager.getConnection(url, properties)
conn.prepareStatement("create schema test").executeUpdate()
conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn.prepareStatement(
"create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate()
conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate()
conn.prepareStatement("insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate()
conn.prepareStatement(
"insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate()
conn.commit()

sql(
Expand Down Expand Up @@ -132,83 +134,83 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
}

test("SELECT *") {
assert(sql("SELECT * FROM foobar").collect().size == 3)
assert(sql("SELECT * FROM foobar").collect().size === 3)
}

test("SELECT * WHERE (simple predicates)") {
assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size == 0)
assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size == 2)
assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size == 1)
assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size == 1)
assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size == 2)
assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size == 2)
assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0)
assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2)
assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1)
assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1)
assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2)
assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2)
}

test("SELECT * WHERE (quoted strings)") {
assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size == 1)
assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1)
}

test("SELECT first field") {
val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _)
assert(names.size == 3)
assert(names.size === 3)
assert(names(0).equals("fred"))
assert(names(1).equals("joe 'foo' \"bar\""))
assert(names(2).equals("mary"))
}

test("SELECT second field") {
val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _)
assert(ids.size == 3)
assert(ids(0) == 1)
assert(ids(1) == 2)
assert(ids(2) == 3)
assert(ids.size === 3)
assert(ids(0) === 1)
assert(ids(1) === 2)
assert(ids(2) === 3)
}

test("SELECT * partitioned") {
assert(sql("SELECT * FROM parts").collect().size == 3)
}

test("SELECT WHERE (simple predicates) partitioned") {
assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size == 0)
assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size == 2)
assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size == 1)
assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0)
assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2)
assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1)
}

test("SELECT second field partitioned") {
val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _)
assert(ids.size == 3)
assert(ids(0) == 1)
assert(ids(1) == 2)
assert(ids(2) == 3)
assert(ids.size === 3)
assert(ids(0) === 1)
assert(ids(1) === 2)
assert(ids(2) === 3)
}

test("Basic API") {
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3)
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect().size === 3)
}

test("Partitioning via JDBCPartitioningInfo API") {
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3)
.collect.size == 3)
.collect.size === 3)
}

test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3)
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect().size === 3)
}

test("H2 integral types") {
val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect()
assert(rows.size == 1)
assert(rows(0).getInt(0) == 1)
assert(rows(0).getBoolean(1) == false)
assert(rows(0).getInt(2) == 3)
assert(rows(0).getInt(3) == 4)
assert(rows(0).getLong(4) == 1234567890123L)
assert(rows.size === 1)
assert(rows(0).getInt(0) === 1)
assert(rows(0).getBoolean(1) === false)
assert(rows(0).getInt(2) === 3)
assert(rows(0).getInt(3) === 4)
assert(rows(0).getLong(4) === 1234567890123L)
}

test("H2 null entries") {
val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect()
assert(rows.size == 1)
assert(rows.size === 1)
assert(rows(0).isNullAt(0))
assert(rows(0).isNullAt(1))
assert(rows(0).isNullAt(2))
Expand All @@ -230,27 +232,27 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
val rows = sql("SELECT * FROM timetypes").collect()
val cal = new GregorianCalendar(java.util.Locale.ROOT)
cal.setTime(rows(0).getAs[java.sql.Timestamp](0))
assert(cal.get(Calendar.HOUR_OF_DAY) == 12)
assert(cal.get(Calendar.MINUTE) == 34)
assert(cal.get(Calendar.SECOND) == 56)
assert(cal.get(Calendar.HOUR_OF_DAY) === 12)
assert(cal.get(Calendar.MINUTE) === 34)
assert(cal.get(Calendar.SECOND) === 56)
cal.setTime(rows(0).getAs[java.sql.Timestamp](1))
assert(cal.get(Calendar.YEAR) == 1996)
assert(cal.get(Calendar.MONTH) == 0)
assert(cal.get(Calendar.DAY_OF_MONTH) == 1)
assert(cal.get(Calendar.YEAR) === 1996)
assert(cal.get(Calendar.MONTH) === 0)
assert(cal.get(Calendar.DAY_OF_MONTH) === 1)
cal.setTime(rows(0).getAs[java.sql.Timestamp](2))
assert(cal.get(Calendar.YEAR) == 2002)
assert(cal.get(Calendar.MONTH) == 1)
assert(cal.get(Calendar.DAY_OF_MONTH) == 20)
assert(cal.get(Calendar.HOUR) == 11)
assert(cal.get(Calendar.MINUTE) == 22)
assert(cal.get(Calendar.SECOND) == 33)
assert(rows(0).getAs[java.sql.Timestamp](2).getNanos == 543543543)
assert(cal.get(Calendar.YEAR) === 2002)
assert(cal.get(Calendar.MONTH) === 1)
assert(cal.get(Calendar.DAY_OF_MONTH) === 20)
assert(cal.get(Calendar.HOUR) === 11)
assert(cal.get(Calendar.MINUTE) === 22)
assert(cal.get(Calendar.SECOND) === 33)
assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543)
}

test("H2 floating-point types") {
val rows = sql("SELECT * FROM flttypes").collect()
assert(rows(0).getDouble(0) == 1.00000000000000022) // Yes, I meant ==.
assert(rows(0).getDouble(1) == 1.00000011920928955) // Yes, I meant ==.
assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==.
assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==.
assert(rows(0).getAs[BigDecimal](2)
.equals(new BigDecimal("123456789012345.54321543215432100000")))
}
Expand All @@ -264,7 +266,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
| user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
val rows = sql("SELECT * FROM hack").collect()
assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==.
assert(rows(0).getDouble(0) === 1.00000011920928955) // Yes, I meant ==.
// For some reason, H2 computes this square incorrectly...
assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12)
}
Expand Down
Loading

0 comments on commit 7e0db5e

Please sign in to comment.