Skip to content

Commit

Permalink
Add HashOuterJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Jul 31, 2014
1 parent 2ac37db commit 55baef7
Show file tree
Hide file tree
Showing 3 changed files with 394 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
execution.HashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil

case _ => Nil
}
}
Expand Down
194 changes: 192 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,135 @@ case object BuildLeft extends BuildSide
@DeveloperApi
case object BuildRight extends BuildSide

/**
* Constant Value for Binary Join Node
*/
object BinaryJoinNode {
val SINGLE_NULL_LIST = Seq[Row](null)
val EMPTY_NULL_LIST = Seq[Row]()
}

// TODO If join key was null should be considered as equal? In Hive this is configurable.

/**
* Output the tuples for the matched (with the same join key) join group, base on the join types,
* Both input iterators should be repeatable.
*/
trait BinaryRepeatableIteratorNode extends BinaryNode {
self: Product =>

val leftNullRow = new GenericRow(left.output.length)
val rightNullRow = new GenericRow(right.output.length)

val joinedRow = new JoinedRow()

val boundCondition = InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))

def condition: Option[Expression]
def joinType: JoinType

// TODO we need to rewrite all of the iterators with our own implementation instead of the scala
// iterator for performance / memory usage reason.

def leftOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
var matched = false
(if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).collect {
case r if (boundCondition(joinedRow.withRight(r))) => {
matched = true
joinedRow
}
} ++ BinaryJoinNode.SINGLE_NULL_LIST.collect {
case dummy if (!matched) => {
joinedRow.withRight(rightNullRow)
}
}
}
}

// TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin
def leftSemiIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
leftIter.iterator.filter { l =>
joinedRow.withLeft(l)
(if (!key.anyNull) rightIter else BinaryJoinNode.EMPTY_NULL_LIST).exists {
case r => (boundCondition(joinedRow.withRight(r)))
}
}
}

def rightOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
rightIter.iterator.flatMap{r =>
joinedRow.withRight(r)
var matched = false
(if (!key.anyNull) leftIter else BinaryJoinNode.EMPTY_NULL_LIST).collect {
case l if (boundCondition(joinedRow.withLeft(l))) => {
matched = true
joinedRow
}
} ++ BinaryJoinNode.SINGLE_NULL_LIST.collect {
case dummy if (!matched) => {
joinedRow.withLeft(leftNullRow)
}
}
}
}

def fullOuterIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
if (!key.anyNull) {
val rightMatchedSet = scala.collection.mutable.Set[Int]()
leftIter.iterator.flatMap[Row] { l =>
joinedRow.withLeft(l)
var matched = false
rightIter.zipWithIndex.collect {
case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
matched = true
rightMatchedSet.add(idx)
joinedRow
}
} ++ BinaryJoinNode.SINGLE_NULL_LIST.collect {
case dummy if (!matched) => {
joinedRow.withRight(rightNullRow)
}
}
} ++ rightIter.zipWithIndex.collect {
case (r, idx) if (!rightMatchedSet.contains(idx)) => {
joinedRow(leftNullRow, r)
}
}
} else {
leftIter.iterator.map[Row] { l =>
joinedRow(l, rightNullRow)
} ++ rightIter.iterator.map[Row] { r =>
joinedRow(leftNullRow, r)
}
}
}

// TODO need to unit test this, currently it's the dead code, but should be used in SortMergeJoin
def innerIterator(key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row])
: Iterator[Row] = {
// ignore the join filter for inner join, we assume it will done in the select filter
if (!key.anyNull) {
leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
rightIter.iterator.collect {
case r if boundCondition(joinedRow.withRight(r)) => joinedRow
}
}
} else {
BinaryJoinNode.EMPTY_NULL_LIST.iterator
}
}
}

trait HashJoin {
self: SparkPlan =>

Expand Down Expand Up @@ -72,7 +201,7 @@ trait HashJoin {
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
Expand Down Expand Up @@ -136,6 +265,67 @@ trait HashJoin {
}
}

/**
* :: DeveloperApi ::
* Performs a hash join of two child relations by shuffling the data using the join keys.
* This operator requires loading both tables into memory.
*/
@DeveloperApi
case class HashOuterJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryRepeatableIteratorNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

def output = left.output ++ right.output

private[this] def buildHashTable(iter: Iterator[Row], keyGenerator: Projection)
: Map[Row, ArrayBuffer[Row]] = {
// TODO: Use Spark's HashMap implementation.
val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]()
while (iter.hasNext) {
val currentRow = iter.next()
val rowKey = keyGenerator(currentRow)

val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()})
existingMatchList += currentRow.copy()
}

hashTable.toMap[Row, ArrayBuffer[Row]]
}

def execute() = {
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// TODO this probably can be replaced by external sort (sort merged join?)
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
val rightHashTable= buildHashTable(rightIter, newProjection(rightKeys, right.output))

joinType match {
case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
leftOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST),
rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST))
}
case RightOuter => rightHashTable.keysIterator.flatMap { key =>
rightOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST),
rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST))
}
case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
fullOuterIterator(key, leftHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST),
rightHashTable.getOrElse(key, BinaryJoinNode.EMPTY_NULL_LIST))
}
case x => throw new Exception(s"Need to add implementation for $x")
}
}
}
}

/**
* :: DeveloperApi ::
* Performs an inner hash join of two child relations by first shuffling the data using the join
Expand Down Expand Up @@ -189,7 +379,7 @@ case class LeftSemiJoinHash(
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
Expand Down
Loading

0 comments on commit 55baef7

Please sign in to comment.