Skip to content

Commit

Permalink
Fix issues with the community comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Aug 1, 2014
1 parent 72b1394 commit 65c599e
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 279 deletions.
273 changes: 131 additions & 142 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,135 +37,6 @@ case object BuildLeft extends BuildSide
@DeveloperApi
case object BuildRight extends BuildSide

/**
* Constant Value for Binary Join Node
*/
object BinaryJoinNode {
val SINGLE_NULL_LIST = Seq[Row](null)
val EMPTY_NULL_LIST = Seq[Row]()
}

// TODO If join key was null should be considered as equal? In Hive this is configurable.

/**
* Output the tuples for the matched (with the same join key) join group, base on the join types,
* Both input iterators should be repeatable.
*/
trait BinaryRepeatableIteratorNode extends BinaryNode {
self: Product =>

val leftNullRow = new GenericRow(left.output.length)
val rightNullRow = new GenericRow(right.output.length)

val joinedRow = new JoinedRow()

val boundCondition = InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))

def condition: Option[Expression]
def joinType: JoinType

// TODO we need to rewrite all of the iterators with our own implementation instead of the scala
// iterator for performance / memory usage reason.

def leftOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
var matched = false
(if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).collect {
case r if (boundCondition(joinedRow.withRight(r))) => {
matched = true
joinedRow.copy
}
} ++ BinaryJoinNode.SINGLE_NULL_LIST.collect {
case dummy if (!matched) => {
joinedRow.withRight(rightNullRow).copy
}
}
}
}

// TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin
def leftSemiIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
leftIter.iterator.filter { l =>
joinedRow.withLeft(l)
(if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).exists {
case r => (boundCondition(joinedRow.withRight(r)))
}
}
}

def rightOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
rightIter.iterator.flatMap{r =>
joinedRow.withRight(r)
var matched = false
(if (!key.anyNull) leftIter else BinaryJoinNode.EMPTY_NULL_LIST).collect {
case l if (boundCondition(joinedRow.withLeft(l))) => {
matched = true
joinedRow.copy
}
} ++ BinaryJoinNode.SINGLE_NULL_LIST.collect {
case dummy if (!matched) => {
joinedRow.withLeft(leftNullRow).copy
}
}
}
}

def fullOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
if (!key.anyNull) {
val rightMatchedSet = scala.collection.mutable.Set[Int]()
leftIter.iterator.flatMap[Row] { l =>
joinedRow.withLeft(l)
var matched = false
rightIter.zipWithIndex.collect {
case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
matched = true
rightMatchedSet.add(idx)
joinedRow.copy
}
} ++ BinaryJoinNode.SINGLE_NULL_LIST.collect {
case dummy if (!matched) => {
joinedRow.withRight(rightNullRow).copy
}
}
} ++ rightIter.zipWithIndex.collect {
case (r, idx) if (!rightMatchedSet.contains(idx)) => {
joinedRow(leftNullRow, r).copy
}
}
} else {
leftIter.iterator.map[Row] { l =>
joinedRow(l, rightNullRow).copy
} ++ rightIter.iterator.map[Row] { r =>
joinedRow(leftNullRow, r).copy
}
}
}

// TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin
def innerIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
// ignore the join filter for inner join, we assume it will done in the select filter
if (!key.anyNull) {
leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
rightIter.iterator.collect {
case r if boundCondition(joinedRow.withRight(r)) => joinedRow
}
}
} else {
BinaryJoinNode.EMPTY_NULL_LIST.iterator
}
}
}

trait HashJoin {
self: SparkPlan =>

Expand Down Expand Up @@ -265,10 +136,18 @@ trait HashJoin {
}
}

/**
* Constant Value for Binary Join Node
*/
object HashOuterJoin {
val DUMMY_LIST = Seq[Row](null)
val EMPTY_LIST = Seq[Row]()
}

/**
* :: DeveloperApi ::
* Performs a hash join of two child relations by shuffling the data using the join keys.
* This operator requires loading both tables into memory.
* Performs a hash based outer join for two child relations by shuffling the data using
* the join keys. This operator requires loading the associated partition in both side into memory.
*/
@DeveloperApi
case class HashOuterJoin(
Expand All @@ -277,7 +156,7 @@ case class HashOuterJoin(
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryRepeatableIteratorNode {
right: SparkPlan) extends BinaryNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

Expand All @@ -286,8 +165,113 @@ case class HashOuterJoin(

def output = left.output ++ right.output

private[this] def buildHashTable(iter: Iterator[Row], keyGenerator: Projection)
: Map[Row, ArrayBuffer[Row]] = {
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.

private[this] def leftOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val rightNullRow = new GenericRow(right.output.length)
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
var matched = false
(if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => {
// HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all of the
// records in right side.
// If we didn't get any proper row, then append a single row with empty right
joinedRow.withRight(rightNullRow).copy
})
}
}

private[this] def rightOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

rightIter.iterator.flatMap { r =>
joinedRow.withRight(r)
var matched = false
(if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => {
// HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all of the
// records in left side.
// If we didn't get any proper row, then append a single row with empty left.
joinedRow.withLeft(leftNullRow).copy
})
}
}

private[this] def fullOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
val rightNullRow = new GenericRow(right.output.length)
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

if (!key.anyNull) {
// Store the positions of records in right, if one of its associated row satisfy
// the join condition.
val rightMatchedSet = scala.collection.mutable.Set[Int]()
leftIter.iterator.flatMap[Row] { l =>
joinedRow.withLeft(l)
var matched = false
rightIter.zipWithIndex.collect {
// 1. For those matched (satisfy the join condition) records with both sides filled,
// append them directly

case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
matched = true
// if the row satisfy the join condition, add its index into the matched set
rightMatchedSet.add(idx)
joinedRow.copy
}
} ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => {
// 2. For those unmatched records in left, append additional records with empty right.

// HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all
// of the records in right side.
// If we didn't get any proper row, then append a single row with empty right.
joinedRow.withRight(rightNullRow).copy
})
} ++ rightIter.zipWithIndex.collect {
// 3. For those unmatched records in right, append additional records with empty left.

// Re-visiting the records in right, and append additional row with empty left, if its not
// in the matched set.
case (r, idx) if (!rightMatchedSet.contains(idx)) => {
joinedRow(leftNullRow, r).copy
}
}
} else {
leftIter.iterator.map[Row] { l =>
joinedRow(l, rightNullRow).copy
} ++ rightIter.iterator.map[Row] { r =>
joinedRow(leftNullRow, r).copy
}
}
}

private[this] def buildHashTable(
iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = {
// TODO: Use Spark's HashMap implementation.
val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]()
while (iter.hasNext) {
Expand All @@ -300,25 +284,30 @@ case class HashOuterJoin(

hashTable.toMap[Row, ArrayBuffer[Row]]
}

def execute() = {
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// TODO this probably can be replaced by external sort (sort merged join?)
// Build HashMap for current partition in left relation
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
val rightHashTable= buildHashTable(rightIter, newProjection(rightKeys, right.output))
// Build HashMap for current partition in right relation
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))

val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
joinType match {
case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
leftOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST),
rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST))
leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST),
rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
}
case RightOuter => rightHashTable.keysIterator.flatMap { key =>
rightOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST),
rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST))
rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST),
rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
}
case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
fullOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST),
rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST))
fullOuterIterator(key,
leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST),
rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
}
case x => throw new Exception(s"Need to add implementation for $x")
}
Expand Down
Loading

0 comments on commit 65c599e

Please sign in to comment.