Skip to content

Commit

Permalink
[SQL][SPARK-2212]Hash Outer Join
Browse files Browse the repository at this point in the history
This patch is to support the hash based outer join. Currently, outer join for big relations are resort to `BoradcastNestedLoopJoin`, which is super slow. This PR will create 2 hash tables for both relations in the same partition, which greatly reduce the table scans.

Here is the testing code that I used:
```
package org.apache.spark.sql.hive

import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql._

case class Record(key: String, value: String)

object JoinTablePrepare extends App {
  import TestHive2._

  val rdd = sparkContext.parallelize((1 to 3000000).map(i => Record(s"${i % 828193}", s"val_$i")))

  runSqlHive("SHOW TABLES")
  runSqlHive("DROP TABLE if exists a")
  runSqlHive("DROP TABLE if exists b")
  runSqlHive("DROP TABLE if exists result")
  rdd.registerAsTable("records")

  runSqlHive("""CREATE TABLE a (key STRING, value STRING)
                 | ROW FORMAT SERDE
                 | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe'
                 | STORED AS RCFILE
               """.stripMargin)
  runSqlHive("""CREATE TABLE b (key STRING, value STRING)
                 | ROW FORMAT SERDE
                 | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe'
                 | STORED AS RCFILE
               """.stripMargin)
  runSqlHive("""CREATE TABLE result (key STRING, value STRING)
                 | ROW FORMAT SERDE
                 | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe'
                 | STORED AS RCFILE
               """.stripMargin)

  hql(s"""from records
             | insert into table a
             | select key, value
           """.stripMargin)
  hql(s"""from records
             | insert into table b select key + 100000, value
           """.stripMargin)
}

object JoinTablePerformanceTest extends App {
  import TestHive2._

  hql("SHOW TABLES")
  hql("set spark.sql.shuffle.partitions=20")

  val leftOuterJoin = "insert overwrite table result select a.key, b.value from a left outer join b on a.key=b.key"
  val rightOuterJoin = "insert overwrite table result select a.key, b.value from a right outer join b on a.key=b.key"
  val fullOuterJoin = "insert overwrite table result select a.key, b.value from a full outer join b on a.key=b.key"

  val results = ("LeftOuterJoin", benchmark(leftOuterJoin)) :: ("LeftOuterJoin", benchmark(leftOuterJoin)) ::
                ("RightOuterJoin", benchmark(rightOuterJoin)) :: ("RightOuterJoin", benchmark(rightOuterJoin)) ::
                ("FullOuterJoin", benchmark(fullOuterJoin)) :: ("FullOuterJoin", benchmark(fullOuterJoin)) :: Nil
  val explains = hql(s"explain $leftOuterJoin").collect ++ hql(s"explain $rightOuterJoin").collect ++ hql(s"explain $fullOuterJoin").collect
  println(explains.mkString(",\n"))
  results.foreach { case (prompt, result) => {
      println(s"$prompt: took ${result._1} ms (${result._2} records)")
    }
  }

  def benchmark(cmd: String) = {
    val begin = System.currentTimeMillis()
    val result = hql(cmd)
    val end = System.currentTimeMillis()
    val count = hql("select count(1) from result").collect.mkString("")
    ((end - begin), count)
  }
}
```
And the result as shown below:
```
[Physical execution plan:],
[InsertIntoHiveTable (MetastoreRelation default, result, None), Map(), true],
[ Project [key#95,value#98]],
[  HashOuterJoin [key#95], [key#97], LeftOuter, None],
[   Exchange (HashPartitioning [key#95], 20)],
[    HiveTableScan [key#95], (MetastoreRelation default, a, None), None],
[   Exchange (HashPartitioning [key#97], 20)],
[    HiveTableScan [key#97,value#98], (MetastoreRelation default, b, None), None],
[Physical execution plan:],
[InsertIntoHiveTable (MetastoreRelation default, result, None), Map(), true],
[ Project [key#102,value#105]],
[  HashOuterJoin [key#102], [key#104], RightOuter, None],
[   Exchange (HashPartitioning [key#102], 20)],
[    HiveTableScan [key#102], (MetastoreRelation default, a, None), None],
[   Exchange (HashPartitioning [key#104], 20)],
[    HiveTableScan [key#104,value#105], (MetastoreRelation default, b, None), None],
[Physical execution plan:],
[InsertIntoHiveTable (MetastoreRelation default, result, None), Map(), true],
[ Project [key#109,value#112]],
[  HashOuterJoin [key#109], [key#111], FullOuter, None],
[   Exchange (HashPartitioning [key#109], 20)],
[    HiveTableScan [key#109], (MetastoreRelation default, a, None), None],
[   Exchange (HashPartitioning [key#111], 20)],
[    HiveTableScan [key#111,value#112], (MetastoreRelation default, b, None), None]
LeftOuterJoin: took 16072 ms ([3000000] records)
LeftOuterJoin: took 14394 ms ([3000000] records)
RightOuterJoin: took 14802 ms ([3000000] records)
RightOuterJoin: took 14747 ms ([3000000] records)
FullOuterJoin: took 17715 ms ([6000000] records)
FullOuterJoin: took 17629 ms ([6000000] records)
```

Without this PR, the benchmark will run seems never end.

Author: Cheng Hao <hao.cheng@intel.com>

Closes apache#1147 from chenghao-intel/hash_based_outer_join and squashes the following commits:

65c599e [Cheng Hao] Fix issues with the community comments
72b1394 [Cheng Hao] Fix bug of stale value in joinedRow
55baef7 [Cheng Hao] Add HashOuterJoin
  • Loading branch information
chenghao-intel authored and conviva-zz committed Sep 4, 2014
1 parent 7e43d7e commit 5420551
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
execution.HashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil

case _ => Nil
}
}
Expand Down
183 changes: 181 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ trait HashJoin {
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
Expand Down Expand Up @@ -136,6 +136,185 @@ trait HashJoin {
}
}

/**
* Constant Value for Binary Join Node
*/
object HashOuterJoin {
val DUMMY_LIST = Seq[Row](null)
val EMPTY_LIST = Seq[Row]()
}

/**
* :: DeveloperApi ::
* Performs a hash based outer join for two child relations by shuffling the data using
* the join keys. This operator requires loading the associated partition in both side into memory.
*/
@DeveloperApi
case class HashOuterJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

def output = left.output ++ right.output

// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.

private[this] def leftOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val rightNullRow = new GenericRow(right.output.length)
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
var matched = false
(if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => {
// HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all of the
// records in right side.
// If we didn't get any proper row, then append a single row with empty right
joinedRow.withRight(rightNullRow).copy
})
}
}

private[this] def rightOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

rightIter.iterator.flatMap { r =>
joinedRow.withRight(r)
var matched = false
(if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => {
// HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all of the
// records in left side.
// If we didn't get any proper row, then append a single row with empty left.
joinedRow.withLeft(leftNullRow).copy
})
}
}

private[this] def fullOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
val rightNullRow = new GenericRow(right.output.length)
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

if (!key.anyNull) {
// Store the positions of records in right, if one of its associated row satisfy
// the join condition.
val rightMatchedSet = scala.collection.mutable.Set[Int]()
leftIter.iterator.flatMap[Row] { l =>
joinedRow.withLeft(l)
var matched = false
rightIter.zipWithIndex.collect {
// 1. For those matched (satisfy the join condition) records with both sides filled,
// append them directly

case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
matched = true
// if the row satisfy the join condition, add its index into the matched set
rightMatchedSet.add(idx)
joinedRow.copy
}
} ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => {
// 2. For those unmatched records in left, append additional records with empty right.

// HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all
// of the records in right side.
// If we didn't get any proper row, then append a single row with empty right.
joinedRow.withRight(rightNullRow).copy
})
} ++ rightIter.zipWithIndex.collect {
// 3. For those unmatched records in right, append additional records with empty left.

// Re-visiting the records in right, and append additional row with empty left, if its not
// in the matched set.
case (r, idx) if (!rightMatchedSet.contains(idx)) => {
joinedRow(leftNullRow, r).copy
}
}
} else {
leftIter.iterator.map[Row] { l =>
joinedRow(l, rightNullRow).copy
} ++ rightIter.iterator.map[Row] { r =>
joinedRow(leftNullRow, r).copy
}
}
}

private[this] def buildHashTable(
iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = {
// TODO: Use Spark's HashMap implementation.
val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]()
while (iter.hasNext) {
val currentRow = iter.next()
val rowKey = keyGenerator(currentRow)

val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()})
existingMatchList += currentRow.copy()
}

hashTable.toMap[Row, ArrayBuffer[Row]]
}

def execute() = {
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// TODO this probably can be replaced by external sort (sort merged join?)
// Build HashMap for current partition in left relation
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
// Build HashMap for current partition in right relation
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))

val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
joinType match {
case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST),
rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
}
case RightOuter => rightHashTable.keysIterator.flatMap { key =>
rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST),
rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
}
case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
fullOuterIterator(key,
leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST),
rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
}
case x => throw new Exception(s"Need to add implementation for $x")
}
}
}
}

/**
* :: DeveloperApi ::
* Performs an inner hash join of two child relations by first shuffling the data using the join
Expand Down Expand Up @@ -189,7 +368,7 @@ case class LeftSemiJoinHash(
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
Expand Down
138 changes: 134 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@

package org.apache.spark.sql

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._

class JoinSuite extends QueryTest {
class JoinSuite extends QueryTest with BeforeAndAfterEach {

// Ensures tables are loaded.
TestData
Expand All @@ -34,6 +40,56 @@ class JoinSuite extends QueryTest {
assert(planned.size === 1)
}

test("join operator selection") {
def assertJoin(sqlString: String, c: Class[_]): Any = {
val rdd = sql(sqlString)
val physical = rdd.queryExecution.sparkPlan
val operators = physical.collect {
case j: ShuffledHashJoin => j
case j: HashOuterJoin => j
case j: LeftSemiJoinHash => j
case j: BroadcastHashJoin => j
case j: LeftSemiJoinBNL => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
}

assert(operators.size === 1)
if (operators(0).getClass() != c) {
fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
}
}

val cases1 = Seq(
("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]),
("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]),
("SELECT * FROM testData join testData2", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]),
("SELECT * FROM testData left join testData2", classOf[CartesianProduct]),
("SELECT * FROM testData right join testData2", classOf[CartesianProduct]),
("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]),
("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]),
("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]),
("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]),
("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]),
("SELECT * FROM testData right join testData2 ON key = a where key=2",
classOf[HashOuterJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key=2",
classOf[HashOuterJoin]),
("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]),
("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin])
// TODO add BroadcastNestedLoopJoin
)
cases1.foreach { c => assertJoin(c._1, c._2) }
}

test("multiple-key equi-join is hash-join") {
val x = testData2.as('x)
val y = testData2.as('y)
Expand Down Expand Up @@ -114,6 +170,33 @@ class JoinSuite extends QueryTest {
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)

checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
(1, "A", null, null) ::
(2, "B", 2, "b") ::
(3, "C", 3, "c") ::
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)

checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
(1, "A", null, null) ::
(2, "B", 2, "b") ::
(3, "C", 3, "c") ::
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)

checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
(1, "A", 1, "a") ::
(2, "B", 2, "b") ::
(3, "C", 3, "c") ::
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)
}

test("right outer join") {
Expand All @@ -125,11 +208,38 @@ class JoinSuite extends QueryTest {
(4, "d", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)),
(null, null, 1, "A") ::
(2, "b", 2, "B") ::
(3, "c", 3, "C") ::
(4, "d", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)),
(null, null, 1, "A") ::
(2, "b", 2, "B") ::
(3, "c", 3, "C") ::
(4, "d", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)),
(1, "a", 1, "A") ::
(2, "b", 2, "B") ::
(3, "c", 3, "C") ::
(4, "d", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
}

test("full outer join") {
val left = upperCaseData.where('N <= 4).as('left)
val right = upperCaseData.where('N >= 3).as('right)
upperCaseData.where('N <= 4).registerAsTable("left")
upperCaseData.where('N >= 3).registerAsTable("right")

val left = UnresolvedRelation(None, "left", None)
val right = UnresolvedRelation(None, "right", None)

checkAnswer(
left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
Expand All @@ -139,5 +249,25 @@ class JoinSuite extends QueryTest {
(4, "D", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)

checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
(1, "A", null, null) ::
(2, "B", null, null) ::
(3, "C", null, null) ::
(null, null, 3, "C") ::
(4, "D", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)

checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
(1, "A", null, null) ::
(2, "B", null, null) ::
(3, "C", null, null) ::
(null, null, 3, "C") ::
(4, "D", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
}
}

0 comments on commit 5420551

Please sign in to comment.