Skip to content

Commit

Permalink
Unsafe HashJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jul 17, 2015
1 parent eba6a1a commit bea4a50
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.UTF8String;


Expand Down Expand Up @@ -345,7 +346,7 @@ public double getDouble(int i) {
* This method is only supported on UnsafeRows that do not use ObjectPools.
*/
@Override
public InternalRow copy() {
public UnsafeRow copy() {
if (pool != null) {
throw new UnsupportedOperationException(
"Copy is not supported for UnsafeRows that use object pools");
Expand All @@ -365,6 +366,33 @@ public InternalRow copy() {
}
}

@Override
public int hashCode() {
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
}

@Override
public boolean equals(Object other) {
if (other instanceof UnsafeRow) {
UnsafeRow o = (UnsafeRow) other;
return ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
sizeInBytes);
}
return false;
}

// This is for debugging
@Override
public String toString(){
StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
build.append(',');
}
build.append(']');
return build.toString();
}

@Override
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,12 @@ abstract class UnsafeProjection extends Projection {
object UnsafeProjection {
def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))

def create(fields: Seq[DataType]): UnsafeProjection = {
def create(fields: Array[DataType]): UnsafeProjection = {
val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))
create(exprs)
}

def create(exprs: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(exprs)
}
}
Expand All @@ -96,6 +100,8 @@ object UnsafeProjection {
*/
case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection {

def this(schema: StructType) = this(schema.fields.map(_.dataType))

private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) =>
new BoundReference(idx, dt, true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
/**
* Function for writing a column into an UnsafeRow.
*/
private abstract class UnsafeColumnWriter {
abstract class UnsafeColumnWriter {
/**
* Write a value into an UnsafeRow.
*
Expand All @@ -130,7 +130,7 @@ private abstract class UnsafeColumnWriter {
def getSize(source: InternalRow, column: Int): Int
}

private object UnsafeColumnWriter {
object UnsafeColumnWriter {

def forType(dataType: DataType): UnsafeColumnWriter = {
dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.concurrent.duration._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -62,7 +62,14 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
val hashed = if (left.codegenEnabled &&
buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) {
UnsafeHashedRelation(input.iterator,
buildKeys.map(BindReferences.bindReference(_, buildPlan.output)),
buildPlan.schema)
} else {
HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
}
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql.execution.joins

import java.io.{ObjectInput, ObjectOutput, Externalizable}
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.util.{HashMap => JavaHashMap}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Projection
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection, UnsafeRow, Projection}
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.collection.CompactBuffer


Expand Down Expand Up @@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
}
}


// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.


Expand Down Expand Up @@ -148,3 +148,91 @@ private[joins] object HashedRelation {
}
}
}


/**
* A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a
* sequence of values.
*
* TODO(davies): use BytesToBytesMap
*/
private[joins] final class UnsafeHashedRelation(
private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]],
private var keyTypes: Array[DataType])
extends HashedRelation with Externalizable {

def this() = this(null, null) // Needed for serialization

// UnsafeProjection is not thread safe
@transient lazy val keyProjection = new ThreadLocal[UnsafeProjection]

override def get(key: InternalRow): CompactBuffer[InternalRow] = {
val unsafeKey = if (key.isInstanceOf[UnsafeRow]) {
key.asInstanceOf[UnsafeRow]
} else {
var proj = keyProjection.get()
if (proj eq null) {
proj = UnsafeProjection.create(keyTypes)
keyProjection.set(proj)
}
proj(key)
}
// reply on type erasure in Scala
hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
}

override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(keyTypes))
val bytes = SparkSqlSerializer.serialize(hashTable)
println(s"before write ${hashTable}")
println(s"write bytes ${bytes.toString}")
writeBytes(out, bytes)
}

override def readExternal(in: ObjectInput): Unit = {
keyTypes = SparkSqlSerializer.deserialize(readBytes(in))
hashTable = SparkSqlSerializer.deserialize(readBytes(in))
println(s"loaded ${hashTable}")
}
}

private[joins] object UnsafeHashedRelation {

def apply(
input: Iterator[InternalRow],
buildKey: Seq[Expression],
rowSchema: StructType,
sizeEstimate: Int = 64): HashedRelation = {

// TODO: Use BytesToBytesMap.
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
var currentRow: InternalRow = null
val rowProj = UnsafeProjection.create(rowSchema)
val keyGenerator = UnsafeProjection.create(buildKey)

// Create a mapping of buildKeys -> rows
while (input.hasNext) {
currentRow = input.next()
val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) {
currentRow.asInstanceOf[UnsafeRow]
} else {
rowProj(currentRow)
}
val rowKey = keyGenerator(unsafeRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new CompactBuffer[UnsafeRow]()
hashTable.put(rowKey.copy(), newMatchList)
newMatchList
} else {
existingMatchList
}
matchList += unsafeRow.copy()
}
}

val keySchema = buildKey.map(_.dataType).toArray
new UnsafeHashedRelation(hashTable, keySchema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

Expand All @@ -44,8 +44,16 @@ case class ShuffledHashJoin(
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

protected override def doExecute(): RDD[InternalRow] = {
val codegenEnabled = left.codegenEnabled
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
val hashed = if (codegenEnabled &&
buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) {
UnsafeHashedRelation(buildIter,
buildKeys.map(BindReferences.bindReference(_, buildPlan.output)),
buildPlan.schema)
} else {
HashedRelation(buildIter, buildSideKeyGenerator)
}
hashJoin(streamIter, hashed)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Projection
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
import org.apache.spark.util.collection.CompactBuffer


Expand All @@ -35,29 +37,53 @@ class HashedRelationSuite extends SparkFunSuite {
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])

assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
assert(hashed.get(InternalRow(10)) === null)

val data2 = CompactBuffer[InternalRow](data(2))
data2 += data(2)
assert(hashed.get(data(2)) == data2)
assert(hashed.get(data(2)) === data2)
}

test("UniqueKeyHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])

assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0)))
assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1)))
assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2)))
assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0)))
assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1)))
assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2)))
assert(hashed.get(InternalRow(10)) === null)

val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
assert(uniqHashed.getValue(data(0)) == data(0))
assert(uniqHashed.getValue(data(1)) == data(1))
assert(uniqHashed.getValue(data(2)) == data(2))
assert(uniqHashed.getValue(InternalRow(10)) == null)
assert(uniqHashed.getValue(data(0)) === data(0))
assert(uniqHashed.getValue(data(1)) === data(1))
assert(uniqHashed.getValue(data(2)) === data(2))
assert(uniqHashed.getValue(InternalRow(10)) === null)
}

test("UnsafeHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val buildKey = Seq(BoundReference(0, IntegerType, false))
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema)
assert(hashed.isInstanceOf[UnsafeHashedRelation])

val toUnsafe = UnsafeProjection.create(schema)
assert(hashed.get(data(0)) === CompactBuffer[UnsafeRow](toUnsafe(data(0))))
assert(hashed.get(data(1)) === CompactBuffer[UnsafeRow](toUnsafe(data(1))))
assert(hashed.get(InternalRow(10)) === null)

val data2 = CompactBuffer[InternalRow](toUnsafe(data(2)).copy())
data2 += toUnsafe(data(2)).copy()
assert(hashed.get(data(2)) === data2)

val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed))
.asInstanceOf[UnsafeHashedRelation]
assert(hashed2.get(data(0)) === CompactBuffer[UnsafeRow](toUnsafe(data(0))))
assert(hashed2.get(data(1)) === CompactBuffer[UnsafeRow](toUnsafe(data(1))))
assert(hashed2.get(InternalRow(10)) === null)
assert(hashed2.get(data(2)) === data2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) {
*/
public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) {
long addr = baseOffset;
for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) {
for (int i = 0; i < bitSetWidthInWords; i += 8 * WORD_SIZE, addr += WORD_SIZE) {
if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,16 @@ public int hashInt(int input) {
return fmix(h1, 4);
}

public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) {
public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
return hashUnsafeWords(base, offset, lengthInBytes, seed);
}

public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
// This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
int h1 = seed;
for (int offset = 0; offset < lengthInBytes; offset += 4) {
int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
for (int i = 0; i < lengthInBytes; i += 4) {
int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i);
int k1 = mixK1(halfWord);
h1 = mixH1(h1, k1);
}
Expand Down

0 comments on commit bea4a50

Please sign in to comment.