Skip to content

Commit

Permalink
[SPARK-13917] [SQL] generate broadcast semi join
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR brings codegen support for broadcast left-semi join.

## How was this patch tested?

Existing tests. Added benchmark, the result show 7X speedup.

Author: Davies Liu <davies@databricks.com>

Closes apache#11742 from davies/gen_semi.
  • Loading branch information
Davies Liu authored and roygao94 committed Mar 22, 2016
1 parent 6c3cae0 commit 2c48c2e
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(
LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
joins.LeftSemiJoinHash(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -92,6 +92,9 @@ case class BroadcastHashJoin(
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
}

case LeftSemi =>
hashSemiJoin(streamedIter, hashTable, numOutputRows)

case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
Expand All @@ -108,11 +111,13 @@ case class BroadcastHashJoin(
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
if (joinType == Inner) {
codegenInner(ctx, input)
} else {
// LeftOuter and RightOuter
codegenOuter(ctx, input)
joinType match {
case Inner => codegenInner(ctx, input)
case LeftOuter | RightOuter => codegenOuter(ctx, input)
case LeftSemi => codegenSemi(ctx, input)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
}
}

Expand Down Expand Up @@ -322,4 +327,68 @@ case class BroadcastHashJoin(
""".stripMargin
}
}

/**
* Generates the code for left semi join.
*/
private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
val numOutput = metricTerm(ctx, "numOutputRows")

val checkCondition = if (condition.isDefined) {
val expr = condition.get
// evaluate the variables from build side that used by condition
val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
// filter the output via condition
ctx.currentVars = input ++ buildVars
val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
s"""
|$eval
|${ev.code}
|if (${ev.isNull} || !${ev.value}) continue;
""".stripMargin
} else {
""
}

if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashedRelation
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
|if ($matched == null) continue;
|$checkCondition
|$numOutput.add(1);
|${consume(ctx, input)}
""".stripMargin
} else {
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
val size = ctx.freshName("size")
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
|$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
|if ($matches == null) continue;
|int $size = $matches.size();
|boolean $found = false;
|for (int $i = 0; $i < $size; $i++) {
| UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
| $checkCondition
| $found = true;
| break;
|}
|if (!$found) continue;
|$numOutput.add(1);
|${consume(ctx, input)}
""".stripMargin
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ trait HashJoin {
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
case LeftSemi =>
left.output
case x =>
throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
}
Expand Down Expand Up @@ -104,7 +104,7 @@ trait HashJoin {
keyExpr :: Nil
}

protected val canJoinKeyFitWithinLong: Boolean = {
protected lazy val canJoinKeyFitWithinLong: Boolean = {
val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
val key = rewriteKeyExpr(buildKeys)
sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
Expand Down Expand Up @@ -258,4 +258,21 @@ trait HashJoin {
}
ret.iterator
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation,
numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
val joinKeys = streamSideKeyGenerator
val joinedRow = new JoinedRow
streamIter.filter { current =>
val key = joinKeys(current)
lazy val rowBuffer = hashedRelation.get(key)
val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
(row: InternalRow) => boundCondition(joinedRow(current, row))
})
if (r) numOutputRows += 1
r
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand All @@ -33,7 +34,10 @@ case class LeftSemiJoinHash(
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
condition: Option[Expression]) extends BinaryNode with HashJoin {

override val joinType = LeftSemi
override val buildSide = BuildRight

override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
Expand All @@ -47,7 +51,7 @@ case class LeftSemiJoinHash(
val numOutputRows = longMetric("numOutputRows")

right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
val hashRelation = HashedRelation(buildIter.map(_.copy()), rightKeyGenerator)
val hashRelation = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
hashSemiJoin(streamIter, hashRelation, numOutputRows)
}
}
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
case j: BroadcastHashJoin => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: BroadcastLeftSemiJoinHash => j
case j: BroadcastHashJoin => j
case j: SortMergeJoin => j
}

Expand Down Expand Up @@ -427,7 +427,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[BroadcastLeftSemiJoinHash])
classOf[BroadcastHashJoin])
).foreach {
case (query, joinClass) => assertJoin(query, joinClass)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.util.Benchmark
class BenchmarkWholeStageCodegen extends SparkFunSuite {
lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.sql.autoBroadcastJoinThreshold", "0")
.set("spark.sql.autoBroadcastJoinThreshold", "1")
lazy val sc = SparkContext.getOrCreate(conf)
lazy val sqlContext = SQLContext.getOrCreate(sc)

Expand Down Expand Up @@ -200,6 +200,18 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X
outer join w long codegen=true 769 / 796 136.3 7.3 19.9X
*/

runBenchmark("semi join w long", N) {
sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count()
}

/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
semi join w long codegen=false 5804 / 5969 18.1 55.3 1.0X
semi join w long codegen=true 814 / 934 128.8 7.8 7.1X
*/
}

ignore("sort merge join") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
}

test("unsafe broadcast left semi join updates peak execution memory") {
testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast left semi join", "leftsemi")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
}
}

test(s"$testName using BroadcastLeftSemiJoinHash") {
test(s"$testName using BroadcastHashJoin") {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
BroadcastHashJoin(
leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton {
// Using `sparkPlan` because for relevant patterns in HashJoin to be
// matched, other strategies need to be applied.
var bhj = df.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
case j: BroadcastHashJoin => j
}
assert(bhj.size === 1,
s"actual query plans do not contain broadcast join: ${df.queryExecution}")
Expand All @@ -225,7 +225,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton {
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1")
df = sql(leftSemiJoinQuery)
bhj = df.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
case j: BroadcastHashJoin => j
}
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")

Expand Down

0 comments on commit 2c48c2e

Please sign in to comment.