diff --git a/core/src/main/scala/com/pingcap/tispark/TiDBRelation.scala b/core/src/main/scala/com/pingcap/tispark/TiDBRelation.scala index 13e5783029..7d60f25253 100644 --- a/core/src/main/scala/com/pingcap/tispark/TiDBRelation.scala +++ b/core/src/main/scala/com/pingcap/tispark/TiDBRelation.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.tispark.{TiHandleRDD, TiRDD} import org.apache.spark.sql.types.StructType -class TiDBRelation(session: TiSession, tableRef: TiTableReference, meta: MetaManager)( +case class TiDBRelation(session: TiSession, tableRef: TiTableReference, meta: MetaManager)( @transient val sqlContext: SQLContext ) extends BaseRelation { val table: TiTableInfo = meta @@ -83,4 +83,11 @@ class TiDBRelation(session: TiSession, tableRef: TiTableReference, meta: MetaMan sqlContext.sparkSession ) } + + override def equals(obj: Any): Boolean = obj match { + case other: TiDBRelation => + this.table.equals(other.table) + case _ => + false + } } diff --git a/core/src/main/scala/com/pingcap/tispark/TiTableReference.scala b/core/src/main/scala/com/pingcap/tispark/TiTableReference.scala index 41229c2a6e..220df2145d 100644 --- a/core/src/main/scala/com/pingcap/tispark/TiTableReference.scala +++ b/core/src/main/scala/com/pingcap/tispark/TiTableReference.scala @@ -15,7 +15,4 @@ package com.pingcap.tispark -class TiTableReference(val databaseName: String, - val tableName: String, - val sizeInBytes: Long = Long.MaxValue) - extends Serializable +case class TiTableReference(databaseName: String, tableName: String, sizeInBytes: Long = Long.MaxValue) diff --git a/core/src/main/scala/org/apache/spark/sql/TiContext.scala b/core/src/main/scala/org/apache/spark/sql/TiContext.scala index 19a87a0f96..9af09038f1 100644 --- a/core/src/main/scala/org/apache/spark/sql/TiContext.scala +++ b/core/src/main/scala/org/apache/spark/sql/TiContext.scala @@ -146,9 +146,9 @@ class TiContext(val sparkSession: SparkSession) extends Serializable with Loggin @Deprecated def getDataFrame(dbName: String, tableName: String): DataFrame = { - val tiRelation = new TiDBRelation( + val tiRelation = TiDBRelation( tiSession, - new TiTableReference(dbName, tableName), + TiTableReference(dbName, tableName), meta )(sqlContext) sqlContext.baseRelationToDataFrame(tiRelation) @@ -191,9 +191,9 @@ class TiContext(val sparkSession: SparkSession) extends Serializable with Loggin sizeInBytes = StatisticsManager.estimateTableSize(table) if (!sqlContext.sparkSession.catalog.tableExists("`" + tableName + "`")) { - val rel: TiDBRelation = new TiDBRelation( + val rel: TiDBRelation = TiDBRelation( tiSession, - new TiTableReference(dbName, tableName, sizeInBytes), + TiTableReference(dbName, tableName, sizeInBytes), meta )(sqlContext) diff --git a/core/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala b/core/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala index 8df27c7318..a827b74573 100644 --- a/core/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala +++ b/core/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala @@ -6,7 +6,7 @@ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand} +import org.apache.spark.sql.execution.command.{CacheTableCommand, CreateViewCommand, ExplainCommand, UncacheTableCommand} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{SparkSession, TiContext} @@ -17,6 +17,14 @@ case class TiParser(getOrCreateTiContext: SparkSession => TiContext)(sparkSessio private lazy val internal = new SparkSqlParser(sparkSession.sqlContext.conf) + private def qualifyTableIdentifierInternal(tableIdentifier: TableIdentifier): TableIdentifier = + TableIdentifier( + tableIdentifier.table, + Some(tableIdentifier.database.getOrElse(tiContext.tiCatalog.getCurrentDatabase)) + ) + + private def notTempView(tableIdentifier: TableIdentifier) = tableIdentifier.database.isEmpty && tiContext.sessionCatalog.getTempView(tableIdentifier.table).isEmpty + /** * WAR to lead Spark to consider this relation being on local files. * Otherwise Spark will lookup this relation in his session catalog. @@ -24,15 +32,8 @@ case class TiParser(getOrCreateTiContext: SparkSession => TiContext)(sparkSessio */ private val qualifyTableIdentifier: PartialFunction[LogicalPlan, LogicalPlan] = { case r @ UnresolvedRelation(tableIdentifier) - if tableIdentifier.database.isEmpty && tiContext.sessionCatalog - .getTempView(tableIdentifier.table) - .isEmpty => - r.copy( - TableIdentifier( - tableIdentifier.table, - Some(tableIdentifier.database.getOrElse(tiContext.tiCatalog.getCurrentDatabase)) - ) - ) + if tableIdentifier.database.isEmpty && notTempView(tableIdentifier) => + r.copy(qualifyTableIdentifierInternal(tableIdentifier)) case f @ Filter(condition, _) => f.copy( condition = condition.transform { @@ -59,6 +60,16 @@ case class TiParser(getOrCreateTiContext: SparkSession => TiContext)(sparkSessio cv.copy(child = child.transform(qualifyTableIdentifier)) case e @ ExplainCommand(logicalPlan, _, _, _) => e.copy(logicalPlan = logicalPlan.transform(qualifyTableIdentifier)) + case c @ CacheTableCommand(tableIdentifier, plan, _) + if plan.isEmpty && notTempView(tableIdentifier) => + // Caching an unqualified catalog table. + c.copy(qualifyTableIdentifierInternal(tableIdentifier)) + case c @ CacheTableCommand(_, plan, _) if plan.isDefined => + c.copy(plan = Some(plan.get.transform(qualifyTableIdentifier))) + case u @ UncacheTableCommand(tableIdentifier, _) + if notTempView(tableIdentifier) => + // Uncaching an unqualified catalog table. + u.copy(qualifyTableIdentifierInternal(tableIdentifier)) } override def parsePlan(sqlText: String): LogicalPlan = diff --git a/core/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala b/core/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala index 39729a0a39..8d892db6b0 100644 --- a/core/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala +++ b/core/src/main/scala/org/apache/spark/sql/extensions/TiResolutionRule.scala @@ -33,9 +33,9 @@ case class TiResolutionRule(getOrCreateTiContext: SparkSession => TiContext)( StatisticsManager.loadStatisticsInfo(table.get) } val sizeInBytes = StatisticsManager.estimateTableSize(table.get) - new TiDBRelation( + TiDBRelation( tiSession, - new TiTableReference(dbName, tableName, sizeInBytes), + TiTableReference(dbName, tableName, sizeInBytes), meta )(sqlContext) }