From 706f814fe39756a32768c91043c920a8f2c4acd7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 14 Mar 2016 23:58:57 -0700 Subject: [PATCH] [SPARK-13890][SQL] Remove some internal classes' dependency on SQLContext ## What changes were proposed in this pull request? In general it is better for internal classes to not depend on the external class (in this case SQLContext) to reduce coupling between user-facing APIs and the internal implementations. This patch removes SQLContext dependency from some internal classes such as SparkPlanner, SparkOptimizer. As part of this patch, I also removed the following internal methods from SQLContext: ``` protected[sql] def functionRegistry: FunctionRegistry protected[sql] def optimizer: Optimizer protected[sql] def sqlParser: ParserInterface protected[sql] def planner: SparkPlanner protected[sql] def continuousQueryManager protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] ``` ## How was this patch tested? Existing unit/integration tests. Author: Reynold Xin Closes #11712 from rxin/sqlContext-planner. --- .../apache/spark/sql/DataFrameReader.scala | 3 ++- .../apache/spark/sql/DataFrameWriter.scala | 6 ++--- .../scala/org/apache/spark/sql/Dataset.scala | 6 ++--- .../spark/sql/ExperimentalMethods.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 23 +++++----------- .../spark/sql/execution/QueryExecution.scala | 6 ++--- .../spark/sql/execution/SparkOptimizer.scala | 11 ++++---- .../spark/sql/execution/SparkPlanner.scala | 12 ++++++--- .../spark/sql/execution/SparkStrategies.scala | 4 +-- .../sql/execution/WholeStageCodegen.scala | 6 ++--- .../sql/execution/command/commands.scala | 7 ++--- .../exchange/EnsureRequirements.scala | 12 ++++----- .../sql/execution/exchange/Exchange.scala | 6 ++--- .../apache/spark/sql/execution/subquery.scala | 8 +++--- .../org/apache/spark/sql/functions.scala | 2 +- .../spark/sql/internal/SessionState.scala | 16 +++++++----- .../org/apache/spark/sql/JoinSuite.scala | 4 +-- .../apache/spark/sql/SQLContextSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 3 ++- .../spark/sql/execution/PlannerSuite.scala | 26 +++++++++---------- .../spark/sql/execution/SparkPlanTest.scala | 2 +- .../execution/joins/BroadcastJoinSuite.scala | 3 ++- .../sql/execution/joins/InnerJoinSuite.scala | 4 +-- .../sql/execution/joins/OuterJoinSuite.scala | 2 +- .../sql/execution/joins/SemiJoinSuite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 6 ++--- .../spark/sql/hive/HiveSessionState.scala | 4 +-- .../sql/sources/BucketedWriteSuite.scala | 2 +- 28 files changed, 95 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 52b567ea250b1..76b8d71ac9359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -394,7 +394,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ def table(tableName: String): DataFrame = { Dataset.newDataFrame(sqlContext, - sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName))) + sqlContext.catalog.lookupRelation( + sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3349b8421b3e8..de87f4d7c24ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -242,7 +242,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { options = extraOptions.toMap, partitionColumns = normalizedParCols.getOrElse(Nil)) - df.sqlContext.continuousQueryManager.startQuery( + df.sqlContext.sessionState.continuousQueryManager.startQuery( extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink()) } @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) + insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -354,7 +354,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) + saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b5079cf2763ff..ef239a1e2f324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -818,7 +818,7 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(sqlContext.sqlParser.parseExpression(expr)) + Column(sqlContext.sessionState.sqlParser.parseExpression(expr)) }: _*) } @@ -919,7 +919,7 @@ class Dataset[T] private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) } /** @@ -943,7 +943,7 @@ class Dataset[T] private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index deed45d273c33..d7cd84fd246c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * @since 1.3.0 */ @Experimental -class ExperimentalMethods protected[sql](sqlContext: SQLContext) { +class ExperimentalMethods private[sql]() { /** * Allows extra strategies to be injected into the query planner at runtime. Note this API diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 36fe57f78be1d..0f5d1c8cab519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -121,14 +121,7 @@ class SQLContext private[sql]( protected[sql] lazy val sessionState: SessionState = new SessionState(self) protected[sql] def conf: SQLConf = sessionState.conf protected[sql] def catalog: Catalog = sessionState.catalog - protected[sql] def functionRegistry: FunctionRegistry = sessionState.functionRegistry protected[sql] def analyzer: Analyzer = sessionState.analyzer - protected[sql] def optimizer: Optimizer = sessionState.optimizer - protected[sql] def sqlParser: ParserInterface = sessionState.sqlParser - protected[sql] def planner: SparkPlanner = sessionState.planner - protected[sql] def continuousQueryManager = sessionState.continuousQueryManager - protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] = - sessionState.prepareForExecution /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -197,7 +190,7 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) @@ -244,7 +237,7 @@ class SQLContext private[sql]( */ @Experimental @transient - val experimental: ExperimentalMethods = new ExperimentalMethods(this) + def experimental: ExperimentalMethods = sessionState.experimentalMethods /** * :: Experimental :: @@ -641,7 +634,7 @@ class SQLContext private[sql]( tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -687,7 +680,7 @@ class SQLContext private[sql]( source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -706,7 +699,7 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(sqlParser.parseTableIdentifier(tableName), df.logicalPlan) + catalog.registerTable(sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan) } /** @@ -800,7 +793,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(sqlParser.parseTableIdentifier(tableName)) + table(sessionState.sqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { @@ -837,9 +830,7 @@ class SQLContext private[sql]( * * @since 2.0.0 */ - def streams: ContinuousQueryManager = { - continuousQueryManager - } + def streams: ContinuousQueryManager = sessionState.continuousQueryManager /** * Returns the names of tables in the current database as an array. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9e60c1cd6141c..5b4254f741ab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -45,16 +45,16 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { sqlContext.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) + lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData) lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next() + sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index edaf3b36aa52e..cbde777d98415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.ExperimentalMethods +import org.apache.spark.sql.catalyst.optimizer.Optimizer -class SparkOptimizer(val sqlContext: SQLContext) - extends Optimizer { - override def batches: Seq[Batch] = super.batches :+ Batch( - "User Provided Optimizers", FixedPoint(100), sqlContext.experimental.extraOptimizations: _*) +class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer { + override def batches: Seq[Batch] = super.batches :+ Batch( + "User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 292d366e727d3..9da2c74c62fc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -21,14 +21,18 @@ import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} +import org.apache.spark.sql.internal.SQLConf -class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { - val sparkContext: SparkContext = sqlContext.sparkContext +class SparkPlanner( + val sparkContext: SparkContext, + val conf: SQLConf, + val experimentalMethods: ExperimentalMethods) + extends SparkStrategies { - def numPartitions: Int = sqlContext.conf.numShufflePartitions + def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - sqlContext.experimental.extraStrategies ++ ( + experimentalMethods.extraStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6352c48c76ea5..113cf9ae2f222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -80,8 +80,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object CanBroadcast { def unapply(plan: LogicalPlan): Option[LogicalPlan] = { - if (sqlContext.conf.autoBroadcastJoinThreshold > 0 && - plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + if (conf.autoBroadcastJoinThreshold > 0 && + plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { Some(plan) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 8fb4705581a38..81676d3ebb346 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.internal.SQLConf /** * An interface for those physical operators that support codegen. @@ -427,7 +427,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup /** * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. */ -private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true @@ -472,7 +472,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru } def apply(plan: SparkPlan): SparkPlan = { - if (sqlContext.conf.wholeStageEnabled) { + if (conf.wholeStageEnabled) { insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 6e36a15a6d033..e711797c1b51a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -358,13 +358,14 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru case Some(p) => try { val regex = java.util.regex.Pattern.compile(p) - sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction() + .filter(regex.matcher(_).matches()).map(Row(_)) } catch { // probably will failed in the regex that user provided, then returns empty row. case _: Throwable => Seq.empty[Row] } case None => - sqlContext.functionRegistry.listFunction().map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction().map(Row(_)) } } @@ -395,7 +396,7 @@ case class DescribeFunction( } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.functionRegistry.lookupFunction(functionName) match { + sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { case Some(info) => val result = Row(s"Function: ${info.getName}") :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 709a4246365dd..4864db7f2ac9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.exchange -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.internal.SQLConf /** * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] @@ -30,15 +30,15 @@ import org.apache.spark.sql.execution._ * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the * input partition ordering requirements are met. */ -private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions +case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { + private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions - private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize + private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize - private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled + private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled private def minNumPostShufflePartitions: Option[Int] = { - val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + val minNumPostShufflePartitions = conf.minNumPostShufflePartitions if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 12513e9106707..9eaadea1b11ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -22,11 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType /** @@ -64,10 +64,10 @@ case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) * Find out duplicated exchanges in the spark plan, then use the same exchange for all the * references. */ -private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - if (!sqlContext.conf.exchangeReuseEnabled) { + if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index e6d7480b0422c..0d580703f5547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.DataType /** @@ -62,12 +62,12 @@ case class ScalarSubquery( /** * Convert the subquery from logical plan into executed plan. */ -case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => - val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next() - val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan) + val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next() + val executedPlan = sessionState.prepareForExecution.execute(sparkPlan) ScalarSubquery(executedPlan, subquery.exprId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 326c1e5a7cc03..dd4aa9e93ae4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1161,7 +1161,7 @@ object functions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sqlParser).getOrElse(new CatalystQl()) + val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl()) Column(parser.parseExpression(expr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 98ada4d58af7e..e6be0ab3bc420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import org.apache.spark.sql.{ContinuousQueryManager, SQLContext, UDFRegistration} +import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration} import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog} import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface @@ -40,6 +40,8 @@ private[sql] class SessionState(ctx: SQLContext) { */ lazy val conf = new SQLConf + lazy val experimentalMethods = new ExperimentalMethods + /** * Internal catalog for managing table and database states. */ @@ -73,7 +75,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Logical query plan optimizer. */ - lazy val optimizer: Optimizer = new SparkOptimizer(ctx) + lazy val optimizer: Optimizer = new SparkOptimizer(experimentalMethods) /** * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. @@ -83,7 +85,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Planner that converts optimized logical plans to physical plans. */ - lazy val planner: SparkPlanner = new SparkPlanner(ctx) + lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) /** * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal @@ -91,10 +93,10 @@ private[sql] class SessionState(ctx: SQLContext) { */ lazy val prepareForExecution = new RuleExecutor[SparkPlan] { override val batches: Seq[Batch] = Seq( - Batch("Subquery", Once, PlanSubqueries(ctx)), - Batch("Add exchange", Once, EnsureRequirements(ctx)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)), - Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx)) + Batch("Subquery", Once, PlanSubqueries(SessionState.this)), + Batch("Add exchange", Once, EnsureRequirements(conf)), + Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)), + Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2bd29ef19b649..50647c28402eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -139,7 +139,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index ec19d97d8cec2..2ad92b52c4ff0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -76,6 +76,6 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ test("Catalyst optimization passes are modifiable at runtime") { val sqlContext = SQLContext.getOrCreate(sc) sqlContext.experimental.extraOptimizations = Seq(DummyRule) - assert(sqlContext.optimizer.batches.flatMap(_.rules).contains(DummyRule)) + assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 98d0008489f4d..836fb1ce853c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -54,7 +54,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { val regex = java.util.regex.Pattern.compile(pattern) - sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction() + .filter(regex.matcher(_).matches()).map(Row(_)) } checkAnswer(sql("SHOW functions"), getFunctions(".*")) Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index ab0a7ff628962..88fbcda296cac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -37,7 +37,7 @@ class PlannerSuite extends SharedSQLContext { setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val planner = sqlContext.planner + val planner = sqlContext.sessionState.planner import planner._ val plannedOption = Aggregation(query).headOption val planned = @@ -294,7 +294,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -314,7 +314,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) } @@ -332,7 +332,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -352,7 +352,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") @@ -375,7 +375,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(outputOrdering, outputOrdering) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") @@ -391,7 +391,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -407,7 +407,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") @@ -424,7 +424,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA, orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -443,7 +443,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") @@ -463,7 +463,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") @@ -491,7 +491,7 @@ class PlannerSuite extends SharedSQLContext { shuffle, shuffle) - val outputPlan = ReuseExchange(sqlContext).apply(inputPlan) + val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } @@ -507,7 +507,7 @@ class PlannerSuite extends SharedSQLContext { ShuffleExchange(finalPartitioning, inputPlan), ShuffleExchange(finalPartitioning, inputPlan)) - val outputPlan2 = ReuseExchange(sqlContext).apply(inputPlan2) + val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index aa928cfc8096f..ed0d3f56e5ca9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -233,7 +233,7 @@ object SparkPlanTest { private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a256ee95a153c..6d5b777733f41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -63,7 +63,8 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan) + val plan = + EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 7eb15249ebbd6..eeb44404e9e47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -98,7 +98,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext).apply(broadcastJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) } def makeSortMergeJoin( @@ -109,7 +109,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext).apply(sortMergeJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 0d1c29fe574a6..45254864309eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -98,7 +98,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( + EnsureRequirements(sqlContext.sessionState.conf).apply( SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index bc341db5571be..d8c9564f1e4fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -76,7 +76,7 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext).apply( + EnsureRequirements(left.sqlContext.sessionState.conf).apply( LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8244dd4230102..a78b7b0cc4961 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -348,12 +348,12 @@ class HiveContext private[hive]( * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -367,7 +367,7 @@ class HiveContext private[hive]( * @since 1.2.0 */ def analyze(tableName: String) { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val relation = EliminateSubqueryAliases(catalog.lookupRelation(tableIdent)) relation match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index cbb6333336383..d9cd96d66f493 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -74,11 +74,11 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) * Planner that takes into account Hive-specific strategies. */ override lazy val planner: SparkPlanner = { - new SparkPlanner(ctx) with HiveStrategies { + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with HiveStrategies { override val hiveContext = ctx override def strategies: Seq[Strategy] = { - ctx.experimental.extraStrategies ++ Seq( + experimentalMethods.extraStrategies ++ Seq( FileSourceStrategy, DataSourceStrategy, HiveCommandStrategy(ctx), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index d77c88fa4b384..33c1bb059e2fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -69,7 +69,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") def tableDir: File = { - val identifier = hiveContext.sqlParser.parseTableIdentifier("bucketed_table") + val identifier = hiveContext.sessionState.sqlParser.parseTableIdentifier("bucketed_table") new File(URI.create(hiveContext.catalog.hiveDefaultTableFilePath(identifier))) }