Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for cache table commands #467

Merged
merged 1 commit into from Nov 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion core/src/main/scala/com/pingcap/tispark/TiDBRelation.scala
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.tispark.{TiHandleRDD, TiRDD}
import org.apache.spark.sql.types.StructType

class TiDBRelation(session: TiSession, tableRef: TiTableReference, meta: MetaManager)(
case class TiDBRelation(session: TiSession, tableRef: TiTableReference, meta: MetaManager)(
@transient val sqlContext: SQLContext
) extends BaseRelation {
val table: TiTableInfo = meta
Expand Down Expand Up @@ -83,4 +83,11 @@ class TiDBRelation(session: TiSession, tableRef: TiTableReference, meta: MetaMan
sqlContext.sparkSession
)
}

override def equals(obj: Any): Boolean = obj match {
case other: TiDBRelation =>
this.table.equals(other.table)
case _ =>
false
}
}
Expand Up @@ -15,7 +15,4 @@

package com.pingcap.tispark

class TiTableReference(val databaseName: String,
val tableName: String,
val sizeInBytes: Long = Long.MaxValue)
extends Serializable
case class TiTableReference(databaseName: String, tableName: String, sizeInBytes: Long = Long.MaxValue)
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/sql/TiContext.scala
Expand Up @@ -146,9 +146,9 @@ class TiContext(val sparkSession: SparkSession) extends Serializable with Loggin

@Deprecated
def getDataFrame(dbName: String, tableName: String): DataFrame = {
val tiRelation = new TiDBRelation(
val tiRelation = TiDBRelation(
tiSession,
new TiTableReference(dbName, tableName),
TiTableReference(dbName, tableName),
meta
)(sqlContext)
sqlContext.baseRelationToDataFrame(tiRelation)
Expand Down Expand Up @@ -191,9 +191,9 @@ class TiContext(val sparkSession: SparkSession) extends Serializable with Loggin
sizeInBytes = StatisticsManager.estimateTableSize(table)

if (!sqlContext.sparkSession.catalog.tableExists("`" + tableName + "`")) {
val rel: TiDBRelation = new TiDBRelation(
val rel: TiDBRelation = TiDBRelation(
tiSession,
new TiTableReference(dbName, tableName, sizeInBytes),
TiTableReference(dbName, tableName, sizeInBytes),
meta
)(sqlContext)

Expand Down
31 changes: 21 additions & 10 deletions core/src/main/scala/org/apache/spark/sql/extensions/TiParser.scala
Expand Up @@ -6,7 +6,7 @@ import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand}
import org.apache.spark.sql.execution.command.{CacheTableCommand, CreateViewCommand, ExplainCommand, UncacheTableCommand}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{SparkSession, TiContext}

Expand All @@ -17,22 +17,23 @@ case class TiParser(getOrCreateTiContext: SparkSession => TiContext)(sparkSessio

private lazy val internal = new SparkSqlParser(sparkSession.sqlContext.conf)

private def qualifyTableIdentifierInternal(tableIdentifier: TableIdentifier): TableIdentifier =
TableIdentifier(
tableIdentifier.table,
Some(tableIdentifier.database.getOrElse(tiContext.tiCatalog.getCurrentDatabase))
)

private def notTempView(tableIdentifier: TableIdentifier) = tableIdentifier.database.isEmpty && tiContext.sessionCatalog.getTempView(tableIdentifier.table).isEmpty

/**
* WAR to lead Spark to consider this relation being on local files.
* Otherwise Spark will lookup this relation in his session catalog.
* See [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations.resolveRelation]] for detail.
*/
private val qualifyTableIdentifier: PartialFunction[LogicalPlan, LogicalPlan] = {
case r @ UnresolvedRelation(tableIdentifier)
if tableIdentifier.database.isEmpty && tiContext.sessionCatalog
.getTempView(tableIdentifier.table)
.isEmpty =>
r.copy(
TableIdentifier(
tableIdentifier.table,
Some(tableIdentifier.database.getOrElse(tiContext.tiCatalog.getCurrentDatabase))
)
)
if tableIdentifier.database.isEmpty && notTempView(tableIdentifier) =>
r.copy(qualifyTableIdentifierInternal(tableIdentifier))
case f @ Filter(condition, _) =>
f.copy(
condition = condition.transform {
Expand All @@ -59,6 +60,16 @@ case class TiParser(getOrCreateTiContext: SparkSession => TiContext)(sparkSessio
cv.copy(child = child.transform(qualifyTableIdentifier))
case e @ ExplainCommand(logicalPlan, _, _, _) =>
e.copy(logicalPlan = logicalPlan.transform(qualifyTableIdentifier))
case c @ CacheTableCommand(tableIdentifier, plan, _)
if plan.isEmpty && notTempView(tableIdentifier) =>
// Caching an unqualified catalog table.
c.copy(qualifyTableIdentifierInternal(tableIdentifier))
case c @ CacheTableCommand(_, plan, _) if plan.isDefined =>
c.copy(plan = Some(plan.get.transform(qualifyTableIdentifier)))
case u @ UncacheTableCommand(tableIdentifier, _)
if notTempView(tableIdentifier) =>
// Uncaching an unqualified catalog table.
u.copy(qualifyTableIdentifierInternal(tableIdentifier))
}

override def parsePlan(sqlText: String): LogicalPlan =
Expand Down
Expand Up @@ -33,9 +33,9 @@ case class TiResolutionRule(getOrCreateTiContext: SparkSession => TiContext)(
StatisticsManager.loadStatisticsInfo(table.get)
}
val sizeInBytes = StatisticsManager.estimateTableSize(table.get)
new TiDBRelation(
TiDBRelation(
tiSession,
new TiTableReference(dbName, tableName, sizeInBytes),
TiTableReference(dbName, tableName, sizeInBytes),
meta
)(sqlContext)
}
Expand Down