Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Feb 12, 2015
1 parent aba2e88 commit 572870d
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 24 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def table(self, tableName):
def tables(self, dbName=None):
"""Returns a DataFrame containing names of table in the given database.
If `dbName` is `None`, the database will be the current database.
If `dbName` is not specified, the current database will be used.
The returned DataFrame has two columns, tableName and isTemporary
(a column with BooleanType indicating if a table is a temporary one or not).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ trait Catalog {
alias: Option[String] = None): LogicalPlan

/**
* Returns names and flags indicating if a table is temporary or not of all tables in the
* database identified by `databaseIdentifier`.
* Returns tuples of (tableName, isTemporary) for all tables in the given database.
* isTemporary is a Boolean value indicates if a table is a temporary or not.
*/
def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)]
def getTables(databaseName: Option[String]): Seq[(String, Boolean)]

def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit

Expand Down Expand Up @@ -66,10 +66,6 @@ trait Catalog {
protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = {
(tableIdent.lift(tableIdent.size - 2), tableIdent.last)
}

protected def getDBName(databaseIdentifier: Seq[String]): Option[String] = {
databaseIdentifier.lift(databaseIdentifier.size - 1)
}
}

class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
Expand Down Expand Up @@ -112,7 +108,7 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
}

override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
tables.map {
case (name, _) => (name, true)
}.toSeq
Expand Down Expand Up @@ -153,20 +149,19 @@ trait OverrideCatalog extends Catalog {
withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
}

abstract override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
val dbName = getDBName(databaseIdentifier)
abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
val temporaryTables = overrides.filter {
// If a temporary table does not have an associated database, we should return its name.
case ((None, _), _) => true
// If a temporary table does have an associated database, we should return it if the database
// matches the given database name.
case ((db: Some[String], _), _) if db == dbName => true
case ((db: Some[String], _), _) if db == databaseName => true
case _ => false
}.map {
case ((_, tableName), _) => (tableName, true)
}.toSeq

temporaryTables ++ super.getTables(databaseIdentifier)
temporaryTables ++ super.getTables(databaseName)
}

override def registerTable(
Expand Down Expand Up @@ -204,7 +199,7 @@ object EmptyCatalog extends Catalog {
throw new UnsupportedOperationException
}

override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
throw new UnsupportedOperationException
}

Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* indicating if a table is a temporary one or not).
*/
def tables(databaseName: String): DataFrame = {
createDataFrame(catalog.getTables(Seq(databaseName))).toDataFrame("tableName", "isTemporary")
createDataFrame(catalog.getTables(Some(databaseName))).toDataFrame("tableName", "isTemporary")
}

/**
Expand All @@ -749,7 +749,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* indicating if a table is a temporary one or not).
*/
def tables(): DataFrame = {
createDataFrame(catalog.getTables(Seq.empty[String])).toDataFrame("tableName", "isTemporary")
createDataFrame(catalog.getTables(None)).toDataFrame("tableName", "isTemporary")
}

protected[sql] class SparkPlanner extends SparkStrategies {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
}

override def afterAll(): Unit = {
catalog.unregisterAllTables()
(1 to 10).foreach(i => catalog.unregisterTable(Seq(s"table$i")))
}

test("get All Tables") {
test("get all tables") {
checkAnswer(tables(), (1 to 10).map(i => Row(s"table$i", true)))
}

test("getting All Tables with a database name has not impact on returned table names") {
test("getting all Tables with a database name has no impact on returned table names") {
checkAnswer(tables("DB"), (1 to 10).map(i => Row(s"table$i", true)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
}
}

override def getTables(databaseIdentifier: Seq[String]): Seq[(String, Boolean)] = {
val dbName = getDBName(databaseIdentifier).getOrElse(hive.sessionState.getCurrentDatabase)
override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
val dbName = databaseName.getOrElse(hive.sessionState.getCurrentDatabase)
client.getAllTables(dbName).map(tableName => (tableName, false))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
}

override def afterAll(): Unit = {
catalog.unregisterAllTables()
(1 to 10).foreach(i => catalog.unregisterTable(Seq(s"Table$i")))
(1 to 10).foreach(i => catalog.unregisterTable(Seq("db1", s"db1TempTable$i")))
(1 to 10).foreach {
i => sql(s"DROP TABLE IF EXISTS hivetable$i")
}
Expand All @@ -57,15 +58,15 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
sql("DROP DATABASE IF EXISTS db1")
}

test("get All Tables of current database") {
test("get all tables of current database") {
// We are using default DB.
val expectedTables =
(1 to 10).map(i => Row(s"table$i", true)) ++
(1 to 10).map(i => Row(s"hivetable$i", false))
checkAnswer(tables(), expectedTables)
}

test("getting All Tables with a database name has not impact on returned table names") {
test("getting all tables with a database name") {
val expectedTables =
// We are expecting to see Table1 to Table10 since there is no database associated with them.
(1 to 10).map(i => Row(s"table$i", true)) ++
Expand Down

0 comments on commit 572870d

Please sign in to comment.