diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 121b6d9e97d19..de4b4b799fabd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -29,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescribeCommand, _} -import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} +import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _} +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.internal.SQLConf @@ -69,8 +69,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - joins.LeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) case _ => Nil } } @@ -80,8 +80,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object CanBroadcast { def unapply(plan: LogicalPlan): Option[LogicalPlan] = { - if (conf.autoBroadcastJoinThreshold > 0 && - plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { + if (plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { Some(plan) } else { None @@ -101,10 +100,41 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side * of the join will be broadcasted and the other side will be streamed, with no shuffling * performed. If both sides of the join are eligible to be broadcasted then the + * - Shuffle hash join: if single partition is small enough to build a hash table. * - Sort merge: if the matching join keys are sortable. */ object EquiJoinSelection extends Strategy with PredicateHelper { + /** + * Matches a plan whose single partition should be small enough to build a hash table. + */ + def canBuildHashMap(plan: LogicalPlan): Boolean = { + plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + } + + /** + * Returns whether plan a is much smaller (3X) than plan b. + * + * The cost to build hash map is higher than sorting, we should only build hash map on a table + * that is much smaller than other one. Since we does not have the statistic for number of rows, + * use the size of bytes here as estimation. + */ + private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { + a.statistics.sizeInBytes * 3 <= b.statistics.sizeInBytes + } + + /** + * Returns whether we should use shuffle hash join or not. + * + * We should only use shuffle hash join when: + * 1) any single partition of a small table could fit in memory. + * 2) the smaller table is much smaller (3X) than the other one. + */ + private def shouldShuffleHashJoin(left: LogicalPlan, right: LogicalPlan): Boolean = { + canBuildHashMap(left) && muchSmaller(left, right) || + canBuildHashMap(right) && muchSmaller(right, left) + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- Inner joins -------------------------------------------------------------------------- @@ -117,6 +147,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(joins.BroadcastHashJoin( leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && shouldShuffleHashJoin(left, right) || + !RowOrdering.isOrderable(leftKeys) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + BuildRight + } else { + BuildLeft + } + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, Inner, buildSide, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoin( @@ -134,6 +176,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(joins.BroadcastHashJoin( leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && canBuildHashMap(right) && muchSmaller(right, left) || + !RowOrdering.isOrderable(leftKeys) => + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && canBuildHashMap(left) && muchSmaller(left, right) || + !RowOrdering.isOrderable(leftKeys) => + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoin( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala deleted file mode 100644 index fa549b4d51336..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.LeftSemi -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Build the right table's join keys into a HashedRelation, and iteratively go through the left - * table, to find if the join keys are in the HashedRelation. - */ -case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan, - condition: Option[Expression]) extends BinaryNode with HashJoin { - - override val joinType = LeftSemi - override val buildSide = BuildRight - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => - val hashRelation = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator) - hashSemiJoin(streamIter, hashRelation, numOutputRows) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala new file mode 100644 index 0000000000000..1e8879ae01530 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Performs an inner hash join of two child relations by first shuffling the data using the join + * keys. + */ +case class ShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override private[sql] lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + override def outputPartitioning: Partitioning = joinType match { + case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftSemi => left.outputPartitioning + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => + val hashed = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator) + val joinedRow = new JoinedRow + joinType match { + case Inner => + hashJoin(streamIter, hashed, numOutputRows) + + case LeftSemi => + hashSemiJoin(streamIter, hashed, numOutputRows) + + case LeftOuter => + val keyGenerator = streamSideKeyGenerator + val resultProj = createResultProjection + streamIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) + }) + + case RightOuter => + val keyGenerator = streamSideKeyGenerator + val resultProj = createResultProjection + streamIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) + }) + + case x => + throw new IllegalArgumentException( + s"ShuffledHashJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 807b39ace6266..60bd8ea39a7af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -665,11 +665,11 @@ private[joins] class SortMergeJoinScanner( * An iterator for outputting rows in left outer join. */ private class LeftOuterIterator( - smjScanner: SortMergeJoinScanner, - rightNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends OneSideOuterIterator( smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { @@ -681,13 +681,12 @@ private class LeftOuterIterator( * An iterator for outputting rows in right outer join. */ private class RightOuterIterator( - smjScanner: SortMergeJoinScanner, - leftNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) - extends OneSideOuterIterator( - smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) @@ -710,11 +709,11 @@ private class RightOuterIterator( * @param numOutputRows an accumulator metric for the number of rows output */ private abstract class OneSideOuterIterator( - smjScanner: SortMergeJoinScanner, - bufferedSideNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) extends RowIterator { + smjScanner: SortMergeJoinScanner, + bufferedSideNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends RowIterator { // A row to store the joined result, reused many times protected[this] val joinedRow: JoinedRow = new JoinedRow() @@ -777,14 +776,14 @@ private abstract class OneSideOuterIterator( } private class SortMergeFullOuterJoinScanner( - leftKeyGenerator: Projection, - rightKeyGenerator: Projection, - keyOrdering: Ordering[InternalRow], - leftIter: RowIterator, - rightIter: RowIterator, - boundCondition: InternalRow => Boolean, - leftNullRow: InternalRow, - rightNullRow: InternalRow) { + leftKeyGenerator: Projection, + rightKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + leftIter: RowIterator, + rightIter: RowIterator, + boundCondition: InternalRow => Boolean, + leftNullRow: InternalRow, + rightNullRow: InternalRow) { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var leftRow: InternalRow = _ private[this] var leftRowKey: InternalRow = _ @@ -950,10 +949,9 @@ private class SortMergeFullOuterJoinScanner( } private class FullOuterIterator( - smjScanner: SortMergeFullOuterJoinScanner, - resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric -) extends RowIterator { + smjScanner: SortMergeFullOuterJoinScanner, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric) extends RowIterator { private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() override def advanceNext(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9aabe2d0abe1c..c308161413e02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -236,6 +236,11 @@ object SQLConf { doc = "When true, enable partition pruning for in-memory columnar tables.", isPublic = false) + val PREFER_SORTMERGEJOIN = booleanConf("spark.sql.join.preferSortMergeJoin", + defaultValue = Some(true), + doc = "When true, prefer sort merge join over shuffle hash join", + isPublic = false) + val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold", defaultValue = Some(10 * 1024 * 1024), doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " + @@ -586,6 +591,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) + def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 03d6df8c28bef..dfffa4bc8b1c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -45,8 +45,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { - case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j + case j: ShuffledHashJoin => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: SortMergeJoin => j @@ -63,7 +63,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), @@ -434,7 +434,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } @@ -460,7 +460,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[LeftSemiJoinHash]), + classOf[ShuffledHashJoin]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index d293ff66fbc6b..a16bd77bfe4a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -247,7 +247,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } - ignore("rube") { + ignore("shuffle hash join") { + val N = 4 << 20 + sqlContext.setConf("spark.sql.shuffle.partitions", "2") + sqlContext.setConf("spark.sql.autoBroadcastJoinThreshold", "10000000") + sqlContext.setConf("spark.sql.join.preferSortMergeJoin", "false") + runBenchmark("shuffle hash join", N) { + val df1 = sqlContext.range(N).selectExpr(s"id as k1") + val df2 = sqlContext.range(N / 5).selectExpr(s"id * 3 as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + shuffle hash join codegen=false 1168 / 1902 3.6 278.6 1.0X + shuffle hash join codegen=true 850 / 1196 4.9 202.8 1.4X + */ + } + + ignore("cube") { val N = 5 << 20 runBenchmark("cube", N) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 9cd50abda6f00..e9b65539b0d62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -143,7 +143,7 @@ class PlannerSuite extends SharedSQLContext { val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(sortMergeJoins.isEmpty, "Should not use sort merge join") + assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join") sqlContext.clearCache() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 814e25d10e5cc..cf2681050e000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -101,6 +101,20 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) } + def makeShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val shuffledHashJoin = + joins.ShuffledHashJoin(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin) + } + def makeSortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -136,6 +150,30 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + test(s"$testName using SortMergeJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 1c8b2ea808b30..4cacb20aa0791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -76,6 +76,22 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { ExtractEquiJoinKeys.unapply(join) } + if (joinType != FullOuter) { + test(s"$testName using ShuffledHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext.sessionState.conf).apply( + ShuffledHashJoin( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + if (joinType != FullOuter) { test(s"$testName using BroadcastHashJoin") { val buildSide = joinType match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 5eb6a745239ab..985a96f684541 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -72,12 +72,13 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { ExtractEquiJoinKeys.unapply(join) } - test(s"$testName using LeftSemiJoinHash") { + test(s"$testName using ShuffledHashJoin") { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( - LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + ShuffledHashJoin( + leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 988852a4fc0b9..fa68c1a91dbcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -263,32 +263,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("LeftSemiJoinHash metrics") { + test("ShuffledHashJoin metrics") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") // Assume the execution plan is - // ... -> LeftSemiJoinHash(nodeId = 0) + // ... -> ShuffledHashJoin(nodeId = 0) val df = df1.join(df2, $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 1, Map( - 0L -> ("LeftSemiJoinHash", Map( + 0L -> ("ShuffledHashJoin", Map( "number of output rows" -> 2L))) ) } } - test("LeftSemiJoinBNL metrics") { - val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") - val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") - // Assume the execution plan is - // ... -> LeftSemiJoinBNL(nodeId = 0) - val df = df1.join(df2, $"key" < $"key2", "leftsemi") - testSparkPlanMetrics(df, 2, Map( - 0L -> ("LeftSemiJoinBNL", Map( - "number of output rows" -> 2L))) - ) - } - test("CartesianProduct metrics") { val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 1468be4670f26..151aacbdd1c44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -230,7 +230,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") val shj = df.queryExecution.sparkPlan.collect { - case j: LeftSemiJoinHash => j + case j: ShuffledHashJoin => j } assert(shj.size === 1, "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")