diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 87294a0e21441..14fba3e655748 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.UTF8String; @@ -345,7 +346,7 @@ public double getDouble(int i) { * This method is only supported on UnsafeRows that do not use ObjectPools. */ @Override - public InternalRow copy() { + public UnsafeRow copy() { if (pool != null) { throw new UnsupportedOperationException( "Copy is not supported for UnsafeRows that use object pools"); @@ -365,6 +366,33 @@ public InternalRow copy() { } } + @Override + public int hashCode() { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UnsafeRow) { + UnsafeRow o = (UnsafeRow) other; + return ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, + sizeInBytes); + } + return false; + } + + // This is for debugging + @Override + public String toString(){ + StringBuilder build = new StringBuilder("["); + for (int i = 0; i < sizeInBytes; i += 8) { + build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)); + build.append(','); + } + build.append(']'); + return build.toString(); + } + @Override public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 24b01ea55110e..43578b52c0026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -85,8 +85,12 @@ abstract class UnsafeProjection extends Projection { object UnsafeProjection { def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) - def create(fields: Seq[DataType]): UnsafeProjection = { + def create(fields: Array[DataType]): UnsafeProjection = { val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) + create(exprs) + } + + def create(exprs: Seq[Expression]): UnsafeProjection = { GenerateUnsafeProjection.generate(exprs) } } @@ -96,6 +100,8 @@ object UnsafeProjection { */ case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { + def this(schema: StructType) = this(schema.fields.map(_.dataType)) + private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => new BoundReference(idx, dt, true) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 885ab091fcdf5..702deb04acb67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -111,7 +111,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { /** * Function for writing a column into an UnsafeRow. */ -private abstract class UnsafeColumnWriter { +abstract class UnsafeColumnWriter { /** * Write a value into an UnsafeRow. * @@ -130,7 +130,7 @@ private abstract class UnsafeColumnWriter { def getSize(source: InternalRow, column: Int): Int } -private object UnsafeColumnWriter { +object UnsafeColumnWriter { def forType(dataType: DataType): UnsafeColumnWriter = { dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 7ffdce60d2955..03a51afa6f555 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -23,7 +23,7 @@ import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils @@ -62,7 +62,14 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + val hashed = if (left.codegenEnabled && + buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { + UnsafeHashedRelation(input.iterator, + buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), + buildPlan.schema) + } else { + HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + } sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6b51f5d4151d3..0e1089bf4b643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.joins -import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection, UnsafeRow, Projection} import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR } } - // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. @@ -148,3 +148,91 @@ private[joins] object HashedRelation { } } } + + +/** + * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a + * sequence of values. + * + * TODO(davies): use BytesToBytesMap + */ +private[joins] final class UnsafeHashedRelation( + private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]], + private var keyTypes: Array[DataType]) + extends HashedRelation with Externalizable { + + def this() = this(null, null) // Needed for serialization + + // UnsafeProjection is not thread safe + @transient lazy val keyProjection = new ThreadLocal[UnsafeProjection] + + override def get(key: InternalRow): CompactBuffer[InternalRow] = { + val unsafeKey = if (key.isInstanceOf[UnsafeRow]) { + key.asInstanceOf[UnsafeRow] + } else { + var proj = keyProjection.get() + if (proj eq null) { + proj = UnsafeProjection.create(keyTypes) + keyProjection.set(proj) + } + proj(key) + } + // reply on type erasure in Scala + hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + } + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(keyTypes)) + val bytes = SparkSqlSerializer.serialize(hashTable) + println(s"before write ${hashTable}") + println(s"write bytes ${bytes.toString}") + writeBytes(out, bytes) + } + + override def readExternal(in: ObjectInput): Unit = { + keyTypes = SparkSqlSerializer.deserialize(readBytes(in)) + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + println(s"loaded ${hashTable}") + } +} + +private[joins] object UnsafeHashedRelation { + + def apply( + input: Iterator[InternalRow], + buildKey: Seq[Expression], + rowSchema: StructType, + sizeEstimate: Int = 64): HashedRelation = { + + // TODO: Use BytesToBytesMap. + val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) + var currentRow: InternalRow = null + val rowProj = UnsafeProjection.create(rowSchema) + val keyGenerator = UnsafeProjection.create(buildKey) + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + currentRow = input.next() + val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { + currentRow.asInstanceOf[UnsafeRow] + } else { + rowProj(currentRow) + } + val rowKey = keyGenerator(unsafeRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(rowKey.copy(), newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += unsafeRow.copy() + } + } + + val keySchema = buildKey.map(_.dataType).toArray + new UnsafeHashedRelation(hashTable, keySchema) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5439e10a60b2a..8f5b77bdfe217 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeColumnWriter, Expression} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -44,8 +44,16 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val codegenEnabled = left.codegenEnabled buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + val hashed = if (codegenEnabled && + buildKeys.map(_.dataType).forall(UnsafeColumnWriter.canEmbed(_))) { + UnsafeHashedRelation(buildIter, + buildKeys.map(BindReferences.bindReference(_, buildPlan.output)), + buildPlan.schema) + } else { + HashedRelation(buildIter, buildSideKeyGenerator) + } hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9d9858b1c6151..28194df7be906 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.{StructField, StructType, IntegerType} import org.apache.spark.util.collection.CompactBuffer @@ -35,13 +37,13 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) assert(hashed.get(InternalRow(10)) === null) val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) - assert(hashed.get(data(2)) == data2) + assert(hashed.get(data(2)) === data2) } test("UniqueKeyHashedRelation") { @@ -49,15 +51,39 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) == data(0)) - assert(uniqHashed.getValue(data(1)) == data(1)) - assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(InternalRow(10)) == null) + assert(uniqHashed.getValue(data(0)) === data(0)) + assert(uniqHashed.getValue(data(1)) === data(1)) + assert(uniqHashed.getValue(data(2)) === data(2)) + assert(uniqHashed.getValue(InternalRow(10)) === null) + } + + test("UnsafeHashedRelation") { + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val buildKey = Seq(BoundReference(0, IntegerType, false)) + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema) + assert(hashed.isInstanceOf[UnsafeHashedRelation]) + + val toUnsafe = UnsafeProjection.create(schema) + assert(hashed.get(data(0)) === CompactBuffer[UnsafeRow](toUnsafe(data(0)))) + assert(hashed.get(data(1)) === CompactBuffer[UnsafeRow](toUnsafe(data(1)))) + assert(hashed.get(InternalRow(10)) === null) + + val data2 = CompactBuffer[InternalRow](toUnsafe(data(2)).copy()) + data2 += toUnsafe(data(2)).copy() + assert(hashed.get(data(2)) === data2) + + val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) + .asInstanceOf[UnsafeHashedRelation] + assert(hashed2.get(data(0)) === CompactBuffer[UnsafeRow](toUnsafe(data(0)))) + assert(hashed2.get(data(1)) === CompactBuffer[UnsafeRow](toUnsafe(data(1)))) + assert(hashed2.get(InternalRow(10)) === null) + assert(hashed2.get(data(2)) === data2) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 27462c7fa5e62..a936a9b81e16e 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -72,7 +72,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { */ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; - for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { + for (int i = 0; i < bitSetWidthInWords; i += 8 * WORD_SIZE, addr += WORD_SIZE) { if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) { return true; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 85cd02469adb7..61f483ced3217 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -44,12 +44,16 @@ public int hashInt(int input) { return fmix(h1, 4); } - public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) { + public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; - for (int offset = 0; offset < lengthInBytes; offset += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); }