From 55baef712438bd53f9de72fd8bf6d977399610d8 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 31 Jul 2014 12:30:04 +0800 Subject: [PATCH] Add HashOuterJoin --- .../spark/sql/execution/SparkStrategies.scala | 4 + .../apache/spark/sql/execution/joins.scala | 194 ++++++++++++++- .../org/apache/spark/sql/JoinSuite.scala | 226 +++++++++++++++--- 3 files changed, 394 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5f1fe99f75c9d..4342644d98bfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -94,6 +94,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + execution.HashOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 2750ddbce896f..0c173a381e56b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -37,6 +37,135 @@ case object BuildLeft extends BuildSide @DeveloperApi case object BuildRight extends BuildSide +/** + * Constant Value for Binary Join Node + */ +object BinaryJoinNode { + val SINGLE_NULL_LIST = Seq[Row](null) + val EMPTY_NULL_LIST = Seq[Row]() +} + +// TODO If join key was null should be considered as equal? In Hive this is configurable. + +/** + * Output the tuples for the matched (with the same join key) join group, base on the join types, + * Both input iterators should be repeatable. + */ +trait BinaryRepeatableIteratorNode extends BinaryNode { + self: Product => + + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + + val joinedRow = new JoinedRow() + + val boundCondition = InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + def condition: Option[Expression] + def joinType: JoinType + + // TODO we need to rewrite all of the iterators with our own implementation instead of the scala + // iterator for performance / memory usage reason. + + def leftOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) + : Iterator[Row] = { + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).collect { + case r if (boundCondition(joinedRow.withRight(r))) => { + matched = true + joinedRow + } + } ++ BinaryJoinNode.SINGLE_NULL_LIST.collect { + case dummy if (!matched) => { + joinedRow.withRight(rightNullRow) + } + } + } + } + + // TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin + def leftSemiIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) + : Iterator[Row] = { + leftIter.iterator.filter { l => + joinedRow.withLeft(l) + (if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).exists { + case r => (boundCondition(joinedRow.withRight(r))) + } + } + } + + def rightOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) + : Iterator[Row] = { + rightIter.iterator.flatMap{r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter else BinaryJoinNode.EMPTY_NULL_LIST).collect { + case l if (boundCondition(joinedRow.withLeft(l))) => { + matched = true + joinedRow + } + } ++ BinaryJoinNode.SINGLE_NULL_LIST.collect { + case dummy if (!matched) => { + joinedRow.withLeft(leftNullRow) + } + } + } + } + + def fullOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) + : Iterator[Row] = { + if (!key.anyNull) { + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + rightMatchedSet.add(idx) + joinedRow + } + } ++ BinaryJoinNode.SINGLE_NULL_LIST.collect { + case dummy if (!matched) => { + joinedRow.withRight(rightNullRow) + } + } + } ++ rightIter.zipWithIndex.collect { + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r) + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow) + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r) + } + } + } + + // TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin + def innerIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) + : Iterator[Row] = { + // ignore the join filter for inner join, we assume it will done in the select filter + if (!key.anyNull) { + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + rightIter.iterator.collect { + case r if boundCondition(joinedRow.withRight(r)) => joinedRow + } + } + } else { + BinaryJoinNode.EMPTY_NULL_LIST.iterator + } + } +} + trait HashJoin { self: SparkPlan => @@ -72,7 +201,7 @@ trait HashJoin { while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { val newMatchList = new ArrayBuffer[Row]() @@ -136,6 +265,67 @@ trait HashJoin { } } +/** + * :: DeveloperApi :: + * Performs a hash join of two child relations by shuffling the data using the join keys. + * This operator requires loading both tables into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryRepeatableIteratorNode { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def output = left.output ++ right.output + + private[this] def buildHashTable(iter: Iterator[Row], keyGenerator: Projection) + : Map[Row, ArrayBuffer[Row]] = { + // TODO: Use Spark's HashMap implementation. + val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()}) + existingMatchList += currentRow.copy() + } + + hashTable.toMap[Row, ArrayBuffer[Row]] + } + + def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val rightHashTable= buildHashTable(rightIter, newProjection(rightKeys, right.output)) + + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST), + rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST), + rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST), + rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST)) + } + case x => throw new Exception(s"Need to add implementation for $x") + } + } + } +} + /** * :: DeveloperApi :: * Performs an inner hash join of two child relations by first shuffling the data using the join @@ -189,7 +379,7 @@ case class LeftSemiJoinHash( while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { hashSet.add(rowKey) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 025c396ef0629..4380daa00569d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,15 +17,42 @@ package org.apache.spark.sql +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner} +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -class JoinSuite extends QueryTest { +class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData + var left: UnresolvedRelation = _ + var right: UnresolvedRelation = _ + + override def beforeEach() { + super.beforeEach() + left = UnresolvedRelation(None, "left", None) + right = UnresolvedRelation(None, "right", None) + } + + override def afterEach() { + super.afterEach() + + TestSQLContext.catalog.unregisterTable(None, "left") + TestSQLContext.catalog.unregisterTable(None, "right") + } + + def check(run: () => Unit) { + // TODO hack the logical statistics for cost based optimization. + run() + } + test("equi-join is hash-join") { val x = testData2.as('x) val y = testData2.as('y) @@ -34,6 +61,56 @@ class JoinSuite extends QueryTest { assert(planned.size === 1) } + test("join operator selection") { + def assertJoin(sqlString: String, c: Class[_]): Any = { + val rdd = sql(sqlString) + val physical = rdd.queryExecution.sparkPlan + val operators = physical.collect { + case j: ShuffledHashJoin => j + case j: HashOuterJoin => j + case j: LeftSemiJoinHash => j + case j: BroadcastHashJoin => j + case j: LeftSemiJoinBNL => j + case j: CartesianProduct => j + case j: BroadcastNestedLoopJoin => j + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + } + } + + val cases1 = Seq( + ("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a where key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]) + // TODO add BroadcastNestedLoopJoin + ) + cases1.foreach { c => assertJoin(c._1, c._2) } + } + test("multiple-key equi-join is hash-join") { val x = testData2.as('x) val y = testData2.as('y) @@ -106,38 +183,131 @@ class JoinSuite extends QueryTest { } test("left outer join") { - checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + lowerCaseData.registerAsTable("right") + upperCaseData.registerAsTable("left") + def run() { + checkAnswer( + left.join(right, LeftOuter, Some('n === 'N)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + left.join(right, LeftOuter, Some('n === 'N && 'n > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + left.join(right, LeftOuter, Some('n === 'N && 'N > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + left.join(right, LeftOuter, Some('n === 'N && 'l > 'L)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + } + + check(run) } test("right outer join") { - checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + lowerCaseData.registerAsTable("left") + upperCaseData.registerAsTable("right") + + val left = UnresolvedRelation(None, "left", None) + val right = UnresolvedRelation(None, "right", None) + + def run() { + checkAnswer( + left.join(right, RightOuter, Some('n === 'N)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + left.join(right, RightOuter, Some('n === 'N && 'n > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + left.join(right, RightOuter, Some('n === 'N && 'N > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + left.join(right, RightOuter, Some('n === 'N && 'l > 'L)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } + + check(run) } test("full outer join") { - val left = upperCaseData.where('N <= 4).as('left) - val right = upperCaseData.where('N >= 3).as('right) + upperCaseData.where('N <= 4).registerAsTable("left") + upperCaseData.where('N >= 3).registerAsTable("right") - checkAnswer( - left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + def run() { + checkAnswer( + left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, + Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, + Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } + + check(run) } }