From c8709dcfd1237ffa19ee9286e99ddf2718a616d8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 10:28:59 -0700 Subject: [PATCH 01/17] [SPARK-7956] [SQL] Use Janino to compile SQL expressions into bytecode In order to reduce the overhead of codegen, this PR switch to use Janino to compile SQL expressions into bytecode. After this, the time used to compile a SQL expression is decreased from 100ms to 5ms, which is necessary to turn on codegen for general workload, also tests. cc rxin Author: Davies Liu Closes #6479 from davies/janino and squashes the following commits: cc689f5 [Davies Liu] remove globalLock 262d848 [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino eec3a33 [Davies Liu] address comments from Josh f37c8c3 [Davies Liu] fix DecimalType and cast to String 202298b [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino a21e968 [Davies Liu] fix style 0ed3dc6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino 551a851 [Davies Liu] fix tests c3bdffa [Davies Liu] remove print 6089ce5 [Davies Liu] change logging level 7e46ac3 [Davies Liu] fix style d8f0f6c [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino da4926a [Davies Liu] fix tests 03660f3 [Davies Liu] WIP: use Janino to compile Java source f2629cd [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino f7d66cf [Davies Liu] use template based string for codegen --- .../spark/util/collection/OpenHashSet.scala | 12 +- pom.xml | 10 - project/SparkBuild.scala | 11 - sql/catalyst/pom.xml | 16 +- .../sql/catalyst/expressions/UnsafeRow.java | 101 +-- .../org/apache/spark/sql/BaseMutableRow.java | 68 ++ .../scala/org/apache/spark/sql/BaseRow.java | 190 +++++ .../expressions/codegen/CodeGenerator.scala | 797 +++++++++--------- .../codegen/GenerateMutableProjection.scala | 87 +- .../codegen/GenerateOrdering.scala | 146 ++-- .../codegen/GeneratePredicate.scala | 44 +- .../codegen/GenerateProjection.scala | 316 ++++--- .../expressions/codegen/package.scala | 6 - .../ExpressionEvaluationSuite.scala | 15 +- .../GeneratedEvaluationSuite.scala | 5 +- .../GeneratedMutableEvaluationSuite.scala | 7 +- .../org/apache/spark/sql/DataFrameSuite.scala | 11 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 162 ++-- 18 files changed, 1116 insertions(+), 888 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 1501111a06655..64e7102e3654c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -20,6 +20,8 @@ package org.apache.spark.util.collection import scala.reflect._ import com.google.common.hash.Hashing +import org.apache.spark.annotation.Private + /** * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never * removed. @@ -37,7 +39,7 @@ import com.google.common.hash.Hashing * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ -private[spark] +@Private class OpenHashSet[@specialized(Long, Int) T: ClassTag]( initialCapacity: Int, loadFactor: Double) @@ -110,6 +112,14 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( rehashIfNeeded(k, grow, move) } + def union(other: OpenHashSet[T]): OpenHashSet[T] = { + val iterator = other.iterator + while (iterator.hasNext) { + add(iterator.next()) + } + this + } + /** * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. * The caller is responsible for calling rehashIfNeeded. diff --git a/pom.xml b/pom.xml index d03d33bf02468..bcb6ef96a1206 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,6 @@ 2.3.4-spark 1.6 spark - 2.0.1 0.21.1 shaded-protobuf 1.7.10 @@ -1217,15 +1216,6 @@ -target ${java.version} - - - - org.scalamacros - paradise_${scala.version} - ${scala.macros.version} - - diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9a849639233bc..f65031fe25ac2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -178,9 +178,6 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) - /* Catalyst macro settings */ - enable(Catalyst.settings)(catalyst) - /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -275,14 +272,6 @@ object OldDeps { ) } -object Catalyst { - lazy val settings = Seq( - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), - // Quasiquotes break compiling scala doc... - // TODO: Investigate fixing this. - sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) -} - object SQL { lazy val settings = Seq( initialCommands in console := diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index bf0a7327a58a2..f4b1cc3a4ffe7 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -36,10 +36,6 @@ - - org.scala-lang - scala-compiler - org.scala-lang scala-reflect @@ -67,6 +63,11 @@ scalacheck_${scala.binary.version} test + + org.codehaus.janino + janino + 2.7.8 + target/scala-${scala.binary.version}/classes @@ -108,13 +109,6 @@ !scala-2.11 - - - org.scalamacros - quasiquotes_${scala.binary.version} - ${scala.macros.version} - - diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index bb546b3086b33..ec97fe603c44f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,23 +17,25 @@ package org.apache.spark.sql.catalyst.expressions; -import scala.collection.Map; +import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + import scala.collection.Seq; import scala.collection.mutable.ArraySeq; -import javax.annotation.Nullable; -import java.math.BigDecimal; -import java.sql.Date; -import java.util.*; - import org.apache.spark.sql.Row; +import org.apache.spark.sql.BaseMutableRow; import org.apache.spark.sql.types.DataType; -import static org.apache.spark.sql.types.DataTypes.*; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.UTF8String; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; +import static org.apache.spark.sql.types.DataTypes.*; + /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -49,7 +51,7 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow implements MutableRow { +public final class UnsafeRow extends BaseMutableRow { private Object baseObject; private long baseOffset; @@ -227,21 +229,11 @@ public int size() { return numFields; } - @Override - public int length() { - return size(); - } - @Override public StructType schema() { return schema; } - @Override - public Object apply(int i) { - return get(i); - } - @Override public Object get(int i) { assertIndexIsValid(i); @@ -339,60 +331,7 @@ public String getString(int i) { return getUTF8String(i).toString(); } - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - @Override - public Seq getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public List getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Map getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public java.util.Map getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } @Override public Row copy() { @@ -412,24 +351,4 @@ public Seq toSeq() { } return values; } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java new file mode 100644 index 0000000000000..acec2bf4520f2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import org.apache.spark.sql.catalyst.expressions.MutableRow; + +public abstract class BaseMutableRow extends BaseRow implements MutableRow { + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setInt(int ordinal, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setLong(int ordinal, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setDouble(int ordinal, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setShort(int ordinal, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setByte(int ordinal, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setFloat(int ordinal, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setString(int ordinal, String value) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java new file mode 100644 index 0000000000000..d138b43a3482b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import java.math.BigDecimal; +import java.sql.Date; +import java.util.List; + +import scala.collection.Seq; +import scala.collection.mutable.ArraySeq; + +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; + +public abstract class BaseRow implements Row { + + @Override + final public int length() { + return size(); + } + + @Override + public boolean anyNull() { + final int n = size(); + for (int i=0; i < n; i++) { + if (isNullAt(i)) { + return true; + } + } + return false; + } + + @Override + public StructType schema() { throw new UnsupportedOperationException(); } + + @Override + final public Object apply(int i) { + return get(i); + } + + @Override + public int getInt(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getBoolean(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public String getString(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal getDecimal(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Seq getSeq(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public List getList(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.Map getMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { + throw new UnsupportedOperationException(); + } + + @Override + public java.util.Map getJavaMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Row getStruct(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(String fieldName) { + throw new UnsupportedOperationException(); + } + + @Override + public int fieldIndex(String name) { + throw new UnsupportedOperationException(); + } + + @Override + public Row copy() { + final int n = size(); + Object[] arr = new Object[n]; + for (int i = 0; i < n; i++) { + arr[i] = get(i); + } + return new GenericRow(arr); + } + + @Override + public Seq toSeq() { + final int n = size(); + final ArraySeq values = new ArraySeq(n); + for (int i = 0; i < n; i++) { + values.update(i, get(i)); + } + return values; + } + + @Override + public String toString() { + return mkString("[", ",", "]"); + } + + @Override + public String mkString() { + return toSeq().mkString(); + } + + @Override + public String mkString(String sep) { + return toSeq().mkString(sep); + } + + @Override + public String mkString(String start, String sep, String end) { + return toSeq().mkString(start, sep, end); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 36964af68dd8d..cd604121b7dd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import com.google.common.cache.{CacheLoader, CacheBuilder} - +import scala.collection.mutable import scala.language.existentials +import com.google.common.cache.{CacheBuilder, CacheLoader} +import org.codehaus.janino.ClassBodyEvaluator + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ @@ -36,23 +38,15 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * expressions. */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ - - import scala.tools.reflect.ToolBox - - protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() - protected val rowType = typeOf[Row] - protected val mutableRowType = typeOf[MutableRow] - protected val genericRowType = typeOf[GenericRow] - protected val genericMutableRowType = typeOf[GenericMutableRow] - - protected val projectionType = typeOf[Projection] - protected val mutableProjectionType = typeOf[MutableProjection] + protected val rowType = classOf[Row].getName + protected val stringType = classOf[UTF8String].getName + protected val decimalType = classOf[Decimal].getName + protected val exprType = classOf[Expression].getName + protected val mutableRowType = classOf[MutableRow].getName + protected val genericMutableRowType = classOf[GenericMutableRow].getName private val curId = new java.util.concurrent.atomic.AtomicInteger() - private val javaSeparator = "$" /** * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. @@ -74,6 +68,20 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Binds an input expression to a given input schema */ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + /** + * Compile the Java source code into a Java class, using Janino. + * + * It will track the time used to compile + */ + protected def compile(code: String): Class[_] = { + val startTime = System.nanoTime() + val clazz = new ClassBodyEvaluator(code).getClazz() + val endTime = System.nanoTime() + def timeMs: Double = (endTime - startTime).toDouble / 1000000 + logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms") + clazz + } + /** * A cache of generated classes. * @@ -87,7 +95,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin .maximumSize(1000) .build( new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = globalLock.synchronized { + override def load(in: InType): OutType = { val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() @@ -110,8 +118,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - protected def freshName(prefix: String): TermName = { - newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") + protected def freshName(prefix: String): String = { + s"$prefix${curId.getAndIncrement}" } /** @@ -125,32 +133,51 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ protected case class EvaluatedExpression( - code: Seq[Tree], - nullTerm: TermName, - primitiveTerm: TermName, - objectTerm: TermName) + code: String, + nullTerm: String, + primitiveTerm: String, + objectTerm: String) + + /** + * A context for codegen, which is used to bookkeeping the expressions those are not supported + * by codegen, then they are evaluated directly. The unsupported expression is appended at the + * end of `references`, the position of it is kept in the code, used to access and evaluate it. + */ + protected class CodeGenContext { + /** + * Holding all the expressions those do not support codegen, will be evaluated directly. + */ + val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + } + + /** + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen + */ + def newCodeGenContext(): CodeGenContext = { + new CodeGenContext() + } /** * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that * can be used to determine the result of evaluating the expression on an input row. */ - def expressionEvaluator(e: Expression): EvaluatedExpression = { + def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = { val primitiveTerm = freshName("primitiveTerm") val nullTerm = freshName("nullTerm") val objectTerm = freshName("objectTerm") implicit class Evaluate1(e: Expression) { - def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${f(eval.primitiveTerm)} - """.children + def castOrNull(f: String => String, dataType: DataType): String = { + val eval = expressionEvaluator(e, ctx) + eval.code + + s""" + boolean $nullTerm = ${eval.nullTerm}; + ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; + if (!$nullTerm) { + $primitiveTerm = ${f(eval.primitiveTerm)}; + } + """ } } @@ -163,529 +190,505 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * * @param f a function from two primitive term names to a tree that evaluates them. */ - def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = + def evaluate(f: (String, String) => String): String = evaluateAs(expressions._1.dataType)(f) - def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { + def evaluateAs(resultType: DataType)(f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (expressions._1.dataType != expressions._2.dataType) { log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") } - val eval1 = expressionEvaluator(expressions._1) - val eval2 = expressionEvaluator(expressions._2) + val eval1 = expressionEvaluator(expressions._1, ctx) + val eval2 = expressionEvaluator(expressions._2, ctx) val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - eval1.code ++ eval2.code ++ - q""" - val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} - val $primitiveTerm: ${termForType(resultType)} = - if($nullTerm) { - ${defaultPrimitive(resultType)} - } else { - $resultCode.asInstanceOf[${termForType(resultType)}] - } - """.children : Seq[Tree] + eval1.code + eval2.code + + s""" + boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}; + ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)}; + if(!$nullTerm) { + $primitiveTerm = (${primitiveForType(resultType)})($resultCode); + } + """ } } - val inputTuple = newTermName(s"i") + val inputTuple = "i" // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { + val primitiveEvaluation: PartialFunction[Expression, String] = { case b @ BoundReference(ordinal, dataType, nullable) => - val nullValue = q"$inputTuple.isNullAt($ordinal)" - q""" - val $nullTerm: Boolean = $nullValue - val $primitiveTerm: ${termForType(dataType)} = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${getColumn(inputTuple, dataType, ordinal)} - """.children + s""" + final boolean $nullTerm = $inputTuple.isNullAt($ordinal); + final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ? + ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)}); + """ case expressions.Literal(null, dataType) => - q""" - val $nullTerm = true - val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] - """.children - - case expressions.Literal(value: Boolean, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: UTF8String, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = - org.apache.spark.sql.types.UTF8String(${value.getBytes}) - """.children - - case expressions.Literal(value: Int, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: Long, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case Cast(e @ BinaryType(), StringType) => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) - """.children + s""" + final boolean $nullTerm = true; + ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; + """ + + case expressions.Literal(value: UTF8String, StringType) => + val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}" + s""" + final boolean $nullTerm = false; + ${stringType} $primitiveTerm = + new ${stringType}().set(${arr}); + """ + + case expressions.Literal(value, FloatType) => + s""" + final boolean $nullTerm = false; + float $primitiveTerm = ${value}f; + """ + + case expressions.Literal(value, dt @ DecimalType()) => + s""" + final boolean $nullTerm = false; + ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value); + """ + + case expressions.Literal(value, dataType) => + s""" + final boolean $nullTerm = false; + ${primitiveForType(dataType)} $primitiveTerm = $value; + """ + + case Cast(child @ BinaryType(), StringType) => + child.castOrNull(c => + s"new ${stringType}().set($c)", + StringType) case Cast(child @ DateType(), StringType) => child.castOrNull(c => - q"""org.apache.spark.sql.types.UTF8String( + s"""new ${stringType}().set( org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", StringType) - case Cast(child @ NumericType(), IntegerType) => - child.castOrNull(c => q"$c.toInt", IntegerType) + case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt) - case Cast(child @ NumericType(), LongType) => - child.castOrNull(c => q"$c.toLong", LongType) + case Cast(child @ DecimalType(), IntegerType) => + child.castOrNull(c => s"($c).toInt()", IntegerType) - case Cast(child @ NumericType(), DoubleType) => - child.castOrNull(c => q"$c.toDouble", DoubleType) + case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + child.castOrNull(c => s"($c).to${termForType(dt)}()", dt) - case Cast(child @ NumericType(), FloatType) => - child.castOrNull(c => q"$c.toFloat", FloatType) + case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt) // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. case Cast(e, StringType) if e.dataType != TimestampType => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) - """.children + e.castOrNull(c => + s"new ${stringType}().set(String.valueOf($c))", + StringType) case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => - q""" - java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], - $eval2.asInstanceOf[Array[Byte]]) - """ + s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" } case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } - - /* TODO: Fix null semantics. - case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => - val eval = expressionEvaluator(e1) - - val checks = list.map { - case expressions.Literal(v: String, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - case expressions.Literal(v: Int, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - } - - val funcName = newTermName(s"isIn${curId.getAndIncrement()}") - - q""" - def $funcName: Boolean = { - ..${eval.code} - if(${eval.nullTerm}) return false - ..$checks - return false - } - val $nullTerm = false - val $primitiveTerm = $funcName - """.children - */ + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" } case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" } case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" } case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" } case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" } case And(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm} == false) { + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + s""" + ${eval1.code} + boolean $nullTerm = false; + boolean $primitiveTerm = false; + + if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { } else { - ..${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm} == false) { + ${eval2.code} + if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true + $primitiveTerm = true; } else { - $nullTerm = true + $nullTerm = true; } } - """.children + """ case Or(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false + s""" + ${eval1.code} + boolean $nullTerm = false; + boolean $primitiveTerm = false; if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true + $primitiveTerm = true; } else { - ..${eval2.code} + ${eval2.code} if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true + $primitiveTerm = true; } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false + $primitiveTerm = false; } else { - $nullTerm = true + $nullTerm = true; } } - """.children + """ case Not(child) => // Uh, bad function name... - child.castOrNull(c => q"!$c", BooleanType) - - case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } - case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } - case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } + child.castOrNull(c => s"!$c", BooleanType) + + case Add(e1 @ DecimalType(), e2 @ DecimalType()) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" } + case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" } + case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" } + case Divide(e1 @ DecimalType(), e2 @ DecimalType()) => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = null; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm}); + } + """ + case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = 0; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm}); + } + """ + + case Add(e1, e2) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" } + case Subtract(e1, e2) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" } + case Multiply(e1, e2) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" } case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm} + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = 0; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}; } - """.children - + """ case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $nullTerm = false - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm} + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = 0; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}; } - """.children + """ case IsNotNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} - """.children + val eval = expressionEvaluator(e, ctx) + s""" + ${eval.code} + boolean $nullTerm = false; + boolean $primitiveTerm = !${eval.nullTerm}; + """ case IsNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} - """.children - - case c @ Coalesce(children) => - q""" - var $nullTerm = true - var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} - """.children ++ + val eval = expressionEvaluator(e, ctx) + s""" + ${eval.code} + boolean $nullTerm = false; + boolean $primitiveTerm = ${eval.nullTerm}; + """ + + case e @ Coalesce(children) => + s""" + boolean $nullTerm = true; + ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; + """ + children.map { c => - val eval = expressionEvaluator(c) - q""" + val eval = expressionEvaluator(c, ctx) + s""" if($nullTerm) { - ..${eval.code} + ${eval.code} if(!${eval.nullTerm}) { - $nullTerm = false - $primitiveTerm = ${eval.primitiveTerm} + $nullTerm = false; + $primitiveTerm = ${eval.primitiveTerm}; } } """ - } + }.mkString("\n") - case i @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition) - val trueEval = expressionEvaluator(trueValue) - val falseEval = expressionEvaluator(falseValue) + case e @ expressions.If(condition, trueValue, falseValue) => + val condEval = expressionEvaluator(condition, ctx) + val trueEval = expressionEvaluator(trueValue, ctx) + val falseEval = expressionEvaluator(falseValue, ctx) - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} - ..${condEval.code} + s""" + boolean $nullTerm = false; + ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; + ${condEval.code} if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ..${trueEval.code} - $nullTerm = ${trueEval.nullTerm} - $primitiveTerm = ${trueEval.primitiveTerm} + ${trueEval.code} + $nullTerm = ${trueEval.nullTerm}; + $primitiveTerm = ${trueEval.primitiveTerm}; } else { - ..${falseEval.code} - $nullTerm = ${falseEval.nullTerm} - $primitiveTerm = ${falseEval.primitiveTerm} + ${falseEval.code} + $nullTerm = ${falseEval.nullTerm}; + $primitiveTerm = ${falseEval.primitiveTerm}; } - """.children + """ case NewSet(elementType) => - q""" - val $nullTerm = false - val $primitiveTerm = new ${hashSetForType(elementType)}() - """.children + s""" + boolean $nullTerm = false; + ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}(); + """ case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item) - val setEval = expressionEvaluator(set) + val itemEval = expressionEvaluator(item, ctx) + val setEval = expressionEvaluator(set, ctx) val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = hashSetForType(elementType) - itemEval.code ++ setEval.code ++ - q""" - if (!${itemEval.nullTerm}) { - ${setEval.primitiveTerm} - .asInstanceOf[${hashSetForType(elementType)}] - .add(${itemEval.primitiveTerm}) + itemEval.code + setEval.code + + s""" + if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { + (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); } - - val $nullTerm = false - val $primitiveTerm = ${setEval.primitiveTerm} - """.children + boolean $nullTerm = false; + ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm}; + """ case CombineSets(left, right) => - val leftEval = expressionEvaluator(left) - val rightEval = expressionEvaluator(right) + val leftEval = expressionEvaluator(left, ctx) + val rightEval = expressionEvaluator(right, ctx) val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = hashSetForType(elementType) - leftEval.code ++ rightEval.code ++ - q""" - val $nullTerm = false - var $primitiveTerm: ${hashSetForType(elementType)} = null - - { - val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val iterator = rightSet.iterator - while (iterator.hasNext) { - leftSet.add(iterator.next()) - } - $primitiveTerm = leftSet - } - """.children + leftEval.code + rightEval.code + + s""" + boolean $nullTerm = false; + ${htype} $primitiveTerm = + (${htype})${leftEval.primitiveTerm}; + $primitiveTerm.union((${htype})${rightEval.primitiveTerm}); + """ - case MaxOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) + case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} + $nullTerm = ${eval2.nullTerm}; + $primitiveTerm = ${eval2.primitiveTerm}; } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} + $nullTerm = ${eval1.nullTerm}; + $primitiveTerm = ${eval1.primitiveTerm}; } else { if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} + $primitiveTerm = ${eval1.primitiveTerm}; } else { - $primitiveTerm = ${eval2.primitiveTerm} + $primitiveTerm = ${eval2.primitiveTerm}; } } - """.children + """ - case MinOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) + case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} + $nullTerm = ${eval2.nullTerm}; + $primitiveTerm = ${eval2.primitiveTerm}; } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} + $nullTerm = ${eval1.nullTerm}; + $primitiveTerm = ${eval1.primitiveTerm}; } else { if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} + $primitiveTerm = ${eval1.primitiveTerm}; } else { - $primitiveTerm = ${eval2.primitiveTerm} + $primitiveTerm = ${eval2.primitiveTerm}; } } - """.children + """ case UnscaledValue(child) => - val childEval = expressionEvaluator(child) - - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: Long = if (!$nullTerm) { - ${childEval.primitiveTerm}.toUnscaledLong - } else { - ${defaultPrimitive(LongType)} - } - """.children + val childEval = expressionEvaluator(child, ctx) + + childEval.code + + s""" + boolean $nullTerm = ${childEval.nullTerm}; + long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong(); + """ case MakeDecimal(child, precision, scale) => - val childEval = expressionEvaluator(child) + val eval = expressionEvaluator(child, ctx) - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: org.apache.spark.sql.types.Decimal = - ${defaultPrimitive(DecimalType())} + eval.code + + s""" + boolean $nullTerm = ${eval.nullTerm}; + org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())}; if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal() - $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) - $nullTerm = $primitiveTerm == null + $primitiveTerm = new org.apache.spark.sql.types.Decimal(); + $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale); + $nullTerm = $primitiveTerm == null; } - """.children + """ } // If there was no match in the partial function above, we fall back on calling the interpreted // expression evaluator. - val code: Seq[Tree] = + val code: String = primitiveEvaluation.lift.apply(e).getOrElse { - log.debug(s"No rules to generate $e") - val tree = reify { e } - q""" - val $objectTerm = $tree.eval(i) - val $nullTerm = $objectTerm == null - val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] - """.children - } - - // Only inject debugging code if debugging is turned on. - val debugCode = - if (debugLogging) { - val localLogger = log - val localLoggerTree = reify { localLogger } - q""" - $localLoggerTree.debug( - ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) - """ :: Nil - } else { - Nil + logError(s"No rules to generate $e") + ctx.references += e + s""" + /* expression: ${e} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + boolean $nullTerm = $objectTerm == null; + ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; + if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm; + """ } - EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) + EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) } - protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { + protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = { dataType match { - case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" - case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" + case StringType => s"(${stringType})$inputRow.apply($ordinal)" + case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)" + case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)" } } protected def setColumn( - destinationRow: TermName, + destinationRow: String, dataType: DataType, ordinal: Int, - value: TermName) = { + value: String): String = { dataType match { - case StringType => q"$destinationRow.update($ordinal, $value)" + case StringType => s"$destinationRow.update($ordinal, $value)" case dt: DataType if isNativeType(dt) => - q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" - case _ => q"$destinationRow.update($ordinal, $value)" + s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case _ => s"$destinationRow.update($ordinal, $value)" } } - protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") - protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + protected def accessorForType(dt: DataType) = dt match { + case IntegerType => "getInt" + case other => s"get${termForType(dt)}" + } + + protected def mutatorForType(dt: DataType) = dt match { + case IntegerType => "setInt" + case other => s"set${termForType(dt)}" + } - protected def hashSetForType(dt: DataType) = dt match { - case IntegerType => typeOf[IntegerHashSet] - case LongType => typeOf[LongHashSet] + protected def hashSetForType(dt: DataType): String = dt match { + case IntegerType => classOf[IntegerHashSet].getName + case LongType => classOf[LongHashSet].getName case unsupportedType => sys.error(s"Code generation not support for hashset of type $unsupportedType") } - protected def primitiveForType(dt: DataType) = dt match { - case IntegerType => "Int" + protected def primitiveForType(dt: DataType): String = dt match { + case IntegerType => "int" + case LongType => "long" + case ShortType => "short" + case ByteType => "byte" + case DoubleType => "double" + case FloatType => "float" + case BooleanType => "boolean" + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case DateType => "int" + case TimestampType => "java.sql.Timestamp" + case _ => "Object" + } + + protected def defaultPrimitive(dt: DataType): String = dt match { + case BooleanType => "false" + case FloatType => "-1.0f" + case ShortType => "-1" + case LongType => "-1" + case ByteType => "-1" + case DoubleType => "-1.0" + case IntegerType => "-1" + case DateType => "-1" + case dt: DecimalType => "null" + case StringType => "null" + case _ => "null" + } + + protected def termForType(dt: DataType): String = dt match { + case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" case ByteType => "Byte" case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case StringType => "org.apache.spark.sql.types.UTF8String" - } - - protected def defaultPrimitive(dt: DataType) = dt match { - case BooleanType => ru.Literal(Constant(false)) - case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" - case ShortType => ru.Literal(Constant(-1.toShort)) - case LongType => ru.Literal(Constant(-1L)) - case ByteType => ru.Literal(Constant(-1.toByte)) - case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" - case IntegerType => ru.Literal(Constant(-1)) - case DateType => ru.Literal(Constant(-1)) - case _ => ru.Literal(Constant(null)) - } - - protected def termForType(dt: DataType) = dt match { - case n: AtomicType => n.tag - case _ => typeTag[Any] + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case DateType => "Integer" + case TimestampType => "java.sql.Timestamp" + case _ => "Object" } /** * List of data types that have special accessors and setters in [[Row]]. */ protected val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) /** * Returns true if the data type has a special accessor and setter in [[Row]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 840260703ab74..638b53fe0fe2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +// MutableProjection is not accessible in Java +abstract class BaseMutableProjection extends MutableProjection {} + /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new * input [[Row]] for a fixed set of [[Expression Expressions]]. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ - - val mutableRowName = newTermName("mutableRow") protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -36,41 +35,61 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu in.map(BindReferences.bindReference(_, inputSchema)) protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { - val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => - val evaluationCode = expressionEvaluator(e) - - evaluationCode.code :+ - q""" - if(${evaluationCode.nullTerm}) - mutableRow.setNullAt($i) - else - ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} - """ - } + val ctx = newCodeGenContext() + val projectionCode = expressions.zipWithIndex.map { case (e, i) => + val evaluationCode = expressionEvaluator(e, ctx) + evaluationCode.code + + s""" + if(${evaluationCode.nullTerm}) + mutableRow.setNullAt($i); + else + ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; + """ + }.mkString("\n") + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); + } + + class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { - val code = - q""" - () => { new $mutableProjectionType { + private $exprType[] expressions = null; + private $mutableRowType mutableRow = null; - private[this] var $mutableRowName: $mutableRowType = - new $genericMutableRowType(${expressions.size}) + public SpecificProjection($exprType[] expr) { + expressions = expr; + mutableRow = new $genericMutableRowType(${expressions.size}); + } - def target(row: $mutableRowType): $mutableProjectionType = { - $mutableRowName = row - this - } + public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { + mutableRow = row; + return this; + } - /* Provide immutable access to the last projected row. */ - def currentValue: $rowType = mutableRow + /* Provide immutable access to the last projected row. */ + public Row currentValue() { + return mutableRow; + } - def apply(i: $rowType): $rowType = { - ..$projectionCode - mutableRow - } - } } - """ + public Object apply(Object _i) { + Row i = (Row) _i; + $projectionCode - log.debug(s"code for ${expressions.mkString(",")}:\n$code") - toolBox.eval(code).asInstanceOf[() => MutableProjection] + return mutableRow; + } + } + """ + + + logDebug(s"code for ${expressions.mkString(",")}:\n$code") + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + () => { + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection] + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index b129c0d898bb7..0ff840dab393c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -18,18 +18,29 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging +import org.apache.spark.annotation.Private +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BinaryType, StringType, NumericType} +import org.apache.spark.sql.types.{BinaryType, NumericType} + +/** + * Inherits some default implementation for Java from `Ordering[Row]` + */ +@Private +class BaseOrdering extends Ordering[Row] { + def compare(a: Row, b: Row): Int = { + throw new UnsupportedOperationException + } +} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of * [[Expression Expressions]]. */ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { - import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = @@ -38,73 +49,90 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { val a = newTermName("a") val b = newTermName("b") - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child) - val evalB = expressionEvaluator(order.child) + val ctx = newCodeGenContext() + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = expressionEvaluator(order.child, ctx) + val evalB = expressionEvaluator(order.child, ctx) + val asc = order.direction == Ascending val compare = order.child.dataType match { case BinaryType => - q""" - val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm} - val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm} - var i = 0 - while (i < x.length && i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - i = i+1 - } - return x.length - y.length - """ + s""" + { + byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm}; + byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm}; + int j = 0; + while (j < x.length && j < y.length) { + if (x[j] != y[j]) return x[j] - y[j]; + j = j + 1; + } + int d = x.length - y.length; + if (d != 0) { + return d; + } + }""" case _: NumericType => - q""" - val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} - if(comp != 0) { - return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} - } - """ - case StringType => - if (order.direction == Ascending) { - q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + s""" + if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) { + if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) { + return ${if (asc) "1" else "-1"}; + } else { + return ${if (asc) "-1" else "1"}; + } + }""" + case _ => + s""" + int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm}); + if (comp != 0) { + return ${if (asc) "comp" else "-comp"}; + }""" + } + + s""" + i = $a; + ${evalA.code} + i = $b; + ${evalB.code} + if (${evalA.nullTerm} && ${evalB.nullTerm}) { + // Nothing + } else if (${evalA.nullTerm}) { + return ${if (order.direction == Ascending) "-1" else "1"}; + } else if (${evalB.nullTerm}) { + return ${if (order.direction == Ascending) "1" else "-1"}; } else { - q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + $compare } + """ + }.mkString("\n") + + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificOrdering generate($exprType[] expr) { + return new SpecificOrdering(expr); } - q""" - i = $a - ..${evalA.code} - i = $b - ..${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { - // Nothing - } else if (${evalA.nullTerm}) { - return ${if (order.direction == Ascending) q"-1" else q"1"} - } else if (${evalB.nullTerm}) { - return ${if (order.direction == Ascending) q"1" else q"-1"} - } else { - $compare + class SpecificOrdering extends ${typeOf[BaseOrdering]} { + + private $exprType[] expressions = null; + + public SpecificOrdering($exprType[] expr) { + expressions = expr; } - """ - } - val q"class $orderingName extends $orderingType { ..$body }" = reify { - class SpecificOrdering extends Ordering[Row] { - val o = ordering - } - }.tree.children.head - - val code = q""" - class $orderingName extends $orderingType { - ..$body - def compare(a: $rowType, b: $rowType): Int = { - var i: $rowType = null // Holds current row being evaluated. - ..$comparisons - return 0 + @Override + public int compare(Row a, Row b) { + Row i = null; // Holds current row being evaluated. + $comparisons + return 0; } - } - new $orderingName() - """ + }""" + logDebug(s"Generated Ordering: $code") - toolBox.eval(code).asInstanceOf[Ordering[Row]] + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 40e163024360e..fb18769f00da3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +/** + * Interface for generated predicate + */ +abstract class Predicate { + def eval(r: Row): Boolean +} + /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. */ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) @@ -32,17 +37,34 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { BindReferences.bindReference(in, inputSchema) protected def create(predicate: Expression): ((Row) => Boolean) = { - val cEval = expressionEvaluator(predicate) + val ctx = newCodeGenContext() + val eval = expressionEvaluator(predicate, ctx) + val code = s""" + import org.apache.spark.sql.Row; - val code = - q""" - (i: $rowType) => { - ..${cEval.code} - if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + public SpecificPredicate generate($exprType[] expr) { + return new SpecificPredicate(expr); + } + + class SpecificPredicate extends ${classOf[Predicate].getName} { + private final $exprType[] expressions; + public SpecificPredicate($exprType[] expr) { + expressions = expr; + } + + @Override + public boolean eval(Row i) { + ${eval.code} + return !${eval.nullTerm} && ${eval.primitiveTerm}; } - """ + }""" + + logDebug(s"Generated predicate '$predicate':\n$code") - log.debug(s"Generated predicate '$predicate':\n$code") - toolBox.eval(code).asInstanceOf[Row => Boolean] + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate] + (r: Row) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 31c63a79ebc8c..d5be1fc12e0f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.BaseMutableRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProject extends Projection {} /** * Generates bytecode that produces a new [[Row]] object based on a fixed set of input @@ -27,7 +32,6 @@ import org.apache.spark.sql.types._ * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = @@ -38,201 +42,183 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { // Make Mutablility optional... protected def create(expressions: Seq[Expression]): Projection = { - val tupleLength = ru.Literal(Constant(expressions.length)) - val lengthDef = q"final val length = $tupleLength" - - /* TODO: Configurable... - val nullFunctions = - q""" - private final val nullSet = new org.apache.spark.util.collection.BitSet(length) - final def setNullAt(i: Int) = nullSet.set(i) - final def isNullAt(i: Int) = nullSet.get(i) - """ - */ - - val nullFunctions = - q""" - private[this] var nullBits = new Array[Boolean](${expressions.size}) - override def setNullAt(i: Int) = { nullBits(i) = true } - override def isNullAt(i: Int) = nullBits(i) - """.children - - val tupleElements = expressions.zipWithIndex.flatMap { + val ctx = newCodeGenContext() + val columns = expressions.zipWithIndex.map { case (e, i) => - val elementName = newTermName(s"c$i") - val evaluatedExpression = expressionEvaluator(e) - val iLit = ru.Literal(Constant(i)) + s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n" + }.mkString("\n ") - q""" - var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + val initColumns = expressions.zipWithIndex.map { + case (e, i) => + val eval = expressionEvaluator(e, ctx) + s""" { - ..${evaluatedExpression.code} - if(${evaluatedExpression.nullTerm}) - setNullAt($iLit) - else { - nullBits($iLit) = false - $elementName = ${evaluatedExpression.primitiveTerm} + // column$i + ${eval.code} + nullBits[$i] = ${eval.nullTerm}; + if(!${eval.nullTerm}) { + c$i = ${eval.primitiveTerm}; } } - """.children : Seq[Tree] - } + """ + }.mkString("\n") - val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" - val applyFunction = { - val cases = (0 until expressions.size).map { i => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) + val getCases = (0 until expressions.size).map { i => + s"case $i: return c$i;" + }.mkString("\n ") - q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" - } - q"override def apply(i: Int): Any = { ..$cases; $accessorFailure }" - } - - val updateFunction = { - val cases = expressions.zipWithIndex.map {case (e, i) => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) - - q""" - if(i == $ordinal) { - if(value == null) { - setNullAt(i) - } else { - nullBits(i) = false - $elementName = value.asInstanceOf[${termForType(e.dataType)}] - } - return - }""" - } - q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" - } + val updateCases = expressions.zipWithIndex.map { case (e, i) => + s"case $i: { c$i = (${termForType(e.dataType)})value; return;}" + }.mkString("\n ") val specificAccessorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // getString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) return $elementName" :: Nil - case _ => Nil - } - dataType match { - // Row() need this interface to compile - case StringType => - q""" - override def getString(i: Int): String = { - $accessorFailure - }""" - case other => - q""" - override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" + val cases = expressions.zipWithIndex.map { + case (e, i) if e.dataType == dataType => + s"case $i: return c$i;" + case _ => "" + }.mkString("\n ") + if (cases.count(_ != '\n') > 0) { + s""" + @Override + public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) { + if (isNullAt(i)) { + return ${defaultPrimitive(dataType)}; + } + switch (i) { + $cases + } + return ${defaultPrimitive(dataType)}; + }""" + } else { + "" } - } + }.mkString("\n") val specificMutatorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // setString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil - case _ => Nil - } - dataType match { - case StringType => - // MutableRow() need this interface to compile - q""" - override def setString(i: Int, value: String) { - $accessorFailure - }""" - case other => - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { - ..$ifStatements; - $accessorFailure - }""" + val cases = expressions.zipWithIndex.map { + case (e, i) if e.dataType == dataType => + s"case $i: { c$i = value; return; }" + case _ => "" + }.mkString("\n") + if (cases.count(_ != '\n') > 0) { + s""" + @Override + public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) { + nullBits[i] = false; + switch (i) { + $cases + } + }""" + } else { + "" } - } + }.mkString("\n") val hashValues = expressions.zipWithIndex.map { case (e, i) => - val elementName = newTermName(s"c$i") + val col = newTermName(s"c$i") val nonNull = e.dataType match { - case BooleanType => q"if ($elementName) 0 else 1" - case ByteType | ShortType | IntegerType => q"$elementName.toInt" - case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" - case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case BooleanType => s"$col ? 0 : 1" + case ByteType | ShortType | IntegerType | DateType => s"$col" + case LongType => s"$col ^ ($col >>> 32)" + case FloatType => s"Float.floatToIntBits($col)" case DoubleType => - q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" - case _ => q"$elementName.hashCode" + s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)" + case _ => s"$col.hashCode()" } - q"if (isNullAt($i)) 0 else $nonNull" + s"isNullAt($i) ? 0 : ($nonNull)" } - val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + val hashUpdates: String = hashValues.map( v => + s""" + result *= 37; result += $v;""" + ).mkString("\n") - val hashCodeFunction = - q""" - override def hashCode(): Int = { - var result: Int = 37 - ..$hashUpdates - result - } + val columnChecks = expressions.zipWithIndex.map { case (e, i) => + s""" + if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) { + return false; + } """ + }.mkString("\n") - val columnChecks = (0 until expressions.size).map { i => - val elementName = newTermName(s"c$i") - q"if (this.$elementName != specificType.$elementName) return false" + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); } - val equalsFunction = - q""" - override def equals(other: Any): Boolean = other match { - case specificType: SpecificRow => - ..$columnChecks - return true - case other => super.equals(other) - } - """ + class SpecificProjection extends ${typeOf[BaseProject]} { + private $exprType[] expressions = null; + + public SpecificProjection($exprType[] expr) { + expressions = expr; + } - val allColumns = (0 until expressions.size).map { i => - val iLit = ru.Literal(Constant(i)) - q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + @Override + public Object apply(Object r) { + return new SpecificRow(expressions, (Row) r); + } } - val copyFunction = - q"override def copy() = new $genericRowType(Array[Any](..$allColumns))" - - val toSeqFunction = - q"override def toSeq: Seq[Any] = Seq(..$allColumns)" - - val classBody = - nullFunctions ++ ( - lengthDef +: - applyFunction +: - updateFunction +: - equalsFunction +: - hashCodeFunction +: - copyFunction +: - toSeqFunction +: - (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) - - val code = q""" - final class SpecificRow(i: $rowType) extends $mutableRowType { - ..$classBody + final class SpecificRow extends ${typeOf[BaseMutableRow]} { + + $columns + + public SpecificRow($exprType[] expressions, Row i) { + $initColumns + } + + public int size() { return ${expressions.length};} + private boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } + + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; + } + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } + } + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + @Override + public boolean equals(Object other) { + if (other instanceof Row) { + Row row = (Row) other; + if (row.length() != size()) return false; + $columnChecks + return true; + } + return super.equals(other); + } + } """ - log.debug( - s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") - toolBox.eval(code).asInstanceOf[Projection] + logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 528e38a50a740..7f1b12cdd5800 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,12 +27,6 @@ import org.apache.spark.util.Utils */ package object codegen { - /** - * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala - * 2.10. - */ - protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock - /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index b6927485f42bf..5df528770ca6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -344,7 +344,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation("abdef" cast TimestampType, null) checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) - checkEvaluation(Literal(1) cast LongType, 1) + checkEvaluation(Literal(1) cast LongType, 1.toLong) checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) @@ -363,13 +363,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) + Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), + 5.toLong) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) + ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), + 0.toShort) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) + DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), + 0.toShort) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Literal(true) cast StringType, "true") @@ -509,9 +512,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val seconds = millis * 1000 + 2 val ts = new Timestamp(millis) val tss = new Timestamp(seconds) - checkEvaluation(Cast(ts, ShortType), 15) + checkEvaluation(Cast(ts, ShortType), 15.toShort) checkEvaluation(Cast(ts, IntegerType), 15) - checkEvaluation(Cast(ts, LongType), 15) + checkEvaluation(Cast(ts, LongType), 15.toLong) checkEvaluation(Cast(ts, FloatType), 15.002f) checkEvaluation(Cast(ts, DoubleType), 15.002) checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index d7c437095e395..8cfd853afa35f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -32,11 +32,12 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() } catch { case e: Throwable => - val evaluated = GenerateProjection.expressionEvaluator(expression) + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) fail( s""" |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} + |${evaluated.code} |$e """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index a40324b008e16..9ab1f7d7ad0db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -28,7 +28,8 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - lazy val evaluated = GenerateProjection.expressionEvaluator(expression) + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) val plan = try { GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) @@ -37,7 +38,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { fail( s""" |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} + |${evaluated.code} |$e """.stripMargin) } @@ -49,7 +50,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { s""" |Mismatched hashCodes for values: $actual, $expectedRow |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code.mkString("\n")} + |${evaluated.code} """.stripMargin) } if (actual != expectedRow) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9aaec2b064d76..b41b1b77d049e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -451,10 +451,13 @@ class DataFrameSuite extends QueryTest { test("SPARK-6899") { val originalValue = TestSQLContext.conf.codegenEnabled TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + try{ + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + } finally { + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } test("SPARK-7133: Implement struct, array, and map field accessor") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 63f7d314fb699..55b68d8e2283c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -184,77 +184,79 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(df, expectedResults) } - // Just to group rows. - testCodeGen( - "SELECT key FROM testData3x GROUP BY key", - (1 to 100).map(Row(_))) - // COUNT - testCodeGen( - "SELECT key, count(value) FROM testData3x GROUP BY key", - (1 to 100).map(i => Row(i, 3))) - testCodeGen( - "SELECT count(key) FROM testData3x", - Row(300) :: Nil) - // COUNT DISTINCT ON int - testCodeGen( - "SELECT value, count(distinct key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 1))) - testCodeGen( - "SELECT count(distinct key) FROM testData3x", - Row(100) :: Nil) - // SUM - testCodeGen( - "SELECT value, sum(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 3 * i))) - testCodeGen( - "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", - Row(5050 * 3, 5050 * 3.0) :: Nil) - // AVERAGE - testCodeGen( - "SELECT value, avg(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT avg(key) FROM testData3x", - Row(50.5) :: Nil) - // MAX - testCodeGen( - "SELECT value, max(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT max(key) FROM testData3x", - Row(100) :: Nil) - // MIN - testCodeGen( - "SELECT value, min(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT min(key) FROM testData3x", - Row(1) :: Nil) - // Some combinations. - testCodeGen( - """ - |SELECT - | value, - | sum(key), - | max(key), - | min(key), - | avg(key), - | count(key), - | count(distinct key) - |FROM testData3x - |GROUP BY value - """.stripMargin, - (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) - testCodeGen( - "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", - Row(100, 1, 50.5, 300, 100) :: Nil) - // Aggregate with Code generation handling all null values - testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) - - dropTempTable("testData3x") - setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + try { + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + } finally { + dropTempTable("testData3x") + setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } test("Add Parser of SQL COALESCE()") { @@ -463,9 +465,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val codegenbefore = conf.codegenEnabled setConf(SQLConf.EXTERNAL_SORT, "false") setConf(SQLConf.CODEGEN_ENABLED, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + try{ + sortTest() + } finally { + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } } test("SPARK-6927 external sorting with codegen on") { @@ -473,9 +478,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val codegenbefore = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") setConf(SQLConf.EXTERNAL_SORT, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + try { + sortTest() + } finally { + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } } test("limit") { From df7da07a86a30c684d5b07d955f1045a66715e3a Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Thu, 4 Jun 2015 11:30:07 -0700 Subject: [PATCH 02/17] [SPARK-7969] [SQL] Added a DataFrame.drop function that accepts a Column reference. Added a `DataFrame.drop` function that accepts a `Column` reference rather than a `String`, and added associated unit tests. Basically iterates through the `DataFrame` to find a column with an expression that is equivalent to that of the `Column` argument supplied to the function. Author: Mike Dusenberry Closes #6585 from dusenberrymw/SPARK-7969_Drop_method_on_Dataframes_should_handle_Column and squashes the following commits: 514727a [Mike Dusenberry] Updating the @since tag of the drop(Column) function doc to reflect version 1.4.1 instead of 1.4.0. 2f1bb4e [Mike Dusenberry] Adding an additional assert statement to the 'drop column after join' unit test in order to make sure the correct column was indeed left over. 6bf7c0e [Mike Dusenberry] Minor code formatting change. e583888 [Mike Dusenberry] Adding more Python doctests for the df.drop with column reference function to test joined datasets that have columns with the same name. 5f74401 [Mike Dusenberry] Updating DataFrame.drop with column reference function to use logicalPlan.output to prevent ambiguities resulting from columns with the same name. Also added associated unit tests for joined datasets with duplicate column names. 4b8bbe8 [Mike Dusenberry] Adding Python support for Dataframe.drop with a Column reference. 986129c [Mike Dusenberry] Added a DataFrame.drop function that accepts a Column reference rather than a String, and added associated unit tests. Basically iterates through the DataFrame to find a column with an expression that is equivalent to one supplied to the function. --- python/pyspark/sql/dataframe.py | 21 +++++++-- .../org/apache/spark/sql/DataFrame.scala | 16 +++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 45 +++++++++++++++++++ 3 files changed, 79 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7673153abe0e2..03b01a1136e45 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1189,15 +1189,30 @@ def withColumnRenamed(self, existing, new): @since(1.4) @ignore_unicode_prefix - def drop(self, colName): + def drop(self, col): """Returns a new :class:`DataFrame` that drops the specified column. - :param colName: string, name of the column to drop. + :param col: a string name of the column to drop, or a + :class:`Column` to drop. >>> df.drop('age').collect() [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.drop(df.age).collect() + [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect() + [Row(age=5, height=85, name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect() + [Row(age=5, name=u'Bob', height=85)] """ - jdf = self._jdf.drop(colName) + if isinstance(col, basestring): + jdf = self._jdf.drop(col) + elif isinstance(col, Column): + jdf = self._jdf.drop(col._jc) + else: + raise TypeError("col should be a string or a Column") return DataFrame(jdf, self.sql_ctx) @since(1.3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 034d887901975..d1a54ada7b191 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1082,6 +1082,22 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] with a column dropped. + * This version of drop accepts a Column rather than a name. + * This is a no-op if the DataFrame doesn't have a column + * with an equivalent expression. + * @group dfops + * @since 1.4.1 + */ + def drop(col: Column): DataFrame = { + val attrs = this.logicalPlan.output + val colsAfterDrop = attrs.filter { attr => + attr != col.expr + }.map(attr => Column(attr)) + select(colsAfterDrop : _*) + } + /** * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. * This is an alias for `distinct`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b41b1b77d049e..8e81dacb8660f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -334,6 +334,51 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name) === Seq("key", "value")) } + test("drop column using drop with column reference") { + val col = testData("key") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().map(x => Row(x.getString(1))).toSeq) + assert(df.schema.map(_.name) === Seq("value")) + } + + test("drop unknown column (no-op) with column reference") { + val col = Column("random") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop unknown column with same name (no-op) with column reference") { + val col = Column("key") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop column after join with duplicate columns using column reference") { + val newSalary = salary.withColumnRenamed("personId", "id") + val col = newSalary("id") + // this join will result in duplicate "id" columns + val joinedDf = person.join(newSalary, + person("id") === newSalary("id"), "inner") + // remove only the "id" column that was associated with newSalary + val df = joinedDf.drop(col) + checkAnswer( + df, + joinedDf.collect().map { + case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) => + Row(id, name, age, salary) + }.toSeq) + assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) + assert(df("id") == person("id")) + } + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") From cd3176bd86eafa09a5e11baf3636861c1f46e844 Mon Sep 17 00:00:00 2001 From: Thomas Omans Date: Thu, 4 Jun 2015 11:32:03 -0700 Subject: [PATCH 03/17] [SPARK-7743] [SQL] Parquet 1.7 Resolves [SPARK-7743](https://issues.apache.org/jira/browse/SPARK-7743). Trivial changes of versions, package names, as well as a small issue in `ParquetTableOperations.scala` ```diff - val readContext = getReadSupport(configuration).init( + val readContext = ParquetInputFormat.getReadSupportInstance(configuration).init( ``` Since ParquetInputFormat.getReadSupport was made package private in the latest release. Thanks -- Thomas Omans Author: Thomas Omans Closes #6597 from eggsby/SPARK-7743 and squashes the following commits: 2df0d1b [Thomas Omans] [SPARK-7743] [SQL] Upgrading parquet version to 1.7.0 --- .../src/main/python/parquet_inputformat.py | 2 +- pom.xml | 6 ++-- sql/core/pom.xml | 4 +-- .../DirectParquetOutputCommitter.scala | 6 ++-- .../spark/sql/parquet/ParquetConverter.scala | 6 ++-- .../spark/sql/parquet/ParquetFilters.scala | 10 +++--- .../spark/sql/parquet/ParquetRelation.scala | 10 +++--- .../sql/parquet/ParquetTableOperations.scala | 34 +++++++++---------- .../sql/parquet/ParquetTableSupport.scala | 12 +++---- .../spark/sql/parquet/ParquetTypes.scala | 14 ++++---- .../apache/spark/sql/parquet/newParquet.scala | 8 ++--- .../sql/parquet/timestamp/NanoTime.scala | 4 +-- .../apache/spark/sql/sources/commands.scala | 2 +- sql/core/src/test/resources/log4j.properties | 10 +++--- .../sql/parquet/ParquetFilterSuite.scala | 4 +-- .../spark/sql/parquet/ParquetIOSuite.scala | 18 +++++----- .../sql/parquet/ParquetSchemaSuite.scala | 2 +- 17 files changed, 76 insertions(+), 76 deletions(-) diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 96ddac761d698..e1fd85b082c08 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -51,7 +51,7 @@ parquet_rdd = sc.newAPIHadoopFile( path, - 'parquet.avro.AvroParquetInputFormat', + 'org.apache.parquet.avro.AvroParquetInputFormat', 'java.lang.Void', 'org.apache.avro.generic.IndexedRecord', valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') diff --git a/pom.xml b/pom.xml index bcb6ef96a1206..abb9b55400340 100644 --- a/pom.xml +++ b/pom.xml @@ -136,7 +136,7 @@ 0.13.1 10.10.1.1 - 1.6.0rc3 + 1.7.0 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 @@ -1080,13 +1080,13 @@ - com.twitter + org.apache.parquet parquet-column ${parquet.version} ${parquet.deps.scope} - com.twitter + org.apache.parquet parquet-hadoop ${parquet.version} ${parquet.deps.scope} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 3192f81ffaecd..ed75475a87067 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -61,11 +61,11 @@ test - com.twitter + org.apache.parquet parquet-column - com.twitter + org.apache.parquet parquet-hadoop diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala index f5ce2718bec4a..62c4e92ebec68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -21,9 +21,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import parquet.Log -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} +import org.apache.parquet.Log +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index caa9f045537d0..85c2ce740fe52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -23,9 +23,9 @@ import java.util.{TimeZone, Calendar} import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import jodd.datetime.JDateTime -import parquet.column.Dictionary -import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} -import parquet.schema.MessageType +import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} +import org.apache.parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index f0f4e7d147e75..88ae88e9684c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -21,11 +21,11 @@ import java.nio.ByteBuffer import com.google.common.io.BaseEncoding import org.apache.hadoop.conf.Configuration -import parquet.filter2.compat.FilterCompat -import parquet.filter2.compat.FilterCompat._ -import parquet.filter2.predicate.FilterApi._ -import parquet.filter2.predicate.{FilterApi, FilterPredicate} -import parquet.io.api.Binary +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.compat.FilterCompat._ +import org.apache.parquet.filter2.predicate.FilterApi._ +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.io.api.Binary import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index fcb9513ab66f6..09088ee91106c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -24,9 +24,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction import org.apache.spark.sql.types.{StructType, DataType} -import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} -import parquet.hadoop.metadata.CompressionCodecName -import parquet.schema.MessageType +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.schema.MessageType import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} @@ -107,7 +107,7 @@ private[sql] object ParquetRelation { // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[parquet.Log].getName) + Class.forName(classOf[org.apache.parquet.Log].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. @@ -127,7 +127,7 @@ private[sql] object ParquetRelation { type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow // The compression type - type CompressionType = parquet.hadoop.metadata.CompressionCodecName + type CompressionType = org.apache.parquet.hadoop.metadata.CompressionCodecName // The parquet compression short names val shortParquetCompressionCodecNames = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index cb7ae246d0d75..1e694f2feabee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -33,13 +33,13 @@ import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} -import parquet.hadoop._ -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{InitContext, ReadSupport} -import parquet.hadoop.metadata.GlobalMetaData -import parquet.hadoop.util.ContextUtil -import parquet.io.ParquetDecodingException -import parquet.schema.MessageType +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.hadoop.metadata.GlobalMetaData +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.io.ParquetDecodingException +import org.apache.parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil @@ -78,7 +78,7 @@ private[sql] case class ParquetTableScan( }.toArray protected override def doExecute(): RDD[Row] = { - import parquet.filter2.compat.FilterCompat.FilterPredicateCompat + import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) @@ -136,7 +136,7 @@ private[sql] case class ParquetTableScan( baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] + split.asInstanceOf[org.apache.parquet.hadoop.ParquetInputSplit] .getPath .toString .split("/") @@ -378,7 +378,7 @@ private[sql] case class InsertIntoParquetTable( * to imported ones. */ private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends parquet.hadoop.ParquetOutputFormat[Row] { + extends org.apache.parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} var committer: OutputCommitter = null @@ -431,7 +431,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) * RecordFilter we want to use. */ private[parquet] class FilteringParquetRowInputFormat - extends parquet.hadoop.ParquetInputFormat[Row] with Logging { + extends org.apache.parquet.hadoop.ParquetInputFormat[Row] with Logging { private var fileStatuses = Map.empty[Path, FileStatus] @@ -439,7 +439,7 @@ private[parquet] class FilteringParquetRowInputFormat inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { - import parquet.filter2.compat.FilterCompat.NoOpFilter + import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter val readSupport: ReadSupport[Row] = new RowReadSupport() @@ -501,7 +501,7 @@ private[parquet] class FilteringParquetRowInputFormat globalMetaData = new GlobalMetaData(globalMetaData.getSchema, mergedMetadata, globalMetaData.getCreatedBy) - val readContext = getReadSupport(configuration).init( + val readContext = ParquetInputFormat.getReadSupportInstance(configuration).init( new InitContext(configuration, globalMetaData.getKeyValueMetaData, globalMetaData.getSchema)) @@ -531,8 +531,8 @@ private[parquet] class FilteringParquetRowInputFormat minSplitSize: JLong, readContext: ReadContext): JList[ParquetInputSplit] = { - import parquet.filter2.compat.FilterCompat.Filter - import parquet.filter2.compat.RowGroupFilter + import org.apache.parquet.filter2.compat.FilterCompat.Filter + import org.apache.parquet.filter2.compat.RowGroupFilter import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache @@ -547,7 +547,7 @@ private[parquet] class FilteringParquetRowInputFormat // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved val generateSplits = - Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") + Class.forName("org.apache.parquet.hadoop.ClientSideMetadataSplitStrategy") .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) generateSplits.setAccessible(true) @@ -612,7 +612,7 @@ private[parquet] class FilteringParquetRowInputFormat // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved val generateSplits = - Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy") + Class.forName("org.apache.parquet.hadoop.TaskSideMetadataSplitStrategy") .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( sys.error( s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 70a220cc43ab9..89db408b1c382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.parquet import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration -import parquet.column.ParquetProperties -import parquet.hadoop.ParquetOutputFormat -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{ReadSupport, WriteSupport} -import parquet.io.api._ -import parquet.schema.MessageType +import org.apache.parquet.column.ParquetProperties +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{ReadSupport, WriteSupport} +import org.apache.parquet.io.api._ +import org.apache.parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index f8a5d84549336..ba2a35b74ef82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -25,13 +25,13 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job -import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import parquet.schema.Type.Repetition -import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index bf55e2383ab56..5dda440240e60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -29,10 +29,10 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import parquet.filter2.predicate.FilterApi -import parquet.hadoop._ -import parquet.hadoop.metadata.CompressionCodecName -import parquet.hadoop.util.ContextUtil +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark.{Partition => SparkPartition, SerializableWritable, Logging, SparkException} import org.apache.spark.broadcast.Broadcast diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala index 70bcca7526aae..4d5ed211ad0c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.parquet.timestamp import java.nio.{ByteBuffer, ByteOrder} -import parquet.Preconditions -import parquet.io.api.{Binary, RecordConsumer} +import org.apache.parquet.Preconditions +import org.apache.parquet.io.api.{Binary, RecordConsumer} private[parquet] class NanoTime extends Serializable { private var julianDay = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 71f016b1f14de..e9932c09107db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 28e90b9520b2c..12fb128149d32 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,11 +36,11 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. -log4j.additivity.parquet.hadoop.ParquetRecordReader=false -log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF -log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false -log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF @@ -52,5 +52,5 @@ log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF # Parquet related logging -log4j.logger.parquet.hadoop=WARN +log4j.logger.org.apache.parquet.hadoop=WARN log4j.logger.org.apache.spark.sql.parquet=INFO diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index bdc2ebabc5e9a..4aa5bcb7fdbca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll -import parquet.filter2.predicate.Operators._ -import parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.Operators._ +import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index dd48bb350f26d..7f7c2cc1a6c26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -24,14 +24,14 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.scalatest.BeforeAndAfterAll -import parquet.example.data.simple.SimpleGroup -import parquet.example.data.{Group, GroupWriter} -import parquet.hadoop.api.WriteSupport -import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} -import parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} -import parquet.io.api.RecordConsumer -import parquet.schema.{MessageType, MessageTypeParser} +import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} +import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row @@ -400,7 +400,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } finally { configuration.set("spark.sql.parquet.output.committer.class", - "parquet.hadoop.ParquetOutputCommitter") + "org.apache.parquet.hadoop.ParquetOutputCommitter") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index caec2a6f25489..8b1745124b8e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import parquet.schema.MessageTypeParser +import org.apache.parquet.schema.MessageTypeParser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection From 3dc005282a694e105f40e429b28b0a677743341f Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 4 Jun 2015 12:52:16 -0700 Subject: [PATCH 04/17] [SPARK-8027] [SPARKR] Move man pages creation to install-dev.sh This also helps us get rid of the sparkr-docs maven profile as docs are now built by just using -Psparkr when the roxygen2 package is available Related to discussion in #6567 cc pwendell srowen -- Let me know if this looks better Author: Shivaram Venkataraman Closes #6593 from shivaram/sparkr-pom-cleanup and squashes the following commits: b282241 [Shivaram Venkataraman] Remove sparkr-docs from release script as well 8f100a5 [Shivaram Venkataraman] Move man pages creation to install-dev.sh This also helps us get rid of the sparkr-docs maven profile as docs are now built by just using -Psparkr when the roxygen2 package is available --- R/create-docs.sh | 5 +---- R/install-dev.sh | 9 ++++++++- core/pom.xml | 23 ----------------------- dev/create-release/create-release.sh | 16 ++++++++-------- 4 files changed, 17 insertions(+), 36 deletions(-) diff --git a/R/create-docs.sh b/R/create-docs.sh index af47c0863bdd0..6a4687b06ecb9 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -30,10 +30,7 @@ set -e export FWDIR="$(cd "`dirname "$0"`"; pwd)" pushd $FWDIR -# Generate Rd file -Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' - -# Install the package +# Install the package (this will also generate the Rd files) ./install-dev.sh # Now create HTML files diff --git a/R/install-dev.sh b/R/install-dev.sh index b9e2527035994..1edd551f8d243 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -34,5 +34,12 @@ LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -# Install R +pushd $FWDIR + +# Generate Rd files if devtools is installed +Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' + +# Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ + +popd diff --git a/core/pom.xml b/core/pom.xml index e35694e9e98b4..40a64beccdc24 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -481,29 +481,6 @@ - - sparkr-docs - - - - org.codehaus.mojo - exec-maven-plugin - - - sparkr-pkg-docs - compile - - exec - - - - - ..${path.separator}R${path.separator}create-docs${script.extension} - - - - - diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 0b14a618e755c..54274a83f6d66 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -228,14 +228,14 @@ if [[ ! "$@" =~ --skip-package ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Psparkr-docs -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Psparkr-docs -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Psparkr-docs -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & + make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & wait rm -rf spark-$RELEASE_VERSION-bin-*/ From 0526fea483066086dfc27d1606f74220fe822f7f Mon Sep 17 00:00:00 2001 From: Cheolsoo Park Date: Thu, 4 Jun 2015 13:27:35 -0700 Subject: [PATCH 05/17] [SPARK-6909][SQL] Remove Hive Shim code This is a follow-up on #6393. I am removing the following files in this PR. ``` ./sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala ./sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala ``` Basically, I re-factored the shim code as follows- * Rewrote code directly with Hive 0.13 methods, or * Converted code into private methods, or * Extracted code into separate classes But for leftover code that didn't fit in any of these cases, I created a HiveShim object. For eg, helper functions which wrap Hive 0.13 methods to work around Hive bugs are placed here. Author: Cheolsoo Park Closes #6604 from piaozhexiu/SPARK-6909 and squashes the following commits: 5dccc20 [Cheolsoo Park] Remove hive shim code --- .../hive/thriftserver/HiveThriftServer2.scala | 10 +- .../SparkExecuteStatementOperation.scala} | 102 +--- .../hive/thriftserver/SparkSQLCLIDriver.scala | 6 +- .../thriftserver/SparkSQLCLIService.scala | 7 +- ...rkSQLDriver.scala => SparkSQLDriver.scala} | 20 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 4 +- .../thriftserver/SparkSQLSessionManager.scala | 75 +++ .../HiveThriftServer2Suites.scala | 8 +- .../execution/HiveCompatibilitySuite.scala | 3 +- .../apache/spark/sql/hive/HiveContext.scala | 23 +- .../spark/sql/hive/HiveInspectors.scala | 187 +++++-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 22 +- .../org/apache/spark/sql/hive/HiveQl.scala | 4 +- .../org/apache/spark/sql/hive/HiveShim.scala | 247 ++++++++++ .../apache/spark/sql/hive/TableReader.scala | 11 +- .../hive/execution/InsertIntoHiveTable.scala | 12 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 1 + .../spark/sql/hive/hiveWriterContainers.scala | 3 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 46 +- .../spark/sql/hive/StatisticsSuite.scala | 4 - .../sql/hive/execution/HiveQuerySuite.scala | 25 +- .../sql/hive/execution/SQLQuerySuite.scala | 58 ++- .../org/apache/spark/sql/hive/Shim13.scala | 457 ------------------ 23 files changed, 619 insertions(+), 716 deletions(-) rename sql/hive-thriftserver/{v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala => src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala} (66%) rename sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/{AbstractSparkSQLDriver.scala => SparkSQLDriver.scala} (86%) create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala delete mode 100644 sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 94687eeda4179..5b391d3dce882 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -29,12 +26,15 @@ import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab -import org.apache.spark.sql.hive.{HiveContext, HiveShim} import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a * `HiveThriftServer2` thrift server. @@ -51,7 +51,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) - sqlContext.setConf("spark.sql.hive.version", HiveShim.version) + sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala similarity index 66% rename from sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index b9d4f1c58c982..c0d1266212cdd 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -18,66 +18,31 @@ package org.apache.spark.sql.hive.thriftserver import java.sql.{Date, Timestamp} -import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, UUID} - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.thrift.TProtocolVersion -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import java.util.{Map => JMap, UUID} import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.{SessionManager, HiveSession} +import org.apache.hive.service.cli.session.HiveSession -import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +import org.apache.spark.Logging import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.13.1" - - def setServerUserName( - sparkServiceUGI: UserGroupInformation, - sparkCliService:SparkSQLCLIService) = { - setSuperField(sparkCliService, "serviceUGI", sparkServiceUGI) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JList[_]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) - hiveResponse = null - true - } - } -} +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map => SMap} private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, statement: String, confOverlay: JMap[String, String], - runInBackground: Boolean = true)( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) + runInBackground: Boolean = true) + (hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String]) // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution - extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { + extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) + with Logging { private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ @@ -88,7 +53,7 @@ private[hive] class SparkExecuteStatementOperation( logDebug("CLOSING") } - def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { dataTypes(ordinal) match { case StringType => to += from.getString(ordinal) @@ -209,48 +174,3 @@ private[hive] class SparkExecuteStatementOperation( HiveThriftServer2.listener.onStatementFinish(statementId) } } - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) - } - - override def openSession( - protocol: TProtocolVersion, - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - protocol, username, passwd, sessionConf, withImpersonation, delegationToken) - val session = super.getSession(sessionHandle) - HiveThriftServer2.listener.onSessionCreated( - session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) - sessionHandle - } - - override def closeSession(sessionHandle: SessionHandle) { - HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() - } -} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 14f6f658d9b75..039cfa40d26b3 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -32,12 +32,12 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket import org.apache.spark.Logging -import org.apache.spark.sql.hive.{HiveContext, HiveShim} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { @@ -267,7 +267,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) + val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) if (proc != null) { if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 499e077d7294a..41f647d5f8c5a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,8 +21,6 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException -import scala.collection.JavaConversions._ - import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader @@ -34,7 +32,8 @@ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) extends CLIService @@ -52,7 +51,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) try { HiveAuthFactory.loginFromKeytab(hiveConf) sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) - HiveThriftServerShim.setServerUserName(sparkServiceUGI, this) + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala similarity index 86% rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 48ac9062af96a..77272aecf2835 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ +import java.util.{ArrayList => JArrayList, List => JList} import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} @@ -27,8 +27,12 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -private[hive] abstract class AbstractSparkSQLDriver( - val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging { +import scala.collection.JavaConversions._ + +private[hive] class SparkSQLDriver( + val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver + with Logging { private[hive] var tableSchema: Schema = _ private[hive] var hiveResponse: Seq[String] = _ @@ -71,6 +75,16 @@ private[hive] abstract class AbstractSparkSQLDriver( 0 } + override def getResults(res: JList[_]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + hiveResponse = null + true + } + } + override def getSchema: Schema = tableSchema override def destroy() { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 7c0c505e2d61e..79eda1f5123bf 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,7 +22,7 @@ import java.io.PrintStream import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener -import org.apache.spark.sql.hive.{HiveShim, HiveContext} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils @@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) - hiveContext.setConf("spark.sql.hive.version", HiveShim.version) + hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000000..357b27f7401a3 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.SessionHandle +import org.apache.hive.service.cli.session.SessionManager +import org.apache.hive.service.cli.thrift.TProtocolVersion + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } + + override def openSession(protocol: TProtocolVersion, + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + hiveContext.openSession() + val sessionHandle = super.openSession( + protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val session = super.getSession(sessionHandle) + HiveThriftServer2.listener.onSessionCreated( + session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + sessionHandle + } + + override def closeSession(sessionHandle: SessionHandle) { + HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) + super.closeSession(sessionHandle) + sparkSqlOperationManager.sessionToActivePool -= sessionHandle + + hiveContext.detachSession() + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index a93a3dee43511..f57c7083ea504 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -40,7 +40,7 @@ import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils object TestData { @@ -111,7 +111,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === + s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") } } @@ -365,7 +366,8 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === + s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0b1917a392901..048f78b4daa8d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -23,7 +23,6 @@ import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.hive.test.TestHive /** @@ -254,7 +253,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // the answer is sensitive for jdk version "udf_java_method" - ) ++ HiveShim.compatibilityBlackList + ) /** * The set of tests that are believed to be working in catalyst. Tests not on whiteList or diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fbf2c7d8cbc06..800f51c5e2e86 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -17,37 +17,34 @@ package org.apache.spark.sql.hive -import java.io.{BufferedReader, File, InputStreamReader, PrintStream} +import java.io.File import java.net.{URL, URLClassLoader} import java.sql.Timestamp -import java.util.{ArrayList => JArrayList} -import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.spark.sql.catalyst.ParserDialect import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.HashMap import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy} +import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -331,7 +328,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableParameters = relation.hiveQlTable.getParameters val oldTotalSize = - Option(tableParameters.get(HiveShim.getStatsSetupConstTotalSize)) + Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)) .map(_.toLong) .getOrElse(0L) val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) @@ -342,7 +339,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.client.alterTable( relation.table.copy( properties = relation.table.properties + - (HiveShim.getStatsSetupConstTotalSize -> newTotalSize.toString))) + (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) } case otherRelation => throw new UnsupportedOperationException( @@ -564,7 +561,7 @@ private[hive] object HiveContext { case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString - HiveShim.createDecimal(decimal).toString + HiveDecimal.create(decimal).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 24cd335082639..c466203cd0220 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} @@ -350,7 +351,7 @@ private[hive] trait HiveInspectors { new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) case _: JavaDateObjectInspector => (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) @@ -439,31 +440,31 @@ private[hive] trait HiveInspectors { case _ if a == null => null case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. - case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) + case _: StringObjectInspector if x.preferWritable() => getStringWritable(a) case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() - case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) + case _: IntObjectInspector if x.preferWritable() => getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] - case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) + case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a) case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean] - case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a) + case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a) case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float] - case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a) + case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a) case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double] - case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a) + case _: LongObjectInspector if x.preferWritable() => getLongWritable(a) case _: LongObjectInspector => a.asInstanceOf[java.lang.Long] - case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a) + case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a) case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short] - case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a) + case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a) case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte] case _: HiveDecimalObjectInspector if x.preferWritable() => - HiveShim.getDecimalWritable(a.asInstanceOf[Decimal]) + getDecimalWritable(a.asInstanceOf[Decimal]) case _: HiveDecimalObjectInspector => - HiveShim.createDecimal(a.asInstanceOf[Decimal].toJavaBigDecimal) - case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) + HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal) + case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] - case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) + case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) - case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) + case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] } case x: SettableStructObjectInspector => @@ -574,31 +575,31 @@ private[hive] trait HiveInspectors { */ def toInspector(expr: Expression): ObjectInspector = expr match { case Literal(value, StringType) => - HiveShim.getStringWritableConstantObjectInspector(value) + getStringWritableConstantObjectInspector(value) case Literal(value, IntegerType) => - HiveShim.getIntWritableConstantObjectInspector(value) + getIntWritableConstantObjectInspector(value) case Literal(value, DoubleType) => - HiveShim.getDoubleWritableConstantObjectInspector(value) + getDoubleWritableConstantObjectInspector(value) case Literal(value, BooleanType) => - HiveShim.getBooleanWritableConstantObjectInspector(value) + getBooleanWritableConstantObjectInspector(value) case Literal(value, LongType) => - HiveShim.getLongWritableConstantObjectInspector(value) + getLongWritableConstantObjectInspector(value) case Literal(value, FloatType) => - HiveShim.getFloatWritableConstantObjectInspector(value) + getFloatWritableConstantObjectInspector(value) case Literal(value, ShortType) => - HiveShim.getShortWritableConstantObjectInspector(value) + getShortWritableConstantObjectInspector(value) case Literal(value, ByteType) => - HiveShim.getByteWritableConstantObjectInspector(value) + getByteWritableConstantObjectInspector(value) case Literal(value, BinaryType) => - HiveShim.getBinaryWritableConstantObjectInspector(value) + getBinaryWritableConstantObjectInspector(value) case Literal(value, DateType) => - HiveShim.getDateWritableConstantObjectInspector(value) + getDateWritableConstantObjectInspector(value) case Literal(value, TimestampType) => - HiveShim.getTimestampWritableConstantObjectInspector(value) + getTimestampWritableConstantObjectInspector(value) case Literal(value, DecimalType()) => - HiveShim.getDecimalWritableConstantObjectInspector(value) + getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => - HiveShim.getPrimitiveNullWritableConstantObjectInspector + getPrimitiveNullWritableConstantObjectInspector case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) if (value == null) { @@ -658,8 +659,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) - case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) + case w: WritableHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -668,10 +669,136 @@ private[hive] trait HiveInspectors { case _: JavaVoidObjectInspector => NullType } + private def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + private def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, getStringWritable(value)) + + private def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intTypeInfo, getIntWritable(value)) + + private def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) + + private def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) + + private def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.longTypeInfo, getLongWritable(value)) + + private def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) + + private def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.shortTypeInfo, getShortWritable(value)) + + private def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.byteTypeInfo, getByteWritable(value)) + + private def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) + + private def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.dateTypeInfo, getDateWritable(value)) + + private def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) + + private def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) + + private def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.voidTypeInfo, null) + + private def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) + + private def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + private def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + } + + private def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + } + + private def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + private def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + } + + private def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + private def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + private def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + } + + private def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) + + private def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + } + + private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveDecimal.create(value.asInstanceOf[Decimal].toJavaBigDecimal)) + } + implicit class typeInfoConversions(dt: DataType) { import org.apache.hadoop.hive.serde2.typeinfo._ import TypeInfoFactory._ + private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case _ => new DecimalTypeInfo( + HiveShim.UNLIMITED_DECIMAL_PRECISION, + HiveShim.UNLIMITED_DECIMAL_SCALE) + } + def toTypeInfo: TypeInfo = dt match { case ArrayType(elemType, _) => getListTypeInfo(elemType.toTypeInfo) @@ -690,7 +817,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case d: DecimalType => HiveShim.decimalTypeInfo(d) + case d: DecimalType => decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ca1f49b546bd7..5a4651a887b7c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} + import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.metastore.Warehouse import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ -import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} @@ -37,7 +39,6 @@ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} -import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -670,8 +671,8 @@ private[hive] case class MetastoreRelation @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { - val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize) - val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize) + val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) + val rawDataSize = hiveQlTable.getParameters.get(StatsSetupConst.RAW_DATA_SIZE) // TODO: check if this estimate is valid for tables after partition pruning. // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be // relatively cheap if parameters for the table are populated into the metastore. An @@ -697,11 +698,7 @@ private[hive] case class MetastoreRelation } } - val tableDesc = HiveShim.getTableDesc( - Class.forName( - hiveQlTable.getSerializationLib, - true, - Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]], + val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to @@ -743,6 +740,11 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($HiveShim.UNLIMITED_DECIMAL_PRECISION,$HiveShim.UNLIMITED_DECIMAL_SCALE)" + } + def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => @@ -759,7 +761,7 @@ private[hive] object HiveMetastoreTypes { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case d: DecimalType => HiveShim.decimalMetastoreString(d) + case d: DecimalType => decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index a5ca3613c5e00..9544d12c9053c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive import java.sql.Date -import scala.collection.mutable.ArrayBuffer - import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.ql.{ErrorMsg, Context} @@ -39,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.sources.DescribeCommand +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ @@ -46,6 +45,7 @@ import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer /** * Used when we need to start parsing the AST before deciding that we are going to pass the command diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala new file mode 100644 index 0000000000000..fa5409f602444 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{InputStream, OutputStream} +import java.rmi.server.UID + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector +import org.apache.hadoop.io.Writable + +import org.apache.spark.Logging +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.Utils + +/* Implicit conversions */ +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +private[hive] object HiveShim { + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + val UNLIMITED_DECIMAL_PRECISION = 38 + val UNLIMITED_DECIMAL_SCALE = 18 + + /* + * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null && ids.size > 0) { + ColumnProjectionUtils.appendReadColumns(conf, ids) + } + if (names != null && names.size > 0) { + appendReadColumnNames(conf, names) + } + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + case _ => + } + w + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + /** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + */ + private[hive] case class HiveFunctionWrapper(var functionClassName: String) + extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + private var instance: AnyRef = null + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.read(functionInBytes, 0, functionInBytesLength) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = Utils.getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + * */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName(): String = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 294fc3bd7d5e9..334bfccc9d200 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,14 +25,13 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.SerializableWritable +import org.apache.spark.{Logging, SerializableWritable} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateUtils @@ -172,7 +171,7 @@ class HadoopTableReader( path.toString + tails } - val partPath = HiveShim.getDataLocationPath(partition) + val partPath = partition.getDataLocation val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); var pathPatternStr = getPathPatternByPath(partNum, partPath) if (!pathPatternSet.contains(pathPatternStr)) { @@ -187,7 +186,7 @@ class HadoopTableReader( val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) - val partPath = HiveShim.getDataLocationPath(partition) + val partPath = partition.getDataLocation val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) val ifc = partDesc.getInputFileFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] @@ -325,7 +324,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] } else { - HiveShim.getConvertedOI( + ObjectInspectorConverters.getConvertedOI( rawDeser.getObjectInspector, tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 8613332186f28..eeb472602be3c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,27 +19,25 @@ package org.apache.spark.sql.hive.execution import java.util -import scala.collection.JavaConversions._ - import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.{SerializableWritable, SparkException, TaskContext} +import scala.collection.JavaConversions._ + private[hive] case class InsertIntoHiveTable( table: MetastoreRelation, @@ -126,7 +124,7 @@ case class InsertIntoHiveTable( // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = HiveShim.getExternalTmpPath(hiveContext, tableLocation) + val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 1658bb93b0b79..01f47352b2313 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ /* Implicit conversions */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 2bb526b14be34..ee440e304ec19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -35,8 +35,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 58e2d1fbfa73e..af586712e3235 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -561,30 +561,28 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } } - if (HiveShim.version == "0.13.1") { - test("scan a parquet table created through a CTAS statement") { - withSQLConf( - "spark.sql.hive.convertMetastoreParquet" -> "true", - SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { - - withTempTable("jt") { - (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") - - withTable("test_parquet_ctas") { - sql( - """CREATE TABLE test_parquet_ctas STORED AS PARQUET - |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 - """.stripMargin) - - checkAnswer( - sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), - Row(3) :: Row(4) :: Nil) - - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK - case _ => - fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") - } + test("scan a parquet table created through a CTAS statement") { + withSQLConf( + "spark.sql.hive.convertMetastoreParquet" -> "true", + SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + + withTempTable("jt") { + (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") + + withTable("test_parquet_ctas") { + sql( + """CREATE TABLE test_parquet_ctas STORED AS PARQUET + |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 00a69de9e4262..e16e530555aee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -79,10 +79,6 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - // TODO: How does it works? needs to add it back for other hive version. - if (HiveShim.version =="0.12.0") { - assert(queryTotalSize("analyzeTable") === conf.defaultSizeInBytes) - } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 440b7c87b0da2..6d8d99ebc8164 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -874,15 +874,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |WITH serdeproperties('s1'='9') """.stripMargin) } - // Now only verify 0.12.0, and ignore other versions due to binary compatibility - // current TestSerDe.jar is from 0.12.0 - if (HiveShim.version == "0.12.0") { - sql(s"ADD JAR $testJar") - sql( - """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' - |WITH serdeproperties('s1'='9') - """.stripMargin) - } sql("DROP TABLE alter1") } @@ -890,15 +881,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // this is a test case from mapjoin_addjar.q val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath - if (HiveShim.version == "0.13.1") { - sql(s"ADD JAR $testJar") - sql( - """CREATE TABLE t1(a string, b string) - |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) - sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") - sql("select * from src join t1 on src.key = t1.a") - sql("DROP TABLE t1") - } + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") } test("ADD FILE command") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index aba3becb1bce2..40a35674e4cb8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -330,35 +330,33 @@ class SQLQuerySuite extends QueryTest { "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - if (HiveShim.version =="0.13.1") { - val origUseParquetDataSource = conf.parquetUseDataSourceApi - try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin).collect() - - checkExistence(sql("DESC EXTENDED ctas5"), true, - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - - val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") - // use the Hive SerDe for parquet tables - sql("set spark.sql.hive.convertMetastoreParquet = false") - checkAnswer( - sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) - sql(s"set spark.sql.hive.convertMetastoreParquet = $default") - } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) - } + val origUseParquetDataSource = conf.parquetUseDataSourceApi + try { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + + val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") + // use the Hive SerDe for parquet tables + sql("set spark.sql.hive.convertMetastoreParquet = false") + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql(s"set spark.sql.hive.convertMetastoreParquet = $default") + } finally { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) } } diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala deleted file mode 100644 index dbc5e029e2047..0000000000000 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ /dev/null @@ -1,457 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.rmi.server.UID -import java.util.{Properties, ArrayList => JArrayList} -import java.io.{OutputStream, InputStream} - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions -import scala.reflect.ClassTag - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} -import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.{io => hadoopIo} - -import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} -import org.apache.spark.util.Utils._ - -/** - * This class provides the UDF creation and also the UDF instance serialization and - * de-serialization cross process boundary. - * - * Detail discussion can be found at https://github.com/apache/spark/pull/3640 - * - * @param functionClassName UDF class name - */ -private[hive] case class HiveFunctionWrapper(var functionClassName: String) - extends java.io.Externalizable { - - // for Serialization - def this() = this(null) - - @transient - def deserializeObjectByKryo[T: ClassTag]( - kryo: Kryo, - in: InputStream, - clazz: Class[_]): T = { - val inp = new Input(in) - val t: T = kryo.readObject(inp,clazz).asInstanceOf[T] - inp.close() - t - } - - @transient - def serializeObjectByKryo( - kryo: Kryo, - plan: Object, - out: OutputStream ) { - val output: Output = new Output(out) - kryo.writeObject(output, plan) - output.close() - } - - def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) - .asInstanceOf[UDFType] - } - - def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) - } - - private var instance: AnyRef = null - - def writeExternal(out: java.io.ObjectOutput) { - // output the function name - out.writeUTF(functionClassName) - - // Write a flag if instance is null or not - out.writeBoolean(instance != null) - if (instance != null) { - // Some of the UDF are serializable, but some others are not - // Hive Utilities can handle both cases - val baos = new java.io.ByteArrayOutputStream() - serializePlan(instance, baos) - val functionInBytes = baos.toByteArray - - // output the function bytes - out.writeInt(functionInBytes.length) - out.write(functionInBytes, 0, functionInBytes.length) - } - } - - def readExternal(in: java.io.ObjectInput) { - // read the function name - functionClassName = in.readUTF() - - if (in.readBoolean()) { - // if the instance is not null - // read the function in bytes - val functionInBytesLength = in.readInt() - val functionInBytes = new Array[Byte](functionInBytesLength) - in.read(functionInBytes, 0, functionInBytesLength) - - // deserialize the function object via Hive Utilities - instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), - getContextOrSparkClassLoader.loadClass(functionClassName)) - } - } - - def createFunction[UDFType <: AnyRef](): UDFType = { - if (instance != null) { - instance.asInstanceOf[UDFType] - } else { - val func = getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - if (!func.isInstanceOf[UDF]) { - // We cache the function if it's no the Simple UDF, - // as we always have to create new instance for Simple UDF - instance = func - } - func - } - } -} - -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[hive] object HiveShim { - val version = "0.13.1" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(inputFormatClass, outputFormatClass, properties) - } - - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.voidTypeInfo, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) { - null - } else { - new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - } - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) { - null - } else { - new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - } - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) { - null - } else { - new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - } - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) { - null - } else { - new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - } - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - // TODO precise, scale? - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[Object] - - def processResults(results: JArrayList[Object]) = { - results.map { r => - r match { - case s: String => s - case a: Array[Object] => a(0).asInstanceOf[String] - } - } - } - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { - context.runSqlHive("CREATE DATABASE default") - context.runSqlHive("USE default") - } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd, conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - HiveDecimal.create(bd) - } - - /* - * This function in hive-0.13 become private, but we have to do this to walkaround hive bug - */ - private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { - val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") - val result: StringBuilder = new StringBuilder(old) - var first: Boolean = old.isEmpty - - for (col <- cols) { - if (first) { - first = false - } else { - result.append(',') - } - result.append(col) - } - conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) - } - - /* - * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty - */ - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.size > 0) { - ColumnProjectionUtils.appendReadColumns(conf, ids) - } - if (names != null && names.size > 0) { - appendReadColumnNames(conf, names) - } - } - - def getExternalTmpPath(context: Context, path: Path) = { - context.getExternalTmpPath(path.toUri) - } - - def getDataLocationPath(p: Partition) = p.getDataLocation - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsOf(tbl) - - def compatibilityBlackList = Seq() - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation())) - } - - /* - * Bug introdiced in hive-0.13. FileSinkDesc is serializable, but its member path is not. - * Fix it through wrapper. - * */ - implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { - var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) - f.setCompressCodec(w.compressCodec) - f.setCompressType(w.compressType) - f.setTableInfo(w.tableInfo) - f.setDestTableId(w.destTableId) - f - } - - // Precision and scale to pass for unlimited decimals; these are the same as the precision and - // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) - private val UNLIMITED_DECIMAL_PRECISION = 38 - private val UNLIMITED_DECIMAL_SCALE = 18 - - def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { - case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" - case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" - } - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { - case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) - } - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] - DecimalType(info.precision(), info.scale()) - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, - hdoi.precision(), hdoi.scale()) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) - } - } - - def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) - } - - /* - * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that - * is needed to initialize before serialization. - */ - def prepareWritable(w: Writable): Writable = { - w match { - case w: AvroGenericRecordWritable => - w.setRecordReaderID(new UID()) - case _ => - } - w - } - - def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = { - if (crtTbl != null && crtTbl.getNullFormat() != null) { - tbl.setSerdeParam(serdeConstants.SERIALIZATION_NULL_FORMAT, crtTbl.getNullFormat()) - } - } -} - -/* - * Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not. - * Fix it through wrapper. - */ -private[hive] class ShimFileSinkDesc( - var dir: String, - var tableInfo: TableDesc, - var compressed: Boolean) - extends Serializable with Logging { - var compressCodec: String = _ - var compressType: String = _ - var destTableId: Int = _ - - def setCompressed(compressed: Boolean) { - this.compressed = compressed - } - - def getDirName = dir - - def setDestTableId(destTableId: Int) { - this.destTableId = destTableId - } - - def setTableInfo(tableInfo: TableDesc) { - this.tableInfo = tableInfo - } - - def setCompressCodec(intermediateCompressorCodec: String) { - compressCodec = intermediateCompressorCodec - } - - def setCompressType(intermediateCompressType: String) { - compressType = intermediateCompressType - } -} From 65938422718383d17f084e577763e2c671726baa Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 4 Jun 2015 13:44:47 -0700 Subject: [PATCH 06/17] Fixed style issues for [SPARK-6909][SQL] Remove Hive Shim code. --- .../sql/hive/thriftserver/HiveThriftServer2.scala | 5 +++-- .../hive/thriftserver/SparkSQLSessionManager.scala | 14 ++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 5b391d3dce882..c9da25253e13f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive.thriftserver +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -32,8 +35,6 @@ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext} -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 357b27f7401a3..2d5ee68002286 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) extends SessionManager with ReflectedCompositeService { @@ -50,12 +51,13 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) initCompositeService(hiveConf) } - override def openSession(protocol: TProtocolVersion, - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { + override def openSession( + protocol: TProtocolVersion, + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { hiveContext.openSession() val sessionHandle = super.openSession( protocol, username, passwd, sessionConf, withImpersonation, delegationToken) From 2bcdf8c239d2ba79f64fb8878da83d4c2ec28b30 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 4 Jun 2015 13:52:53 -0700 Subject: [PATCH 07/17] [SPARK-7440][SQL] Remove physical Distinct operator in favor of Aggregate This patch replaces Distinct with Aggregate in the optimizer, so Distinct will become more efficient over time as we optimize Aggregate (via Tungsten). Author: Reynold Xin Closes #6637 from rxin/replace-distinct and squashes the following commits: b3cc50e [Reynold Xin] Mima excludes. 93d6117 [Reynold Xin] Code review feedback. 87e4741 [Reynold Xin] [SPARK-7440][SQL] Remove physical Distinct operator in favor of Aggregate. --- project/MimaExcludes.scala | 4 +- .../sql/catalyst/optimizer/Optimizer.scala | 14 +++++++ .../plans/logical/basicOperators.scala | 3 ++ .../ReplaceDistinctWithAggregateSuite.scala | 42 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../spark/sql/execution/basicOperators.scala | 31 -------------- 7 files changed, 65 insertions(+), 35 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 34371c9659423..73e4bfd78e577 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -46,7 +46,9 @@ object MimaExcludes { "org.apache.spark.api.java.JavaRDDLike.partitioner"), // Mima false positive (was a private[spark] class) ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.PairIterator") + "org.apache.spark.util.collection.PairIterator"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution") ) case v if v.startsWith("1.4") => Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5c6379b8d44b0..0a17b10c521e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,6 +36,8 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + Batch("Distinct", FixedPoint(100), + ReplaceDistinctWithAggregate) :: Batch("Operator Reordering", FixedPoint(100), UnionPushdown, CombineFilters, @@ -696,3 +698,15 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } } + +/** + * Replaces logical [[Distinct]] operator with an [[Aggregate]] operator. + * {{{ + * SELECT DISTINCT f1, f2 FROM t ==> SELECT f1, f2 FROM t GROUP BY f1, f2 + * }}} + */ +object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Distinct(child) => Aggregate(child.output, child.output, child) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 33a9e55a47dee..e77e5c27b687a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -339,6 +339,9 @@ case class Sample( override def output: Seq[Attribute] = child.output } +/** + * Returns a new logical plan that dedups input rows. + */ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala new file mode 100644 index 0000000000000..df29a62ff0e15 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReplaceDistinctWithAggregateSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil + } + + test("replace distinct with aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d1a54ada7b191..4a224153e1a37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1311,7 +1311,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - override def distinct: DataFrame = Distinct(logicalPlan) + override def distinct: DataFrame = dropDuplicates() /** * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d0a1ad00560d3..7a1331a39151a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -284,8 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.Distinct(child) => - execution.Distinct(partial = false, - execution.Distinct(partial = true, planLater(child))) :: Nil + throw new IllegalStateException( + "logical distinct operator should have been replaced by aggregate in the optimizer") case logical.Repartition(numPartitions, shuffle, child) => execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil case logical.SortPartitions(sortExprs, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index a30ade86441ca..fb42072f9d5a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -230,37 +230,6 @@ case class ExternalSort( override def outputOrdering: Seq[SortOrder] = sortOrder } -/** - * :: DeveloperApi :: - * Computes the set of distinct input rows using a HashSet. - * @param partial when true the distinct operation is performed partially, per partition, without - * shuffling the data. - * @param child the input query plan. - */ -@DeveloperApi -case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - - override def requiredChildDistribution: Seq[Distribution] = - if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil - - protected override def doExecute(): RDD[Row] = { - child.execute().mapPartitions { iter => - val hashSet = new scala.collection.mutable.HashSet[Row]() - - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - if (!hashSet.contains(currentRow)) { - hashSet.add(currentRow.copy()) - } - } - - hashSet.iterator - } - } -} - /** * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. From 63bc0c4430680cce230dd7a10d34da0492351446 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Thu, 4 Jun 2015 16:24:50 -0700 Subject: [PATCH 08/17] [SPARK-8098] [WEBUI] Show correct length of bytes on log page The log page should only show desired length of bytes. Currently it shows bytes from the startIndex to the end of the file. The "Next" button on the page is always disabled. Author: Carson Wang Closes #6640 from carsonwang/logpage and squashes the following commits: 58cb3fd [Carson Wang] Show correct length of bytes on log page --- .../main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 53f8f9a46cf8d..5a1d06eb87db9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -159,7 +159,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offset } } - val endIndex = math.min(startIndex + totalLength, totalLength) + val endIndex = math.min(startIndex + byteLength, totalLength) logDebug(s"Getting log from $startIndex to $endIndex") val logText = Utils.offsetBytes(files, startIndex, endIndex) logDebug(s"Got log of length ${logText.length} bytes") From 74dc2a90bcb05b64c3e7efc02d1451b0cbc2adba Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 4 Jun 2015 17:33:24 -0700 Subject: [PATCH 09/17] [SPARK-8106] [SQL] Set derby.system.durability=test to speed up Hive compatibility tests Derby has a `derby.system.durability` configuration property that can be used to disable I/O synchronization calls for writes. This sacrifices durability but can result in large performance gains, which is appropriate for tests. We should enable this in our test system properties in order to speed up the Hive compatibility tests. I saw 2-3x speedups locally with this change. See https://db.apache.org/derby/docs/10.8/ref/rrefproperdurability.html for more documentation of this property. Author: Josh Rosen Closes #6651 from JoshRosen/hive-compat-suite-speedup and squashes the following commits: b7a08a2 [Josh Rosen] Set derby.system.durability=test in our unit tests. --- pom.xml | 2 ++ project/SparkBuild.scala | 1 + 2 files changed, 3 insertions(+) diff --git a/pom.xml b/pom.xml index abb9b55400340..e28d4b9fc2b17 100644 --- a/pom.xml +++ b/pom.xml @@ -1254,6 +1254,7 @@ ${test.java.home} + test true ${spark.test.home} 1 @@ -1286,6 +1287,7 @@ ${test.java.home} + test true ${spark.test.home} 1 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f65031fe25ac2..ef3a175bac209 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -504,6 +504,7 @@ object TestSettings { javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", + javaOptions in Test += "-Dderby.system.durability=test", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", From 8f16b94afb39e1641c02d4e0be18d34ef7c211cc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 4 Jun 2015 22:15:58 -0700 Subject: [PATCH 10/17] [SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ I kept some of the sql import there to avoid changing too many lines. Author: Reynold Xin Closes #6661 from rxin/remove-wildcard-import-sqlcontext and squashes the following commits: c265347 [Reynold Xin] Fixed ListTablesSuite failure. de9d491 [Reynold Xin] Fixed tests. 73b5365 [Reynold Xin] Mima. 8f6b642 [Reynold Xin] Fixed style violation. 443f6e8 [Reynold Xin] [SPARK-8113][SQL] Remove some wildcard import on TestSQLContext._ --- .../sql/catalyst/analysis/Analyzer.scala | 12 +- .../apache/spark/sql/CachedTableSuite.scala | 160 +++++++++--------- .../spark/sql/ColumnExpressionSuite.scala | 15 +- .../spark/sql/DataFrameAggregateSuite.scala | 9 +- .../spark/sql/DataFrameFunctionsSuite.scala | 4 +- .../spark/sql/DataFrameImplicitsSuite.scala | 15 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 9 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 5 +- .../apache/spark/sql/DataFrameStatSuite.scala | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 68 ++++---- .../org/apache/spark/sql/JoinSuite.scala | 65 +++---- .../apache/spark/sql/ListTablesSuite.scala | 35 ++-- .../spark/sql/MathExpressionsSuite.scala | 44 +++-- .../scala/org/apache/spark/sql/RowSuite.scala | 7 +- .../org/apache/spark/sql/SQLConfSuite.scala | 67 ++++---- .../apache/spark/sql/SQLContextSuite.scala | 16 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 125 +++++++------- .../sql/ScalaReflectionRelationSuite.scala | 31 ++-- .../apache/spark/sql/SerializationSuite.scala | 5 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 28 ++- .../spark/sql/UserDefinedTypeSuite.scala | 23 ++- 21 files changed, 373 insertions(+), 378 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bc17169f35a46..5883d938b676d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -235,9 +235,8 @@ class Analyzer( } /** - * Replaces [[UnresolvedAttribute]]s with concrete - * [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's - * children. + * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from + * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -455,7 +454,7 @@ class Analyzer( } /** - * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. + * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -846,9 +845,8 @@ class Analyzer( } /** - * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are - * only required to provide scoping information for attributes and can be removed once analysis is - * complete. + * Removes [[Subquery]] operators from the plan. Subqueries are only required to provide + * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubQueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0772e5e187425..72e60d9aa75cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,8 +25,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.storage.{RDDBlockId, StorageLevel} case class BigData(s: String) @@ -34,8 +32,12 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan + val executedPlan = ctx.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -45,47 +47,47 @@ class CachedTableSuite extends QueryTest { } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - cacheTable("tempTable") + ctx.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != cacheManager.lookupCachedData(testData)) + assert(None != ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - cacheTable("tempTable1") + ctx.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - uncacheTable("tempTable2") + ctx.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -93,103 +95,103 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(table("bigData").count() === 200000L) - table("bigData").unpersist(blocking = true) + ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(ctx.table("bigData").count() === 200000L) + ctx.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - table("testData").cache() - assertCached(table("testData")) - table("testData").unpersist(blocking = true) + ctx.table("testData").cache() + assertCached(ctx.table("testData")) + ctx.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - table("testData").cache() - table("testData").count() - table("testData").unpersist(blocking = true) - assertCached(table("testData"), 0) + ctx.table("testData").cache() + ctx.table("testData").count() + ctx.table("testData").unpersist(blocking = true) + assertCached(ctx.table("testData"), 0) } test("isCached") { - cacheTable("testData") + ctx.cacheTable("testData") - assertCached(table("testData")) - assert(table("testData").queryExecution.withCachedData match { + assertCached(ctx.table("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - uncacheTable("testData") - assert(!isCached("testData")) - assert(table("testData").queryExecution.withCachedData match { + ctx.uncacheTable("testData") + assert(!ctx.isCached("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - cacheTable("testData") - assertCached(table("testData")) + ctx.cacheTable("testData") + assertCached(ctx.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - cacheTable("testData") + ctx.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - uncacheTable("testData") + ctx.uncacheTable("testData") } test("read from cached table and uncache") { - cacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData")) + ctx.cacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData")) - uncacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData"), 0) + ctx.uncacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - uncacheTable("testData") + ctx.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - cacheTable("selectStar") + ctx.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - uncacheTable("selectStar") + ctx.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - cacheTable("testData") + ctx.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - uncacheTable("testData") + ctx.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -197,7 +199,7 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!isCached("testData"), "Table 'testData' should not be cached") + assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -206,14 +208,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -221,14 +223,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -236,7 +238,7 @@ class CachedTableSuite extends QueryTest { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -248,7 +250,7 @@ class CachedTableSuite extends QueryTest { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - uncacheTable("testData") + ctx.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -256,7 +258,7 @@ class CachedTableSuite extends QueryTest { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -265,38 +267,38 @@ class CachedTableSuite extends QueryTest { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - table("t1") - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + ctx.table("t1") + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - cacheTable("t1") + ctx.cacheTable("t1") - assert(isCached("t1")) - assert(isCached("t2")) + assert(ctx.isCached("t1")) + assert(ctx.isCached("t2")) - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) - assert(!isCached("t2")) + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + assert(!ctx.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") - clearCache() - assert(cacheManager.isEmpty) + ctx.cacheTable("t1") + ctx.cacheTable("t2") + ctx.clearCache() + assert(ctx.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") sql("Clear CACHE") - assert(cacheManager.isEmpty) + assert(ctx.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { @@ -305,8 +307,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") assert((accsSize + 2) == Accumulators.originals.size) } @@ -317,8 +319,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - uncacheTable("t1") - uncacheTable("t2") + ctx.uncacheTable("t1") + ctx.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bfba379d9a518..4f5484f1368d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -21,13 +21,14 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -213,7 +214,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -274,7 +275,7 @@ class ColumnExpressionSuite extends QueryTest { } test("between") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -287,7 +288,7 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } - val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -413,7 +414,7 @@ class ColumnExpressionSuite extends QueryTest { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -423,7 +424,7 @@ class ColumnExpressionSuite extends QueryTest { } test("sparkPartitionId") { - val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( df.select(sparkPartitionId()), Row(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 232f05c00918f..790b405c72697 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.DecimalType class DataFrameAggregateSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -67,12 +68,12 @@ class DataFrameAggregateSuite extends QueryTest { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false") + ctx.conf.setConf("spark.sql.retainGroupColumns", "false") checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true") + ctx.conf.setConf("spark.sql.retainGroupColumns", "true") } test("agg without groups") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b1e0faa310b68..53c2befb73702 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ /** @@ -27,6 +26,9 @@ import org.apache.spark.sql.types._ */ class DataFrameFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") val row = df.select(array("a", "b")).first() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 2d2367d6e7292..fbb30706a4943 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc} -import org.apache.spark.sql.test.TestSQLContext.implicits._ - - class DataFrameImplicitsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("RDD of tuples") { checkAnswer( - sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -37,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest { test("RDD[Int]") { checkAnswer( - sc.parallelize(1 to 10).toDF("intCol"), + ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - sc.parallelize(1L to 10L).toDF("longCol"), + ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 787f3f175fea2..051d13e9a544f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameJoinSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") @@ -49,7 +49,8 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect().toSeq) } test("join - using aliases after self join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 41b4f02e6a294..495701d4f616c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameNaFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( ("Bob", 16, 176.5), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 438f479459dfe..0d3ff899dad72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends SparkFunSuite { - val sqlCtx = TestSQLContext - def toLetter(i: Int): String = (i + 97).toChar.toString + private val sqlCtx = org.apache.spark.sql.test.TestSQLContext + import sqlCtx.implicits._ + + private def toLetter(i: Int): String = (i + 97).toChar.toString test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8e81dacb8660f..bb8621abe64ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,17 +21,19 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint} class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("analysis error should be eagerly reported") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { @@ -45,11 +47,11 @@ class DataFrameSuite extends QueryTest { } // No more eager analysis once the flag is turned off - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") testData.select('nonExistentName) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("dataframe toString") { @@ -67,12 +69,12 @@ class DataFrameSuite extends QueryTest { } test("invalid plan toString, debug mode") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + val oldSetting = ctx.conf.dataFrameEagerAnalysis + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - TestSQLContext.debug() + ctx.debug() val badPlan = testData.select('badColumn) @@ -81,7 +83,7 @@ class DataFrameSuite extends QueryTest { badPlan.toString) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("access complex data") { @@ -97,8 +99,8 @@ class DataFrameSuite extends QueryTest { } test("empty data frame") { - assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(TestSQLContext.emptyDataFrame.count() === 0) + assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(ctx.emptyDataFrame.count() === 0) } test("head and take") { @@ -311,7 +313,7 @@ class DataFrameSuite extends QueryTest { } test("replace column using withColumn") { - val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -392,7 +394,7 @@ class DataFrameSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -487,21 +489,21 @@ class DataFrameSuite extends QueryTest { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = TestSQLContext.createDataFrame(rowRDD, schema) + val df = ctx.createDataFrame(rowRDD, schema) df.rdd.collect() } test("SPARK-6899") { - val originalValue = TestSQLContext.conf.codegenEnabled - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + val originalValue = ctx.conf.codegenEnabled + ctx.setConf(SQLConf.CODEGEN_ENABLED, "true") try{ checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) } finally { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } } @@ -513,14 +515,14 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + val df = ctx.read.json(ctx.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + val df2 = ctx.read.json(ctx.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -540,7 +542,7 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7324 dropDuplicates") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -588,49 +590,49 @@ class DataFrameSuite extends QueryTest { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = TestSQLContext.range(0, 10, 1, 15).select("id") + val res1 = ctx.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = TestSQLContext.range(3, 15, 3, 2).select("id") + val res2 = ctx.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = TestSQLContext.range(1, -2).select("id") + val res3 = ctx.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = TestSQLContext.range(1, -2, -2, 6).select("id") + val res4 = ctx.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id") + val res5 = ctx.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id") + val res6 = ctx.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id") + val res7 = ctx.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = TestSQLContext.range(10).select("id") + val res10 = ctx.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = TestSQLContext.range(-1).select("id") + val res11 = ctx.range(-1).select("id") assert(res11.count == 0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 407c789657834..ffd26c4f5a7c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,27 +20,28 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.logicalPlanToSparkQuery + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = sql(sqlString) + val df = ctx.sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -61,9 +62,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - cacheManager.clearCache() + ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -94,22 +95,22 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } } test("broadcasted hash join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), @@ -117,7 +118,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -126,17 +127,17 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } - sql("UNCACHE TABLE testData") + ctx.sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } @@ -241,7 +242,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -255,7 +256,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -301,7 +302,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -310,7 +311,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -362,7 +363,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -371,7 +372,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -386,7 +387,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -401,7 +402,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -411,11 +412,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") - val tmp = conf.autoBroadcastJoinThreshold + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash]) @@ -423,7 +424,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) @@ -431,12 +432,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) - sql("UNCACHE TABLE testData") + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + ctx.sql("UNCACHE TABLE testData") } test("left semi join") { - val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 3ce97c3fffdb4..2089660c52bf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,49 +19,47 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} class ListTablesSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ - val df = - sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") + private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") before { df.registerTempTable("ListTablesSuiteTable") } after { - catalog.unregisterTable(Seq("ListTablesSuiteTable")) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - tables().filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -69,19 +67,20 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(tables(), sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + ctx.sql( + "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) checkAnswer( - tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - dropTempTable("tables") + ctx.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index dd68965444f5d..0a38af2b4c889 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -17,36 +17,29 @@ package org.apache.spark.sql -import java.lang.{Double => JavaDouble} - import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ - -private[this] object MathExpressionsTestData { - - case class DoubleData(a: JavaDouble, b: JavaDouble) - val doubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() - - val nnDoubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() - - case class NullDoubles(a: JavaDouble) - val nullDoubles = - TestSQLContext.sparkContext.parallelize( - NullDoubles(1.0) :: - NullDoubles(2.0) :: - NullDoubles(3.0) :: - NullDoubles(null) :: Nil - ).toDF() + + +private object MathExpressionsTestData { + case class DoubleData(a: java.lang.Double, b: java.lang.Double) + case class NullDoubles(a: java.lang.Double) } class MathExpressionsSuite extends QueryTest { import MathExpressionsTestData._ - def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() + + private lazy val nnDoubleData = (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1)).toDF() + + private lazy val nullDoubles = + Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() + + private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( c: Column => Column, f: T => T): Unit = { checkAnswer( @@ -65,7 +58,8 @@ class MathExpressionsSuite extends QueryTest { ) } - def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = + { checkAnswer( nnDoubleData.select(c('a)), (1 to 10).map(n => Row(f(n * 0.1))) @@ -89,7 +83,7 @@ class MathExpressionsSuite extends QueryTest { ) } - def testTwoToOneMathFunction( + private def testTwoToOneMathFunction( c: (Column, Column) => Column, d: (Column, Double) => Column, f: (Double, Double) => Double): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 513ac915dcb2a..d84b57af9c882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends SparkFunSuite { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) @@ -56,7 +57,7 @@ class RowSuite extends SparkFunSuite { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 3a5f071e2f7cb..76d0dd1744a41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,67 +17,64 @@ package org.apache.spark.sql -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test._ - -/* Implicits */ -import TestSQLContext._ class SQLConfSuite extends QueryTest { - val testKey = "test.key.0" - val testVal = "test.val.0" + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + + private val testKey = "test.key.0" + private val testVal = "test.val.0" test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(TestSQLContext.sparkContext) - assert(newContext.getConf("spark.sql.testkey", "false") == "true") + val newContext = new SQLContext(ctx.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - conf.clear() - assert(getAllConfs.size === 0) + ctx.conf.clear() + assert(ctx.getAllConfs.size === 0) - setConf(testKey, testVal) - assert(getConf(testKey) == testVal) - assert(getConf(testKey, testVal + "_") == testVal) - assert(getAllConfs.contains(testKey)) + ctx.setConf(testKey, testVal) + assert(ctx.getConf(testKey) === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.getConf(testKey) == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getAllConfs.contains(testKey)) + assert(ctx.getConf(testKey) == testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) - conf.clear() + ctx.conf.clear() } test("parse SQL set commands") { - conf.clear() - sql(s"set $testKey=$testVal") - assert(getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + ctx.conf.clear() + ctx.sql(s"set $testKey=$testVal") + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) - sql("set some.property=20") - assert(getConf("some.property", "0") == "20") - sql("set some.property = 40") - assert(getConf("some.property", "0") == "40") + ctx.sql("set some.property=20") + assert(ctx.getConf("some.property", "0") === "20") + ctx.sql("set some.property = 40") + assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - sql(s"set $key=$vs") - assert(getConf(key, "0") == vs) + ctx.sql(s"set $key=$vs") + assert(ctx.getConf(key, "0") === vs) - sql(s"set $key=") - assert(getConf(key, "0") == "") + ctx.sql(s"set $key=") + assert(ctx.getConf(key, "0") === "") - conf.clear() + ctx.conf.clear() } test("deprecated property") { - conf.clear() - sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") + ctx.conf.clear() + ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(ctx.getConf(SQLConf.SHUFFLE_PARTITIONS) === "10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 797d123b48668..c8d8796568a41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -20,31 +20,29 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - private val testSqlContext = TestSQLContext - private val testSparkContext = TestSQLContext.sparkContext + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(testSqlContext) + SQLContext.setLastInstantiatedContext(ctx) } test("getOrCreate instantiates SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(testSparkContext) + val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate gets last explicitly instantiated SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(testSparkContext) - assert(SQLContext.getOrCreate(testSparkContext) != null, + val sqlContext = new SQLContext(ctx.sparkContext) + assert(SQLContext.getOrCreate(ctx.sparkContext) != null, "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 55b68d8e2283c..5babc4332cc77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} - +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ @@ -36,8 +34,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - val sqlContext = TestSQLContext + val sqlContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.implicits._ + import sqlContext.sql test("SPARK-6743: no columns from cache") { Seq( @@ -46,7 +45,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { (43, 81, 24) ).toDF("a", "b", "c").registerTempTable("cachedData") - cacheTable("cachedData") + sqlContext.cacheTable("cachedData") checkAnswer( sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) @@ -94,14 +93,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -118,7 +117,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("grouping on nested fields") { - read.json(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.parallelize( + """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") checkAnswer( @@ -135,8 +135,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6201 IN type conversion") { - read.json( - sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + sqlContext.read.json( + sqlContext.sparkContext.parallelize( + Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") checkAnswer( @@ -157,12 +158,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("aggregation with codegen") { - val originalValue = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") + val originalValue = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") // Prepare a table that we can group some rows. - table("testData") - .unionAll(table("testData")) - .unionAll(table("testData")) + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { @@ -254,8 +255,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { "SELECT sum('a'), avg('a'), count(null) FROM testData", Row(0, null, 0) :: Nil) } finally { - dropTempTable("testData3x") - setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } } @@ -447,42 +448,42 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("external sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "true") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("SPARK-6927 sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") - setConf(SQLConf.CODEGEN_ENABLED, "true") + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") try{ sortTest() } finally { - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) } } test("SPARK-6927 external sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") - setConf(SQLConf.EXTERNAL_SORT, "true") + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") try { sortTest() } finally { - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) } } @@ -516,7 +517,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Allow only a single WITH clause per query") { intercept[RuntimeException] { - sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + sql( + "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } } @@ -863,7 +865,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SET commands semantics using sql()") { - conf.clear() + sqlContext.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -895,17 +897,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { sql(s"SET $nonexistentKey"), Row(s"$nonexistentKey=") ) - conf.clear() + sqlContext.conf.clear() } test("SET commands with illegal or inappropriate argument") { - conf.clear() + sqlContext.conf.clear() // Set negative mapred.reduce.tasks for automatically determing // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - conf.clear() + sqlContext.conf.clear() } test("apply schema") { @@ -923,7 +925,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -953,7 +955,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = createDataFrame(rowRDD2, schema2) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -978,7 +980,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = createDataFrame(rowRDD3, schema2) + val df3 = sqlContext.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1023,7 +1025,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1038,7 +1040,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3371 Renaming a function expression with group by gives error") { - TestSQLContext.udf.register("len", (s: String) => s.length) + sqlContext.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1219,9 +1221,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize( + val data = sqlContext.sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - read.json(data).registerTempTable("records") + sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1262,13 +1264,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-4322 Grouping field with struct field as sub expression") { - read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - dropTempTable("data") + sqlContext.dropTempTable("data") - read.json(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json( + sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - dropTempTable("data") + sqlContext.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { @@ -1287,10 +1291,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1299,22 +1303,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } test("SPARK-4699 case sensitivity SQL query") { - setConf(SQLConf.CASE_SENSITIVE, "false") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, "false") val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - setConf(SQLConf.CASE_SENSITIVE, "true") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, "true") } test("SPARK-6145: ORDER BY test for nested fields") { - read.json(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1326,14 +1331,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6145: special cases") { - read.json(sparkContext.makeRDD( + sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - read.json(sparkContext.makeRDD( + sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index d2ede39f0a5f6..ece3d6fdf2af5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext._ case class ReflectData( stringField: String, @@ -75,15 +74,15 @@ case class ComplexReflectData( class ScalaReflectionRelationSuite extends SparkFunSuite { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectData") + Seq(data).toDF().registerTempTable("reflectData") - assert(sql("SELECT * FROM reflectData").collect().head === + assert(ctx.sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -91,27 +90,26 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectNullData") + Seq(data).toDF().registerTempTable("reflectNullData") - assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectOptionalData") + Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === + assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. test("query binary data") { - val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.toDF().registerTempTable("reflectBinary") + Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] + val result = ctx.sql("SELECT data FROM reflectBinary") + .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -127,10 +125,9 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Map(10 -> 100L, 20 -> 200L), Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectComplexData") - assert(sql("SELECT * FROM reflectComplexData").collect().head === + Seq(data).toDF().registerTempTable("reflectComplexData") + assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === new GenericRow(Array[Any]( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 1e8cde606b67b..e55c9e460b791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.test.TestSQLContext class SerializationSuite extends SparkFunSuite { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(TestSQLContext.sparkContext) + val sqlContext = new SQLContext(ctx.sparkContext) new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 1a9ba66416b21..064c040d2b771 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,43 +17,41 @@ package org.apache.spark.sql -import org.apache.spark.sql.test._ - -/* Implicits */ -import TestSQLContext._ -import TestSQLContext.implicits._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("Simple UDF") { - udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) + ctx.udf.register("strLenScala", (_: String).length) + assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) + ctx.udf.register("random0", () => { Math.random()}) + assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - udf.register("strLenScala", (_: String).length + (_: Int)) - assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + ctx.udf.register("strLenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { - udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - sql("SELECT returnStruct('test', 'test2') as ret") + ctx.sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } test("udf that is transformed") { - udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index dc2d43a197f40..45c9f06941c10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import java.io.File - -import org.apache.spark.util.Utils - import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog @@ -28,12 +24,11 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -72,11 +67,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD = sparkContext.parallelize(points).toDF() + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val pointsRDD = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -94,10 +91,10 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - TestSQLContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - sql("SELECT testType(features) from points"), + ctx.sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } From e5054605994b8777e629c02fcbf8a5a6cbd0b0fe Mon Sep 17 00:00:00 2001 From: Ted Blackman Date: Thu, 4 Jun 2015 22:21:11 -0700 Subject: [PATCH 11/17] [SPARK-8116][PYSPARK] Allow sc.range() to take a single argument. Author: Ted Blackman Closes #6656 from belisarius222/branch-1.4 and squashes the following commits: 747cbc2 [Ted Blackman] [SPARK-8116][PYSPARK] Allow sc.range() to take a single argument. (cherry picked from commit f02af7c8f7f43e4cfe3c412d2b5ea4128669ce22) Signed-off-by: Reynold Xin --- python/pyspark/context.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index aeb7ad4f2f83e..44d90f1437bc9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -324,10 +324,12 @@ def stop(self): with SparkContext._lock: SparkContext._active_spark_context = None - def range(self, start, end, step=1, numSlices=None): + def range(self, start, end=None, step=1, numSlices=None): """ Create a new RDD of int containing elements from `start` to `end` - (exclusive), increased by `step` every element. + (exclusive), increased by `step` every element. Can be called the same + way as python's built-in range() function. If called with a single argument, + the argument is interpreted as `end`, and `start` is set to 0. :param start: the start value :param end: the end value (exclusive) @@ -335,9 +337,17 @@ def range(self, start, end, step=1, numSlices=None): :param numSlices: the number of partitions of the new RDD :return: An RDD of int + >>> sc.range(5).collect() + [0, 1, 2, 3, 4] + >>> sc.range(2, 4).collect() + [2, 3] >>> sc.range(1, 7, 2).collect() [1, 3, 5] """ + if end is None: + end = start + start = 0 + return self.parallelize(xrange(start, end, step), numSlices) def parallelize(self, c, numSlices=None): From 2777ed3948d26b14e342ba161e145009e31b8829 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 5 Jun 2015 07:45:25 +0200 Subject: [PATCH 12/17] [DOC][Minor]Specify the common sources available for collecting I was wondering what else common sources available until search the source code. Maybe better to make this clear. Author: Yijie Shen Closes #6641 from yijieshen/patch-1 and squashes the following commits: b5b99b4 [Yijie Shen] Make it clear that JvmSource is the only available additional source currently f23140c [Yijie Shen] [DOC][Minor]Specify the common sources available for collecting --- conf/metrics.properties.template | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7de0011a48ca8..7f17bc7eea4f5 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -4,7 +4,7 @@ # divided into instances which correspond to internal components. # Each instance can be configured to report its metrics to one or more sinks. # Accepted values for [instance] are "master", "worker", "executor", "driver", -# and "applications". A wild card "*" can be used as an instance name, in +# and "applications". A wildcard "*" can be used as an instance name, in # which case all instances will inherit the supplied property. # # Within an instance, a "source" specifies a particular set of grouped metrics. @@ -32,7 +32,7 @@ # name (see examples below). # 2. Some sinks involve a polling period. The minimum allowed polling period # is 1 second. -# 3. Wild card properties can be overridden by more specific properties. +# 3. Wildcard properties can be overridden by more specific properties. # For example, master.sink.console.period takes precedence over # *.sink.console.period. # 4. A metrics specific configuration @@ -47,6 +47,13 @@ # instance master and applications. MetricsServlet may not be configured by self. # +## List of available common sources and their properties. + +# org.apache.spark.metrics.source.JvmSource +# Note: Currently, JvmSource is the only available common source +# to add additionaly to an instance, to enable this, +# set the "class" option to its fully qulified class name (see examples below) + ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink From 3a5c4da473a8a497004dfe6eacc0e6646651b227 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 5 Jun 2015 00:32:46 -0700 Subject: [PATCH 13/17] [MINOR] remove unused interpolation var in log message Completely trivial but I noticed this wrinkle in a log message today; `$sender` doesn't refer to anything and isn't interpolated here. Author: Sean Owen Closes #6650 from srowen/Interpolation and squashes the following commits: 518687a [Sean Owen] Actually interpolate log string 7edb866 [Sean Owen] Trivial: remove unused interpolation var in log message --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index fcad959540f5a..7c7f70d8a193b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -103,7 +103,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case None => // Ignoring the update since we don't know about the executor. logWarning(s"Ignored task status update ($taskId state $state) " + - "from unknown executor $sender with ID $executorId") + s"from unknown executor with ID $executorId") } } From da20c8ca37663738112b04657057858ee3e55072 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 5 Jun 2015 10:32:33 +0200 Subject: [PATCH 14/17] [MINOR] [BUILD] Change link to jenkins builds on github. Link to the tail of the console log, instead of the full log. That's bound to have the info the user is looking for, and at the same time loads way more quickly than the (huge) full log, which is just one click away if needed. Author: Marcelo Vanzin Closes #6664 from vanzin/jenkins-link and squashes the following commits: ba07ed8 [Marcelo Vanzin] [minor] [build] Change link to jenkins builds on github. --- dev/run-tests-jenkins | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3cbd8666c8d68..641b0ff3c4be4 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -193,7 +193,7 @@ done test_result="$?" if [ "$test_result" -eq "124" ]; then - fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}consoleFull)** \ + fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}console)** \ for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \ after a configured wait of \`${TESTS_TIMEOUT}\`." @@ -233,7 +233,7 @@ done # post end message { result_message="\ - [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}consoleFull) for \ + [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}console) for \ PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" From b16b5434ff44c42e4b3a337f9af147669ba44896 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 5 Jun 2015 14:11:38 +0200 Subject: [PATCH 15/17] [MINOR] [BUILD] Use custom temp directory during build. Even with all the efforts to cleanup the temp directories created by unit tests, Spark leaves a lot of garbage in /tmp after a test run. This change overrides java.io.tmpdir to place those files under the build directory instead. After an sbt full unit test run, I was left with > 400 MB of temp files. Since they're now under the build dir, it's much easier to clean them up. Also make a slight change to a unit test to make it not pollute the source directory with test data. Author: Marcelo Vanzin Closes #6653 from vanzin/unit-test-tmp and squashes the following commits: 31e2dd5 [Marcelo Vanzin] Fix tests that depend on each other. aa92944 [Marcelo Vanzin] [minor] [build] Use custom temp directory during build. --- .../spark/deploy/SparkSubmitUtilsSuite.scala | 22 ++++++++++--------- pom.xml | 4 +++- project/SparkBuild.scala | 1 + 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 8fda5c8b472c9..07d261cc428c4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -28,9 +28,12 @@ import org.apache.ivy.plugins.resolver.IBiblioResolver import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.util.Utils class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { + private var tempIvyPath: String = _ + private val noOpOutputStream = new OutputStream { def write(b: Int) = {} } @@ -47,6 +50,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { super.beforeAll() // We don't want to write logs during testing SparkSubmitUtils.printStream = new BufferPrintStream + tempIvyPath = Utils.createTempDir(namePrefix = "ivy").getAbsolutePath() } test("incorrect maven coordinate throws error") { @@ -90,21 +94,20 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { - val ivyPath = "dummy" + File.separator + "ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") - var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) + var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(tempIvyPath)) for (i <- 0 until 3) { - val index = jPaths.indexOf(ivyPath) + val index = jPaths.indexOf(tempIvyPath) assert(index >= 0) - jPaths = jPaths.substring(index + ivyPath.length) + jPaths = jPaths.substring(index + tempIvyPath.length) } val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") IvyTestUtils.withRepository(main, None, None) { repo => // end to end val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), - Option(ivyPath), true) - assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + Option(tempIvyPath), true) + assert(jarPath.indexOf(tempIvyPath) >= 0, "should use non-default ivy path") } } @@ -123,13 +126,12 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mylib") >= 0, "should find artifact") } // Local ivy repository with modified home - val dummyIvyPath = "dummy" + File.separator + "ivy" - val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) + val dummyIvyLocal = new File(tempIvyPath, "local" + File.separator) IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal), true) { repo => val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, - Some(dummyIvyPath), true) + Some(tempIvyPath), true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") + assert(jarPath.indexOf(tempIvyPath) >= 0, "should be in new ivy path") } } diff --git a/pom.xml b/pom.xml index e28d4b9fc2b17..a848deffe7375 100644 --- a/pom.xml +++ b/pom.xml @@ -179,7 +179,7 @@ compile ${session.executionRootDirectory} @@ -1256,6 +1256,7 @@ test true + ${project.build.directory}/tmp ${spark.test.home} 1 false @@ -1289,6 +1290,7 @@ test true + ${project.build.directory}/tmp ${spark.test.home} 1 false diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ef3a175bac209..921f1599fedef 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -496,6 +496,7 @@ object TestSettings { "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), + javaOptions in Test += s"-Djava.io.tmpdir=$sparkHome/target/tmp", javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", From 019dc9f558cf7c0b708d3b1f0882b0c19134ffb6 Mon Sep 17 00:00:00 2001 From: Akhil Das Date: Fri, 5 Jun 2015 14:23:23 +0200 Subject: [PATCH 16/17] [STREAMING] Update streaming-kafka-integration.md Fixed the broken links (Examples) in the documentation. Author: Akhil Das Closes #6666 from akhld/patch-2 and squashes the following commits: 2228b83 [Akhil Das] Update streaming-kafka-integration.md --- docs/streaming-kafka-integration.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 64714f0b799fc..d6d5605948a5a 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -29,7 +29,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*; @@ -39,7 +39,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
@@ -105,7 +105,7 @@ Next, we discuss how to use this approach in your streaming application. streamingContext, [map of Kafka parameters], [set of topics to consume]) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*; @@ -116,7 +116,7 @@ Next, we discuss how to use this approach in your streaming application. [map of Kafka parameters], [set of topics to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
@@ -153,4 +153,4 @@ Next, we discuss how to use this approach in your streaming application. Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. -3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. \ No newline at end of file +3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. From 700312e12f9588f01a592d6eac7bff7eb366ac8f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 5 Jun 2015 14:32:00 +0200 Subject: [PATCH 17/17] [SPARK-6324] [CORE] Centralize handling of script usage messages. Reorganize code so that the launcher library handles most of the work of printing usage messages, instead of having an awkward protocol between the library and the scripts for that. This mostly applies to SparkSubmit, since the launcher lib does not do command line parsing for classes invoked in other ways, and thus cannot handle failures for those. Most scripts end up going through SparkSubmit, though, so it all works. The change adds a new, internal command line switch, "--usage-error", which prints the usage message and exits with a non-zero status. Scripts can override the command printed in the usage message by setting an environment variable - this avoids having to grep the output of SparkSubmit to remove references to the "spark-submit" script. The only sub-optimal part of the change is the special handling for the spark-sql usage, which is now done in SparkSubmitArguments. Author: Marcelo Vanzin Closes #5841 from vanzin/SPARK-6324 and squashes the following commits: 2821481 [Marcelo Vanzin] Merge branch 'master' into SPARK-6324 bf139b5 [Marcelo Vanzin] Filter output of Spark SQL CLI help. c6609bf [Marcelo Vanzin] Fix exit code never being used when printing usage messages. 6bc1b41 [Marcelo Vanzin] [SPARK-6324] [core] Centralize handling of script usage messages. --- bin/pyspark | 16 +--- bin/pyspark2.cmd | 1 + bin/spark-class | 13 +-- bin/spark-shell | 15 +--- bin/spark-shell2.cmd | 21 +---- bin/spark-sql | 39 +-------- bin/spark-submit | 12 --- bin/spark-submit2.cmd | 13 +-- bin/sparkR | 18 +--- .../org/apache/spark/deploy/SparkSubmit.scala | 10 +-- .../spark/deploy/SparkSubmitArguments.scala | 76 ++++++++++++++++- .../spark/deploy/SparkSubmitSuite.scala | 2 +- .../java/org/apache/spark/launcher/Main.java | 83 ++++++++++--------- .../launcher/SparkSubmitCommandBuilder.java | 18 +++- .../launcher/SparkSubmitOptionParser.java | 2 + 15 files changed, 147 insertions(+), 192 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 7cb19c51b43a2..f9dbddfa53560 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -17,24 +17,10 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/pyspark [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi +export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 09b4149c2a439..45e9e3def5121 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -21,6 +21,7 @@ rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. call %SPARK_HOME%\bin\load-spark-env.cmd +set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/spark-class b/bin/spark-class index c49d97ce5cf25..7bb1afe4b44f5 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -16,18 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -set -e # Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" . "$SPARK_HOME"/bin/load-spark-env.sh -if [ -z "$1" ]; then - echo "Usage: spark-class []" 1>&2 - exit 1 -fi - # Find the java binary if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" @@ -98,9 +92,4 @@ CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") - -if [ "${CMD[0]}" = "usage" ]; then - "${CMD[@]}" -else - exec "${CMD[@]}" -fi +exec "${CMD[@]}" diff --git a/bin/spark-shell b/bin/spark-shell index b3761b5e1375b..a6dc863d83fc6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -29,20 +29,7 @@ esac set -o posix export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -usage() { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-shell [options]" - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi +export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" # SPARK-4161: scala does not assume use of the java classpath, # so we need to add the "-Dscala.usejavacp=true" flag manually. We diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 00fd30fa38d36..251309d67f860 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -18,12 +18,7 @@ rem limitations under the License. rem set SPARK_HOME=%~dp0.. - -echo "%*" | findstr " \<--help\> \<-h\>" >nul -if %ERRORLEVEL% equ 0 ( - call :usage - exit /b 0 -) +set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] rem SPARK-4161: scala does not assume use of the java classpath, rem so we need to add the "-Dscala.usejavacp=true" flag manually. We @@ -37,16 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -call %SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -echo "Usage: .\bin\spark-shell.cmd [options]" >&2 -call %SPARK_HOME%\bin\spark-submit2.cmd --help 2>&1 | findstr /V "Usage" 1>&2 -goto :eof +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql index ca1729f4cfcb4..4ea7bc6e39c07 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -17,41 +17,6 @@ # limitations under the License. # -# -# Shell script for starting the Spark SQL CLI - -# Enter posix mode for bash -set -o posix - -# NOTE: This exact class name is matched downstream by SparkSubmit. -# Any changes need to be reflected there. -export CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" - -# Figure out where Spark is installed export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -function usage { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-sql [options] [cli option]" - pattern="usage" - pattern+="\|Spark assembly has been built with Hive" - pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" - pattern+="\|Spark Command: " - pattern+="\|--help" - pattern+="\|=======" - - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - echo - echo "CLI options:" - "$FWDIR"/bin/spark-class "$CLASS" --help 2>&1 | grep -v "$pattern" 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi - -exec "$FWDIR"/bin/spark-submit --class "$CLASS" "$@" +export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" +exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 0e0afe71a0f05..255378b0f077c 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -22,16 +22,4 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" # disable randomized hash for string in Python 3.3+ export PYTHONHASHSEED=0 -# Only define a usage function if an upstream script hasn't done so. -if ! type -t usage >/dev/null 2>&1; then - usage() { - if [ -n "$1" ]; then - echo "$1" - fi - "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit --help - exit "$2" - } - export -f usage -fi - exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index d3fc4a5cc3f6e..651376e526928 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -24,15 +24,4 @@ rem disable randomized hash for string in Python 3.3+ set PYTHONHASHSEED=0 set CLASS=org.apache.spark.deploy.SparkSubmit -call %~dp0spark-class2.cmd %CLASS% %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -call %SPARK_HOME%\bin\spark-class2.cmd %CLASS% --help -goto :eof +%~dp0spark-class2.cmd %CLASS% %* diff --git a/bin/sparkR b/bin/sparkR index 8c918e2b09aef..464c29f369424 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -17,23 +17,7 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" - source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/sparkR [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi - +export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]" exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 8cf4d58847d8e..3aa3f948e865d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -82,13 +82,13 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // Exposed for testing - private[spark] var exitFn: () => Unit = () => System.exit(1) + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn() + exitFn(1) } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to @@ -99,7 +99,7 @@ object SparkSubmit { /_/ """.format(SPARK_VERSION)) printStream.println("Type --help for more information.") - exitFn() + exitFn(0) } def main(args: Array[String]): Unit = { @@ -160,7 +160,7 @@ object SparkSubmit { // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - exitFn() + exitFn(1) } else { throw e } @@ -700,7 +700,7 @@ object SparkSubmit { /** * Return whether the given main class represents a sql shell. */ - private def isSqlShell(mainClass: String): Boolean = { + private[deploy] def isSqlShell(mainClass: String): Boolean = { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index cc6a7bd9f4119..b7429a901e162 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,12 +17,15 @@ package org.apache.spark.deploy +import java.io.{ByteArrayOutputStream, PrintStream} +import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.io.Source import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.launcher.SparkSubmitArgumentsParser @@ -412,6 +415,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case VERSION => SparkSubmit.printVersionAndExit() + case USAGE_ERROR => + printUsageAndExit(1) + case _ => throw new IllegalArgumentException(s"Unexpected argument '$opt'.") } @@ -449,11 +455,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) } - outStream.println( + val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] - |Usage: spark-submit --status [submission ID] --master [spark://...] - | + |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) + outStream.println(command) + + outStream.println( + """ |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -525,6 +534,65 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | delegation tokens periodically. """.stripMargin ) - SparkSubmit.exitFn() + + if (SparkSubmit.isSqlShell(mainClass)) { + outStream.println("CLI options:") + outStream.println(getSqlShellOptions()) + } + + SparkSubmit.exitFn(exitCode) } + + /** + * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter + * the results to remove unwanted lines. + * + * Since the CLI will call `System.exit()`, we install a security manager to prevent that call + * from working, and restore the original one afterwards. + */ + private def getSqlShellOptions(): String = { + val currentOut = System.out + val currentErr = System.err + val currentSm = System.getSecurityManager() + try { + val out = new ByteArrayOutputStream() + val stream = new PrintStream(out) + System.setOut(stream) + System.setErr(stream) + + val sm = new SecurityManager() { + override def checkExit(status: Int): Unit = { + throw new SecurityException() + } + + override def checkPermission(perm: java.security.Permission): Unit = {} + } + System.setSecurityManager(sm) + + try { + Class.forName(mainClass).getMethod("main", classOf[Array[String]]) + .invoke(null, Array(HELP)) + } catch { + case e: InvocationTargetException => + // Ignore SecurityException, since we throw it above. + if (!e.getCause().isInstanceOf[SecurityException]) { + throw e + } + } + + stream.flush() + + // Get the output and discard any unnecessary lines from it. + Source.fromString(new String(out.toByteArray())).getLines + .filter { line => + !line.startsWith("log4j") && !line.startsWith("usage") + } + .mkString("\n") + } finally { + System.setSecurityManager(currentSm) + System.setOut(currentOut) + System.setErr(currentErr) + } + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 46369457f000a..46ea28d0f18f6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -62,7 +62,7 @@ class SparkSubmitSuite SparkSubmit.printStream = printStream @volatile var exitedCleanly = false - SparkSubmit.exitFn = () => exitedCleanly = true + SparkSubmit.exitFn = (_) => exitedCleanly = true val thread = new Thread { override def run() = try { diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 929b29a49ed70..62492f9baf3bb 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -53,21 +53,33 @@ public static void main(String[] argsArray) throws Exception { List args = new ArrayList(Arrays.asList(argsArray)); String className = args.remove(0); - boolean printLaunchCommand; - boolean printUsage; + boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); AbstractCommandBuilder builder; - try { - if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + try { builder = new SparkSubmitCommandBuilder(args); - } else { - builder = new SparkClassCommandBuilder(className, args); + } catch (IllegalArgumentException e) { + printLaunchCommand = false; + System.err.println("Error: " + e.getMessage()); + System.err.println(); + + MainClassOptionParser parser = new MainClassOptionParser(); + try { + parser.parse(args); + } catch (Exception ignored) { + // Ignore parsing exceptions. + } + + List help = new ArrayList(); + if (parser.className != null) { + help.add(parser.CLASS); + help.add(parser.className); + } + help.add(parser.USAGE_ERROR); + builder = new SparkSubmitCommandBuilder(help); } - printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - printUsage = false; - } catch (IllegalArgumentException e) { - builder = new UsageCommandBuilder(e.getMessage()); - printLaunchCommand = false; - printUsage = true; + } else { + builder = new SparkClassCommandBuilder(className, args); } Map env = new HashMap(); @@ -78,13 +90,7 @@ public static void main(String[] argsArray) throws Exception { } if (isWindows()) { - // When printing the usage message, we can't use "cmd /v" since that prevents the env - // variable from being seen in the caller script. So do not call prepareWindowsCommand(). - if (printUsage) { - System.out.println(join(" ", cmd)); - } else { - System.out.println(prepareWindowsCommand(cmd, env)); - } + System.out.println(prepareWindowsCommand(cmd, env)); } else { // In bash, use NULL as the arg separator since it cannot be used in an argument. List bashCmd = prepareBashCommand(cmd, env); @@ -135,33 +141,30 @@ private static List prepareBashCommand(List cmd, Map buildCommand(Map env) { - if (isWindows()) { - return Arrays.asList("set", "SPARK_LAUNCHER_USAGE_ERROR=" + message); - } else { - return Arrays.asList("usage", message, "1"); - } + protected boolean handleUnknown(String opt) { + return false; + } + + @Override + protected void handleExtraArgs(List extra) { + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 7d387d406edae..3e5a2820b6c11 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -77,6 +77,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } private final List sparkArgs; + private final boolean printHelp; /** * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed @@ -87,10 +88,11 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkSubmitCommandBuilder() { this.sparkArgs = new ArrayList(); + this.printHelp = false; } SparkSubmitCommandBuilder(List args) { - this(); + this.sparkArgs = new ArrayList(); List submitArgs = args; if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { this.allowsMixedArguments = true; @@ -104,14 +106,16 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = false; } - new OptionParser().parse(submitArgs); + OptionParser parser = new OptionParser(); + parser.parse(submitArgs); + this.printHelp = parser.helpRequested; } @Override public List buildCommand(Map env) throws IOException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { + if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -311,6 +315,8 @@ private boolean isThriftServer(String mainClass) { private class OptionParser extends SparkSubmitOptionParser { + boolean helpRequested = false; + @Override protected boolean handle(String opt, String value) { if (opt.equals(MASTER)) { @@ -341,6 +347,9 @@ protected boolean handle(String opt, String value) { allowsMixedArguments = true; appResource = specialClasses.get(value); } + } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { + helpRequested = true; + sparkArgs.add(opt); } else { sparkArgs.add(opt); if (value != null) { @@ -360,6 +369,7 @@ protected boolean handleUnknown(String opt) { appArgs.add(opt); return true; } else { + checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); sparkArgs.add(opt); return false; } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 229000087688f..b88bba883ac65 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -61,6 +61,7 @@ class SparkSubmitOptionParser { // Options that do not take arguments. protected final String HELP = "--help"; protected final String SUPERVISE = "--supervise"; + protected final String USAGE_ERROR = "--usage-error"; protected final String VERBOSE = "--verbose"; protected final String VERSION = "--version"; @@ -120,6 +121,7 @@ class SparkSubmitOptionParser { final String[][] switches = { { HELP, "-h" }, { SUPERVISE }, + { USAGE_ERROR }, { VERBOSE, "-v" }, { VERSION }, };