Skip to content

Commit

Permalink
[SPARK-3299][SQL]Public API in SQLContext to list tables
Browse files Browse the repository at this point in the history
https://issues.apache.org/jira/browse/SPARK-3299

Author: Yin Huai <yhuai@databricks.com>

Closes apache#4547 from yhuai/tables and squashes the following commits:

6c8f92e [Yin Huai] Add tableNames.
acbb281 [Yin Huai] Update Python test.
7793dcb [Yin Huai] Fix scala test.
572870d [Yin Huai] Address comments.
aba2e88 [Yin Huai] Format.
12c86df [Yin Huai] Add tables() to SQLContext to return a DataFrame containing existing tables.
  • Loading branch information
yhuai authored and marmbrus committed Feb 13, 2015
1 parent c025a46 commit 1d0596a
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 0 deletions.
34 changes: 34 additions & 0 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,40 @@ def table(self, tableName):
"""
return DataFrame(self._ssql_ctx.table(tableName), self)

def tables(self, dbName=None):
"""Returns a DataFrame containing names of tables in the given database.
If `dbName` is not specified, the current database will be used.
The returned DataFrame has two columns, tableName and isTemporary
(a column with BooleanType indicating if a table is a temporary one or not).
>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> df2 = sqlCtx.tables()
>>> df2.filter("tableName = 'table1'").first()
Row(tableName=u'table1', isTemporary=True)
"""
if dbName is None:
return DataFrame(self._ssql_ctx.tables(), self)
else:
return DataFrame(self._ssql_ctx.tables(dbName), self)

def tableNames(self, dbName=None):
"""Returns a list of names of tables in the database `dbName`.
If `dbName` is not specified, the current database will be used.
>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> "table1" in sqlCtx.tableNames()
True
>>> "table1" in sqlCtx.tableNames("db")
True
"""
if dbName is None:
return [name for name in self._ssql_ctx.tableNames()]
else:
return [name for name in self._ssql_ctx.tableNames(dbName)]

def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
self._ssql_ctx.cacheTable(tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ trait Catalog {
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan

/**
* Returns tuples of (tableName, isTemporary) for all tables in the given database.
* isTemporary is a Boolean value indicates if a table is a temporary or not.
*/
def getTables(databaseName: Option[String]): Seq[(String, Boolean)]

def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit

def unregisterTable(tableIdentifier: Seq[String]): Unit
Expand Down Expand Up @@ -101,6 +107,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
// properly qualified with this alias.
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
}

override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
tables.map {
case (name, _) => (name, true)
}.toSeq
}
}

/**
Expand Down Expand Up @@ -137,6 +149,27 @@ trait OverrideCatalog extends Catalog {
withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
}

abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
val dbName = if (!caseSensitive) {
if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None
} else {
databaseName
}

val temporaryTables = overrides.filter {
// If a temporary table does not have an associated database, we should return its name.
case ((None, _), _) => true
// If a temporary table does have an associated database, we should return it if the database
// matches the given database name.
case ((db: Some[String], _), _) if db == dbName => true
case _ => false
}.map {
case ((_, tableName), _) => (tableName, true)
}.toSeq

temporaryTables ++ super.getTables(databaseName)
}

override def registerTable(
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
Expand Down Expand Up @@ -172,6 +205,10 @@ object EmptyCatalog extends Catalog {
throw new UnsupportedOperationException
}

override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
throw new UnsupportedOperationException
}

def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
Expand Down
36 changes: 36 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,42 @@ class SQLContext(@transient val sparkContext: SparkContext)
def table(tableName: String): DataFrame =
DataFrame(this, catalog.lookupRelation(Seq(tableName)))

/**
* Returns a [[DataFrame]] containing names of existing tables in the given database.
* The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
* indicating if a table is a temporary one or not).
*/
def tables(): DataFrame = {
createDataFrame(catalog.getTables(None)).toDataFrame("tableName", "isTemporary")
}

/**
* Returns a [[DataFrame]] containing names of existing tables in the current database.
* The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
* indicating if a table is a temporary one or not).
*/
def tables(databaseName: String): DataFrame = {
createDataFrame(catalog.getTables(Some(databaseName))).toDataFrame("tableName", "isTemporary")
}

/**
* Returns an array of names of tables in the current database.
*/
def tableNames(): Array[String] = {
catalog.getTables(None).map {
case (tableName, _) => tableName
}.toArray
}

/**
* Returns an array of names of tables in the given database.
*/
def tableNames(databaseName: String): Array[String] = {
catalog.getTables(Some(databaseName)).map {
case (tableName, _) => tableName
}.toArray
}

protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext

Expand Down
76 changes: 76 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}

class ListTablesSuite extends QueryTest with BeforeAndAfter {

import org.apache.spark.sql.test.TestSQLContext.implicits._

val df =
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")

before {
df.registerTempTable("ListTablesSuiteTable")
}

after {
catalog.unregisterTable(Seq("ListTablesSuiteTable"))
}

test("get all tables") {
checkAnswer(
tables().filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))

catalog.unregisterTable(Seq("ListTablesSuiteTable"))
assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}

test("getting all Tables with a database name has no impact on returned table names") {
checkAnswer(
tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))

catalog.unregisterTable(Seq("ListTablesSuiteTable"))
assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}

test("query the returned DataFrame of tables") {
val tableDF = tables()
val schema = StructType(
StructField("tableName", StringType, true) ::
StructField("isTemporary", BooleanType, false) :: Nil)
assert(schema === tableDF.schema)

tableDF.registerTempTable("tables")
checkAnswer(
sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
Row(true, "ListTablesSuiteTable")
)
checkAnswer(
tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
dropTempTable("tables")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
}
}

override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
val dbName = databaseName.getOrElse(hive.sessionState.getCurrentDatabase)
client.getAllTables(dbName).map(tableName => (tableName, false))
}

/**
* Create table with specified database, table name, table description and schema
* @param databaseName Database Name
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.Row

class ListTablesSuite extends QueryTest with BeforeAndAfterAll {

import org.apache.spark.sql.hive.test.TestHive.implicits._

val df =
sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")

override def beforeAll(): Unit = {
// The catalog in HiveContext is a case insensitive one.
catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan)
catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan)
sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
}

override def afterAll(): Unit = {
catalog.unregisterTable(Seq("ListTablesSuiteTable"))
catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"))
sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable")
sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable")
sql("DROP DATABASE IF EXISTS ListTablesSuiteDB")
}

test("get all tables of current database") {
val allTables = tables()
// We are using default DB.
checkAnswer(
allTables.filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0)
checkAnswer(
allTables.filter("tableName = 'hivelisttablessuitetable'"),
Row("hivelisttablessuitetable", false))
assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0)
}

test("getting all tables with a database name") {
val allTables = tables("ListTablesSuiteDB")
checkAnswer(
allTables.filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
allTables.filter("tableName = 'indblisttablessuitetable'"),
Row("indblisttablessuitetable", true))
assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0)
checkAnswer(
allTables.filter("tableName = 'hiveindblisttablessuitetable'"),
Row("hiveindblisttablessuitetable", false))
}
}

0 comments on commit 1d0596a

Please sign in to comment.