From a05b4f6e0eb6cf9d983120f470571cb8fdefa3ee Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Jul 2015 13:55:23 -0700 Subject: [PATCH] support UnsafeRow in LeftSemiJoinBNL and BroadcastNestedLoopJoin --- .../expressions/UnsafeRowConverter.scala | 4 +-- .../joins/BroadcastNestedLoopJoin.scala | 35 ++++++++++++------- .../sql/execution/joins/LeftSemiJoinBNL.scala | 3 ++ .../sql/execution/rowFormatConverters.scala | 18 +++++++--- 4 files changed, 41 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 702deb04acb67..885ab091fcdf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -111,7 +111,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { /** * Function for writing a column into an UnsafeRow. */ -abstract class UnsafeColumnWriter { +private abstract class UnsafeColumnWriter { /** * Write a value into an UnsafeRow. * @@ -130,7 +130,7 @@ abstract class UnsafeColumnWriter { def getSize(source: InternalRow, column: Int): Int } -object UnsafeColumnWriter { +private object UnsafeColumnWriter { def forType(dataType: DataType): UnsafeColumnWriter = { dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 60b4266fad8b1..7be83ddfc17a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -44,6 +44,17 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + + @transient private[this] lazy val resultProjection: Projection = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + ((r: InternalRow) => r).asInstanceOf[Projection] + } + } + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = { @@ -74,6 +85,7 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -86,11 +98,11 @@ case class BroadcastNestedLoopJoin( val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case _ => @@ -100,9 +112,9 @@ case class BroadcastNestedLoopJoin( (streamRowMatched, joinType, buildSide) match { case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() + matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy() case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() + matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy() case _ => } } @@ -110,12 +122,9 @@ case class BroadcastNestedLoopJoin( } val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } + val allIncludedBroadcastTuples = includedBroadcastTuples.fold( + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + )(_ ++ _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -127,8 +136,10 @@ case class BroadcastNestedLoopJoin( while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => + buf += resultProjection(new JoinedRow(leftNulls, rel(i))) + case (LeftOuter | FullOuter, BuildLeft) => + buf += resultProjection(new JoinedRow(rel(i), rightNulls)) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index db5be9f453674..4443455ef11fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -39,6 +39,9 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output + override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 14a9f0523ae0d..29f3beb3cb3c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -96,11 +96,19 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { } case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, then convert everything - // to unsafe rows - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + // If this operator's children produce both unsafe and safe rows, + // convert everything unsafe rows if all the schema of them are support by UnsafeRow + if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } } } } else {