Skip to content

Commit

Permalink
Fix scala test.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Feb 12, 2015
1 parent 572870d commit 7793dcb
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,18 @@ trait OverrideCatalog extends Catalog {
}

abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
val dbName = if (!caseSensitive) {
if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None
} else {
databaseName
}

val temporaryTables = overrides.filter {
// If a temporary table does not have an associated database, we should return its name.
case ((None, _), _) => true
// If a temporary table does have an associated database, we should return it if the database
// matches the given database name.
case ((db: Some[String], _), _) if db == databaseName => true
case ((db: Some[String], _), _) if db == dbName => true
case _ => false
}.map {
case ((_, tableName), _) => (tableName, true)
Expand Down
35 changes: 20 additions & 15 deletions sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,43 @@

package org.apache.spark.sql

import org.scalatest.BeforeAndAfterAll
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}

class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
class ListTablesSuite extends QueryTest with BeforeAndAfter {

import org.apache.spark.sql.test.TestSQLContext.implicits._

val df =
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")

override def beforeAll(): Unit = {
(1 to 10).foreach(i => df.registerTempTable(s"table$i"))
before {
df.registerTempTable("ListTablesSuiteTable")
}

override def afterAll(): Unit = {
(1 to 10).foreach(i => catalog.unregisterTable(Seq(s"table$i")))
after {
catalog.unregisterTable(Seq("ListTablesSuiteTable"))
}

test("get all tables") {
checkAnswer(tables(), (1 to 10).map(i => Row(s"table$i", true)))
checkAnswer(
tables().filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))

catalog.unregisterTable(Seq("ListTablesSuiteTable"))
assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}

test("getting all Tables with a database name has no impact on returned table names") {
checkAnswer(tables("DB"), (1 to 10).map(i => Row(s"table$i", true)))
checkAnswer(
tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))

catalog.unregisterTable(Seq("ListTablesSuiteTable"))
assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}

test("query the returned DataFrame of tables") {
Expand All @@ -53,15 +63,10 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
StructField("isTemporary", BooleanType, false) :: Nil)
assert(schema === tableDF.schema)

checkAnswer(
tableDF.select("tableName"),
(1 to 10).map(i => Row(s"table$i"))
)

tableDF.registerTempTable("tables")
checkAnswer(
sql("SELECT isTemporary, tableName from tables WHERE isTemporary"),
(1 to 10).map(i => Row(true, s"table$i"))
sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
Row(true, "ListTablesSuiteTable")
)
checkAnswer(
tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,78 +23,55 @@ import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}

class ListTablesSuite extends QueryTest with BeforeAndAfterAll {

import org.apache.spark.sql.hive.test.TestHive.implicits._

val sqlContext = TestHive
val df =
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")

override def beforeAll(): Unit = {
// The catalog in HiveContext is a case insensitive one.
(1 to 10).foreach(i => catalog.registerTable(Seq(s"Table$i"), df.logicalPlan))
(1 to 10).foreach(i => catalog.registerTable(Seq("db1", s"db1TempTable$i"), df.logicalPlan))
(1 to 10).foreach {
i => sql(s"CREATE TABLE hivetable$i (key int, value string)")
}
sql("CREATE DATABASE IF NOT EXISTS db1")
(1 to 10).foreach {
i => sql(s"CREATE TABLE db1.db1hivetable$i (key int, value string)")
}
catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan)
catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan)
sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
}

override def afterAll(): Unit = {
(1 to 10).foreach(i => catalog.unregisterTable(Seq(s"Table$i")))
(1 to 10).foreach(i => catalog.unregisterTable(Seq("db1", s"db1TempTable$i")))
(1 to 10).foreach {
i => sql(s"DROP TABLE IF EXISTS hivetable$i")
}
(1 to 10).foreach {
i => sql(s"DROP TABLE IF EXISTS db1.db1hivetable$i")
}
sql("DROP DATABASE IF EXISTS db1")
catalog.unregisterTable(Seq("ListTablesSuiteTable"))
catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"))
sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable")
sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable")
sql("DROP DATABASE IF EXISTS ListTablesSuiteDB")
}

test("get all tables of current database") {
val allTables = tables()
// We are using default DB.
val expectedTables =
(1 to 10).map(i => Row(s"table$i", true)) ++
(1 to 10).map(i => Row(s"hivetable$i", false))
checkAnswer(tables(), expectedTables)
checkAnswer(
allTables.filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0)
checkAnswer(
allTables.filter("tableName = 'hivelisttablessuitetable'"),
Row("hivelisttablessuitetable", false))
assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0)
}

test("getting all tables with a database name") {
val expectedTables =
// We are expecting to see Table1 to Table10 since there is no database associated with them.
(1 to 10).map(i => Row(s"table$i", true)) ++
(1 to 10).map(i => Row(s"db1temptable$i", true)) ++
(1 to 10).map(i => Row(s"db1hivetable$i", false))
checkAnswer(tables("db1"), expectedTables)
}

test("query the returned DataFrame of tables") {
val tableDF = tables()
val schema = StructType(
StructField("tableName", StringType, true) ::
StructField("isTemporary", BooleanType, false) :: Nil)
assert(schema === tableDF.schema)

val allTables = tables("ListTablesSuiteDB")
checkAnswer(
tableDF.filter("NOT isTemporary").select("tableName"),
(1 to 10).map(i => Row(s"hivetable$i"))
)

tableDF.registerTempTable("tables")
allTables.filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
sql("SELECT isTemporary, tableName from tables WHERE isTemporary"),
(1 to 10).map(i => Row(true, s"table$i"))
)
allTables.filter("tableName = 'indblisttablessuitetable'"),
Row("indblisttablessuitetable", true))
assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0)
checkAnswer(
tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
dropTempTable("tables")
allTables.filter("tableName = 'hiveindblisttablessuitetable'"),
Row("hiveindblisttablessuitetable", false))
}
}

0 comments on commit 7793dcb

Please sign in to comment.