Skip to content

Commit

Permalink
Add support for cache table commands (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
birdstorm authored and ilovesoup committed Nov 4, 2018
1 parent 7511268 commit 5e5fade
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 21 deletions.
9 changes: 8 additions & 1 deletion core/src/main/scala/com/pingcap/tispark/TiDBRelation.scala
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.tispark.{TiHandleRDD, TiRDD}
import org.apache.spark.sql.types.StructType

class TiDBRelation(session: TiSession, tableRef: TiTableReference, meta: MetaManager)(
case class TiDBRelation(session: TiSession, tableRef: TiTableReference, meta: MetaManager)(
@transient val sqlContext: SQLContext
) extends BaseRelation {
val table: TiTableInfo = meta
Expand Down Expand Up @@ -83,4 +83,11 @@ class TiDBRelation(session: TiSession, tableRef: TiTableReference, meta: MetaMan
sqlContext.sparkSession
)
}

override def equals(obj: Any): Boolean = obj match {
case other: TiDBRelation =>
this.table.equals(other.table)
case _ =>
false
}
}
Expand Up @@ -15,7 +15,4 @@

package com.pingcap.tispark

class TiTableReference(val databaseName: String,
val tableName: String,
val sizeInBytes: Long = Long.MaxValue)
extends Serializable
case class TiTableReference(databaseName: String, tableName: String, sizeInBytes: Long = Long.MaxValue)
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/sql/TiContext.scala
Expand Up @@ -146,9 +146,9 @@ class TiContext(val sparkSession: SparkSession) extends Serializable with Loggin

@Deprecated
def getDataFrame(dbName: String, tableName: String): DataFrame = {
val tiRelation = new TiDBRelation(
val tiRelation = TiDBRelation(
tiSession,
new TiTableReference(dbName, tableName),
TiTableReference(dbName, tableName),
meta
)(sqlContext)
sqlContext.baseRelationToDataFrame(tiRelation)
Expand Down Expand Up @@ -191,9 +191,9 @@ class TiContext(val sparkSession: SparkSession) extends Serializable with Loggin
sizeInBytes = StatisticsManager.estimateTableSize(table)

if (!sqlContext.sparkSession.catalog.tableExists("`" + tableName + "`")) {
val rel: TiDBRelation = new TiDBRelation(
val rel: TiDBRelation = TiDBRelation(
tiSession,
new TiTableReference(dbName, tableName, sizeInBytes),
TiTableReference(dbName, tableName, sizeInBytes),
meta
)(sqlContext)

Expand Down
31 changes: 21 additions & 10 deletions core/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala
Expand Up @@ -6,7 +6,7 @@ import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand}
import org.apache.spark.sql.execution.command.{CacheTableCommand, CreateViewCommand, ExplainCommand, UncacheTableCommand}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{SparkSession, TiContext}

Expand All @@ -17,22 +17,23 @@ case class TiParser(getOrCreateTiContext: SparkSession => TiContext)(sparkSessio

private lazy val internal = new SparkSqlParser(sparkSession.sqlContext.conf)

private def qualifyTableIdentifierInternal(tableIdentifier: TableIdentifier): TableIdentifier =
TableIdentifier(
tableIdentifier.table,
Some(tableIdentifier.database.getOrElse(tiContext.tiCatalog.getCurrentDatabase))
)

private def notTempView(tableIdentifier: TableIdentifier) = tableIdentifier.database.isEmpty && tiContext.sessionCatalog.getTempView(tableIdentifier.table).isEmpty

/**
* WAR to lead Spark to consider this relation being on local files.
* Otherwise Spark will lookup this relation in his session catalog.
* See [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations.resolveRelation]] for detail.
*/
private val qualifyTableIdentifier: PartialFunction[LogicalPlan, LogicalPlan] = {
case r @ UnresolvedRelation(tableIdentifier)
if tableIdentifier.database.isEmpty && tiContext.sessionCatalog
.getTempView(tableIdentifier.table)
.isEmpty =>
r.copy(
TableIdentifier(
tableIdentifier.table,
Some(tableIdentifier.database.getOrElse(tiContext.tiCatalog.getCurrentDatabase))
)
)
if tableIdentifier.database.isEmpty && notTempView(tableIdentifier) =>
r.copy(qualifyTableIdentifierInternal(tableIdentifier))
case f @ Filter(condition, _) =>
f.copy(
condition = condition.transform {
Expand All @@ -59,6 +60,16 @@ case class TiParser(getOrCreateTiContext: SparkSession => TiContext)(sparkSessio
cv.copy(child = child.transform(qualifyTableIdentifier))
case e @ ExplainCommand(logicalPlan, _, _, _) =>
e.copy(logicalPlan = logicalPlan.transform(qualifyTableIdentifier))
case c @ CacheTableCommand(tableIdentifier, plan, _)
if plan.isEmpty && notTempView(tableIdentifier) =>
// Caching an unqualified catalog table.
c.copy(qualifyTableIdentifierInternal(tableIdentifier))
case c @ CacheTableCommand(_, plan, _) if plan.isDefined =>
c.copy(plan = Some(plan.get.transform(qualifyTableIdentifier)))
case u @ UncacheTableCommand(tableIdentifier, _)
if notTempView(tableIdentifier) =>
// Uncaching an unqualified catalog table.
u.copy(qualifyTableIdentifierInternal(tableIdentifier))
}

override def parsePlan(sqlText: String): LogicalPlan =
Expand Down
Expand Up @@ -33,9 +33,9 @@ case class TiResolutionRule(getOrCreateTiContext: SparkSession => TiContext)(
StatisticsManager.loadStatisticsInfo(table.get)
}
val sizeInBytes = StatisticsManager.estimateTableSize(table.get)
new TiDBRelation(
TiDBRelation(
tiSession,
new TiTableReference(dbName, tableName, sizeInBytes),
TiTableReference(dbName, tableName, sizeInBytes),
meta
)(sqlContext)
}
Expand Down

0 comments on commit 5e5fade

Please sign in to comment.