From 87d890cc105a7f41478433b28f53c9aa431db211 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 21 Jul 2015 11:18:39 -0700 Subject: [PATCH 01/14] Revert "[SPARK-9154] [SQL] codegen StringFormat" This reverts commit 7f072c3d5ec50c65d76bd9f28fac124fce96a89e. Revert #7546 Author: Michael Armbrust Closes #7570 from marmbrus/revert9154 and squashes the following commits: ed2c32a [Michael Armbrust] Revert "[SPARK-9154] [SQL] codegen StringFormat" --- .../expressions/stringOperations.scala | 42 +------------------ .../expressions/StringExpressionsSuite.scala | 18 ++++---- .../spark/sql/StringFunctionsSuite.scala | 10 ----- 3 files changed, 11 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 280ae0e546358..fe57d17f1ec14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { +case class StringFormat(children: Expression*) extends Expression with CodegenFallback { require(children.nonEmpty, "printf() should take at least 1 argument") @@ -536,10 +536,6 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC private def format: Expression = children(0) private def args: Seq[Expression] = children.tail - override def inputTypes: Seq[AbstractDataType] = - children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType) - - override def eval(input: InternalRow): Any = { val pattern = format.eval(input) if (pattern == null) { @@ -555,42 +551,6 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val pattern = children.head.gen(ctx) - - val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) - val argListCode = argListGen.map(_._2.code + "\n") - - val argListString = argListGen.foldLeft("")((s, v) => { - val nullSafeString = - if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { - // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" - } else { - s"(${v._2.isNull}) ? null : ${v._2.primitive}" - } - s + "," + nullSafeString - }) - - val form = ctx.freshName("formatter") - val formatter = classOf[java.util.Formatter].getName - val sb = ctx.freshName("sb") - val stringBuffer = classOf[StringBuffer].getName - s""" - ${pattern.code} - boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${argListCode.mkString} - $stringBuffer $sb = new $stringBuffer(); - $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); - $form.format(${pattern.primitive}.toString() $argListString); - ${ev.primitive} = UTF8String.fromString($sb.toString()); - } - """ - } - override def prettyName: String = "printf" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3c2d88731beb4..96c540ab36f08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,16 +351,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + val f = 'f.string.at(0) + val d1 = 'd.int.at(1) + val s1 = 's.int.at(2) + + val row1 = create_row("aa%d%s", 12, "cc") + val row2 = create_row(null, 12, "cc") + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) - checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) - checkEvaluation( - StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") - checkEvaluation( - StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") + checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) + checkEvaluation(StringFormat(f, d1, s1), null, row2) } test("INSTR") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3702e73b4e74f..d1f855903ca4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -132,16 +132,6 @@ class StringFunctionsSuite extends QueryTest { checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) - - val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") - - checkAnswer( - df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) - - checkAnswer( - df2.selectExpr("printf(a, b, c)"), - Row("aa123cc")) } test("string instr function") { From 9ba7c64decfc92853bd281e9e7bfb95211080dd4 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Tue, 21 Jul 2015 11:52:52 -0700 Subject: [PATCH 02/14] [SPARK-8357] Fix unsafe memory leak on empty inputs in GeneratedAggregate This patch fixes a managed memory leak in GeneratedAggregate. The leak occurs when the unsafe aggregation path is used to perform grouped aggregation on an empty input; in this case, GeneratedAggregate allocates an UnsafeFixedWidthAggregationMap that is never cleaned up because `next()` is never called on the aggregate result iterator. This patch fixes this by short-circuiting on empty inputs. This patch is an updated version of #6810. Closes #6810. Author: navis.ryu Author: Josh Rosen Closes #7560 from JoshRosen/SPARK-8357 and squashes the following commits: 3486ce4 [Josh Rosen] Some minor cleanup c649310 [Josh Rosen] Revert SparkPlan change: 3c7db0f [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-8357 adc8239 [Josh Rosen] Back out Projection changes. c5419b3 [navis.ryu] addressed comments 143e1ef [navis.ryu] fixed format & added test for CCE case 735972f [navis.ryu] used new conf apis 1a02a55 [navis.ryu] Rolled-back test-conf cleanup & fixed possible CCE & added more tests 51178e8 [navis.ryu] addressed comments 4d326b9 [navis.ryu] fixed test fails 15c5afc [navis.ryu] added a test as suggested by JoshRosen d396589 [navis.ryu] added comments 1b07556 [navis.ryu] [SPARK-8357] [SQL] Memory leakage on unsafe aggregation path with empty input --- .../sql/execution/GeneratedAggregate.scala | 14 +++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 ++++ .../spark/sql/execution/AggregateSuite.scala | 48 +++++++++++++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index c069da016f9f0..ecde9c57139a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -266,7 +266,18 @@ case class GeneratedAggregate( val joinedRow = new JoinedRow3 - if (groupingExpressions.isEmpty) { + if (!iter.hasNext) { + // This is an empty input, so return early so that we do not allocate data structures + // that won't be cleaned up (see SPARK-8357). + if (groupingExpressions.isEmpty) { + // This is a global aggregate, so return an empty aggregation buffer. + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(newAggregationBuffer(EmptyRow))) + } else { + // This is a grouped aggregate, so return an empty iterator. + Iterator[InternalRow]() + } + } else if (groupingExpressions.isEmpty) { // TODO: Codegening anything other than the updateProjection is probably over kill. val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] var currentRow: InternalRow = null @@ -280,6 +291,7 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) } else if (unsafeEnabled) { + assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") val aggregationMap = new UnsafeFixedWidthAggregationMap( newAggregationBuffer, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 61d5f2061ae18..beee10173fbc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -648,6 +648,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(2, 1, 2, 2, 1)) } + test("count of empty table") { + withTempTable("t") { + Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t") + checkAnswer( + sql("select count(a) from t"), + Row(0)) + } + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala new file mode 100644 index 0000000000000..20def6bef0c17 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.test.TestSQLContext + +class AggregateSuite extends SparkPlanTest { + + test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") { + val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED) + val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED) + try { + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true) + TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true) + val df = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + df, + GeneratedAggregate( + partial = true, + Seq(df.col("b").expr), + Seq(Alias(Count(df.col("a").expr), "cnt")()), + unsafeEnabled = true, + _: SparkPlan), + Seq.empty + ) + } finally { + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault) + } + } +} From 60c0ce134d90ef18852ed2c637d2f240b7f99ab9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 21 Jul 2015 11:56:38 -0700 Subject: [PATCH 03/14] [SPARK-8906][SQL] Move all internal data source classes into execution.datasources. This way, the sources package contains only public facing interfaces. Author: Reynold Xin Closes #7565 from rxin/move-ds and squashes the following commits: 7661aff [Reynold Xin] Mima 9d5196a [Reynold Xin] Rearranged imports. 3dd7174 [Reynold Xin] [SPARK-8906][SQL] Move all internal data source classes into execution.datasources. --- project/MimaExcludes.scala | 47 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 2 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 9 ++-- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../SqlNewHadoopRDD.scala | 9 ++-- .../datasources}/DataSourceStrategy.scala | 11 ++--- .../datasources}/LogicalRelation.scala | 7 +-- .../datasources}/PartitioningUtils.scala | 5 +- .../datasources}/commands.scala | 5 +- .../datasources}/ddl.scala | 9 ++-- .../datasources}/rules.scala | 10 ++-- .../apache/spark/sql/parquet/newParquet.scala | 5 +- .../apache/spark/sql/sources/filters.scala | 4 ++ .../apache/spark/sql/sources/interfaces.scala | 4 +- .../org/apache/spark/sql/json/JsonSuite.scala | 2 +- .../sql/parquet/ParquetFilterSuite.scala | 2 +- .../ParquetPartitionDiscoverySuite.scala | 4 +- .../sources/CreateTableAsSelectSuite.scala | 1 + .../sql/sources/ResolvedDataSourceSuite.scala | 1 + .../apache/spark/sql/hive/HiveContext.scala | 6 +-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 11 +++-- .../org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- .../spark/sql/hive/execution/commands.scala | 1 + .../spark/sql/hive/orc/OrcRelation.scala | 1 + .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../hive/execution/HiveComparisonTest.scala | 4 +- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 2 +- .../sql/sources/hadoopFsRelationSuites.scala | 6 ++- 32 files changed, 124 insertions(+), 62 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{sources => execution}/SqlNewHadoopRDD.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{sources => execution/datasources}/DataSourceStrategy.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{sources => execution/datasources}/LogicalRelation.scala (88%) rename sql/core/src/main/scala/org/apache/spark/sql/{sources => execution/datasources}/PartitioningUtils.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{sources => execution/datasources}/commands.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{sources => execution/datasources}/ddl.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{sources => execution/datasources}/rules.scala (94%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a2595ff6c22f4..fa36629c37a35 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -104,6 +104,53 @@ object MimaExcludes { // SPARK-7422 add argmax for sparse vectors ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Vector.argmax") + ) ++ Seq( + // SPARK-8906 Move all internal data source classes into execution.datasources + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException") ) case v if v.startsWith("1.4") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 830fba35bb7bc..323ff17357fda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} +import org.apache.spark.sql.execution.datasources.CreateTableUsingAsSelect import org.apache.spark.sql.json.JacksonGenerator -import org.apache.spark.sql.sources.CreateTableUsingAsSelect import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f1c1ddf898986..e9d782cdcd667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql import java.util.Properties import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, Partition} +import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json.JSONRelation import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3e7b9cd7976c3..ee0201a9d4cb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -22,8 +22,8 @@ import java.util.Properties import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} -import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 2dda3ad1211fa..8b4528b5d52fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -39,8 +39,9 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} -import org.apache.spark.sql.execution.{Filter, _} -import org.apache.spark.sql.sources._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -146,11 +147,11 @@ class SQLContext(@transient val sparkContext: SparkContext) new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = ExtractPythonUDFs :: - sources.PreInsertCastAndRename :: + PreInsertCastAndRename :: Nil override val extendedCheckRules = Seq( - sources.PreWriteCheck(catalog) + datasources.PreWriteCheck(catalog) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 240332a80af0f..8cef7f200d2dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -25,10 +26,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.parquet._ -import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala index 2bdc341021256..e1c1a6c06268f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala @@ -15,24 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution import java.text.SimpleDateFormat import java.util.Date +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.spark.broadcast.Broadcast - -import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.{RDD, HadoopRDD} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD +import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 70c9e06927582..2b400926177fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -15,22 +15,21 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.{InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql.{SaveMode, Strategy, execution, sources} -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A Strategy for planning scans over data sources defined using the sources API. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index f374abffdd505..a7123dc845fa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -14,11 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.sources.BaseRelation /** * Used to link a [[BaseRelation]] in to a logical query plan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 8b2a45d8e970a..6b4a359db22d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources -import java.lang.{Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong} +import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import scala.collection.mutable.ArrayBuffer @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ + private[sql] case class Partition(values: InternalRow, path: String) private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala index 5c6ef2dc90c73..84a0441e145c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} @@ -24,7 +24,6 @@ import scala.collection.JavaConversions.asScalaIterator import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} - import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil @@ -35,9 +34,11 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StringType import org.apache.spark.util.SerializableConfiguration + private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, query: LogicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 5a8c97c773ee6..c8033d3c0470a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -15,23 +15,22 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources import scala.language.{existentials, implicitConversions} import scala.util.matching.Regex import org.apache.hadoop.fs.Path - import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, InternalRow} import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} import org.apache.spark.util.Utils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 40ee048e2653e..11bb49b8d83de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.{SaveMode, AnalysisException} -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, Catalog} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Alias} +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.analysis.{Catalog, EliminateSubQueries} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation} /** * A rule to do pre-insert data type casting and field renaming. Before we insert into diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index e683eb0126004..2f9f880c70690 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -35,15 +35,18 @@ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} +import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} + private[sql] class DefaultSource extends HadoopFsRelationProvider { override def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 24e86ca415c51..4d942e4f9287a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.sources +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines all the filters that we can push down to the data sources. +//////////////////////////////////////////////////////////////////////////////////////////////////// + /** * A filter predicate for data sources. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 2cd8b358d81c6..7cd005b959488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.conf.Configuration @@ -33,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.RDDConversions +import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.StructType import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration @@ -523,7 +523,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio }) } - private[sources] final def buildScan( + private[sql] final def buildScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 3475f9dd6787e..1d04513a44672 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -26,8 +26,8 @@ import org.scalactic.Tolerance._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.json.InferSchema.compatibleType -import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index a2763c78b6450..23df102cd951d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -24,7 +24,7 @@ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 37b0a9fbf7a4e..4f98776b91160 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -28,11 +28,11 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.sources.PartitioningUtils._ -import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} import org.apache.spark.sql.types._ import org.apache.spark.sql._ import org.apache.spark.unsafe.types.UTF8String +import PartitioningUtils._ // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index a71088430bfd5..1907e643c85dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -22,6 +22,7 @@ import java.io.{File, IOException} import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.util.Utils class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 296b0d6f74a0c..3cbf5467b253a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.ResolvedDataSource class ResolvedDataSourceSuite extends SparkFunSuite { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4684d48aff889..cec7685bb6859 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -44,9 +44,9 @@ import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} +import org.apache.spark.sql.execution.datasources.{PreWriteCheck, PreInsertCastAndRename, DataSourceStrategy} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -384,11 +384,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { catalog.PreInsertionCasts :: ExtractPythonUDFs :: ResolveHiveWindowFunction :: - sources.PreInsertCastAndRename :: + PreInsertCastAndRename :: Nil override val extendedCheckRules = Seq( - sources.PreWriteCheck(catalog) + PreWriteCheck(catalog) ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index b15261b7914dd..0a2121c955871 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import scala.collection.JavaConversions._ + import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} @@ -28,6 +30,7 @@ import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -35,14 +38,12 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, PartitionSpec, CreateTableUsingAsSelect, ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} -/* Implicit conversions */ -import scala.collection.JavaConversions._ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) extends Catalog with Logging { @@ -278,7 +279,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive parquetRelation.paths.toSet == pathsInMetastore.toSet && logical.schema.sameType(metastoreSchema) && parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { - PartitionSpec(StructType(Nil), Array.empty[sources.Partition]) + PartitionSpec(StructType(Nil), Array.empty[datasources.Partition]) } if (useCached) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7fc517b646b20..f5574509b0b38 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand -import org.apache.spark.sql.sources.DescribeCommand +import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9638a8201e190..a22c3292eff94 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.types.StringType diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 71fa3e9c33ad9..a47f9a4feb21b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 48d35a60a759b..de63ee56dd8e6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -37,6 +37,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index d910af22c3dd1..e403f32efaf91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -28,12 +28,12 @@ import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index c9dd4c0935a72..efb04bf3d5097 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -22,11 +22,11 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.sources.DescribeCommand -import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} +import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.test.TestHive /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 05a1f0094e5e1..03428265422e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -23,12 +23,12 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ case class Nested1(f1: Nested2) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 9d79a4b007d66..82a8daf8b4b09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -23,12 +23,12 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} +import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} -import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index afecf9675e11f..1cef83fd5e990 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.sources -import scala.collection.JavaConversions._ - import java.io.File +import scala.collection.JavaConversions._ + import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -31,10 +31,12 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ + abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { override lazy val sqlContext: SQLContext = TestHive From c07838b5a9cdf96c0f49055ea1c397e0f0e915d2 Mon Sep 17 00:00:00 2001 From: Dennis Huo Date: Tue, 21 Jul 2015 13:12:11 -0700 Subject: [PATCH 04/14] [SPARK-9206] [SQL] Fix HiveContext classloading for GCS connector. IsolatedClientLoader.isSharedClass includes all of com.google.\*, presumably for Guava, protobuf, and/or other shared Google libraries, but needs to count com.google.cloud.\* as "hive classes" when determining which ClassLoader to use. Otherwise, things like HiveContext.parquetFile will throw a ClassCastException when fs.defaultFS is set to a Google Cloud Storage (gs://) path. On StackOverflow: http://stackoverflow.com/questions/31478955 EDIT: Adding yhuai who worked on the relevant classloading isolation pieces. Author: Dennis Huo Closes #7549 from dennishuo/dhuo-fix-hivecontext-gcs and squashes the following commits: 1f8db07 [Dennis Huo] Fix HiveContext classloading for GCS connector. --- .../org/apache/spark/sql/hive/client/IsolatedClientLoader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 3d609a66f3664..97fb98199991b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -125,7 +125,7 @@ private[hive] class IsolatedClientLoader( name.contains("log4j") || name.startsWith("org.apache.spark.") || name.startsWith("scala.") || - name.startsWith("com.google") || + (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || name.startsWith("java.lang.") || name.startsWith("java.net") || sharedPrefixes.exists(name.startsWith) From d4c7a7a3642a74ad40093c96c4bf45a62a470605 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Tue, 21 Jul 2015 15:47:40 -0700 Subject: [PATCH 05/14] [SPARK-9154] [SQL] codegen StringFormat Jira: https://issues.apache.org/jira/browse/SPARK-9154 fixes bug of #7546 marmbrus I can't reopen the other PR, because I didn't closed it. Can you trigger Jenkins? Author: Tarek Auel Closes #7571 from tarekauel/SPARK-9154 and squashes the following commits: dcae272 [Tarek Auel] [SPARK-9154][SQL] build fix 1487602 [Tarek Auel] Merge remote-tracking branch 'upstream/master' into SPARK-9154 f512c5f [Tarek Auel] [SPARK-9154][SQL] build fix a943d3e [Tarek Auel] [SPARK-9154] implicit input cast, added tests for null, support for null primitives 10b4de8 [Tarek Auel] [SPARK-9154][SQL] codegen removed fallback trait cd8322b [Tarek Auel] [SPARK-9154][SQL] codegen string format 086caba [Tarek Auel] [SPARK-9154][SQL] codegen string format --- .../expressions/stringOperations.scala | 42 ++++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 18 ++++---- .../org/apache/spark/sql/functions.scala | 11 +++++ .../spark/sql/StringFunctionsSuite.scala | 10 +++++ 4 files changed, 70 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index fe57d17f1ec14..1f18a6e9ff8a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with CodegenFallback { +case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, "printf() should take at least 1 argument") @@ -536,6 +536,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa private def format: Expression = children(0) private def args: Seq[Expression] = children.tail + override def inputTypes: Seq[AbstractDataType] = + StringType :: List.fill(children.size - 1)(AnyDataType) + + override def eval(input: InternalRow): Any = { val pattern = format.eval(input) if (pattern == null) { @@ -551,6 +555,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val pattern = children.head.gen(ctx) + + val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) + val argListCode = argListGen.map(_._2.code + "\n") + + val argListString = argListGen.foldLeft("")((s, v) => { + val nullSafeString = + if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + // Java primitives get boxed in order to allow null values. + s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + + s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" + } else { + s"(${v._2.isNull}) ? null : ${v._2.primitive}" + } + s + "," + nullSafeString + }) + + val form = ctx.freshName("formatter") + val formatter = classOf[java.util.Formatter].getName + val sb = ctx.freshName("sb") + val stringBuffer = classOf[StringBuffer].getName + s""" + ${pattern.code} + boolean ${ev.isNull} = ${pattern.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${argListCode.mkString} + $stringBuffer $sb = new $stringBuffer(); + $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); + $form.format(${pattern.primitive}.toString() $argListString); + ${ev.primitive} = UTF8String.fromString($sb.toString()); + } + """ + } + override def prettyName: String = "printf" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 96c540ab36f08..3c2d88731beb4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - val f = 'f.string.at(0) - val d1 = 'd.int.at(1) - val s1 = 's.int.at(2) - - val row1 = create_row("aa%d%s", 12, "cc") - val row2 = create_row(null, 12, "cc") - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") - checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) - checkEvaluation(StringFormat(f, d1, s1), null, row2) + checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation( + StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + checkEvaluation( + StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d94d7335828c5..e5ff8ae7e3179 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1741,6 +1741,17 @@ object functions { */ def rtrim(e: Column): Column = StringTrimRight(e.expr) + /** + * Format strings in printf-style. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: Column, arguments: Column*): Column = { + StringFormat((format +: arguments).map(_.expr): _*) + } + /** * Format strings in printf-style. * NOTE: `format` is the string value of the formatter, not column name. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index d1f855903ca4b..3702e73b4e74f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -132,6 +132,16 @@ class StringFunctionsSuite extends QueryTest { checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) + + val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df2.selectExpr("printf(a, b, c)"), + Row("aa123cc")) } test("string instr function") { From a4c83cb1e4b066cd60264b6572fd3e51d160d26a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 21 Jul 2015 19:14:07 -0700 Subject: [PATCH 06/14] [SPARK-9154][SQL] Rename formatString to format_string. Also make format_string the canonical form, rather than printf. Author: Reynold Xin Closes #7579 from rxin/format_strings and squashes the following commits: 53ee54f [Reynold Xin] Fixed unit tests. 52357e1 [Reynold Xin] Add format_string alias. b40a42a [Reynold Xin] [SPARK-9154][SQL] Rename formatString to format_string. --- .../catalyst/analysis/FunctionRegistry.scala | 3 ++- .../expressions/stringOperations.scala | 13 +++++-------- .../expressions/StringExpressionsSuite.scala | 14 +++++++------- .../scala/org/apache/spark/sql/functions.scala | 18 +++--------------- .../spark/sql/StringFunctionsSuite.scala | 12 +----------- 5 files changed, 18 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e3d8d2adf2135..9c349838c28a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,7 +168,8 @@ object FunctionRegistry { expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), - expression[StringFormat]("printf"), + expression[FormatString]("format_string"), + expression[FormatString]("printf"), expression[StringRPad]("rpad"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 1f18a6e9ff8a5..cf187ad5a0a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -526,29 +526,26 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { +case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { - require(children.nonEmpty, "printf() should take at least 1 argument") + require(children.nonEmpty, "format_string() should take at least 1 argument") override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children(0).nullable override def dataType: DataType = StringType - private def format: Expression = children(0) - private def args: Seq[Expression] = children.tail override def inputTypes: Seq[AbstractDataType] = StringType :: List.fill(children.size - 1)(AnyDataType) - override def eval(input: InternalRow): Any = { - val pattern = format.eval(input) + val pattern = children(0).eval(input) if (pattern == null) { null } else { val sb = new StringBuffer() val formatter = new java.util.Formatter(sb, Locale.US) - val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) + val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef]) formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*) UTF8String.fromString(sb.toString) @@ -591,7 +588,7 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC """ } - override def prettyName: String = "printf" + override def prettyName: String = "format_string" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3c2d88731beb4..3d294fda5d103 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,16 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa")), "aa", create_row(null)) + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa%d%s"), 12, "cc"), "aa12cc") - checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation(FormatString(Literal.create(null, StringType), 12, "cc"), null) checkEvaluation( - StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + FormatString(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") checkEvaluation( - StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") + FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e5ff8ae7e3179..28159cbd5ab96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1742,26 +1742,14 @@ object functions { def rtrim(e: Column): Column = StringTrimRight(e.expr) /** - * Format strings in printf-style. + * Formats the arguments in printf-style and returns the result as a string column. * * @group string_funcs * @since 1.5.0 */ @scala.annotation.varargs - def formatString(format: Column, arguments: Column*): Column = { - StringFormat((format +: arguments).map(_.expr): _*) - } - - /** - * Format strings in printf-style. - * NOTE: `format` is the string value of the formatter, not column name. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def formatString(format: String, arguNames: String*): Column = { - StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) + def format_string(format: String, arguments: Column*): Column = { + FormatString((lit(format) +: arguments).map(_.expr): _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3702e73b4e74f..0f9c986f649a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -126,22 +126,12 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") checkAnswer( - df.select(formatString("aa%d%s", "b", "c")), + df.select(format_string("aa%d%s", $"b", $"c")), Row("aa123cc")) checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) - - val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") - - checkAnswer( - df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) - - checkAnswer( - df2.selectExpr("printf(a, b, c)"), - Row("aa123cc")) } test("string instr function") { From 63f4bcc73f5a09c1790cc3c333f08b18609de6a4 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Tue, 21 Jul 2015 22:50:27 -0700 Subject: [PATCH 07/14] [SPARK-9121] [SPARKR] Get rid of the warnings about `no visible global function definition` in SparkR [[SPARK-9121] Get rid of the warnings about `no visible global function definition` in SparkR - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9121) ## The Result of `dev/lint-r` [The result of lint-r for SPARK-9121 at the revision:1ddd0f2f1688560f88470e312b72af04364e2d49 when I have sent a PR](https://gist.github.com/yu-iskw/6f55953425901725edf6) Author: Yu ISHIKAWA Closes #7567 from yu-iskw/SPARK-9121 and squashes the following commits: c8cfd63 [Yu ISHIKAWA] Fix the typo b1f19ed [Yu ISHIKAWA] Add a validate statement for local SparkR 1a03987 [Yu ISHIKAWA] Load the `testthat` package in `dev/lint-r.R`, instead of using the full path of function. 3a5e0ab [Yu ISHIKAWA] [SPARK-9121][SparkR] Get rid of the warnings about `no visible global function definition` in SparkR --- dev/lint-r.R | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dev/lint-r.R b/dev/lint-r.R index dcb1a184291e1..48bd6246096ae 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -15,15 +15,21 @@ # limitations under the License. # +argv <- commandArgs(TRUE) +SPARK_ROOT_DIR <- as.character(argv[1]) + # Installs lintr from Github. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { devtools::install_github("jimhester/lintr") } -library(lintr) -argv <- commandArgs(TRUE) -SPARK_ROOT_DIR <- as.character(argv[1]) +library(lintr) +library(methods) +library(testthat) +if (! library(SparkR, lib.loc = file.path(SPARK_ROOT_DIR, "R", "lib"), logical.return = TRUE)) { + stop("You should install SparkR in a local directory with `R/install-dev.sh`.") +} path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") lint_package(path.to.package, cache = FALSE) From f4785f5b82c57bce41d3dc26ed9e3c9e794c7558 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 21 Jul 2015 23:00:13 -0700 Subject: [PATCH 08/14] [SPARK-9232] [SQL] Duplicate code in JSONRelation Author: Andrew Or Closes #7576 from andrewor14/clean-up-json-relation and squashes the following commits: ea80803 [Andrew Or] Clean up duplicate code --- .../apache/spark/sql/json/JSONRelation.scala | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 25802d054ac00..922794ac9aac5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.json import java.io.IOException -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException @@ -87,20 +87,7 @@ private[sql] class DefaultSource case SaveMode.Append => sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") case SaveMode.Overwrite => { - var success: Boolean = false - try { - success = fs.delete(filesystemPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table.") - } + JSONRelation.delete(filesystemPath, fs) true } case SaveMode.ErrorIfExists => @@ -195,20 +182,7 @@ private[sql] class JSONRelation( if (overwrite) { if (fs.exists(filesystemPath)) { - var success: Boolean = false - try { - success = fs.delete(filesystemPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table.") - } + JSONRelation.delete(filesystemPath, fs) } // Write the data. data.toJSON.saveAsTextFile(filesystemPath.toString) @@ -228,3 +202,21 @@ private[sql] class JSONRelation( case _ => false } } + +private object JSONRelation { + + /** Delete the specified directory to overwrite it with new JSON data. */ + def delete(dir: Path, fs: FileSystem): Unit = { + var success: Boolean = false + val failMessage = s"Unable to clear output directory $dir prior to writing to JSON table" + try { + success = fs.delete(dir, true /* recursive */) + } catch { + case e: IOException => + throw new IOException(s"$failMessage\n${e.toString}") + } + if (!success) { + throw new IOException(failMessage) + } + } +} From c03299a18b4e076cabb4b7833a1e7632c5c0dabe Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 21 Jul 2015 23:26:11 -0700 Subject: [PATCH 09/14] [SPARK-4233] [SPARK-4367] [SPARK-3947] [SPARK-3056] [SQL] Aggregation Improvement This is the first PR for the aggregation improvement, which is tracked by https://issues.apache.org/jira/browse/SPARK-4366 (umbrella JIRA). This PR contains work for its subtasks, SPARK-3056, SPARK-3947, SPARK-4233, and SPARK-4367. This PR introduces a new code path for evaluating aggregate functions. This code path is guarded by `spark.sql.useAggregate2` and by default the value of this flag is true. This new code path contains: * A new aggregate function interface (`AggregateFunction2`) and 7 built-int aggregate functions based on this new interface (`AVG`, `COUNT`, `FIRST`, `LAST`, `MAX`, `MIN`, `SUM`) * A UDAF interface (`UserDefinedAggregateFunction`) based on the new code path and two example UDAFs (`MyDoubleAvg` and `MyDoubleSum`). * A sort-based aggregate operator (`Aggregate2Sort`) for the new aggregate function interface . * A sort-based aggregate operator (`FinalAndCompleteAggregate2Sort`) for distinct aggregations (for distinct aggregations the query plan will use `Aggregate2Sort` and `FinalAndCompleteAggregate2Sort` together). With this change, `spark.sql.useAggregate2` is `true`, the flow of compiling an aggregation query is: 1. Our analyzer looks up functions and returns aggregate functions built based on the old aggregate function interface. 2. When our planner is compiling the physical plan, it tries try to convert all aggregate functions to the ones built based on the new interface. The planner will fallback to the old code path if any of the following two conditions is true: * code-gen is disabled. * there is any function that cannot be converted (right now, Hive UDAFs). * the schema of grouping expressions contain any complex data type. * There are multiple distinct columns. Right now, the new code path handles a single distinct column in the query (you can have multiple aggregate functions using that distinct column). For a query having a aggregate function with DISTINCT and regular aggregate functions, the generated plan will do partial aggregations for those regular aggregate function. Thanks chenghao-intel for his initial work on it. Author: Yin Huai Author: Michael Armbrust Closes #7458 from yhuai/UDAF and squashes the following commits: 7865f5e [Yin Huai] Put the catalyst expression in the comment of the generated code for it. b04d6c8 [Yin Huai] Remove unnecessary change. f1d5901 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 35b0520 [Yin Huai] Use semanticEquals to replace grouping expressions in the output of the aggregate operator. 3b43b24 [Yin Huai] bug fix. 00eb298 [Yin Huai] Make it compile. a3ca551 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF e0afca3 [Yin Huai] Gracefully fallback to old aggregation code path. 8a8ac4a [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 88c7d4d [Yin Huai] Enable spark.sql.useAggregate2 by default for testing purpose. dc96fd1 [Yin Huai] Many updates: 85c9c4b [Yin Huai] newline. 43de3de [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF c3614d7 [Yin Huai] Handle single distinct column. 68b8ee9 [Yin Huai] Support single distinct column set. WIP 3013579 [Yin Huai] Format. d678aee [Yin Huai] Remove AggregateExpressionSuite.scala since our built-in aggregate functions will be based on AlgebraicAggregate and we need to have another way to test it. e243ca6 [Yin Huai] Add aggregation iterators. a101960 [Yin Huai] Change MyJavaUDAF to MyDoubleSum. 594cdf5 [Yin Huai] Change existing AggregateExpression to AggregateExpression1 and add an AggregateExpression as the common interface for both AggregateExpression1 and AggregateExpression2. 380880f [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 0a827b3 [Yin Huai] Add comments and doc. Move some classes to the right places. a19fea6 [Yin Huai] Add UDAF interface. 262d4c4 [Yin Huai] Make it compile. b2e358e [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 6edb5ac [Yin Huai] Format update. 70b169c [Yin Huai] Remove groupOrdering. 4721936 [Yin Huai] Add CheckAggregateFunction to extendedCheckRules. d821a34 [Yin Huai] Cleanup. 32aea9c [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 5b46d41 [Yin Huai] Bug fix. aff9534 [Yin Huai] Make Aggregate2Sort work with both algebraic AggregateFunctions and non-algebraic AggregateFunctions. 2857b55 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 4435f20 [Yin Huai] Add ConvertAggregateFunction to HiveContext's analyzer. 1b490ed [Michael Armbrust] make hive test 8cfa6a9 [Michael Armbrust] add test 1b0bb3f [Yin Huai] Do not bind references in AlgebraicAggregate and use code gen for all places. 072209f [Yin Huai] Bug fix: Handle expressions in grouping columns that are not attribute references. f7d9e54 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into UDAF 39ee975 [Yin Huai] Code cleanup: Remove unnecesary AttributeReferences. b7720ba [Yin Huai] Add an analysis rule to convert aggregate function to the new version. 5c00f3f [Michael Armbrust] First draft of codegen 6bbc6ba [Michael Armbrust] now with correct answers\! f7996d0 [Michael Armbrust] Add AlgebraicAggregate dded1c5 [Yin Huai] wip --- .../apache/spark/sql/catalyst/SqlParser.scala | 3 +- .../sql/catalyst/analysis/Analyzer.scala | 24 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 + .../sql/catalyst/analysis/unresolved.scala | 5 +- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 3 +- .../expressions/aggregate/functions.scala | 292 +++++++ .../expressions/aggregate/interfaces.scala | 206 +++++ .../sql/catalyst/expressions/aggregates.scala | 100 +-- .../codegen/GenerateMutableProjection.scala | 21 +- .../sql/catalyst/planning/patterns.scala | 4 +- .../plans/logical/basicOperators.scala | 1 + .../scala/org/apache/spark/sql/SQLConf.scala | 5 + .../org/apache/spark/sql/SQLContext.scala | 4 + .../apache/spark/sql/UDAFRegistration.scala | 35 + .../spark/sql/execution/Aggregate.scala | 12 +- .../apache/spark/sql/execution/Exchange.scala | 11 +- .../sql/execution/GeneratedAggregate.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 100 ++- .../aggregate/aggregateOperators.scala | 173 ++++ .../aggregate/sortBasedIterators.scala | 749 ++++++++++++++++++ .../spark/sql/execution/aggregate/utils.scala | 364 +++++++++ .../sql/expressions/aggregate/udaf.scala | 280 +++++++ .../org/apache/spark/sql/functions.scala | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +- .../spark/sql/execution/PlannerSuite.scala | 26 +- .../HiveWindowFunctionQuerySuite.scala | 1 + .../SortMergeCompatibilitySuite.scala | 7 + .../apache/spark/sql/hive/HiveContext.scala | 1 + .../org/apache/spark/sql/hive/HiveQl.scala | 7 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 8 +- .../spark/sql/hive/aggregate/MyDoubleAvg.java | 107 +++ .../spark/sql/hive/aggregate/MyDoubleSum.java | 100 +++ ...f_unhex-0-50131c0ba7b7a6b65c789a5a8497bada | 1 + ...f_unhex-1-11eb3cc5216d5446f4165007203acc47 | 1 + ...f_unhex-2-a660886085b8651852b9b77934848ae4 | 14 + ...f_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e | 1 + ...f_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 | 1 + .../execution/AggregationQuerySuite.scala | 507 ++++++++++++ 39 files changed, 3087 insertions(+), 100 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala create mode 100644 sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java create mode 100644 sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index d4ef04c2294a2..c04bd6cd85187 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { case "sum" => SumDistinct(exprs.head) case "count" => CountDistinct(exprs) + case name => UnresolvedFunction(name, exprs, isDistinct = true) case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e58f3f64947f3..8cadbc57e87e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -277,7 +278,7 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil @@ -517,9 +518,26 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { - case u @ UnresolvedFunction(name, children) => + case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { - registry.lookupFunction(name, children) + registry.lookupFunction(name, children) match { + // We get an aggregate function built based on AggregateFunction2 interface. + // So, we wrap it in AggregateExpression2. + case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) + // Currently, our old aggregate function interface supports SUM(DISTINCT ...) + // and COUTN(DISTINCT ...). + case sumDistinct: SumDistinct => sumDistinct + case countDistinct: CountDistinct => countDistinct + // DISTINCT is not meaningful with Max and Min. + case max: Max if isDistinct => max + case min: Min if isDistinct => min + // For other aggregate functions, DISTINCT keyword is not supported for now. + // Once we converted to the new code path, we will allow using DISTINCT keyword. + case other if isDistinct => + failAnalysis(s"$name does not support DISTINCT keyword.") + // If it does not have DISTINCT keyword, we will return it as is. + case other => other + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c7f9713344c50..c203fcecf20fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 0daee1990a6e0..03da45b09f928 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -73,7 +73,10 @@ object UnresolvedAttribute { def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) } -case class UnresolvedFunction(name: String, children: Seq[Expression]) +case class UnresolvedFunction( + name: String, + children: Seq[Expression], + isDistinct: Boolean) extends Expression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index b09aea03318da..b10a3c877434b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal]" + override def toString: String = s"input[$ordinal, $dataType]" override def eval(input: InternalRow): Any = input(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index aada25276adb7..29ae47e842ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -96,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] { val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) - ve + // Add `this` in the comment. + ve.copy(s"/* $this */\n" + ve.code) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala new file mode 100644 index 0000000000000..b924af4cc84d8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +case class Average(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Once we remove the old code path, we can use our analyzer to cast NullType + // to the default data type of the NumericType. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => DecimalType.Unlimited + case _ => DoubleType + } + + private val sumDataType = child.dataType match { + case _ @ DecimalType() => DecimalType.Unlimited + case _ => DoubleType + } + + private val currentSum = AttributeReference("currentSum", sumDataType)() + private val currentCount = AttributeReference("currentCount", LongType)() + + override val bufferAttributes = currentSum :: currentCount :: Nil + + override val initialValues = Seq( + /* currentSum = */ Cast(Literal(0), sumDataType), + /* currentCount = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* currentSum = */ + Add( + currentSum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + ) + + override val mergeExpressions = Seq( + /* currentSum = */ currentSum.left + currentSum.right, + /* currentCount = */ currentCount.left + currentCount.right + ) + + // If all input are nulls, currentCount will be 0 and we will get null after the division. + override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType) +} + +case class Count(child: Expression) extends AlgebraicAggregate { + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = false + + // Return data type. + override def dataType: DataType = LongType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val currentCount = AttributeReference("currentCount", LongType)() + + override val bufferAttributes = currentCount :: Nil + + override val initialValues = Seq( + /* currentCount = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + ) + + override val mergeExpressions = Seq( + /* currentCount = */ currentCount.left + currentCount.right + ) + + override val evaluateExpression = Cast(currentCount, LongType) +} + +case class First(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // First is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val first = AttributeReference("first", child.dataType)() + + override val bufferAttributes = first :: Nil + + override val initialValues = Seq( + /* first = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* first = */ If(IsNull(first), child, first) + ) + + override val mergeExpressions = Seq( + /* first = */ If(IsNull(first.left), first.right, first.left) + ) + + override val evaluateExpression = first +} + +case class Last(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Last is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val last = AttributeReference("last", child.dataType)() + + override val bufferAttributes = last :: Nil + + override val initialValues = Seq( + /* last = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* last = */ If(IsNull(child), last, child) + ) + + override val mergeExpressions = Seq( + /* last = */ If(IsNull(last.right), last.left, last.right) + ) + + override val evaluateExpression = last +} + +case class Max(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val max = AttributeReference("max", child.dataType)() + + override val bufferAttributes = max :: Nil + + override val initialValues = Seq( + /* max = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + ) + + override val mergeExpressions = { + val greatest = Greatest(Seq(max.left, max.right)) + Seq( + /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + ) + } + + override val evaluateExpression = max +} + +case class Min(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val min = AttributeReference("min", child.dataType)() + + override val bufferAttributes = min :: Nil + + override val initialValues = Seq( + /* min = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + ) + + override val mergeExpressions = { + val least = Least(Seq(min.left, min.right)) + Seq( + /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + ) + } + + override val evaluateExpression = min +} + +case class Sum(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => DecimalType.Unlimited + case _ => child.dataType + } + + private val sumDataType = child.dataType match { + case _ @ DecimalType() => DecimalType.Unlimited + case _ => child.dataType + } + + private val currentSum = AttributeReference("currentSum", sumDataType)() + + private val zero = Cast(Literal(0), sumDataType) + + override val bufferAttributes = currentSum :: Nil + + override val initialValues = Seq( + /* currentSum = */ Literal.create(null, sumDataType) + ) + + override val updateExpressions = Seq( + /* currentSum = */ + Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum)) + ) + + override val mergeExpressions = { + val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType)) + Seq( + /* currentSum = */ + Coalesce(Seq(add, currentSum.left)) + ) + } + + override val evaluateExpression = Cast(currentSum, resultType) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala new file mode 100644 index 0000000000000..577ede73cb01f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** The mode of an [[AggregateFunction1]]. */ +private[sql] sealed trait AggregateMode + +/** + * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation. + * This function updates the given aggregation buffer with the original input of this + * function. When it has processed all input rows, the aggregation buffer is returned. + */ +private[sql] case object Partial extends AggregateMode + +/** + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function. + * This function updates the given aggregation buffer by merging multiple aggregation buffers. + * When it has processed all input rows, the aggregation buffer is returned. + */ +private[sql] case object PartialMerge extends AggregateMode + +/** + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function and the generate final result. + * This function updates the given aggregation buffer by merging multiple aggregation buffers. + * When it has processed all input rows, the final result of this function is returned. + */ +private[sql] case object Final extends AggregateMode + +/** + * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly + * from original input rows without any partial aggregation. + * This function updates the given aggregation buffer with the original input of this + * function. When it has processed all input rows, the final result of this function is returned. + */ +private[sql] case object Complete extends AggregateMode + +/** + * A place holder expressions used in code-gen, it does not change the corresponding value + * in the row. + */ +private[sql] case object NoOp extends Expression with Unevaluable { + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = { + throw new TreeNodeException( + this, s"No function to evaluate expression. type: ${this.nodeName}") + } + override def dataType: DataType = NullType + override def children: Seq[Expression] = Nil +} + +/** + * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field + * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. + * @param aggregateFunction + * @param mode + * @param isDistinct + */ +private[sql] case class AggregateExpression2( + aggregateFunction: AggregateFunction2, + mode: AggregateMode, + isDistinct: Boolean) extends AggregateExpression { + + override def children: Seq[Expression] = aggregateFunction :: Nil + override def dataType: DataType = aggregateFunction.dataType + override def foldable: Boolean = false + override def nullable: Boolean = aggregateFunction.nullable + + override def references: AttributeSet = { + val childReferemces = mode match { + case Partial | Complete => aggregateFunction.references.toSeq + case PartialMerge | Final => aggregateFunction.bufferAttributes + } + + AttributeSet(childReferemces) + } + + override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" +} + +abstract class AggregateFunction2 + extends Expression with ImplicitCastInputTypes { + + self: Product => + + /** An aggregate function is not foldable. */ + override def foldable: Boolean = false + + /** + * The offset of this function's buffer in the underlying buffer shared with other functions. + */ + var bufferOffset: Int = 0 + + /** The schema of the aggregation buffer. */ + def bufferSchema: StructType + + /** Attributes of fields in bufferSchema. */ + def bufferAttributes: Seq[AttributeReference] + + /** Clones bufferAttributes. */ + def cloneBufferAttributes: Seq[Attribute] + + /** + * Initializes its aggregation buffer located in `buffer`. + * It will use bufferOffset to find the starting point of + * its buffer in the given `buffer` shared with other functions. + */ + def initialize(buffer: MutableRow): Unit + + /** + * Updates its aggregation buffer located in `buffer` based on the given `input`. + * It will use bufferOffset to find the starting point of its buffer in the given `buffer` + * shared with other functions. + */ + def update(buffer: MutableRow, input: InternalRow): Unit + + /** + * Updates its aggregation buffer located in `buffer1` by combining intermediate results + * in the current buffer and intermediate results from another buffer `buffer2`. + * It will use bufferOffset to find the starting point of its buffer in the given `buffer1` + * and `buffer2`. + */ + def merge(buffer1: MutableRow, buffer2: InternalRow): Unit + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + +/** + * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. + */ +abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable { + self: Product => + + val initialValues: Seq[Expression] + val updateExpressions: Seq[Expression] + val mergeExpressions: Seq[Expression] + val evaluateExpression: Expression + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + /** + * A helper class for representing an attribute used in merging two + * aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`, + * we merge buffer values and then update bufferLeft. A [[RichAttribute]] + * of an [[AttributeReference]] `a` has two functions `left` and `right`, + * which represent `a` in `bufferLeft` and `bufferRight`, respectively. + * @param a + */ + implicit class RichAttribute(a: AttributeReference) { + /** Represents this attribute at the mutable buffer side. */ + def left: AttributeReference = a + + /** Represents this attribute at the input buffer side (the data value is read-only). */ + def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a)) + } + + /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */ + override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) + + override def initialize(buffer: MutableRow): Unit = { + var i = 0 + while (i < bufferAttributes.size) { + buffer(i + bufferOffset) = initialValues(i).eval() + i += 1 + } + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's update should not be called directly") + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's merge should not be called directly") + } + + override def eval(buffer: InternalRow): Any = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's eval should not be called directly") + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index d705a1286065c..e07c920a41d0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -27,7 +27,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -trait AggregateExpression extends Expression with Unevaluable { +trait AggregateExpression extends Expression with Unevaluable + +trait AggregateExpression1 extends AggregateExpression { /** * Aggregate expressions should not be foldable. @@ -38,7 +40,7 @@ trait AggregateExpression extends Expression with Unevaluable { * Creates a new instance that can be used to compute this aggregate expression for a group * of input rows/ */ - def newInstance(): AggregateFunction + def newInstance(): AggregateFunction1 } /** @@ -54,10 +56,10 @@ case class SplitEvaluation( partialEvaluations: Seq[NamedExpression]) /** - * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples. + * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. * These partial evaluations can then be combined to compute the actual answer. */ -trait PartialAggregate extends AggregateExpression { +trait PartialAggregate1 extends AggregateExpression1 { /** * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. @@ -67,13 +69,13 @@ trait PartialAggregate extends AggregateExpression { /** * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression]] with an algorithm that will be used to compute one specific result. + * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. */ -abstract class AggregateFunction - extends LeafExpression with AggregateExpression with Serializable { +abstract class AggregateFunction1 + extends LeafExpression with AggregateExpression1 with Serializable { /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression + val base: AggregateExpression1 override def nullable: Boolean = base.nullable override def dataType: DataType = base.dataType @@ -81,12 +83,12 @@ abstract class AggregateFunction def update(input: InternalRow): Unit // Do we really need this? - override def newInstance(): AggregateFunction = { + override def newInstance(): AggregateFunction1 = { makeCopy(productIterator.map { case a: AnyRef => a }.toArray) } } -case class Min(child: Expression) extends UnaryExpression with PartialAggregate { +case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -102,7 +104,7 @@ case class Min(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForOrderingExpr(child.dataType, "function min") } -case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -119,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMin.value } -case class Max(child: Expression) extends UnaryExpression with PartialAggregate { +case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -135,7 +137,7 @@ case class Max(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForOrderingExpr(child.dataType, "function max") } -case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -152,7 +154,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMax.value } -case class Count(child: Expression) extends UnaryExpression with PartialAggregate { +case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -165,7 +167,7 @@ case class Count(child: Expression) extends UnaryExpression with PartialAggregat override def newInstance(): CountFunction = new CountFunction(child, this) } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var count: Long = _ @@ -180,7 +182,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = count } -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { +case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -200,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CountDistinctFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -220,7 +222,7 @@ case class CountDistinctFunction( override def eval(input: InternalRow): Any = seen.size.toLong } -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { +case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -233,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress case class CollectHashSetFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -255,7 +257,7 @@ case class CollectHashSetFunction( } } -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { +case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = inputSet :: Nil @@ -269,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression case class CombineSetsAndCountFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -305,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { } case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: DataType = HyperLogLogUDT @@ -317,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) case class ApproxCountDistinctPartitionFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -335,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction( } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -347,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) case class ApproxCountDistinctMergeFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -363,7 +365,7 @@ case class ApproxCountDistinctMergeFunction( } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends UnaryExpression with PartialAggregate { + extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -381,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } -case class Average(child: Expression) extends UnaryExpression with PartialAggregate { +case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { override def prettyName: String = "avg" @@ -427,8 +429,8 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg TypeUtils.checkForNumericExpr(child.dataType, "function average") } -case class AverageFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class AverageFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -474,7 +476,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class Sum(child: Expression) extends UnaryExpression with PartialAggregate { +case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true @@ -509,7 +511,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForNumericExpr(child.dataType, "function sum") } -case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. private val calcType = @@ -554,7 +556,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr * <-- null <-- no data * null <-- null <-- no data */ -case class CombineSum(child: Expression) extends AggregateExpression { +case class CombineSum(child: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = child :: Nil @@ -564,8 +566,8 @@ case class CombineSum(child: Expression) extends AggregateExpression { override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) } -case class CombineSumFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class CombineSumFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -601,7 +603,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } -case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate { +case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { def this() = this(null) override def nullable: Boolean = true @@ -627,8 +629,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") } -case class SumDistinctFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -653,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { def this() = this(null, null) override def children: Seq[Expression] = inputSet :: Nil @@ -667,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends Agg case class CombineSetsAndSumFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -695,7 +697,7 @@ case class CombineSetsAndSumFunction( } } -case class First(child: Expression) extends UnaryExpression with PartialAggregate { +case class First(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType override def toString: String = s"FIRST($child)" @@ -709,7 +711,7 @@ case class First(child: Expression) extends UnaryExpression with PartialAggregat override def newInstance(): FirstFunction = new FirstFunction(child, this) } -case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null @@ -723,7 +725,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = result } -case class Last(child: Expression) extends UnaryExpression with PartialAggregate { +case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -738,7 +740,7 @@ case class Last(child: Expression) extends UnaryExpression with PartialAggregate override def newInstance(): LastFunction = new LastFunction(child, this) } -case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 03b4b3c216f49..d838268f46956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import scala.collection.mutable.ArrayBuffer @@ -38,15 +39,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { case (e, i) => - val evaluationCode = e.gen(ctx) - evaluationCode.code + - s""" - if(${evaluationCode.isNull}) - mutableRow.setNullAt($i); - else - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; - """ + val projectionCode = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => + val evaluationCode = e.gen(ctx) + evaluationCode.code + + s""" + if(${evaluationCode.isNull}) + mutableRow.setNullAt($i); + else + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + """ } // collect projections into blocks as function has 64kb codesize limit in JVM val projectionBlocks = new ArrayBuffer[String]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 179a348d5baac..b8e3b0d53a505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -129,10 +129,10 @@ object PartialAggregation { case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => // Collect all aggregate expressions. val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) // Collect all aggregate expressions that can be computed partially. val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) // Only do partial aggregation if supported by all aggregate expressions. if (allAggregates.size == partialAggregates.size) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 986c315b3173a..6aefa9f67556a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 78c780bdc5797..1474b170ba896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -402,6 +402,9 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) + val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", + defaultValue = Some(true), doc = "") + val USE_SQL_SERIALIZER2 = booleanConf( "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) @@ -473,6 +476,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) + private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8b4528b5d52fe..49bfe74b680af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -285,6 +285,9 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient val udf: UDFRegistration = new UDFRegistration(this) + @transient + val udaf: UDAFRegistration = new UDAFRegistration(this) + /** * Returns true if the table is currently cached in-memory. * @group cachemgmt @@ -863,6 +866,7 @@ class SQLContext(@transient val sparkContext: SparkContext) DDLStrategy :: TakeOrderedAndProject :: HashAggregation :: + Aggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala new file mode 100644 index 0000000000000..5b872f5e3eecd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{Expression} +import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction} + +class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { + + private val functionRegistry = sqlContext.functionRegistry + + def register( + name: String, + func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { + def builder(children: Seq[Expression]) = ScalaUDAF(children, func) + functionRegistry.registerFunction(name, builder) + func + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 3cd60a2aa55ed..c2c945321db95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -68,14 +68,14 @@ case class Aggregate( * output. */ case class ComputedAggregate( - unbound: AggregateExpression, - aggregate: AggregateExpression, + unbound: AggregateExpression1, + aggregate: AggregateExpression1, resultAttribute: AttributeReference) /** A list of aggregates that need to be computed for each group. */ private[this] val computedAggregates = aggregateExpressions.flatMap { agg => agg.collect { - case a: AggregateExpression => + case a: AggregateExpression1 => ComputedAggregate( a, BindReferences.bindReference(a, child.output), @@ -87,8 +87,8 @@ case class Aggregate( private[this] val computedSchema = computedAggregates.map(_.resultAttribute) /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction] = { - val buffer = new Array[AggregateFunction](computedAggregates.length) + private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { + val buffer = new Array[AggregateFunction1](computedAggregates.length) var i = 0 while (i < computedAggregates.length) { buffer(i) = computedAggregates(i).aggregate.newInstance() @@ -146,7 +146,7 @@ case class Aggregate( } } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction]] + val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) var currentRow: InternalRow = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 2750053594f99..d31e265a293e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -247,8 +247,15 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } def addSortIfNecessary(child: SparkPlan): SparkPlan = { - if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) { - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + + if (rowOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort. + val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min + if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) { + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + } else { + child + } } else { child } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index ecde9c57139a6..0e63f2fe29cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -69,7 +69,7 @@ case class GeneratedAggregate( protected override def doExecute(): RDD[InternalRow] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression => agg} + a.collect { case agg: AggregateExpression1 => agg} } // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8cef7f200d2dc..f54aa2027f6a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -148,7 +149,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if canBeCodeGened( allAggregates(partialComputation) ++ allAggregates(rewrittenAggregateExpressions)) && - codegenEnabled => + codegenEnabled && + !canBeConvertedToNewAggregation(plan) => execution.GeneratedAggregate( partial = false, namedGroupingAttributes, @@ -167,7 +169,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { rewrittenAggregateExpressions, groupingExpressions, partialComputation, - child) => + child) if !canBeConvertedToNewAggregation(plan) => execution.Aggregate( partial = false, namedGroupingAttributes, @@ -181,7 +183,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { + def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = { + aggregate.Utils.tryConvert( + plan, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled).isDefined + } + + def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists { case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && @@ -189,10 +198,74 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => true } - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] = - exprs.flatMap(_.collect { case a: AggregateExpression => a }) + def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = + exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) } + /** + * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. + */ + object Aggregation extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p: logical.Aggregate => + val converted = + aggregate.Utils.tryConvert( + p, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled) + converted match { + case None => Nil // Cannot convert to new aggregation code path. + case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => + // Extracts all distinct aggregate expressions from the resultExpressions. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + // For those distinct aggregate expressions, we create a map from the + // aggregate function to the corresponding attribute of the function. + val aggregateFunctionMap = aggregateExpressions.map { agg => + val aggregateFunction = agg.aggregateFunction + (aggregateFunction, agg.isDistinct) -> + Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + }.toMap + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets (aggregate.NewAggregation will not match). + sys.error( + "Multiple distinct column sets are not supported by the new aggregation" + + "code path.") + } + + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.Utils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } else { + aggregate.Utils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } + + aggregateOperator + } + + case _ => Nil + } + } + + object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => @@ -336,8 +409,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + case a @ logical.Aggregate(group, agg, child) => { + val useNewAggregation = + aggregate.Utils.tryConvert( + a, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled).isDefined + if (useNewAggregation) { + // If this logical.Aggregate can be planned to use new aggregation code path + // (i.e. it can be planned by the Strategy Aggregation), we will not use the old + // aggregation code path. + Nil + } else { + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + } + } case logical.Window(projectList, windowExpressions, spec, child) => execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala new file mode 100644 index 0000000000000..0c9082897f390 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} + +case class Aggregate2Sort( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def canProcessUnsafeRows: Boolean = true + + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + aggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + // TODO: We should not sort the input rows if they are just in reversed order. + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + // It is possible that the child.outputOrdering starts with the required + // ordering expressions (e.g. we require [a] as the sort expression and the + // child's outputOrdering is [a, b]). We can only guarantee the output rows + // are sorted by values of groupingExpressions. + groupingExpressions.map(SortOrder(_, Ascending)) + } + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + if (aggregateExpressions.length == 0) { + new GroupingIterator( + groupingExpressions, + resultExpressions, + newMutableProjection, + child.output, + iter) + } else { + val aggregationIterator: SortAggregationIterator = { + aggregateExpressions.map(_.mode).distinct.toList match { + case Partial :: Nil => + new PartialSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter) + case PartialMerge :: Nil => + new PartialMergeSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter) + case Final :: Nil => + new FinalSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + case other => + sys.error( + s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " + + s"modes $other in this operator.") + } + } + + aggregationIterator + } + } + } +} + +case class FinalAndCompleteAggregate2Sort( + previousGroupingExpressions: Seq[NamedExpression], + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- + AttributeSet(finalAggregateExpressions) -- + AttributeSet(completeAggregateExpressions) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + finalAggregateExpressions.flatMap(_.references) ++ + completeAggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + + new FinalAndCompleteSortAggregationIterator( + previousGroupingExpressions.length, + groupingExpressions, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala new file mode 100644 index 0000000000000..ce1cbdc9cb090 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -0,0 +1,749 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.NullType + +import scala.collection.mutable.ArrayBuffer + +/** + * An iterator used to evaluate aggregate functions. It assumes that input rows + * are already grouped by values of `groupingExpressions`. + */ +private[sql] abstract class SortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends Iterator[InternalRow] { + + /////////////////////////////////////////////////////////////////////////// + // Static fields for this iterator + /////////////////////////////////////////////////////////////////////////// + + protected val aggregateFunctions: Array[AggregateFunction2] = { + var bufferOffset = initialBufferOffset + val functions = new Array[AggregateFunction2](aggregateExpressions.length) + var i = 0 + while (i < aggregateExpressions.length) { + val func = aggregateExpressions(i).aggregateFunction + val funcWithBoundReferences = aggregateExpressions(i).mode match { + case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => + // We need to create BoundReferences if the function is not an + // AlgebraicAggregate (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, inputAttributes) + case _ => func + } + // Set bufferOffset for this function. It is important that setting bufferOffset + // happens after all potential bindReference operations because bindReference + // will create a new instance of the function. + funcWithBoundReferences.bufferOffset = bufferOffset + bufferOffset += funcWithBoundReferences.bufferSchema.length + functions(i) = funcWithBoundReferences + i += 1 + } + functions + } + + // All non-algebraic aggregate functions. + protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + aggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // Positions of those non-algebraic aggregate functions in aggregateFunctions. + // For example, we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are non-algebraic aggregate functions. + // nonAlgebraicAggregateFunctionPositions will be [1, 2]. + protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { + case agg: AlgebraicAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray + } + + // This is used to project expressions for the grouping expressions. + protected val groupGenerator = + newMutableProjection(groupingExpressions, inputAttributes)() + + // The underlying buffer shared by all aggregate functions. + protected val buffer: MutableRow = { + // The number of elements of the underlying buffer of this operator. + // All aggregate functions are sharing this underlying buffer and they find their + // buffer values through bufferOffset. + var size = initialBufferOffset + var i = 0 + while (i < aggregateFunctions.length) { + size += aggregateFunctions(i).bufferSchema.length + i += 1 + } + new GenericMutableRow(size) + } + + protected val joinedRow = new JoinedRow4 + + protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) + + // This projection is used to initialize buffer values for all AlgebraicAggregates. + protected val algebraicInitialProjection = { + val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.initialValues + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(initExpressions, Nil)().target(buffer) + } + + /////////////////////////////////////////////////////////////////////////// + // Mutable states + /////////////////////////////////////////////////////////////////////////// + + // The partition key of the current partition. + protected var currentGroupingKey: InternalRow = _ + // The partition key of next partition. + protected var nextGroupingKey: InternalRow = _ + // The first row of next partition. + protected var firstRowInNextGroup: InternalRow = _ + // Indicates if we has new group of rows to process. + protected var hasNewGroup: Boolean = true + + /////////////////////////////////////////////////////////////////////////// + // Private methods + /////////////////////////////////////////////////////////////////////////// + + /** Initializes buffer values for all aggregate functions. */ + protected def initializeBuffer(): Unit = { + algebraicInitialProjection(EmptyRow) + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).initialize(buffer) + i += 1 + } + } + + protected def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + + /** Processes rows in the current group. It will stop when it find a new group. */ + private def processCurrentGroup(): Unit = { + currentGroupingKey = nextGroupingKey + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(firstRowInNextGroup) + // The search will stop when we see the next group or there is no + // input row left in the iter. + while (inputIter.hasNext && !findNextPartition) { + val currentRow = inputIter.next() + // Get the grouping key based on the grouping expressions. + // For the below compare method, we do not need to make a copy of groupingKey. + val groupingKey = groupGenerator(currentRow) + // Check if the current row belongs the current input row. + currentGroupingKey.equals(groupingKey) + + if (currentGroupingKey == groupingKey) { + processRow(currentRow) + } else { + // We find a new group. + findNextPartition = true + nextGroupingKey = groupingKey.copy() + firstRowInNextGroup = currentRow.copy() + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the iter. + if (!findNextPartition) { + hasNewGroup = false + } + } + + /////////////////////////////////////////////////////////////////////////// + // Public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = hasNewGroup + + override final def next(): InternalRow = { + if (hasNext) { + // Process the current group. + processCurrentGroup() + // Generate output row for the current group. + val outputRow = generateOutput() + // Initilize buffer values for the next group. + initializeBuffer() + + outputRow + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Methods that need to be implemented + /////////////////////////////////////////////////////////////////////////// + + protected def initialBufferOffset: Int + + protected def processRow(row: InternalRow): Unit + + protected def generateOutput(): InternalRow + + /////////////////////////////////////////////////////////////////////////// + // Initialize this iterator + /////////////////////////////////////////////////////////////////////////// + + initialize() +} + +/** + * An iterator only used to group input rows according to values of `groupingExpressions`. + * It assumes that input rows are already grouped by values of `groupingExpressions`. + */ +class GroupingIterator( + groupingExpressions: Seq[NamedExpression], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + Nil, + newMutableProjection, + inputAttributes, + inputIter) { + + private val resultProjection = + newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))() + + override protected def initialBufferOffset: Int = 0 + + override protected def processRow(row: InternalRow): Unit = { + // Since we only do grouping, there is nothing to do at here. + } + + override protected def generateOutput(): InternalRow = { + resultProjection(currentGroupingKey) + } +} + +/** + * An iterator used to do partial aggregations (for those aggregate functions with mode Partial). + * It assumes that input rows are already grouped by values of `groupingExpressions`. + * The format of its output rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + */ +class PartialSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // This projection is used to update buffer values for all AlgebraicAggregates. + private val algebraicUpdateProjection = { + val bufferSchema = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } + val updateExpressions = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) + } + + override protected def initialBufferOffset: Int = 0 + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicUpdateProjection(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).update(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // We just output the grouping expressions and the underlying buffer. + joinedRow(currentGroupingKey, buffer).copy() + } +} + +/** + * An iterator used to do partial merge aggregations (for those aggregate functions with mode + * PartialMerge). It assumes that input rows are already grouped by values of + * `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + * + * The format of its internal buffer is: + * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN| + * Every placeholder is for a grouping expression. + * The actual buffers are stored after placeholderN. + * The reason that we have placeholders at here is to make our underlying buffer have the same + * length with a input row. + * + * The format of its output rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + */ +class PartialMergeSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + private val placeholderAttribtues = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = + placeholderAttribtues ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ placeholderAttribtues ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to extract aggregation buffers from the underlying buffer. + // We need it because the underlying buffer has placeholders at its beginning. + private val extractsBufferValues = { + val expressions = aggregateFunctions.flatMap { + case agg => agg.bufferAttributes + } + + newMutableProjection(expressions, inputAttributes)() + } + + override protected def initialBufferOffset: Int = groupingExpressions.length + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // We output grouping expressions and aggregation buffers. + joinedRow(currentGroupingKey, extractsBufferValues(buffer)) + } +} + +/** + * An iterator used to do final aggregations (for those aggregate functions with mode + * Final). It assumes that input rows are already grouped by values of + * `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + * + * The format of its internal buffer is: + * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN| + * Every placeholder is for a grouping expression. + * The actual buffers are stored after placeholderN. + * The reason that we have placeholders at here is to make our underlying buffer have the same + * length with a input row. + * + * The format of its output rows is represented by the schema of `resultExpressions`. + */ +class FinalSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // The result of aggregate functions. + private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length) + + // The projection used to generate the output rows of this operator. + // This is only used when we are generating final results of aggregate functions. + private val resultProjection = + newMutableProjection( + resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() + + private val offsetAttributes = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + override protected def initialBufferOffset: Int = groupingExpressions.length + + override def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + if (groupingExpressions.isEmpty) { + // If there is no grouping expression, we need to generate a single row as the output. + initializeBuffer() + // Right now, the buffer only contains initial buffer values. Because + // merging two buffers with initial values will generate a row that + // still store initial values. We set the currentRow as the copy of the current buffer. + val currentRow = buffer.copy() + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + } + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(buffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(buffer)) + i += 1 + } + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } +} + +/** + * An iterator used to do both final aggregations (for those aggregate functions with mode + * Final) and complete aggregations (for those aggregate functions with mode Complete). + * It assumes that input rows are already grouped by values of `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN| + * col1 to colM are columns used by aggregate functions with Complete mode. + * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with + * Final mode. + * + * The format of its internal buffer is: + * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)| + * The first N placeholders represent slots of grouping expressions. + * Then, next M placeholders represent slots of col1 to colM. + * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with + * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode + * Complete. The reason that we have placeholders at here is to make our underlying buffer + * have the same length with a input row. + * + * The format of its output rows is represented by the schema of `resultExpressions`. + */ +class FinalAndCompleteSortAggregationIterator( + override protected val initialBufferOffset: Int, + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + // TODO: document the ordering + finalAggregateExpressions ++ completeAggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // The result of aggregate functions. + private val aggregateResult: MutableRow = + new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length) + + // The projection used to generate the output rows of this operator. + // This is only used when we are generating final results of aggregate functions. + private val resultProjection = { + val inputSchema = + groupingExpressions.map(_.toAttribute) ++ + finalAggregateAttributes ++ + completeAggregateAttributes + newMutableProjection(resultExpressions, inputSchema)() + } + + private val offsetAttributes = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // All aggregate functions with mode Final. + private val finalAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) + var i = 0 + while (i < finalAggregateExpressions.length) { + functions(i) = aggregateFunctions(i) + i += 1 + } + functions + } + + // All non-algebraic aggregate functions with mode Final. + private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + finalAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // All aggregate functions with mode Complete. + private val completeAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](completeAggregateExpressions.length) + var i = 0 + while (i < completeAggregateExpressions.length) { + functions(i) = aggregateFunctions(finalAggregateFunctions.length + i) + i += 1 + } + functions + } + + // All non-algebraic aggregate functions with mode Complete. + private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + completeAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // This projection is used to merge buffer values for all AlgebraicAggregates with mode + // Final. + private val finalAlgebraicMergeProjection = { + val numCompleteOffsetAttributes = + completeAggregateFunctions.map(_.bufferAttributes.length).sum + val completeOffsetAttributes = + Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)()) + val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp) + + val bufferSchemata = + offsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } ++ completeOffsetAttributes + val mergeExpressions = + placeholderExpressions ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } ++ completeOffsetExpressions + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to update buffer values for all AlgebraicAggregates with mode + // Complete. + private val completeAlgebraicUpdateProjection = { + val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum + val finalOffsetAttributes = + Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)()) + val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp) + + val bufferSchema = + offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } + val updateExpressions = + placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + override def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + if (groupingExpressions.isEmpty) { + // If there is no grouping expression, we need to generate a single row as the output. + initializeBuffer() + // Right now, the buffer only contains initial buffer values. Because + // merging two buffers with initial values will generate a row that + // still store initial values. We set the currentRow as the copy of the current buffer. + val currentRow = buffer.copy() + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + } + + override protected def processRow(row: InternalRow): Unit = { + val input = joinedRow(buffer, row) + // For all aggregate functions with mode Complete, update buffers. + completeAlgebraicUpdateProjection(input) + var i = 0 + while (i < completeNonAlgebraicAggregateFunctions.length) { + completeNonAlgebraicAggregateFunctions(i).update(buffer, row) + i += 1 + } + + // For all aggregate functions with mode Final, merge buffers. + finalAlgebraicMergeProjection.target(buffer)(input) + i = 0 + while (i < finalNonAlgebraicAggregateFunctions.length) { + finalNonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(buffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(buffer)) + i += 1 + } + + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala new file mode 100644 index 0000000000000..1cb27710e0480 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{StructType, MapType, ArrayType} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object Utils { + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { + val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { + case array: ArrayType => true + case map: MapType => true + case struct: StructType => true + case _ => false + } + + !hasComplexTypes + } + + private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = p.transformExpressionsDown { + case expressions.Average(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Average(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Count(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(child), + mode = aggregate.Complete, + isDistinct = false) + + // We do not support multiple COUNT DISTINCT columns for now. + case expressions.CountDistinct(children) if children.length == 1 => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(children.head), + mode = aggregate.Complete, + isDistinct = true) + + case expressions.First(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.First(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Last(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Last(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Max(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Max(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Min(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Min(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Sum(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.SumDistinct(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = true) + } + // Check if there is any expressions.AggregateExpression1 left. + // If so, we cannot convert this plan. + val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => + // For every expressions, check if it contains AggregateExpression1. + expr.find { + case agg: expressions.AggregateExpression1 => true + case other => false + }.isDefined + } + + // Check if there are multiple distinct columns. + val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val hasMultipleDistinctColumnSets = + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + true + } else { + false + } + + if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None + + case other => None + } + + private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { + // If the plan cannot be converted, we will do a final round check to if the original + // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, + // we need to throw an exception. + val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg.aggregateFunction + } + }.distinct + if (aggregateFunction2s.nonEmpty) { + // For functions implemented based on the new interface, prepare a list of function names. + val invalidFunctions = { + if (aggregateFunction2s.length > 1) { + s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + + s"and ${aggregateFunction2s.head.nodeName} are" + } else { + s"${aggregateFunction2s.head.nodeName} is" + } + } + val errorMessage = + s"${invalidFunctions} implemented based on the new Aggregate Function " + + s"interface and it cannot be used with functions implemented based on " + + s"the old Aggregate Function interface." + throw new AnalysisException(errorMessage) + } + } + + def tryConvert( + plan: LogicalPlan, + useNewAggregation: Boolean, + codeGenEnabled: Boolean): Option[Aggregate] = plan match { + case p: Aggregate if useNewAggregation && codeGenEnabled => + val converted = tryConvert(p) + if (converted.isDefined) { + converted + } else { + checkInvalidAggregateFunction2(p) + None + } + case p: Aggregate => + checkInvalidAggregateFunction2(p) + None + case other => None + } + + def planAggregateWithoutDistinct( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + // 1. Create an Aggregate Operator for partial aggregations. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + val partialAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Partial, isDistinct) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + namedGroupingExpressions.map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for final aggregations. + val finalAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Final, isDistinct) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val finalAggregate = Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + rewrittenResultExpressions, + partialAggregate) + + finalAggregate :: Nil + } + + def planAggregateWithOneDistinct( + groupingExpressions: Seq[Expression], + functionsWithDistinct: Seq[AggregateExpression2], + functionsWithoutDistinct: Seq[AggregateExpression2], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + // 1. Create an Aggregate Operator for partial aggregations. + // The grouping expressions are original groupingExpressions and + // distinct columns. For example, for avg(distinct value) ... group by key + // the grouping expressions of this Aggregate Operator will be [key, value]. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + // It is safe to call head at here since functionsWithDistinct has at least one + // AggregateExpression2. + val distinctColumnExpressions = + functionsWithDistinct.head.aggregateFunction.children + val namedDistinctColumnExpressions = distinctColumnExpressions.map { + case ne: NamedExpression => ne -> ne + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap + val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) + + val partialAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Partial, false) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for partial merge aggregations. + val partialMergeAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, PartialMerge, false) + } + val partialMergeAggregateAttributes = + partialMergeAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val partialMergeAggregate = + Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes ++ distinctColumnAttributes, + partialMergeAggregateExpressions, + partialMergeAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes, + partialAggregate) + + // 3. Create an Aggregate Operator for partial merge aggregations. + val finalAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Final, false) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) => + val rewrittenAggregateFunction = aggregateFunction.transformDown { + case expr if distinctColumnExpressionMap.contains(expr) => + distinctColumnExpressionMap(expr).toAttribute + }.asInstanceOf[AggregateFunction2] + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + val rewrittenAggregateExpression = + AggregateExpression2(rewrittenAggregateFunction, Complete, false) + + val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct) + (rewrittenAggregateExpression -> aggregateFunctionAttribute) + }.unzip + + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort( + namedGroupingAttributes ++ distinctColumnAttributes, + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + rewrittenResultExpressions, + partialMergeAggregate) + + finalAndCompleteAggregate :: Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala new file mode 100644 index 0000000000000..6c49a906c848a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.aggregate + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row + +/** + * The abstract class for implementing user-defined aggregate function. + */ +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def returnDataType: DataType + + /** Indicates if this function is deterministic. */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer. Initial values set by this method should satisfy + * the condition that when merging two buffers with initial values, the new buffer should + * still store initial values. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any +} + +private[sql] abstract class AggregationBuffer( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int) + extends Row { + + override def length: Int = toCatalystConverters.length + + protected val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } +} + +/** + * A Mutable [[Row]] representing an mutable aggregation buffer. + */ +class MutableAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingBuffer: MutableRow) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingBuffer(offsets(i))) + } + + def update(i: Int, value: Any): Unit = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not update ${i}th value in this buffer because it only has $length values.") + } + underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) + } + + override def copy(): MutableAggregationBuffer = { + new MutableAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingBuffer) + } +} + +/** + * A [[Row]] representing an immutable aggregation buffer. + */ +class InputAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingInputBuffer: Row) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingInputBuffer(offsets(i))) + } + + override def copy(): InputAggregationBuffer = { + new InputAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingInputBuffer) + } +} + +/** + * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the + * internal aggregation code path. + * @param children + * @param udaf + */ +case class ScalaUDAF( + children: Seq[Expression], + udaf: UserDefinedAggregateFunction) + extends AggregateFunction2 with Logging { + + require( + children.length == udaf.inputSchema.length, + s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + + s"but ${children.length} are provided.") + + override def nullable: Boolean = true + + override def dataType: DataType = udaf.returnDataType + + override def deterministic: Boolean = udaf.deterministic + + override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) + + override val bufferSchema: StructType = udaf.bufferSchema + + override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + val childrenSchema: StructType = { + val inputFields = children.zipWithIndex.map { + case (child, index) => + StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) + } + StructType(inputFields) + } + + lazy val inputProjection = { + val inputAttributes = childrenSchema.toAttributes + log.debug( + s"Creating MutableProj: $children, inputSchema: $inputAttributes.") + try { + GenerateMutableProjection.generate(children, inputAttributes)() + } catch { + case e: Exception => + log.error("Failed to generate mutable projection, fallback to interpreted", e) + new InterpretedMutableProjection(children, inputAttributes) + } + } + + val inputToScalaConverters: Any => Any = + CatalystTypeConverters.createToScalaConverter(childrenSchema) + + val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToCatalystConverter(field.dataType) + } + + val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToScalaConverter(field.dataType) + } + + lazy val inputAggregateBuffer: InputAggregationBuffer = + new InputAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + lazy val mutableAggregateBuffer: MutableAggregationBuffer = + new MutableAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + + override def initialize(buffer: MutableRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.initialize(mutableAggregateBuffer) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.update( + mutableAggregateBuffer, + inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer1 + inputAggregateBuffer.underlyingInputBuffer = buffer2 + + udaf.merge(mutableAggregateBuffer, inputAggregateBuffer) + } + + override def eval(buffer: InternalRow = null): Any = { + inputAggregateBuffer.underlyingInputBuffer = buffer + + udaf.evaluate(inputAggregateBuffer) + } + + override def toString: String = { + s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = udaf.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 28159cbd5ab96..bfeecbe8b2ab5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2420,7 +2420,7 @@ object functions { * @since 1.5.0 */ def callUDF(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } /** @@ -2449,7 +2449,7 @@ object functions { exprs(i) = cols(i).expr i += 1 } - UnresolvedFunction(udfName, exprs) + UnresolvedFunction(udfName, exprs, isDistinct = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index beee10173fbc4..ab8dce603c117 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.execution.aggregate.Aggregate2Sort import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { var hasGeneratedAgg = false df.queryExecution.executedPlan.foreach { case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case newAggregate: Aggregate2Sort => hasGeneratedAgg = true case _ => } if (!hasGeneratedAgg) { @@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Aggregate with Code generation handling all null values testCodeGen( "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) + Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3dd24130af81a..3d71deb13e884 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext._ @@ -30,6 +31,20 @@ import org.apache.spark.sql.{Row, SQLConf, execution} class PlannerSuite extends SparkFunSuite { + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val planned = + plannedOption.getOrElse( + fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } + + // For the new aggregation code path, there will be three aggregate operator for + // distinct aggregations. + assert( + aggregations.size == 2 || aggregations.size == 3, + s"The plan of query $query does not have partial aggregations.") + } + test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head @@ -42,23 +57,18 @@ class PlannerSuite extends SparkFunSuite { test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - val planned = HashAggregation(query).head - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - - assert(aggregations.size === 2) + testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 31a49a3683338..24a758f53170a 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -833,6 +833,7 @@ abstract class HiveWindowFunctionQueryFileBaseSuite "windowing_adjust_rowcontainer_sz" ) + // Only run those query tests in the realWhileList (do not try other ignored query files). override def testCases: Seq[(String, File)] = super.testCases.filter { case (name, _) => realWhiteList.contains(name) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index f458567e5d7ea..1fe4fe9629c02 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.io.File + import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive @@ -159,4 +161,9 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { "join_reorder4", "join_star" ) + + // Only run those query tests in the realWhileList (do not try other ignored query files). + override def testCases: Seq[(String, File)] = super.testCases.filter { + case (name, _) => realWhiteList.contains(name) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cec7685bb6859..4cdb83c5116f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -451,6 +451,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { DataSinks, Scripts, HashAggregation, + Aggregation, LeftSemiJoin, HashJoin, BasicOperators, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f5574509b0b38..8518e333e8058 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1464,9 +1464,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr)) + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) /* Literals */ case Token("TOK_NULL", Nil) => Literal.create(null, NullType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4d23c7035c03d..3259b50acc765 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -409,7 +409,7 @@ private[hive] case class HiveWindowFunction( private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = AbstractGenericUDAFResolver @@ -441,7 +441,7 @@ private[hive] case class HiveGenericUDAF( /** It is used as a wrapper for the hive functions which uses UDAF interface */ private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = UDAF @@ -550,9 +550,9 @@ private[hive] case class HiveGenericUDTF( private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], - base: AggregateExpression, + base: AggregateExpression1, isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction + extends AggregateFunction1 with HiveInspectors { def this() = this(null, null, null) diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java new file mode 100644 index 0000000000000..5c9d0e97a99c6 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class MyDoubleAvg extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleAvg() { + List inputfields = new ArrayList(); + inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputfields); + + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); + bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType returnDataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, null); + buffer.update(1, 0L); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + buffer.update(1, 1L); + } else { + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + buffer.update(1, buffer.getLong(1) + 1L); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + buffer1.update(0, buffer2.getDouble(0)); + buffer1.update(1, buffer2.getLong(1)); + } else { + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + return null; + } else { + return buffer.getDouble(0) / buffer.getLong(1) + 100.0; + } + } +} + diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java new file mode 100644 index 0000000000000..1d4587a27c787 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.Row; + +public class MyDoubleSum extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleSum() { + List inputfields = new ArrayList(); + inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputfields); + + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType returnDataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, null); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + } else { + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + buffer1.update(0, buffer2.getDouble(0)); + } else { + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + return null; + } else { + return buffer.getDouble(0); + } + } +} diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 new file mode 100644 index 0000000000000..44b2a42cc26c5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 @@ -0,0 +1 @@ +unhex(str) - Converts hexadecimal argument to binary diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 new file mode 100644 index 0000000000000..97af3b812a429 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 @@ -0,0 +1,14 @@ +unhex(str) - Converts hexadecimal argument to binary +Performs the inverse operation of HEX(str). That is, it interprets +each pair of hexadecimal digits in the argument as a number and +converts it to the byte representation of the number. The +resulting characters are returned as a binary string. + +Example: +> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1; +'MySQL' + +The characters in the argument string must be legal hexadecimal +digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters +any nonhexadecimal digits in the argument, it returns NULL. Also, +if there are an odd number of characters a leading 0 is appended. diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e new file mode 100644 index 0000000000000..b4a6f2b692227 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e @@ -0,0 +1 @@ +MySQL 1267 a -4 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 new file mode 100644 index 0000000000000..3a67adaf0a9a8 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 @@ -0,0 +1 @@ +NULL NULL NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala new file mode 100644 index 0000000000000..0375eb79add95 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.execution.aggregate.Aggregate2Sort +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.scalatest.BeforeAndAfterAll +import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} + +class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + + override val sqlContext = TestHive + import sqlContext.implicits._ + + var originalUseAggregate2: Boolean = _ + + override def beforeAll(): Unit = { + originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 + sqlContext.sql("set spark.sql.useAggregate2=true") + val data1 = Seq[(Integer, Integer)]( + (1, 10), + (null, -60), + (1, 20), + (1, 30), + (2, 0), + (null, -10), + (2, -1), + (2, null), + (2, null), + (null, 100), + (3, null), + (null, null), + (3, null)).toDF("key", "value") + data1.write.saveAsTable("agg1") + + val data2 = Seq[(Integer, Integer, Integer)]( + (1, 10, -10), + (null, -60, 60), + (1, 30, -30), + (1, 30, 30), + (2, 1, 1), + (null, -10, 10), + (2, -1, null), + (2, 1, 1), + (2, null, 1), + (null, 100, -10), + (3, null, 3), + (null, null, null), + (3, null, null)).toDF("key", "value1", "value2") + data2.write.saveAsTable("agg2") + + val emptyDF = sqlContext.createDataFrame( + sqlContext.sparkContext.emptyRDD[Row], + StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) + emptyDF.registerTempTable("emptyTable") + + // Register UDAFs + sqlContext.udaf.register("mydoublesum", new MyDoubleSum) + sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg) + } + + override def afterAll(): Unit = { + sqlContext.sql("DROP TABLE IF EXISTS agg1") + sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.dropTempTable("emptyTable") + sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2") + } + + test("empty table") { + // If there is no GROUP BY clause and the table is empty, we will generate a single row. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key), + | COUNT(DISTINCT value) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null, 0) :: Nil) + + // If there is a GROUP BY clause and the table is empty, there is no output. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(value), + | FIRST(value), + | LAST(value), + | MAX(value), + | MIN(value), + | SUM(value), + | COUNT(DISTINCT value) + |FROM emptyTable + |GROUP BY key + """.stripMargin), + Nil) + } + + test("only do grouping") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT DISTINCT value1, key + |FROM agg2 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT value1, key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + } + + test("case in-sensitive resolution") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), kEY - 100 + |FROM agg1 + |GROUP BY Key - 100 + """.stripMargin), + Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT sum(distinct value1), kEY - 100, count(distinct value1) + |FROM agg2 + |GROUP BY Key - 100 + """.stripMargin), + Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT valUe * key - 100 + |FROM agg1 + |GROUP BY vAlue * keY - 100 + """.stripMargin), + Row(-90) :: + Row(-80) :: + Row(-70) :: + Row(-100) :: + Row(-102) :: + Row(null) :: Nil) + } + + test("test average no key in output") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil) + } + + test("test average") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + 1.5, key + 10 + |FROM agg1 + |GROUP BY key + 10 + """.stripMargin), + Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) FROM agg1 + """.stripMargin), + Row(11.125) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(null) + """.stripMargin), + Row(null) :: Nil) + } + + test("udaf") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | mydoubleavg(value), + | avg(value - key), + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) :: + Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) :: + Row(3, null, null, null, null, null) :: + Row(null, null, 110.0, null, null, 10.0) :: Nil) + } + + test("non-AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value) FROM agg1 + """.stripMargin), + Row(89.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(null) + """.stripMargin), + Row(null) :: Nil) + } + + test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1, 20.0) :: + Row(-1.0, 2, -0.5) :: + Row(null, 3, null) :: + Row(30.0, null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoublesum(value + 1.5 * key), + | avg(value - key), + | key, + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(64.5, 19.0, 1, 55.5, 20.0) :: + Row(5.0, -2.5, 2, -7.0, -0.5) :: + Row(null, null, 3, null, null) :: + Row(null, null, null, null, 10.0) :: Nil) + } + + test("single distinct column set") { + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100.0)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + } + + test("test count") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1) :: + Row(1, -60, 1, 1, null) :: + Row(2, 30, 2, 2, 1) :: + Row(2, 1, 2, 2, 2) :: + Row(1, -10, 1, 1, null) :: + Row(0, -1, 1, 1, 2) :: + Row(1, null, 1, 1, 2) :: + Row(1, 100, 1, 1, null) :: + Row(1, null, 2, 2, 3) :: + Row(0, null, 1, 1, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key, + | count(DISTINCT abs(value2)) + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1, 1) :: + Row(1, -60, 1, 1, null, 1) :: + Row(2, 30, 2, 2, 1, 1) :: + Row(2, 1, 2, 2, 2, 1) :: + Row(1, -10, 1, 1, null, 1) :: + Row(0, -1, 1, 1, 2, 0) :: + Row(1, null, 1, 1, 2, 1) :: + Row(1, 100, 1, 1, null, 1) :: + Row(1, null, 2, 2, 3, 1) :: + Row(0, null, 1, 1, null, 0) :: Nil) + } + + test("error handling") { + sqlContext.sql(s"set spark.sql.useAggregate2=false") + var errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | mydoublesum(value), + | mydoubleavg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + + // TODO: once we support Hive UDAF in the new interface, + // we can remove the following two tests. + sqlContext.sql(s"set spark.sql.useAggregate2=true") + errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + + // This will fall back to the old aggregate + val newAggregateOperators = sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).queryExecution.executedPlan.collect { + case agg: Aggregate2Sort => agg + } + val message = + "We should fallback to the old aggregation code path if there is any aggregate function " + + "that cannot be converted to the new interface." + assert(newAggregateOperators.isEmpty, message) + + sqlContext.sql(s"set spark.sql.useAggregate2=true") + } +} From b55a36bc30a628d76baa721d38789fc219eccc27 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 22 Jul 2015 09:32:42 -0700 Subject: [PATCH 10/14] [SPARK-9254] [BUILD] [HOTFIX] sbt-launch-lib.bash should support HTTP/HTTPS redirection Target file(s) can be hosted on CDN nodes. HTTP/HTTPS redirection must be supported to download these files. Author: Cheng Lian Closes #7597 from liancheng/spark-9254 and squashes the following commits: fd266ca [Cheng Lian] Uses `--fail' to make curl return non-zero value and remove garbage output when the download fails a7cbfb3 [Cheng Lian] Supports HTTP/HTTPS redirection --- build/sbt-launch-lib.bash | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 504be48b358fa..7930a38b9674a 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -51,9 +51,13 @@ acquire_sbt_jar () { printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\ + (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\ + mv "${JAR_DL}" "${JAR}" elif [ $(command -v wget) ]; then - (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + (wget --quiet ${URL1} -O "${JAR_DL}" ||\ + (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\ + mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" exit -1 From 76520955fddbda87a5c53d0a394dedc91dce67e8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 22 Jul 2015 11:45:51 -0700 Subject: [PATCH 11/14] [SPARK-9082] [SQL] Filter using non-deterministic expressions should not be pushed down Author: Wenchen Fan Closes #7446 from cloud-fan/filter and squashes the following commits: 330021e [Wenchen Fan] add exists to tree node 2cab68c [Wenchen Fan] more enhance 949be07 [Wenchen Fan] push down part of predicate if possible 3912f84 [Wenchen Fan] address comments 8ce15ca [Wenchen Fan] fix bug 557158e [Wenchen Fan] Filter using non-deterministic expressions should not be pushed down --- .../sql/catalyst/optimizer/Optimizer.scala | 50 +++++++++++++++---- .../optimizer/FilterPushdownSuite.scala | 45 ++++++++++++++++- 2 files changed, 84 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e42f0b9a247e3..d2db3dd3d078e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -541,20 +541,50 @@ object SimplifyFilters extends Rule[LogicalPlan] { * * This heuristic is valid assuming the expression evaluation cost is minimal. */ -object PushPredicateThroughProject extends Rule[LogicalPlan] { +object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, project @ Project(fields, grandChild)) => - val sourceAliases = fields.collect { case a @ Alias(c, _) => - (a.toAttribute: Attribute) -> c - }.toMap - project.copy(child = filter.copy( - replaceAlias(condition, sourceAliases), - grandChild)) + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + val aliasMap = AttributeMap(fields.collect { + case a: Alias => (a.toAttribute, a.child) + }) + + // Split the condition into small conditions by `And`, so that we can push down part of this + // condition without nondeterministic expressions. + val andConditions = splitConjunctivePredicates(condition) + val nondeterministicConditions = andConditions.filter(hasNondeterministic(_, aliasMap)) + + // If there is no nondeterministic conditions, push down the whole condition. + if (nondeterministicConditions.isEmpty) { + project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + } else { + // If they are all nondeterministic conditions, leave it un-changed. + if (nondeterministicConditions.length == andConditions.length) { + filter + } else { + val deterministicConditions = andConditions.filterNot(hasNondeterministic(_, aliasMap)) + // Push down the small conditions without nondeterministic expressions. + val pushedCondition = deterministicConditions.map(replaceAlias(_, aliasMap)).reduce(And) + Filter(nondeterministicConditions.reduce(And), + project.copy(child = Filter(pushedCondition, grandChild))) + } + } + } + + private def hasNondeterministic( + condition: Expression, + sourceAliases: AttributeMap[Expression]) = { + condition.collect { + case a: Attribute if sourceAliases.contains(a) => sourceAliases(a) + }.exists(!_.deterministic) } - private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = { - condition transform { - case a: AttributeReference => sourceAliases.getOrElse(a, a) + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { + condition.transform { + case a: Attribute => sourceAliases.getOrElse(a, a) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index dc28b3ffb59ee..0f1fde2fb0f67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.expressions.{SortOrder, Ascending, Count, Explode} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ @@ -146,6 +146,49 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("nondeterministic: can't push down filter through project") { + val originalQuery = testRelation + .select(Rand(10).as('rand), 'a) + .where('rand > 5 || 'a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + + test("nondeterministic: push down part of filter through project") { + val originalQuery = testRelation + .select(Rand(10).as('rand), 'a) + .where('rand > 5 && 'a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + val correctAnswer = testRelation + .where('a > 5) + .select(Rand(10).as('rand), 'a) + .where('rand > 5) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("nondeterministic: push down filter through project") { + val originalQuery = testRelation + .select(Rand(10).as('rand), 'a) + .where('a > 5 && 'a < 10) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = testRelation + .where('a > 5 && 'a < 10) + .select(Rand(10).as('rand), 'a) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("filters: combines filters") { val originalQuery = testRelation .select('a) From 86f80e2b4759e574fe3eb91695f81b644db87242 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 22 Jul 2015 12:19:59 -0700 Subject: [PATCH 12/14] [SPARK-9165] [SQL] codegen for CreateArray, CreateStruct and CreateNamedStruct JIRA: https://issues.apache.org/jira/browse/SPARK-9165 Author: Yijie Shen Closes #7537 from yjshen/array_struct_codegen and squashes the following commits: 3a6dce6 [Yijie Shen] use infix notion in createArray test 5e90f0a [Yijie Shen] resolve comments: classOf 39cefb8 [Yijie Shen] codegen for createArray createStruct & createNamedStruct --- .../expressions/complexTypeCreator.scala | 65 +++++++++++++++++-- .../expressions/ComplexTypeSuite.scala | 16 +++++ 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index f9fd04c02aaef..20b1eaab8e303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * Returns an Array containing the evaluation of all children expressions. */ -case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -45,14 +47,31 @@ case class CreateArray(children: Seq[Expression]) extends Expression with Codege children.map(_.eval(input)) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + s""" + boolean ${ev.isNull} = false; + $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "array" } /** * Returns a Row containing the evaluation of all children expressions. - * TODO: [[CreateStruct]] does not support codegen. */ -case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -76,6 +95,24 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg InternalRow(children.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "struct" } @@ -84,7 +121,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -122,5 +159,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with InternalRow(valExprs.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); + """ + + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "named_struct" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e3042143632aa..a8aee8f634e03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -117,6 +117,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) } + test("CreateArray") { + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow) + checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow) + checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow) + + val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType) + val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType) + val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) From e0b7ba59a1ace9b78a1ad6f3f07fe153db20b52c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 22 Jul 2015 13:02:43 -0700 Subject: [PATCH 13/14] [SPARK-9024] Unsafe HashJoin/HashOuterJoin/HashSemiJoin This PR introduce unsafe version (using UnsafeRow) of HashJoin, HashOuterJoin and HashSemiJoin, including the broadcast one and shuffle one (except FullOuterJoin, which is better to be implemented using SortMergeJoin). It use HashMap to store UnsafeRow right now, will change to use BytesToBytesMap for better performance (in another PR). Author: Davies Liu Closes #7480 from davies/unsafe_join and squashes the following commits: 6294b1e [Davies Liu] fix projection 10583f1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join dede020 [Davies Liu] fix test 84c9807 [Davies Liu] address comments a05b4f6 [Davies Liu] support UnsafeRow in LeftSemiJoinBNL and BroadcastNestedLoopJoin 611d2ed [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 9481ae8 [Davies Liu] return UnsafeRow after join() ca2b40f [Davies Liu] revert unrelated change 68f5cd9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 0f4380d [Davies Liu] ada a comment 69e38f5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 1a40f02 [Davies Liu] refactor ab1690f [Davies Liu] address comments 60371f2 [Davies Liu] use UnsafeRow in SemiJoin a6c0b7d [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 184b852 [Davies Liu] fix style 6acbb11 [Davies Liu] fix tests 95d0762 [Davies Liu] remove println bea4a50 [Davies Liu] Unsafe HashJoin --- .../sql/catalyst/expressions/UnsafeRow.java | 50 ++++++++++- .../execution/UnsafeExternalRowSorter.java | 10 +-- .../catalyst/expressions/BoundAttribute.scala | 19 ++++- .../sql/catalyst/expressions/Projection.scala | 34 +++++++- .../execution/joins/BroadcastHashJoin.scala | 2 +- .../joins/BroadcastHashOuterJoin.scala | 32 ++----- .../joins/BroadcastLeftSemiJoinHash.scala | 5 +- .../joins/BroadcastNestedLoopJoin.scala | 37 +++++--- .../spark/sql/execution/joins/HashJoin.scala | 43 ++++++++-- .../sql/execution/joins/HashOuterJoin.scala | 82 +++++++++++++++--- .../sql/execution/joins/HashSemiJoin.scala | 74 ++++++++++------ .../sql/execution/joins/HashedRelation.scala | 85 ++++++++++++++++++- .../sql/execution/joins/LeftSemiJoinBNL.scala | 3 + .../execution/joins/LeftSemiJoinHash.scala | 4 +- .../execution/joins/ShuffledHashJoin.scala | 2 +- .../joins/ShuffledHashOuterJoin.scala | 13 +-- .../sql/execution/rowFormatConverters.scala | 21 +++-- .../org/apache/spark/sql/UnsafeRowSuite.scala | 4 +- .../execution/joins/HashedRelationSuite.scala | 49 ++++++++--- .../spark/unsafe/hash/Murmur3_x86_32.java | 10 ++- 20 files changed, 444 insertions(+), 135 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 6ce03a48e9538..7f08bf7b742dc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -20,10 +20,11 @@ import java.io.IOException; import java.io.OutputStream; -import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.UTF8String; @@ -354,7 +355,7 @@ public double getDouble(int i) { * This method is only supported on UnsafeRows that do not use ObjectPools. */ @Override - public InternalRow copy() { + public UnsafeRow copy() { if (pool != null) { throw new UnsupportedOperationException( "Copy is not supported for UnsafeRows that use object pools"); @@ -404,8 +405,51 @@ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOExcepti } } + @Override + public int hashCode() { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UnsafeRow) { + UnsafeRow o = (UnsafeRow) other; + return (sizeInBytes == o.sizeInBytes) && + ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, + sizeInBytes); + } + return false; + } + + /** + * Returns the underlying bytes for this UnsafeRow. + */ + public byte[] getBytes() { + if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET + && (((byte[]) baseObject).length == sizeInBytes)) { + return (byte[]) baseObject; + } else { + byte[] bytes = new byte[sizeInBytes]; + PlatformDependent.copyMemory(baseObject, baseOffset, bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); + return bytes; + } + } + + // This is for debugging + @Override + public String toString() { + StringBuilder build = new StringBuilder("["); + for (int i = 0; i < sizeInBytes; i += 8) { + build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)); + build.append(','); + } + build.append(']'); + return build.toString(); + } + @Override public boolean anyNull() { - return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); + return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index d1d81c87bb052..39fd6e1bc6d13 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -28,11 +28,10 @@ import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ObjectPool; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; @@ -176,12 +175,7 @@ public Iterator sort(Iterator inputIterator) throws IO */ public static boolean supportsSchema(StructType schema) { // TODO: add spilling note to explain why we do this for now: - for (StructField field : schema.fields()) { - if (!UnsafeColumnWriter.canEmbed(field.dataType())) { - return false; - } - } - return true; + return UnsafeProjection.canSupport(schema); } private static final class RowComparator extends RecordComparator { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index b10a3c877434b..4a13b687bf4ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ /** @@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, $dataType]" - override def eval(input: InternalRow): Any = input(ordinal) + // Use special getter for primitive types (for UnsafeRow) + override def eval(input: InternalRow): Any = { + if (input.isNullAt(ordinal)) { + null + } else { + dataType match { + case BooleanType => input.getBoolean(ordinal) + case ByteType => input.getByte(ordinal) + case ShortType => input.getShort(ordinal) + case IntegerType | DateType => input.getInt(ordinal) + case LongType | TimestampType => input.getLong(ordinal) + case FloatType => input.getFloat(ordinal) + case DoubleType => input.getDouble(ordinal) + case _ => input.get(ordinal) + } + } + } override def name: String = s"i[$ordinal]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 24b01ea55110e..69758e653eba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -83,12 +83,42 @@ abstract class UnsafeProjection extends Projection { } object UnsafeProjection { + + /* + * Returns whether UnsafeProjection can support given StructType, Array[DataType] or + * Seq[Expression]. + */ + def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) + def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) + def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) + + /** + * Returns an UnsafeProjection for given StructType. + */ def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) - def create(fields: Seq[DataType]): UnsafeProjection = { + /** + * Returns an UnsafeProjection for given Array of DataTypes. + */ + def create(fields: Array[DataType]): UnsafeProjection = { val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) + create(exprs) + } + + /** + * Returns an UnsafeProjection for given sequence of Expressions (bounded). + */ + def create(exprs: Seq[Expression]): UnsafeProjection = { GenerateUnsafeProjection.generate(exprs) } + + /** + * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to + * `inputSchema`. + */ + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { + create(exprs.map(BindReferences.bindReference(_, inputSchema))) + } } /** @@ -96,6 +126,8 @@ object UnsafeProjection { */ case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { + def this(schema: StructType) = this(schema.fields.map(_.dataType)) + private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => new BoundReference(idx, dt, true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 7ffdce60d2955..abaa4a6ce86a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -62,7 +62,7 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + val hashed = buildHashRelation(input.iterator) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index ab757fc7de6cd..c9d1a880f4ef4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.joins +import scala.concurrent._ +import scala.concurrent.duration._ + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils -import scala.collection.JavaConversions._ -import scala.concurrent._ -import scala.concurrent.duration._ - /** * :: DeveloperApi :: * Performs a outer hash join for two child relations. When the output RDD of this operator is @@ -58,28 +57,11 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - private[this] lazy val (buildPlan, streamedPlan) = joinType match { - case RightOuter => (left, right) - case LeftOuter => (right, left) - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - - private[this] lazy val (buildKeys, streamedKeys) = joinType match { - case RightOuter => (leftKeys, rightKeys) - case LeftOuter => (rightKeys, leftKeys) - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - // buildHashTable uses code-generated rows as keys, which are not serializable - val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output)) + val hashed = buildHashRelation(input.iterator) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) @@ -89,21 +71,21 @@ case class BroadcastHashOuterJoin( streamedPlan.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow() val hashTable = broadcastRelation.value - val keyGenerator = newProjection(streamedKeys, streamedPlan.output) + val keyGenerator = streamedKeyGenerator joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST)) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey)) }) case RightOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow) }) case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 2750f58b005ac..f71c0ce352904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -40,15 +40,14 @@ case class BroadcastLeftSemiJoinHash( val buildIter = right.execute().map(_.copy()).collect().toIterator if (condition.isEmpty) { - // rowKey may be not serializable (from codegen) - val hashSet = buildKeyHashSet(buildIter, copy = true) + val hashSet = buildKeyHashSet(buildIter) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) } } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val hashRelation = buildHashRelation(buildIter) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 60b4266fad8b1..700636966f8be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -44,6 +44,19 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + + @transient private[this] lazy val resultProjection: Projection = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = { @@ -74,6 +87,7 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -86,11 +100,11 @@ case class BroadcastNestedLoopJoin( val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case _ => @@ -100,9 +114,9 @@ case class BroadcastNestedLoopJoin( (streamRowMatched, joinType, buildSide) match { case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() + matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy() case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() + matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy() case _ => } } @@ -110,12 +124,9 @@ case class BroadcastNestedLoopJoin( } val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } + val allIncludedBroadcastTuples = includedBroadcastTuples.fold( + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + )(_ ++ _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -127,8 +138,10 @@ case class BroadcastNestedLoopJoin( while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => + buf += resultProjection(new JoinedRow(leftNulls, rel(i))) + case (LeftOuter | FullOuter, BuildLeft) => + buf += resultProjection(new JoinedRow(rel(i), rightNulls)) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index ff85ea3f6a410..ae34409bcfcca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,11 +44,20 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - @transient protected lazy val buildSideKeyGenerator: Projection = - newProjection(buildKeys, buildPlan.output) + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe - @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = - newMutableProjection(streamedKeys, streamedPlan.output) + @transient protected lazy val streamSideKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newMutableProjection(streamedKeys, streamedPlan.output)() + } protected def hashJoin( streamIter: Iterator[InternalRow], @@ -61,8 +70,17 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow2 + private[this] val resultProjection: Projection = { + if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } - private[this] val joinKeys = streamSideKeyGenerator() + private[this] val joinKeys = streamSideKeyGenerator override final def hasNext: Boolean = (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || @@ -74,7 +92,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 - ret + resultProjection(ret) } /** @@ -89,8 +107,9 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashedRelation.get(joinKeys.currentValue) + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashedRelation.get(key) } } @@ -103,4 +122,12 @@ trait HashJoin { } } } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) + } else { + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 74a7db7761758..6bf2f82954046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -38,7 +38,7 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan -override def outputPartitioning: Partitioning = joinType match { + override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) @@ -59,6 +59,49 @@ override def outputPartitioning: Partitioning = joinType match { } } + protected[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + + protected[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && joinType != FullOuter + && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe + + protected[this] def streamedKeyGenerator(): Projection = { + if (supportUnsafe) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newProjection(streamedKeys, streamedPlan.output) + } + } + + @transient private[this] lazy val resultProjection: Projection = { + if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -76,16 +119,20 @@ override def outputPartitioning: Partitioning = joinType match { rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + } + } else { + List.empty } if (temp.isEmpty) { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } } ret.iterator @@ -97,17 +144,21 @@ override def outputPartitioning: Partitioning = joinType match { joinedRow: JoinedRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy() + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => + resultProjection(joinedRow).copy() + } + } else { + List.empty } if (temp.isEmpty) { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } } ret.iterator @@ -159,6 +210,7 @@ override def outputPartitioning: Partitioning = joinType match { } } + // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { @@ -178,4 +230,12 @@ override def outputPartitioning: Partitioning = joinType match { hashTable } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) + } else { + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 1b983bc3a90f9..7f49264d40354 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -32,34 +32,45 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - @transient protected lazy val rightKeyGenerator: Projection = - newProjection(rightKeys, right.output) + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(left.schema)) + } + + override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = supportUnsafe + + @transient protected lazy val leftKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newMutableProjection(leftKeys, left.output)() + } - @transient protected lazy val leftKeyGenerator: () => MutableProjection = - newMutableProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newMutableProjection(rightKeys, right.output)() + } @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet( - buildIter: Iterator[InternalRow], - copy: Boolean): java.util.Set[InternalRow] = { + protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() var currentRow: InternalRow = null // Create a Hash set of buildKeys + val rightKey = rightKeyGenerator while (buildIter.hasNext) { currentRow = buildIter.next() - val rowKey = rightKeyGenerator(currentRow) + val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { - if (copy) { - hashSet.add(rowKey.copy()) - } else { - // rowKey may be not serializable (from codegen) - hashSet.add(rowKey) - } + hashSet.add(rowKey.copy()) } } } @@ -67,25 +78,34 @@ trait HashSemiJoin { } protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator() - val joinedRow = new JoinedRow + streamIter: Iterator[InternalRow], + hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator streamIter.filter(current => { - lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue) - !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists { - (build: InternalRow) => boundCondition(joinedRow(current, build)) - } + val key = joinKeys(current) + !key.anyNull && hashSet.contains(key) }) } + protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, rightKeys, right) + } else { + HashedRelation(buildIter, newProjection(rightKeys, right.output)) + } + } + protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator() + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow - streamIter.filter(current => { - !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue) - }) + streamIter.filter { current => + val key = joinKeys(current) + lazy val rowBuffer = hashedRelation.get(key) + !key.anyNull && rowBuffer != null && rowBuffer.exists { + (row: InternalRow) => boundCondition(joinedRow(current, row)) + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6b51f5d4151d3..8d5731afd59b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.joins -import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.CompactBuffer @@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR } } - // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. @@ -148,3 +148,80 @@ private[joins] object HashedRelation { } } } + + +/** + * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a + * sequence of values. + * + * TODO(davies): use BytesToBytesMap + */ +private[joins] final class UnsafeHashedRelation( + private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization + + override def get(key: InternalRow): CompactBuffer[InternalRow] = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + // Thanks to type eraser + hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + } + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +private[joins] object UnsafeHashedRelation { + + def apply( + input: Iterator[InternalRow], + buildKeys: Seq[Expression], + buildPlan: SparkPlan, + sizeEstimate: Int = 64): HashedRelation = { + val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output)) + apply(input, boundedKeys, buildPlan.schema, sizeEstimate) + } + + // Used for tests + def apply( + input: Iterator[InternalRow], + buildKeys: Seq[Expression], + rowSchema: StructType, + sizeEstimate: Int): HashedRelation = { + + // TODO: Use BytesToBytesMap. + val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) + val toUnsafe = UnsafeProjection.create(rowSchema) + val keyGenerator = UnsafeProjection.create(buildKeys) + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + val currentRow = input.next() + val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { + currentRow.asInstanceOf[UnsafeRow] + } else { + toUnsafe(currentRow) + } + val rowKey = keyGenerator(unsafeRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(rowKey.copy(), newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += unsafeRow.copy() + } + } + + new UnsafeHashedRelation(hashTable) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index db5be9f453674..4443455ef11fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -39,6 +39,9 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output + override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 9eaac817d9268..874712a4e739f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -43,10 +43,10 @@ case class LeftSemiJoinHash( protected override def doExecute(): RDD[InternalRow] = { right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter, copy = false) + val hashSet = buildKeyHashSet(buildIter) hashSemiJoin(streamIter, hashSet) } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val hashRelation = buildHashRelation(buildIter) hashSemiJoin(streamIter, hashRelation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5439e10a60b2a..948d0ccebceb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,7 +45,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + val hashed = buildHashRelation(buildIter) hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index ab0a6ad56acde..f54f1edd38ec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,24 +50,25 @@ case class ShuffledHashOuterJoin( // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val keyGenerator = newProjection(leftKeys, left.output) + val hashed = buildHashRelation(rightIter) + val keyGenerator = streamedKeyGenerator() leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) }) case RightOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val keyGenerator = newProjection(rightKeys, right.output) + val hashed = buildHashRelation(leftIter) + val keyGenerator = streamedKeyGenerator() rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) }) case FullOuter => + // TODO(davies): use UnsafeRow val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 421d510e6782d..29f3beb3cb3c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.rules.Rule */ @DeveloperApi case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { + + require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") + override def output: Seq[Attribute] = child.output override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false @@ -93,11 +96,19 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { } case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, then convert everything - // to unsafe rows - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + // If this operator's children produce both unsafe and safe rows, + // convert everything unsafe rows if all the schema of them are support by UnsafeRow + if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } } } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 3854dc1b7a3d1..d36e2639376e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -31,7 +31,7 @@ class UnsafeRowSuite extends SparkFunSuite { test("writeToStream") { val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) val arrayBackedUnsafeRow: UnsafeRow = - UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row) + UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) val bytesFromArrayBackedRow: Array[Byte] = { val baos = new ByteArrayOutputStream() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9d9858b1c6151..9dd2220f0967e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.{StructField, StructType, IntegerType} import org.apache.spark.util.collection.CompactBuffer @@ -35,13 +37,13 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) assert(hashed.get(InternalRow(10)) === null) val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) - assert(hashed.get(data(2)) == data2) + assert(hashed.get(data(2)) === data2) } test("UniqueKeyHashedRelation") { @@ -49,15 +51,40 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) == data(0)) - assert(uniqHashed.getValue(data(1)) == data(1)) - assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(InternalRow(10)) == null) + assert(uniqHashed.getValue(data(0)) === data(0)) + assert(uniqHashed.getValue(data(1)) === data(1)) + assert(uniqHashed.getValue(data(2)) === data(2)) + assert(uniqHashed.getValue(InternalRow(10)) === null) + } + + test("UnsafeHashedRelation") { + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val buildKey = Seq(BoundReference(0, IntegerType, false)) + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1) + assert(hashed.isInstanceOf[UnsafeHashedRelation]) + + val toUnsafeKey = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafeKey(_).copy()).toArray + assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) + + val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) + data2 += unsafeData(2).copy() + assert(hashed.get(unsafeData(2)) === data2) + + val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) + .asInstanceOf[UnsafeHashedRelation] + assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed2.get(unsafeData(2)) === data2) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 85cd02469adb7..61f483ced3217 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -44,12 +44,16 @@ public int hashInt(int input) { return fmix(h1, 4); } - public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) { + public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; - for (int offset = 0; offset < lengthInBytes; offset += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } From 8486cd853104255b4eb013860bba793eef4e74e7 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 22 Jul 2015 13:06:01 -0700 Subject: [PATCH 14/14] [SPARK-9224] [MLLIB] OnlineLDA Performance Improvements In-place updates, reduce number of transposes, and vectorize operations in OnlineLDA implementation. Author: Feynman Liang Closes #7454 from feynmanliang/OnlineLDA-perf-improvements and squashes the following commits: 78b0f5a [Feynman Liang] Make in-place variables vals, fix BLAS error 7f62a55 [Feynman Liang] --amend c62cb1e [Feynman Liang] Outer product for stats, revert Range slicing aead650 [Feynman Liang] Range slice, in-place update, reduce transposes --- .../spark/mllib/clustering/LDAOptimizer.scala | 59 +++++++++---------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 8e5154b902d1d..b960ae6c0708d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,15 +19,15 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron} -import breeze.numerics.{digamma, exp, abs} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.numerics.{abs, digamma, exp} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer -import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector} import org.apache.spark.rdd.RDD /** @@ -370,7 +370,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { iteration += 1 val k = this.k val vocabSize = this.vocabSize - val Elogbeta = dirichletExpectation(lambda) + val Elogbeta = dirichletExpectation(lambda).t val expElogbeta = exp(Elogbeta) val alpha = this.alpha val gammaShape = this.gammaShape @@ -385,41 +385,36 @@ final class OnlineLDAOptimizer extends LDAOptimizer { case v => throw new IllegalArgumentException("Online LDA does not support vector type " + v.getClass) } + if (!ids.isEmpty) { + + // Initialize the variational distribution q(theta|gamma) for the mini-batch + val gammad: BDV[Double] = + new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K + val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K + + val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + var meanchange = 1D + val ctsVector = new BDV[Double](cts) // ids + + // Iterate between gamma and phi until convergence + while (meanchange > 1e-3) { + val lastgamma = gammad.copy + // K K * ids ids + gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha + expElogthetad := exp(digamma(gammad) - digamma(sum(gammad))) + phinorm := expElogbetad * expElogthetad :+ 1e-100 + meanchange = sum(abs(gammad - lastgamma)) / k + } - // Initialize the variational distribution q(theta|gamma) for the mini-batch - var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K - var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K - var expElogthetad = exp(Elogthetad) // 1 * K - val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids - - var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids - var meanchange = 1D - val ctsVector = new BDV[Double](cts).t // 1 * ids - - // Iterate between gamma and phi until convergence - while (meanchange > 1e-3) { - val lastgamma = gammad - // 1*K 1 * ids ids * k - gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha - Elogthetad = digamma(gammad) - digamma(sum(gammad)) - expElogthetad = exp(Elogthetad) - phinorm = expElogthetad * expElogbetad + 1e-100 - meanchange = sum(abs(gammad - lastgamma)) / k - } - - val m1 = expElogthetad.t - val m2 = (ctsVector / phinorm).t.toDenseVector - var i = 0 - while (i < ids.size) { - stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i) - i += 1 + stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix } } Iterator(stat) } val statsSum: BDM[Double] = stats.reduce(_ += _) - val batchResult = statsSum :* expElogbeta + val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)