From 2635c4146bfa82396829c7f6d24cbda2ba6c9f53 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Mar 2016 08:43:02 -0800 Subject: [PATCH] [SPARK-13511] [SQL] Add wholestage codegen for limit JIRA: https://issues.apache.org/jira/browse/SPARK-13511 ## What changes were proposed in this pull request? Current limit operator doesn't support wholestage codegen. This is open to add support for it. In the `doConsume` of `GlobalLimit` and `LocalLimit`, we use a count term to count the processed rows. Once the row numbers catches the limit number, we set the variable `stopEarly` of `BufferedRowIterator` newly added in this pr to `true` that indicates we want to stop processing remaining rows. Then when the wholestage codegen framework checks `shouldStop()`, it will stop the processing of the row iterator. Before this, the executed plan for a query `sqlContext.range(N).limit(100).groupBy().sum()` is: TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Final,isDistinct=false)], output=[sum(id)#6L]) +- TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Partial,isDistinct=false)], output=[sum#9L]) +- GlobalLimit 100 +- Exchange SinglePartition, None +- LocalLimit 100 +- Range 0, 1, 1, 524288000, [id#5L] After add wholestage codegen support: WholeStageCodegen : +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Final,isDistinct=false)], output=[sum(id)#41L]) : +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Partial,isDistinct=false)], output=[sum#44L]) : +- GlobalLimit 100 : +- INPUT +- Exchange SinglePartition, None +- WholeStageCodegen : +- LocalLimit 100 : +- Range 0, 1, 1, 524288000, [id#40L] ## How was this patch tested? A test is added into BenchmarkWholeStageCodegen. Author: Liang-Chi Hsieh Closes #11391 from viirya/wholestage-limit. --- .../apache/spark/sql/execution/limit.scala | 35 +++++++++++++++++-- .../BenchmarkWholeStageCodegen.scala | 14 ++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cd543d4195286..45175d36d5c9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { /** * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. */ -trait BaseLimit extends UnaryNode { +trait BaseLimit extends UnaryNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val stopEarly = ctx.freshName("stopEarly") + ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + + ctx.addNewFunction("shouldStop", s""" + @Override + protected boolean shouldStop() { + return !currentRows.isEmpty() || $stopEarly; + } + """) + val countTerm = ctx.freshName("count") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + s""" + | if ($countTerm < $limit) { + | $countTerm += 1; + | ${consume(ctx, input)} + | } else { + | $stopEarly = true; + | } + """.stripMargin + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 6d6cc0186a962..2d3e34d0e1292 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } + ignore("range/limit/sum") { + val N = 500 << 20 + runBenchmark("range/limit/sum", N) { + sqlContext.range(N).limit(1000000).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X + range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X + */ + } + ignore("stat functions") { val N = 100 << 20