diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 0ce8bf1bea574..3a4cecd327d10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -37,135 +37,6 @@ case object BuildLeft extends BuildSide @DeveloperApi case object BuildRight extends BuildSide -/** - * Constant Value for Binary Join Node - */ -object BinaryJoinNode { - val SINGLE_NULL_LIST = Seq[Row](null) - val EMPTY_NULL_LIST = Seq[Row]() -} - -// TODO If join key was null should be considered as equal? In Hive this is configurable. - -/** - * Output the tuples for the matched (with the same join key) join group, base on the join types, - * Both input iterators should be repeatable. - */ -trait BinaryRepeatableIteratorNode extends BinaryNode { - self: Product => - - val leftNullRow = new GenericRow(left.output.length) - val rightNullRow = new GenericRow(right.output.length) - - val joinedRow = new JoinedRow() - - val boundCondition = InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def condition: Option[Expression] - def joinType: JoinType - - // TODO we need to rewrite all of the iterators with our own implementation instead of the scala - // iterator for performance / memory usage reason. - - def leftOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) - : Iterator[Row] = { - leftIter.iterator.flatMap { l => - joinedRow.withLeft(l) - var matched = false - (if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).collect { - case r if (boundCondition(joinedRow.withRight(r))) => { - matched = true - joinedRow.copy - } - } ++ BinaryJoinNode.SINGLE_NULL_LIST.collect { - case dummy if (!matched) => { - joinedRow.withRight(rightNullRow).copy - } - } - } - } - - // TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin - def leftSemiIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) - : Iterator[Row] = { - leftIter.iterator.filter { l => - joinedRow.withLeft(l) - (if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).exists { - case r => (boundCondition(joinedRow.withRight(r))) - } - } - } - - def rightOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) - : Iterator[Row] = { - rightIter.iterator.flatMap{r => - joinedRow.withRight(r) - var matched = false - (if (!key.anyNull) leftIter else BinaryJoinNode.EMPTY_NULL_LIST).collect { - case l if (boundCondition(joinedRow.withLeft(l))) => { - matched = true - joinedRow.copy - } - } ++ BinaryJoinNode.SINGLE_NULL_LIST.collect { - case dummy if (!matched) => { - joinedRow.withLeft(leftNullRow).copy - } - } - } - } - - def fullOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) - : Iterator[Row] = { - if (!key.anyNull) { - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[Row] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { - matched = true - rightMatchedSet.add(idx) - joinedRow.copy - } - } ++ BinaryJoinNode.SINGLE_NULL_LIST.collect { - case dummy if (!matched) => { - joinedRow.withRight(rightNullRow).copy - } - } - } ++ rightIter.zipWithIndex.collect { - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } - } - } else { - leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy - } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy - } - } - } - - // TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin - def innerIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]) - : Iterator[Row] = { - // ignore the join filter for inner join, we assume it will done in the select filter - if (!key.anyNull) { - leftIter.iterator.flatMap { l => - joinedRow.withLeft(l) - rightIter.iterator.collect { - case r if boundCondition(joinedRow.withRight(r)) => joinedRow - } - } - } else { - BinaryJoinNode.EMPTY_NULL_LIST.iterator - } - } -} - trait HashJoin { self: SparkPlan => @@ -265,10 +136,18 @@ trait HashJoin { } } +/** + * Constant Value for Binary Join Node + */ +object HashOuterJoin { + val DUMMY_LIST = Seq[Row](null) + val EMPTY_LIST = Seq[Row]() +} + /** * :: DeveloperApi :: - * Performs a hash join of two child relations by shuffling the data using the join keys. - * This operator requires loading both tables into memory. + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. */ @DeveloperApi case class HashOuterJoin( @@ -277,7 +156,7 @@ case class HashOuterJoin( joinType: JoinType, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryRepeatableIteratorNode { + right: SparkPlan) extends BinaryNode { override def outputPartitioning: Partitioning = left.outputPartitioning @@ -286,8 +165,113 @@ case class HashOuterJoin( def output = left.output ++ right.output - private[this] def buildHashTable(iter: Iterator[Row], keyGenerator: Projection) - : Map[Row, ArrayBuffer[Row]] = { + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = { // TODO: Use Spark's HashMap implementation. val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() while (iter.hasNext) { @@ -300,25 +284,30 @@ case class HashOuterJoin( hashTable.toMap[Row, ArrayBuffer[Row]] } - + def execute() = { left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable= buildHashTable(rightIter, newProjection(rightKeys, right.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) joinType match { case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST), - rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST)) + leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) } case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST), - rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST)) + rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) } case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST), - rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST)) + fullOuterIterator(key, + leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) } case x => throw new Exception(s"Need to add implementation for $x") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 4380daa00569d..037890682f7b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -32,27 +32,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData - var left: UnresolvedRelation = _ - var right: UnresolvedRelation = _ - - override def beforeEach() { - super.beforeEach() - left = UnresolvedRelation(None, "left", None) - right = UnresolvedRelation(None, "right", None) - } - - override def afterEach() { - super.afterEach() - - TestSQLContext.catalog.unregisterTable(None, "left") - TestSQLContext.catalog.unregisterTable(None, "right") - } - - def check(run: () => Unit) { - // TODO hack the logical statistics for cost based optimization. - run() - } - test("equi-join is hash-join") { val x = testData2.as('x) val y = testData2.as('y) @@ -183,131 +162,112 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("left outer join") { - lowerCaseData.registerAsTable("right") - upperCaseData.registerAsTable("left") - def run() { - checkAnswer( - left.join(right, LeftOuter, Some('n === 'N)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) - - checkAnswer( - left.join(right, LeftOuter, Some('n === 'N && 'n > 1)), - (1, "A", null, null) :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) - - checkAnswer( - left.join(right, LeftOuter, Some('n === 'N && 'N > 1)), - (1, "A", null, null) :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) - - checkAnswer( - left.join(right, LeftOuter, Some('n === 'N && 'l > 'L)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) - } - - check(run) + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) } test("right outer join") { - lowerCaseData.registerAsTable("left") - upperCaseData.registerAsTable("right") - - val left = UnresolvedRelation(None, "left", None) - val right = UnresolvedRelation(None, "right", None) - - def run() { - checkAnswer( - left.join(right, RightOuter, Some('n === 'N)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - checkAnswer( - left.join(right, RightOuter, Some('n === 'N && 'n > 1)), - (null, null, 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - checkAnswer( - left.join(right, RightOuter, Some('n === 'N && 'N > 1)), - (null, null, 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - checkAnswer( - left.join(right, RightOuter, Some('n === 'N && 'l > 'L)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - } - - check(run) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) } test("full outer join") { upperCaseData.where('N <= 4).registerAsTable("left") upperCaseData.where('N >= 3).registerAsTable("right") - def run() { - checkAnswer( - left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - - checkAnswer( - left.join(right, FullOuter, - Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", null, null) :: - (null, null, 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - - checkAnswer( - left.join(right, FullOuter, - Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", null, null) :: - (null, null, 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) - } + val left = UnresolvedRelation(None, "left", None) + val right = UnresolvedRelation(None, "right", None) - check(run) + checkAnswer( + left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) } }