Skip to content

Commit

Permalink
[SPARK-13977] [SQL] Brings back Shuffled hash join
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

ShuffledHashJoin (also outer join) is removed in 1.6, in favor of SortMergeJoin, which is more robust and also fast.

ShuffledHashJoin is still useful in this case: 1) one table is much smaller than the other one, then cost to build a hash table on smaller table is smaller than sorting the larger table 2) any partition of the small table could fit in memory.

This PR brings back ShuffledHashJoin, basically revert apache#9645, and fix the conflict. Also merging outer join and left-semi join into the same class. This PR does not implement full outer join, because it's not implemented efficiently (requiring build hash table on both side).

A simple benchmark (one table is 5x smaller than other one) show that ShuffledHashJoin could be 2X faster than SortMergeJoin.

## How was this patch tested?

Added new unit tests for ShuffledHashJoin.

Author: Davies Liu <davies@databricks.com>

Closes apache#11788 from davies/shuffle_join.
  • Loading branch information
Davies Liu authored and roygao94 committed Mar 22, 2016
1 parent c255d7f commit 9d29f22
Show file tree
Hide file tree
Showing 13 changed files with 277 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -29,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescribeCommand, _}
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -69,8 +69,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
Seq(joins.ShuffledHashJoin(
leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
case _ => Nil
}
}
Expand All @@ -80,8 +80,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object CanBroadcast {
def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
if (conf.autoBroadcastJoinThreshold > 0 &&
plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
if (plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
Some(plan)
} else {
None
Expand All @@ -101,10 +100,41 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
* of the join will be broadcasted and the other side will be streamed, with no shuffling
* performed. If both sides of the join are eligible to be broadcasted then the
* - Shuffle hash join: if single partition is small enough to build a hash table.
* - Sort merge: if the matching join keys are sortable.
*/
object EquiJoinSelection extends Strategy with PredicateHelper {

/**
* Matches a plan whose single partition should be small enough to build a hash table.
*/
def canBuildHashMap(plan: LogicalPlan): Boolean = {
plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
}

/**
* Returns whether plan a is much smaller (3X) than plan b.
*
* The cost to build hash map is higher than sorting, we should only build hash map on a table
* that is much smaller than other one. Since we does not have the statistic for number of rows,
* use the size of bytes here as estimation.
*/
private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = {
a.statistics.sizeInBytes * 3 <= b.statistics.sizeInBytes
}

/**
* Returns whether we should use shuffle hash join or not.
*
* We should only use shuffle hash join when:
* 1) any single partition of a small table could fit in memory.
* 2) the smaller table is much smaller (3X) than the other one.
*/
private def shouldShuffleHashJoin(left: LogicalPlan, right: LogicalPlan): Boolean = {
canBuildHashMap(left) && muchSmaller(left, right) ||
canBuildHashMap(right) && muchSmaller(right, left)
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {

// --- Inner joins --------------------------------------------------------------------------
Expand All @@ -117,6 +147,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if !conf.preferSortMergeJoin && shouldShuffleHashJoin(left, right) ||
!RowOrdering.isOrderable(leftKeys) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
BuildRight
} else {
BuildLeft
}
Seq(joins.ShuffledHashJoin(
leftKeys, rightKeys, Inner, buildSide, condition, planLater(left), planLater(right)))

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeJoin(
Expand All @@ -134,6 +176,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))

case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
if !conf.preferSortMergeJoin && canBuildHashMap(right) && muchSmaller(right, left) ||
!RowOrdering.isOrderable(leftKeys) =>
Seq(joins.ShuffledHashJoin(
leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))

case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
if !conf.preferSortMergeJoin && canBuildHashMap(left) && muchSmaller(left, right) ||
!RowOrdering.isOrderable(leftKeys) =>
Seq(joins.ShuffledHashJoin(
leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeJoin(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
* Performs an inner hash join of two child relations by first shuffling the data using the join
* keys.
*/
case class ShuffledHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
extends BinaryNode with HashJoin {

override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))

override def outputPartitioning: Partitioning = joinType match {
case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftSemi => left.outputPartitioning
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case x =>
throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType")
}

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
val joinedRow = new JoinedRow
joinType match {
case Inner =>
hashJoin(streamIter, hashed, numOutputRows)

case LeftSemi =>
hashSemiJoin(streamIter, hashed, numOutputRows)

case LeftOuter =>
val keyGenerator = streamSideKeyGenerator
val resultProj = createResultProjection
streamIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows)
})

case RightOuter =>
val keyGenerator = streamSideKeyGenerator
val resultProj = createResultProjection
streamIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows)
})

case x =>
throw new IllegalArgumentException(
s"ShuffledHashJoin should not take $x as the JoinType")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -665,11 +665,11 @@ private[joins] class SortMergeJoinScanner(
* An iterator for outputting rows in left outer join.
*/
private class LeftOuterIterator(
smjScanner: SortMergeJoinScanner,
rightNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric)
smjScanner: SortMergeJoinScanner,
rightNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric)
extends OneSideOuterIterator(
smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {

Expand All @@ -681,13 +681,12 @@ private class LeftOuterIterator(
* An iterator for outputting rows in right outer join.
*/
private class RightOuterIterator(
smjScanner: SortMergeJoinScanner,
leftNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric)
extends OneSideOuterIterator(
smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
smjScanner: SortMergeJoinScanner,
leftNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric)
extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {

protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
Expand All @@ -710,11 +709,11 @@ private class RightOuterIterator(
* @param numOutputRows an accumulator metric for the number of rows output
*/
private abstract class OneSideOuterIterator(
smjScanner: SortMergeJoinScanner,
bufferedSideNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric) extends RowIterator {
smjScanner: SortMergeJoinScanner,
bufferedSideNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric) extends RowIterator {

// A row to store the joined result, reused many times
protected[this] val joinedRow: JoinedRow = new JoinedRow()
Expand Down Expand Up @@ -777,14 +776,14 @@ private abstract class OneSideOuterIterator(
}

private class SortMergeFullOuterJoinScanner(
leftKeyGenerator: Projection,
rightKeyGenerator: Projection,
keyOrdering: Ordering[InternalRow],
leftIter: RowIterator,
rightIter: RowIterator,
boundCondition: InternalRow => Boolean,
leftNullRow: InternalRow,
rightNullRow: InternalRow) {
leftKeyGenerator: Projection,
rightKeyGenerator: Projection,
keyOrdering: Ordering[InternalRow],
leftIter: RowIterator,
rightIter: RowIterator,
boundCondition: InternalRow => Boolean,
leftNullRow: InternalRow,
rightNullRow: InternalRow) {
private[this] val joinedRow: JoinedRow = new JoinedRow()
private[this] var leftRow: InternalRow = _
private[this] var leftRowKey: InternalRow = _
Expand Down Expand Up @@ -950,10 +949,9 @@ private class SortMergeFullOuterJoinScanner(
}

private class FullOuterIterator(
smjScanner: SortMergeFullOuterJoinScanner,
resultProj: InternalRow => InternalRow,
numRows: LongSQLMetric
) extends RowIterator {
smjScanner: SortMergeFullOuterJoinScanner,
resultProj: InternalRow => InternalRow,
numRows: LongSQLMetric) extends RowIterator {
private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()

override def advanceNext(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ object SQLConf {
doc = "When true, enable partition pruning for in-memory columnar tables.",
isPublic = false)

val PREFER_SORTMERGEJOIN = booleanConf("spark.sql.join.preferSortMergeJoin",
defaultValue = Some(true),
doc = "When true, prefer sort merge join over shuffle hash join",
isPublic = false)

val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold",
defaultValue = Some(10 * 1024 * 1024),
doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " +
Expand Down Expand Up @@ -586,6 +591,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin

def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)

def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)

def defaultSizeInBytes: Long =
getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L)

Expand Down
Loading

0 comments on commit 9d29f22

Please sign in to comment.