diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0678bdd02110e..f9476ff826a62 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -224,7 +224,6 @@ class SparkContext(config: SparkConf) extends Logging { /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration: Configuration = { - val env = SparkEnv.get val hadoopConf = SparkHadoopUtil.get.newConfiguration() // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 14fa9d8135afe..4f3081433a542 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -543,6 +543,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + partitioner: Partitioner) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, partitioner))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -558,6 +570,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + numPartitions: Int) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, numPartitions))) + /** Alias for cogroup. */ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] = fromRDD(cogroupResultToJava(rdd.groupWith(other))) @@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3))) + /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. @@ -786,6 +828,15 @@ object JavaPairRDD { .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) } + private[spark] + def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( + rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) + : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { + rddToPairRDDFunctions(rdd) + .mapValues(x => + (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + } + def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { new JavaPairRDD[K, V](rdd) } diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 5da9615c9e9af..39150deab863c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -21,6 +21,8 @@ import scala.collection.mutable.ListBuffer import org.apache.log4j.Level +import org.apache.spark.util.MemoryParam + /** * Command-line parser for the driver client. */ @@ -51,8 +53,8 @@ private[spark] class ClientArguments(args: Array[String]) { cores = value.toInt parse(tail) - case ("--memory" | "-m") :: value :: tail => - memory = value.toInt + case ("--memory" | "-m") :: MemoryParam(value) :: tail => + memory = value parse(tail) case ("--supervise" | "-s") :: tail => diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index fe36c80e0be84..443d1c587c3ee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -567,6 +567,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) new FlatMappedValuesRDD(self, cleanF) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + partitioner: Partitioner) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner) + cg.mapValues { case Seq(vs, w1s, w2s, w3s) => + (vs.asInstanceOf[Seq[V]], + w1s.asInstanceOf[Seq[W1]], + w2s.asInstanceOf[Seq[W2]], + w3s.asInstanceOf[Seq[W3]]) + } + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -599,6 +621,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -633,6 +665,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, new HashPartitioner(numPartitions)) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + numPartitions: Int) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, new HashPartitioner(numPartitions)) + } + /** Alias for cogroup. */ def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, defaultPartitioner(self, other)) @@ -644,6 +689,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * Return an RDD with the pairs from `this` whose keys are not in `other`. * diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1633b185861b9..cebfd109d825f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -446,7 +446,7 @@ abstract class RDD[T: ClassTag]( * Return this RDD sorted by the given key function. */ def sortBy[K]( - f: (T) ⇒ K, + f: (T) => K, ascending: Boolean = true, numPartitions: Int = this.partitions.size) (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 7ebed5105b9fd..2889e171f627e 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -91,8 +91,13 @@ private[spark] object MetadataCleaner { conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } + /** + * Set the default delay time (in seconds). + * @param conf SparkConf instance + * @param delay default delay time to set + * @param resetAll whether to reset all to default + */ def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) { - // override for all ? conf.set("spark.cleaner.ttl", delay.toString) if (resetAll) { for (cleanerType <- MetadataCleanerType.values) { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e46298c6a9e63..761f2d6a77d33 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -21,6 +21,9 @@ import java.util.*; import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; + import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; @@ -304,6 +307,66 @@ public void cogroup() { cogrouped.collect(); } + @SuppressWarnings("unchecked") + @Test + public void cogroup3() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + + JavaPairRDD, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + + + cogrouped.collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void cogroup4() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", "BR"), + new Tuple2("Apples", "US") + )); + + JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities, countries); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); + + cogrouped.collect(); + } + @SuppressWarnings("unchecked") @Test public void leftOuterJoin() { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 0b9004448a63e..447e38ec9dbd0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { )) } + test("groupWith3") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val joined = rdd1.groupWith(rdd2, rdd3).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'))), + (2, (List(1), List('y', 'z'), List())), + (3, (List(1), List(), List('b'))), + (4, (List(), List('w'), List('c', 'd'))) + )) + } + + test("groupWith4") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val rdd4 = sc.parallelize(Array((2, '@'))) + val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'), List())), + (2, (List(1), List('y', 'z'), List(), List('@'))), + (3, (List(1), List(), List('b'), List())), + (4, (List(), List('w'), List('c', 'd'), List())) + )) + } + test("zero-partition RDD") { val emptyDir = Files.createTempDir() emptyDir.deleteOnExit() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index be506e0287a16..abd7b22310f1a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -239,11 +239,14 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers checkNonZeroAvg( taskInfoMetrics.map(_._2.executorDeserializeTime), stageInfo + " executorDeserializeTime") + + /* Test is disabled (SEE SPARK-2208) if (stageInfo.rddInfos.exists(_.name == d4.name)) { checkNonZeroAvg( taskInfoMetrics.map(_._2.shuffleReadMetrics.get.fetchWaitTime), stageInfo + " fetchWaitTime") } + */ taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => taskMetrics.resultSize should be > (0l) diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 53d7f5c6072e6..02e228945bbd9 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -120,7 +120,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // on SparkConf settings. def testAppenderSelection[ExpectedAppender: ClassTag, ExpectedRollingPolicy]( - properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): FileAppender = { + properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): Unit = { // Set spark conf properties val conf = new SparkConf @@ -129,8 +129,9 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { } // Create and test file appender - val inputStream = new PipedInputStream(new PipedOutputStream()) - val appender = FileAppender(inputStream, new File("stdout"), conf) + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) + val appender = FileAppender(testInputStream, testFile, conf) assert(appender.isInstanceOf[ExpectedAppender]) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) @@ -144,7 +145,8 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { } assert(policyParam === expectedRollingPolicyParam) } - appender + testOutputStream.close() + appender.awaitTermination() } import RollingFileAppender._ diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4243ef480ba39..fecd8f2cc2d48 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -68,15 +68,29 @@ Most of the configs are the same for Spark on YARN as for other deployment modes - spark.yarn.executor.memoryOverhead - 384 + spark.yarn.dist.archives + (none) + + Comma separated list of archives to be extracted into the working directory of each executor. + + + + spark.yarn.dist.files + (none) + + Comma-separated list of files to be placed in the working directory of each executor. + + + + spark.yarn.executor.memoryOverhead + 384 The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. spark.yarn.driver.memoryOverhead - 384 + 384 The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 00d0b18c27a8d..1a0073c9d487e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -419,7 +419,7 @@ class RowMatrix( /** Updates or verifies the number of rows. */ private def updateNumRows(m: Long) { if (nRows <= 0) { - nRows == m + nRows = m } else { require(nRows == m, s"The number of rows $m is different from what specified or previously computed: ${nRows}.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 8f187c9df5102..7bbed9c8fdbef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -60,7 +60,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. */ - def setConvergenceTol(tolerance: Int): this.type = { + def setConvergenceTol(tolerance: Double): this.type = { this.convergenceTol = tolerance this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 4b1850659a18e..fe7a9033cd5f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -195,4 +195,38 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { assert(lossLBFGS3.length == 6) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } + + test("Optimize via class LBFGS.") { + val regParam = 0.2 + + // Prepare another non-zero weights to compare the loss in the first iteration. + val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) + val convergenceTol = 1e-12 + val maxNumIterations = 10 + + val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) + .setNumCorrections(numCorrections) + .setConvergenceTol(convergenceTol) + .setMaxNumIterations(maxNumIterations) + .setRegParam(regParam) + + val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) + + val numGDIterations = 50 + val stepSize = 1.0 + val (weightGD, _) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + squaredL2Updater, + stepSize, + numGDIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // for class LBFGS and the optimize method, we only look at the weights + assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && + compareDouble(weightLBFGS(1), weightGD(1), 0.02), + "The weight differences between LBFGS and GD should be within 2%.") + } } diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 91ae8263f66b8..19235d5f79f85 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -43,12 +43,19 @@ def launch_gateway(): # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE, preexec_fn=preexec_func) else: # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE) - # Determine which ephemeral port the server started on: - gateway_port = int(proc.stdout.readline()) + proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE) + + try: + # Determine which ephemeral port the server started on: + gateway_port = int(proc.stdout.readline()) + except: + error_code = proc.poll() + raise Exception("Launching GatewayServer failed with exit code %d: %s" % + (error_code, "".join(proc.stderr.readlines()))) + # Create a thread to echo output from the GatewayServer, which is required # for Java log output to show up: class EchoOutputThread(Thread): diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 6f94d26ef86a9..5f3a7e71f7866 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -79,15 +79,15 @@ def dispatch(seq): return _do_python_join(rdd, other, numPartitions, dispatch) -def python_cogroup(rdd, other, numPartitions): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) +def python_cogroup(rdds, numPartitions): + def make_mapper(i): + return lambda (k, v): (k, (i, v)) + vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) + rdd_len = len(vrdds) def dispatch(seq): - vbuf, wbuf = [], [] + bufs = [[] for i in range(rdd_len)] for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return (ResultIterable(vbuf), ResultIterable(wbuf)) - return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch) + bufs[n].append(v) + return tuple(map(ResultIterable, bufs)) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a0b2c744f0e7f..1d55c35a8bf48 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -512,7 +512,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x): [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)] """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() bounds = list() @@ -1154,7 +1154,7 @@ def partitionBy(self, numPartitions, partitionFunc=None): set([]) """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() if partitionFunc is None: partitionFunc = lambda x: 0 if x is None else hash(x) @@ -1212,7 +1212,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, [('a', '11'), ('b', '1')] """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() def combineLocally(iterator): combiners = {} for x in iterator: @@ -1233,7 +1233,7 @@ def _mergeCombiners(iterator): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) - + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -1245,7 +1245,7 @@ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ def createZero(): return copy.deepcopy(zeroValue) - + return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) def foldByKey(self, zeroValue, func, numPartitions=None): @@ -1323,12 +1323,20 @@ def mapValues(self, f): map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) - # TODO: support varargs cogroup of several RDDs. - def groupWith(self, other): + def groupWith(self, other, *others): """ - Alias for cogroup. + Alias for cogroup but with support for multiple RDDs. + + >>> w = sc.parallelize([("a", 5), ("b", 6)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> z = sc.parallelize([("b", 42)]) + >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \ + sorted(list(w.groupWith(x, y, z).collect()))) + [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))] + """ - return self.cogroup(other) + return python_cogroup((self, other) + others, numPartitions=None) # TODO: add variant with custom parittioner def cogroup(self, other, numPartitions=None): @@ -1342,7 +1350,7 @@ def cogroup(self, other, numPartitions=None): >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) [('a', ([1], [2])), ('b', ([4], []))] """ - return python_cogroup(self, other, numPartitions) + return python_cogroup((self, other), numPartitions) def subtractByKey(self, other, numPartitions=None): """ @@ -1475,6 +1483,21 @@ def getStorageLevel(self): java_storage_level.replication()) return storage_level + def _defaultReducePartitions(self): + """ + Returns the default number of partitions to use during reduce tasks (e.g., groupBy). + If spark.default.parallelism is set, then we'll use the value from SparkContext + defaultParallelism, otherwise we'll use the number of partitions in this RDD. + + This mirrors the behavior of the Scala Partitioner#defaultPartitioner, intended to reduce + the likelihood of OOMs. Once PySpark adopts Partitioner-based APIs, this behavior will + be inherent. + """ + if self.ctx._conf.contains("spark.default.parallelism"): + return self.ctx.defaultParallelism + else: + return self.getNumPartitions() + # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the # keys in the pairs. This could be an expensive operation, since those diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 46fcfbb9e26ba..0cc4592047b19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -66,43 +66,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected case class Keyword(str: String) protected implicit def asParser(k: Keyword): Parser[String] = - allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - - protected class SqlLexical extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString = chars - } - override lazy val token: Parser[Token] = ( - identChar ~ rep( identChar | digit ) ^^ - { case first ~ rest => processIdent(first :: rest mkString "") } - | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { - case i ~ None => NumericLit(i mkString "") - case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) - } - | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ - { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } - | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ - { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '\"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar = letter | elem('.') | elem('_') - - override def whitespace: Parser[Any] = rep( - whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) - | '#' ~ rep( chrExcept(EofCh, '\n') ) - | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) - | '/' ~ '*' ~ failure("unclosed comment") - ) - } - - override val lexical = new SqlLexical + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") @@ -161,24 +125,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { this.getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword]) - - /** Generate all variations of upper and lower case of a given string */ - private def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s == "") { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) ++ - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } + .map(_.invoke(this).asInstanceOf[Keyword].str) - lexical.reserved ++= reservedWords.flatMap(w => allCaseVersions(w.str)) - - lexical.delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]" - ) + override val lexical = new SqlLexical(reservedWords) protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { @@ -309,13 +258,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) }) protected lazy val comparisonExpression: Parser[Expression] = - termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Equals(e1, e2) } | + termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => EqualTo(e1, e2) } | termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } | termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } | termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } | termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | - termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | - termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | + termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | + termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | @@ -383,7 +332,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { + expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | TRUE ^^^ Literal(true, BooleanType) | @@ -399,3 +348,55 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType } + +class SqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => allCaseVersions(w)) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]" + ) + + override lazy val token: Parser[Token] = ( + identChar ~ rep( identChar | digit ) ^^ + { case first ~ rest => processIdent(first :: rest mkString "") } + | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { + case i ~ None => NumericLit(i mkString "") + case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) + } + | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ + { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } + | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ + { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '\"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') | elem('.') + + override def whitespace: Parser[Any] = rep( + whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) + | '#' ~ rep( chrExcept(EofCh, '\n') ) + | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) + | '/' ~ '*' ~ failure("unclosed comment") + ) + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 66bff660cadc2..76ddeba9cb312 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -33,7 +33,7 @@ object HiveTypeCoercion { } /** - * A collection of [[catalyst.rules.Rule Rules]] that can be used to coerce differing types that + * A collection of [[Rule Rules]] that can be used to coerce differing types that * participate in operations into compatible ones. Most of these rules are based on Hive semantics, * but they do not introduce any dependencies on the hive codebase. For this reason they remain in * Catalyst until we have a more standard set of coercions. @@ -53,8 +53,8 @@ trait HiveTypeCoercion { Nil /** - * Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] data - * types that are made by other rules to instances higher in the query tree. + * Applies any changes to [[AttributeReference]] data types that are made by other rules to + * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -234,8 +234,8 @@ trait HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // No need to change Equals operators as that actually makes sense for boolean types. - case e: Equals => e + // No need to change EqualTo operators as that actually makes sense for boolean types. + case e: EqualTo => e // Otherwise turn them to Byte types so that there exists and ordering. case p: BinaryComparison if p.left.dataType == BooleanType && p.right.dataType == BooleanType => @@ -244,15 +244,20 @@ trait HiveTypeCoercion { } /** - * Casts to/from [[catalyst.types.BooleanType BooleanType]] are transformed into comparisons since + * Casts to/from [[BooleanType]] are transformed into comparisons since * the JVM does not consider Booleans to be numeric types. */ object BooleanCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - - case Cast(e, BooleanType) => Not(Equals(e, Literal(0))) + // Skip if the type is boolean type already. Note that this extra cast should be removed + // by optimizer.SimplifyCasts. + case Cast(e, BooleanType) if e.dataType == BooleanType => e + // If the data type is not boolean and is being cast boolean, turn it into a comparison + // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. + case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) + // Turn true into 1, and false into 0 if casting boolean into other types. case Cast(e, dataType) if e.dataType == BooleanType => Cast(If(e, Literal(1), Literal(0)), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d177339d40ae5..26ad4837b0b01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.types._ * * // These unresolved attributes can be used to create more complicated expressions. * scala> 'a === 'b - * res2: org.apache.spark.sql.catalyst.expressions.Equals = ('a = 'b) + * res2: org.apache.spark.sql.catalyst.expressions.EqualTo = ('a = 'b) * * // SQL verbs can be used to construct logical query plans. * scala> import org.apache.spark.sql.catalyst.plans.logical._ @@ -76,8 +76,8 @@ package object dsl { def <= (other: Expression) = LessThanOrEqual(expr, other) def > (other: Expression) = GreaterThan(expr, other) def >= (other: Expression) = GreaterThanOrEqual(expr, other) - def === (other: Expression) = Equals(expr, other) - def !== (other: Expression) = Not(Equals(expr, other)) + def === (other: Expression) = EqualTo(expr, other) + def !== (other: Expression) = Not(EqualTo(expr, other)) def like(other: Expression) = Like(expr, other) def rlike(other: Expression) = RLike(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4ebf6c4584b94..655d4a08fe93b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -68,7 +68,7 @@ class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { } object BindReferences extends Logging { - def bindReference(expression: Expression, input: Seq[Attribute]): Expression = { + def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) @@ -83,6 +83,6 @@ object BindReferences extends Logging { BoundReference(ordinal, a) } } - } + }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0b3a4e728ec54..1f9716e385e9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -24,72 +24,87 @@ import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def foldable = child.foldable - def nullable = (child.dataType, dataType) match { + + override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true case _ => child.nullable } + override def toString = s"CAST($child, $dataType)" type EvaluatedType = Any - def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) { - null - } else { - func(a.asInstanceOf[T]) - } + // [[func]] assumes the input is no longer null because eval already does the null check. + @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) // UDFToString - def castToString: Any => Any = child.dataType match { - case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8")) - case _ => nullOrCast[Any](_, _.toString) + private[this] def castToString: Any => Any = child.dataType match { + case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) + case _ => buildCast[Any](_, _.toString) } // BinaryConverter - def castToBinary: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, _.getBytes("UTF-8")) + private[this] def castToBinary: Any => Any = child.dataType match { + case StringType => buildCast[String](_, _.getBytes("UTF-8")) } // UDFToBoolean - def castToBoolean: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, _.length() != 0) - case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)}) - case LongType => nullOrCast[Long](_, _ != 0) - case IntegerType => nullOrCast[Int](_, _ != 0) - case ShortType => nullOrCast[Short](_, _ != 0) - case ByteType => nullOrCast[Byte](_, _ != 0) - case DecimalType => nullOrCast[BigDecimal](_, _ != 0) - case DoubleType => nullOrCast[Double](_, _ != 0) - case FloatType => nullOrCast[Float](_, _ != 0) + private[this] def castToBoolean: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, _.length() != 0) + case TimestampType => + buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0) + case LongType => + buildCast[Long](_, _ != 0) + case IntegerType => + buildCast[Int](_, _ != 0) + case ShortType => + buildCast[Short](_, _ != 0) + case ByteType => + buildCast[Byte](_, _ != 0) + case DecimalType => + buildCast[BigDecimal](_, _ != 0) + case DoubleType => + buildCast[Double](_, _ != 0) + case FloatType => + buildCast[Float](_, _ != 0) } // TimestampConverter - def castToTimestamp: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => { - // Throw away extra if more than 9 decimal places - val periodIdx = s.indexOf("."); - var n = s - if (periodIdx != -1) { - if (n.length() - periodIdx > 9) { + private[this] def castToTimestamp: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => { + // Throw away extra if more than 9 decimal places + val periodIdx = s.indexOf(".") + var n = s + if (periodIdx != -1 && n.length() - periodIdx > 9) { n = n.substring(0, periodIdx + 10) } - } - try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null} - }) - case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000)) - case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000)) - case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000)) - case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000)) - case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000)) + try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } + }) + case BooleanType => + buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0) * 1000)) + case LongType => + buildCast[Long](_, l => new Timestamp(l * 1000)) + case IntegerType => + buildCast[Int](_, i => new Timestamp(i * 1000)) + case ShortType => + buildCast[Short](_, s => new Timestamp(s * 1000)) + case ByteType => + buildCast[Byte](_, b => new Timestamp(b * 1000)) // TimestampWritable.decimalToTimestamp - case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d)) + case DecimalType => + buildCast[BigDecimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp - case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d)) + case DoubleType => + buildCast[Double](_, d => decimalToTimestamp(d)) // TimestampWritable.floatToTimestamp - case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f)) + case FloatType => + buildCast[Float](_, f => decimalToTimestamp(f)) } - private def decimalToTimestamp(d: BigDecimal) = { + private[this] def decimalToTimestamp(d: BigDecimal) = { val seconds = d.longValue() val bd = (d - seconds) * 1000000000 val nanos = bd.intValue() @@ -104,85 +119,118 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } // Timestamp to long, converting milliseconds to seconds - private def timestampToLong(ts: Timestamp) = ts.getTime / 1000 + private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000 - private def timestampToDouble(ts: Timestamp) = { + private[this] def timestampToDouble(ts: Timestamp) = { // First part is the seconds since the beginning of time, followed by nanosecs. ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000 } - def castToLong: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toLong catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t)) - case DecimalType => nullOrCast[BigDecimal](_, _.toLong) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) - } - - def castToInt: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toInt catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt) - case DecimalType => nullOrCast[BigDecimal](_, _.toInt) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) - } - - def castToShort: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toShort catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort) - case DecimalType => nullOrCast[BigDecimal](_, _.toShort) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort - } - - def castToByte: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toByte catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte) - case DecimalType => nullOrCast[BigDecimal](_, _.toByte) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte - } - - def castToDecimal: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0)) + private[this] def castToLong: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toLong catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1L else 0L) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t)) + case DecimalType => + buildCast[BigDecimal](_, _.toLong) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + } + + private[this] def castToInt: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toInt catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1 else 0) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toInt) + case DecimalType => + buildCast[BigDecimal](_, _.toInt) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + } + + private[this] def castToShort: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toShort catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toShort) + case DecimalType => + buildCast[BigDecimal](_, _.toShort) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort + } + + private[this] def castToByte: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toByte catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toByte) + case DecimalType => + buildCast[BigDecimal](_, _.toByte) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte + } + + private[this] def castToDecimal: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try BigDecimal(s.toDouble) catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) case TimestampType => // Note that we lose precision here. - nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) - case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) - } - - def castToDouble: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toDouble catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t)) - case DecimalType => nullOrCast[BigDecimal](_, _.toDouble) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) - } - - def castToFloat: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toFloat catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat) - case DecimalType => nullOrCast[BigDecimal](_, _.toFloat) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) + buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) + case x: NumericType => + b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + } + + private[this] def castToDouble: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toDouble catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1d else 0d) + case TimestampType => + buildCast[Timestamp](_, t => timestampToDouble(t)) + case DecimalType => + buildCast[BigDecimal](_, _.toDouble) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) + } + + private[this] def castToFloat: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toFloat catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1f else 0f) + case TimestampType => + buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) + case DecimalType => + buildCast[BigDecimal](_, _.toFloat) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } - private lazy val cast: Any => Any = dataType match { + private[this] lazy val cast: Any => Any = dataType match { case StringType => castToString case BinaryType => castToBinary case DecimalType => castToDecimal @@ -198,10 +246,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def eval(input: Row): Any = { val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - cast(evaluated) - } + if (evaluated == null) null else cast(evaluated) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3912f5f4375fd..0411ce3aefda1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -33,14 +33,11 @@ abstract class Expression extends TreeNode[Expression] { * executed. * * The following conditions are used to determine suitability for constant folding: - * - A [[expressions.Coalesce Coalesce]] is foldable if all of its children are foldable - * - A [[expressions.BinaryExpression BinaryExpression]] is foldable if its both left and right - * child are foldable - * - A [[expressions.Not Not]], [[expressions.IsNull IsNull]], or - * [[expressions.IsNotNull IsNotNull]] is foldable if its child is foldable. - * - A [[expressions.Literal]] is foldable. - * - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its - * child is foldable. + * - A [[Coalesce]] is foldable if all of its children are foldable + * - A [[BinaryExpression]] is foldable if its both left and right child are foldable + * - A [[Not]], [[IsNull]], or [[IsNotNull]] is foldable if its child is foldable + * - A [[Literal]] is foldable + * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable */ def foldable: Boolean = false def nullable: Boolean @@ -58,7 +55,7 @@ abstract class Expression extends TreeNode[Expression] { lazy val resolved: Boolean = childrenResolved /** - * Returns the [[types.DataType DataType]] of the result of evaluating this expression. It is + * Returns the [[DataType]] of the result of evaluating this expression. It is * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). */ def dataType: DataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 77b5429bad432..74ae723686cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -208,6 +208,9 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { + def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = + this(ordering.map(BindReferences.bindReference(_, inputSchema))) + def compare(a: Row, b: Row): Int = { var i = 0 while (i < ordering.size) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index b6aeae92f8bec..5d3bb25ad568c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -50,6 +50,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { null } else { if (child.dataType.isInstanceOf[ArrayType]) { + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives val baseValue = value.asInstanceOf[Seq[_]] val o = key.asInstanceOf[Int] if (o >= baseValue.size || o < 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a8145c37c20fa..66ae22e95b60e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -103,7 +103,7 @@ case class Alias(child: Expression, name: String) * A reference to an attribute produced by another operator in the tree. * * @param name The name of this attribute, should only be used during analysis or for debugging. - * @param dataType The [[types.DataType DataType]] of this attribute. + * @param dataType The [[DataType]] of this attribute. * @param nullable True if null is a valid value for this attribute. * @param exprId A globally unique id used to check if different AttributeReferences refer to the * same attribute. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 573ec052f4266..b6f2451b52e1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -24,7 +24,7 @@ package org.apache.spark.sql.catalyst * expression, a [[NamedExpression]] in addition to the standard collection of expressions. * * ==Standard Expressions== - * A library of standard expressions (e.g., [[Add]], [[Equals]]), aggregates (e.g., SUM, COUNT), + * A library of standard expressions (e.g., [[Add]], [[EqualTo]]), aggregates (e.g., SUM, COUNT), * and other computations (e.g. UDFs). Each expression type is capable of determining its output * schema as a function of its children's output schema. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2902906df2844..b63406b94a4a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -52,7 +52,7 @@ trait PredicateHelper { * * For example consider a join between two relations R(a, b) and S(c, d). * - * `canEvaluate(Equals(a,b), R)` returns `true` where as `canEvaluate(Equals(a,c), R)` returns + * `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns * `false`. */ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = @@ -140,7 +140,7 @@ abstract class BinaryComparison extends BinaryPredicate { self: Product => } -case class Equals(left: Expression, right: Expression) extends BinaryComparison { +case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { def symbol = "=" override def eval(input: Row): Any = { val l = left.eval(input) @@ -233,10 +233,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq @transient private[this] lazy val values = branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq + @transient private[this] lazy val elseValue = + if (branches.length % 2 == 0) None else Option(branches.last) override def nullable = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - values.exists(_.nullable) || (values.length % 2 == 0) + values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) } override lazy val resolved = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 25a347bec0e4c..b20b5de8c46eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,13 +95,13 @@ object ColumnPruning extends Rule[LogicalPlan] { Project(substitutedProjection, child) // Eliminate no-op Projects - case Project(projectList, child) if(child.output == projectList) => child + case Project(projectList, child) if child.output == projectList => child } } /** - * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with - * equivalent [[catalyst.expressions.Literal Literal]] values. This rule is more specific with + * Replaces [[Expression Expressions]] that can be statically evaluated with + * equivalent [[Literal]] values. This rule is more specific with * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { @@ -110,8 +110,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) case e @ Sum(Literal(c, _)) if c == 0 => Cast(Literal(0L), e.dataType) case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType) - case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType) - case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType) + case e @ IsNull(c) if !c.nullable => Literal(false, BooleanType) + case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) @@ -154,8 +154,8 @@ object NullPropagation extends Rule[LogicalPlan] { } /** - * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with - * equivalent [[catalyst.expressions.Literal Literal]] values. + * Replaces [[Expression Expressions]] that can be statically evaluated with + * equivalent [[Literal]] values. */ object ConstantFolding extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -197,7 +197,7 @@ object BooleanSimplification extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[catalyst.plans.logical.Filter Filter]] operators into one, merging the + * Combines two adjacent [[Filter]] operators into one, merging the * conditions into one conjunctive predicate. */ object CombineFilters extends Rule[LogicalPlan] { @@ -223,9 +223,8 @@ object SimplifyFilters extends Rule[LogicalPlan] { } /** - * Pushes [[catalyst.plans.logical.Filter Filter]] operators through - * [[catalyst.plans.logical.Project Project]] operators, in-lining any - * [[catalyst.expressions.Alias Aliases]] that were defined in the projection. + * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] + * that were defined in the projection. * * This heuristic is valid assuming the expression evaluation cost is minimal. */ @@ -248,10 +247,10 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { } /** - * Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be + * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other - * [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the - * [[catalyst.plans.logical.Join Join]]. + * [[Filter]] conditions are moved into the `condition` of the [[Join]]. + * * And also Pushes down the join filter, where the `condition` can be evaluated using only the * attributes of the left or right side of sub query when applicable. * @@ -345,8 +344,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } /** - * Removes [[catalyst.expressions.Cast Casts]] that are unnecessary because the input is already - * the correct type. + * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type. */ object SimplifyCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { @@ -355,7 +353,7 @@ object SimplifyCasts extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the + * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. */ object CombineLimits extends Rule[LogicalPlan] { @@ -366,7 +364,7 @@ object CombineLimits extends Rule[LogicalPlan] { } /** - * Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because + * Removes the inner [[CaseConversionExpression]] that are unnecessary because * the inner conversion is overwritten by the outer one. */ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 820ecfb78b52e..a43bef389c4bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -136,14 +136,14 @@ object HashFilteredJoin extends Logging with PredicateHelper { val Join(left, right, joinType, _) = join val (joinPredicates, otherPredicates) = allPredicates.flatMap(splitConjunctivePredicates).partition { - case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || + case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || (canEvaluate(l, right) && canEvaluate(r, left)) => true case _ => false } val joinKeys = joinPredicates.map { - case Equals(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) - case Equals(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) } // Do not consider this strategy if there are no join keys. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 00e2d3bc24be9..7b82e19b2e714 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index d70ef6e826cc2..98018799c72d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -41,19 +41,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan - * can override this (e.g. [[catalyst.analysis.UnresolvedRelation UnresolvedRelation]] should - * return `false`). + * can override this (e.g. + * [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]] + * should return `false`). */ lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved /** * Returns true if all its children of this query plan have been resolved. */ - def childrenResolved = !children.exists(!_.resolved) + def childrenResolved: Boolean = !children.exists(!_.resolved) /** - * Optionally resolves the given string to a - * [[catalyst.expressions.NamedExpression NamedExpression]]. The attribute is expressed as + * Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ def resolve(name: String): Option[NamedExpression] = { @@ -93,7 +93,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => // Leaf nodes by definition cannot reference any input attributes. - def references = Set.empty + override def references = Set.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index b777cf4249196..3e0639867b278 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -27,7 +27,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend } /** - * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the + * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. @@ -46,32 +46,32 @@ case class Generate( child: LogicalPlan) extends UnaryNode { - protected def generatorOutput = + protected def generatorOutput: Seq[Attribute] = alias .map(a => generator.output.map(_.withQualifiers(a :: Nil))) .getOrElse(generator.output) - def output = + override def output = if (join) child.output ++ generatorOutput else generatorOutput - def references = + override def references = if (join) child.outputSet else generator.references } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = condition.references + override def output = child.output + override def references = condition.references } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - def output = left.output + override def output = left.output override lazy val resolved = childrenResolved && !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } - def references = Set.empty + override def references = Set.empty } case class Join( @@ -80,8 +80,8 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - def references = condition.map(_.references).getOrElse(Set.empty) - def output = joinType match { + override def references = condition.map(_.references).getOrElse(Set.empty) + override def output = joinType match { case LeftSemi => left.output case _ => @@ -96,9 +96,9 @@ case class InsertIntoTable( overwrite: Boolean) extends LogicalPlan { // The table being inserted into is a child for the purposes of transformations. - def children = table :: child :: Nil - def references = Set.empty - def output = child.output + override def children = table :: child :: Nil + override def references = Set.empty + override def output = child.output override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType @@ -109,20 +109,20 @@ case class InsertIntoCreatedTable( databaseName: Option[String], tableName: String, child: LogicalPlan) extends UnaryNode { - def references = Set.empty - def output = child.output + override def references = Set.empty + override def output = child.output } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - def references = Set.empty - def output = child.output + override def references = Set.empty + override def output = child.output } case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = order.flatMap(_.references).toSet + override def output = child.output + override def references = order.flatMap(_.references).toSet } case class Aggregate( @@ -131,18 +131,19 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - def output = aggregateExpressions.map(_.toAttribute) - def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet + override def output = aggregateExpressions.map(_.toAttribute) + override def references = + (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = limitExpr.references + override def output = child.output + override def references = limitExpr.references } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { - def output = child.output.map(_.withQualifiers(alias :: Nil)) - def references = Set.empty + override def output = child.output.map(_.withQualifiers(alias :: Nil)) + override def references = Set.empty } /** @@ -159,7 +160,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case otherType => otherType } - val output = child.output.map { + override val output = child.output.map { case a: AttributeReference => AttributeReference( a.name.toLowerCase, @@ -170,21 +171,21 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case other => other } - def references = Set.empty + override def references = Set.empty } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = Set.empty + override def output = child.output + override def references = Set.empty } case class Distinct(child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = child.outputSet + override def output = child.output + override def references = child.outputSet } case object NoRelation extends LeafNode { - def output = Nil + override def output = Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 3299e86b85941..1d5f033f0d274 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -60,3 +60,19 @@ case class ExplainCommand(plan: LogicalPlan) extends Command { * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. */ case class CacheCommand(tableName: String, doCache: Boolean) extends Command + +/** + * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. + * @param table The table to be described. + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + * It is effective only when the table is a Hive table. + */ +case class DescribeCommand( + table: LogicalPlan, + isExtended: Boolean) extends Command { + override def output = Seq( + // Column names are based on Hive. + BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()), + BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()), + BoundReference(2, AttributeReference("comment", StringType, nullable = false)())) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ffb3a92f8f340..4bb022cf238af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -46,7 +46,7 @@ case object AllTuples extends Distribution /** * Represents data where tuples that share the same values for the `clustering` - * [[catalyst.expressions.Expression Expressions]] will be co-located. Based on the context, this + * [[Expression Expressions]] will be co-located. Based on the context, this * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ @@ -60,7 +60,7 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi /** * Represents data where tuples have been ordered according to the `ordering` - * [[catalyst.expressions.Expression Expressions]]. This is a strictly stronger guarantee than + * [[Expression Expressions]]. This is a strictly stronger guarantee than * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for * the ordering expressions are contiguous and will never be split across partitions. */ @@ -79,19 +79,17 @@ sealed trait Partitioning { val numPartitions: Int /** - * Returns true iff the guarantees made by this - * [[catalyst.plans.physical.Partitioning Partitioning]] are sufficient to satisfy - * the partitioning scheme mandated by the `required` - * [[catalyst.plans.physical.Distribution Distribution]], i.e. the current dataset does not - * need to be re-partitioned for the `required` Distribution (it is possible that tuples within - * a partition need to be reorganized). + * Returns true iff the guarantees made by this [[Partitioning]] are sufficient + * to satisfy the partitioning scheme mandated by the `required` [[Distribution]], + * i.e. the current dataset does not need to be re-partitioned for the `required` + * Distribution (it is possible that tuples within a partition need to be reorganized). */ def satisfies(required: Distribution): Boolean /** * Returns true iff all distribution guarantees made by this partitioning can also be made * for the `other` specified partitioning. - * For example, two [[catalyst.plans.physical.HashPartitioning HashPartitioning]]s are + * For example, two [[HashPartitioning HashPartitioning]]s are * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index da34bd3a21503..bb77bccf86176 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,9 +19,71 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp -import scala.reflect.runtime.universe.{typeTag, TypeTag} +import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.catalyst.expressions.Expression +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.util.Utils + +/** + * + */ +object DataType extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + "StringType" ^^^ StringType | + "FloatType" ^^^ FloatType | + "IntegerType" ^^^ IntegerType | + "ByteType" ^^^ ByteType | + "ShortType" ^^^ ShortType | + "DoubleType" ^^^ DoubleType | + "LongType" ^^^ LongType | + "BinaryType" ^^^ BinaryType | + "BooleanType" ^^^ BooleanType | + "DecimalType" ^^^ DecimalType | + "TimestampType" ^^^ TimestampType + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + "true" ^^^ true | + "false" ^^^ false + + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => new StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") + } +} abstract class DataType { /** Matches any expression that evaluates to this DataType */ @@ -29,25 +91,36 @@ abstract class DataType { case e: Expression if e.dataType == this => true case _ => false } + + def isPrimitive: Boolean = false } case object NullType extends DataType +trait PrimitiveType extends DataType { + override def isPrimitive = true +} + abstract class NativeType extends DataType { type JvmType @transient val tag: TypeTag[JvmType] val ordering: Ordering[JvmType] + + @transient val classTag = { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) + } } -case object StringType extends NativeType { +case object StringType extends NativeType with PrimitiveType { type JvmType = String @transient lazy val tag = typeTag[JvmType] val ordering = implicitly[Ordering[JvmType]] } -case object BinaryType extends DataType { +case object BinaryType extends DataType with PrimitiveType { type JvmType = Array[Byte] } -case object BooleanType extends NativeType { +case object BooleanType extends NativeType with PrimitiveType { type JvmType = Boolean @transient lazy val tag = typeTag[JvmType] val ordering = implicitly[Ordering[JvmType]] @@ -63,7 +136,7 @@ case object TimestampType extends NativeType { } } -abstract class NumericType extends NativeType { +abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets @@ -154,6 +227,17 @@ case object FloatType extends FractionalType { case class ArrayType(elementType: DataType) extends DataType case class StructField(name: String, dataType: DataType, nullable: Boolean) -case class StructType(fields: Seq[StructField]) extends DataType + +object StructType { + def fromAttributes(attributes: Seq[Attribute]): StructType = { + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + } + + // def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq) +} + +case class StructType(fields: Seq[StructField]) extends DataType { + def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) +} case class MapType(keyType: DataType, valueType: DataType) extends DataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 8c3b062d0f801..84d72814778ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -333,6 +333,49 @@ class ExpressionEvaluationSuite extends FunSuite { Literal("^Ba*n", StringType) :: c2 :: Nil), true, row) } + test("case when") { + val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val c1 = 'a.boolean.at(0) + val c2 = 'a.boolean.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row) + + checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) + + assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + + val c4_notNull = 'a.boolean.notNull.at(3) + val c5_notNull = 'a.boolean.notNull.at(4) + val c6_notNull = 'a.boolean.notNull.at(5) + + assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + } + test("complex type") { val row = new GenericRow(Array[Any]( "^Ba*n", // 0 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index cea97c584f7e1..0ff82064012a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -195,8 +195,8 @@ class ConstantFoldingSuite extends PlanTest { Add(Literal(null, IntegerType), 1) as 'c9, Add(1, Literal(null, IntegerType)) as 'c10, - Equals(Literal(null, IntegerType), 1) as 'c11, - Equals(1, Literal(null, IntegerType)) as 'c12, + EqualTo(Literal(null, IntegerType), 1) as 'c11, + EqualTo(1, Literal(null, IntegerType)) as 'c12, Like(Literal(null, StringType), "abc") as 'c13, Like("abc", Literal(null, StringType)) as 'c14, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 65db4f9290f29..b566413bffcf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path)) + new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration))) /** * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index ff9842267ffe0..ff6deeda2394d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -99,7 +99,9 @@ class JavaSQLContext(val sqlContext: SQLContext) { * Loads a parquet file, returning the result as a [[JavaSchemaRDD]]. */ def parquetFile(path: String): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, ParquetRelation(path)) + new JavaSchemaRDD( + sqlContext, + ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration))) /** * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index cef294167f146..f46fa0516566f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.sql.{SQLConf, SQLContext, Row} +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{MutableProjection, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind { override def outputPartitioning = newPartitioning @@ -42,7 +42,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new MutableProjection(expressions) + val hashExpressions = new MutableProjection(expressions, child.output) val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } @@ -53,7 +53,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case RangePartitioning(sortingExpressions, numPartitions) => // TODO: RangePartitioner should take an Ordering. - implicit val ordering = new RowOrdering(sortingExpressions) + implicit val ordering = new RowOrdering(sortingExpressions, child.output) val rdd = child.execute().mapPartitions { iter => val mutablePair = new MutablePair[Row, Null](null, null) @@ -82,9 +82,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } /** - * Ensures that the [[catalyst.plans.physical.Partitioning Partitioning]] of input data meets the - * [[catalyst.plans.physical.Distribution Distribution]] requirements for each operator by inserting - * [[Exchange]] Operators where required. + * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] + * of input data meets the + * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for + * each operator by inserting [[Exchange]] Operators where required. */ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 97cf0f043d9f9..461ca0cd6e6f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -217,7 +217,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.WriteToFile(path, child) => val relation = ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) - InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil + // Note: overwrite=false because otherwise the metadata we just created will be deleted + InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => @@ -313,9 +314,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.SetCommand(key, value) => Seq(execution.SetCommand(key, value, plan.output)(context)) - case logical.ExplainCommand(child) => - val sparkPlan = context.executePlan(child).sparkPlan - Seq(execution.ExplainCommand(sparkPlan, plan.output)(context)) + case logical.ExplainCommand(logicalPlan) => + Seq(execution.ExplainCommand(logicalPlan, plan.output)(context)) case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 39b3246c875df..acb1b0f4dc229 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan trait Command { /** @@ -71,16 +72,23 @@ case class SetCommand( } /** + * An explain command for users to see how a command will be executed. + * + * Note that this command takes in a logical plan, runs the optimizer on the logical plan + * (but do NOT actually execute it). + * * :: DeveloperApi :: */ @DeveloperApi case class ExplainCommand( - child: SparkPlan, output: Seq[Attribute])( + logicalPlan: LogicalPlan, output: Seq[Attribute])( @transient context: SQLContext) - extends UnaryNode with Command { + extends LeafNode with Command { - // Actually "EXPLAIN" command doesn't cause any side effect. - override protected[sql] lazy val sideEffectResult: Seq[String] = this.toString.split("\n") + // Run through the optimizer to generate the physical plan. + override protected[sql] lazy val sideEffectResult: Seq[String] = { + "Physical execution plan:" +: context.executePlan(logicalPlan).executedPlan.toString.split("\n") + } def execute(): RDD[Row] = { val explanation = sideEffectResult.map(row => new GenericRow(Array[Any](row))) @@ -113,3 +121,24 @@ case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: override def output: Seq[Attribute] = Seq.empty } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( + @transient context: SQLContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + Seq(("# Registered as a temporary table", null, null)) ++ + child.output.map(field => (field.name, field.dataType.toString, null)) + } + + override def execute(): RDD[Row] = { + val rows = sideEffectResult.map { + case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + } + context.sparkContext.parallelize(rows, 1) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala new file mode 100644 index 0000000000000..889a408e3c393 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} + +import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} +import parquet.schema.MessageType + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute} +import org.apache.spark.sql.parquet.CatalystConverter.FieldType + +/** + * Collection of converters of Parquet types (group and primitive types) that + * model arrays and maps. The conversions are partly based on the AvroParquet + * converters that are part of Parquet in order to be able to process these + * types. + * + * There are several types of converters: + *
    + *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveConverter]] for primitive + * (numeric, boolean and String) types
  • + *
  • [[org.apache.spark.sql.parquet.CatalystNativeArrayConverter]] for arrays + * of native JVM element types; note: currently null values are not supported!
  • + *
  • [[org.apache.spark.sql.parquet.CatalystArrayConverter]] for arrays of + * arbitrary element types (including nested element types); note: currently + * null values are not supported!
  • + *
  • [[org.apache.spark.sql.parquet.CatalystStructConverter]] for structs
  • + *
  • [[org.apache.spark.sql.parquet.CatalystMapConverter]] for maps; note: + * currently null values are not supported!
  • + *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveRowConverter]] for rows + * of only primitive element types
  • + *
  • [[org.apache.spark.sql.parquet.CatalystGroupConverter]] for other nested + * records, including the top-level row record
  • + *
+ */ + +private[sql] object CatalystConverter { + // The type internally used for fields + type FieldType = StructField + + // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). + // Note that "array" for the array elements is chosen by ParquetAvro. + // Using a different value will result in Parquet silently dropping columns. + val ARRAY_ELEMENTS_SCHEMA_NAME = "array" + val MAP_KEY_SCHEMA_NAME = "key" + val MAP_VALUE_SCHEMA_NAME = "value" + val MAP_SCHEMA_NAME = "map" + + // TODO: consider using Array[T] for arrays to avoid boxing of primitive types + type ArrayScalaType[T] = Seq[T] + type StructScalaType[T] = Seq[T] + type MapScalaType[K, V] = Map[K, V] + + protected[parquet] def createConverter( + field: FieldType, + fieldIndex: Int, + parent: CatalystConverter): Converter = { + val fieldType: DataType = field.dataType + fieldType match { + // For native JVM types we use a converter with native arrays + case ArrayType(elementType: NativeType) => { + new CatalystNativeArrayConverter(elementType, fieldIndex, parent) + } + // This is for other types of arrays, including those with nested fields + case ArrayType(elementType: DataType) => { + new CatalystArrayConverter(elementType, fieldIndex, parent) + } + case StructType(fields: Seq[StructField]) => { + new CatalystStructConverter(fields.toArray, fieldIndex, parent) + } + case MapType(keyType: DataType, valueType: DataType) => { + new CatalystMapConverter( + Array( + new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), + new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, true)), + fieldIndex, + parent) + } + // Strings, Shorts and Bytes do not have a corresponding type in Parquet + // so we need to treat them separately + case StringType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateString(fieldIndex, value) + } + } + case ShortType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addInt(value: Int): Unit = + parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.JvmType]) + } + } + case ByteType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addInt(value: Int): Unit = + parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) + } + } + // All other primitive types use the default converter + case ctype: NativeType => { // note: need the type tag here! + new CatalystPrimitiveConverter(parent, fieldIndex) + } + case _ => throw new RuntimeException( + s"unable to convert datatype ${field.dataType.toString} in CatalystConverter") + } + } + + protected[parquet] def createRootConverter( + parquetSchema: MessageType, + attributes: Seq[Attribute]): CatalystConverter = { + // For non-nested types we use the optimized Row converter + if (attributes.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))) { + new CatalystPrimitiveRowConverter(attributes.toArray) + } else { + new CatalystGroupConverter(attributes.toArray) + } + } +} + +private[parquet] abstract class CatalystConverter extends GroupConverter { + /** + * The number of fields this group has + */ + protected[parquet] val size: Int + + /** + * The index of this converter in the parent + */ + protected[parquet] val index: Int + + /** + * The parent converter + */ + protected[parquet] val parent: CatalystConverter + + /** + * Called by child converters to update their value in its parent (this). + * Note that if possible the more specific update methods below should be used + * to avoid auto-boxing of native JVM types. + * + * @param fieldIndex + * @param value + */ + protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit + + protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = + updateField(fieldIndex, value.getBytes) + + protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = + updateField(fieldIndex, value.toStringUsingUTF8) + + protected[parquet] def isRootConverter: Boolean = parent == null + + protected[parquet] def clearBuffer(): Unit + + /** + * Should only be called in the root (group) converter! + * + * @return + */ + def getCurrentRecord: Row = throw new UnsupportedOperationException +} + +/** + * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record + * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. + * + * @param schema The corresponding Catalyst schema in the form of a list of attributes. + */ +private[parquet] class CatalystGroupConverter( + protected[parquet] val schema: Array[FieldType], + protected[parquet] val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var current: ArrayBuffer[Any], + protected[parquet] var buffer: ArrayBuffer[Row]) + extends CatalystConverter { + + def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = + this( + schema, + index, + parent, + current=null, + buffer=new ArrayBuffer[Row]( + CatalystArrayConverter.INITIAL_ARRAY_SIZE)) + + /** + * This constructor is used for the root converter only! + */ + def this(attributes: Array[Attribute]) = + this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null) + + protected [parquet] val converters: Array[Converter] = + schema.map(field => + CatalystConverter.createConverter(field, schema.indexOf(field), this)) + .toArray + + override val size = schema.size + + override def getCurrentRecord: Row = { + assert(isRootConverter, "getCurrentRecord should only be called in root group converter!") + // TODO: use iterators if possible + // Note: this will ever only be called in the root converter when the record has been + // fully processed. Therefore it will be difficult to use mutable rows instead, since + // any non-root converter never would be sure when it would be safe to re-use the buffer. + new GenericRow(current.toArray) + } + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + // for child converters to update upstream values + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + current.update(fieldIndex, value) + } + + override protected[parquet] def clearBuffer(): Unit = buffer.clear() + + override def start(): Unit = { + current = ArrayBuffer.fill(size)(null) + converters.foreach { + converter => if (!converter.isPrimitive) { + converter.asInstanceOf[CatalystConverter].clearBuffer + } + } + } + + override def end(): Unit = { + if (!isRootConverter) { + assert(current!=null) // there should be no empty groups + buffer.append(new GenericRow(current.toArray)) + parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) + } + } +} + +/** + * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record + * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. Note that his + * converter is optimized for rows of primitive types (non-nested records). + */ +private[parquet] class CatalystPrimitiveRowConverter( + protected[parquet] val schema: Array[FieldType], + protected[parquet] var current: ParquetRelation.RowType) + extends CatalystConverter { + + // This constructor is used for the root converter only + def this(attributes: Array[Attribute]) = + this( + attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), + new ParquetRelation.RowType(attributes.length)) + + protected [parquet] val converters: Array[Converter] = + schema.map(field => + CatalystConverter.createConverter(field, schema.indexOf(field), this)) + .toArray + + override val size = schema.size + + override val index = 0 + + override val parent = null + + // Should be only called in root group converter! + override def getCurrentRecord: ParquetRelation.RowType = current + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + // for child converters to update upstream values + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + throw new UnsupportedOperationException // child converters should use the + // specific update methods below + } + + override protected[parquet] def clearBuffer(): Unit = {} + + override def start(): Unit = { + var i = 0 + while (i < size) { + current.setNullAt(i) + i = i + 1 + } + } + + override def end(): Unit = {} + + // Overriden here to avoid auto-boxing for primitive types + override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = + current.setBoolean(fieldIndex, value) + + override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = + current.setInt(fieldIndex, value) + + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = + current.setLong(fieldIndex, value) + + override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = + current.setShort(fieldIndex, value) + + override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = + current.setByte(fieldIndex, value) + + override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = + current.setDouble(fieldIndex, value) + + override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = + current.setFloat(fieldIndex, value) + + override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = + current.update(fieldIndex, value.getBytes) + + override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = + current.setString(fieldIndex, value.toStringUsingUTF8) +} + +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +private[parquet] class CatalystPrimitiveConverter( + parent: CatalystConverter, + fieldIndex: Int) extends PrimitiveConverter { + override def addBinary(value: Binary): Unit = + parent.updateBinary(fieldIndex, value) + + override def addBoolean(value: Boolean): Unit = + parent.updateBoolean(fieldIndex, value) + + override def addDouble(value: Double): Unit = + parent.updateDouble(fieldIndex, value) + + override def addFloat(value: Float): Unit = + parent.updateFloat(fieldIndex, value) + + override def addInt(value: Int): Unit = + parent.updateInt(fieldIndex, value) + + override def addLong(value: Long): Unit = + parent.updateLong(fieldIndex, value) +} + +object CatalystArrayConverter { + val INITIAL_ARRAY_SIZE = 20 +} + +/** + * A `parquet.io.api.GroupConverter` that converts a single-element groups that + * match the characteristics of an array (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * + * @param elementType The type of the array elements (complex or primitive) + * @param index The position of this (array) field inside its parent converter + * @param parent The parent converter + * @param buffer A data buffer + */ +private[parquet] class CatalystArrayConverter( + val elementType: DataType, + val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var buffer: Buffer[Any]) + extends CatalystConverter { + + def this(elementType: DataType, index: Int, parent: CatalystConverter) = + this( + elementType, + index, + parent, + new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) + + protected[parquet] val converter: Converter = CatalystConverter.createConverter( + new CatalystConverter.FieldType( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + elementType, + false), + fieldIndex=0, + parent=this) + + override def getConverter(fieldIndex: Int): Converter = converter + + // arrays have only one (repeated) field, which is its elements + override val size = 1 + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + // fieldIndex is ignored (assumed to be zero but not checked) + if(value == null) { + throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") + } + buffer += value + } + + override protected[parquet] def clearBuffer(): Unit = { + buffer.clear() + } + + override def start(): Unit = { + if (!converter.isPrimitive) { + converter.asInstanceOf[CatalystConverter].clearBuffer + } + } + + override def end(): Unit = { + assert(parent != null) + // here we need to make sure to use ArrayScalaType + parent.updateField(index, buffer.toArray.toSeq) + clearBuffer() + } +} + +/** + * A `parquet.io.api.GroupConverter` that converts a single-element groups that + * match the characteristics of an array (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * + * @param elementType The type of the array elements (native) + * @param index The position of this (array) field inside its parent converter + * @param parent The parent converter + * @param capacity The (initial) capacity of the buffer + */ +private[parquet] class CatalystNativeArrayConverter( + val elementType: NativeType, + val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) + extends CatalystConverter { + + type NativeType = elementType.JvmType + + private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) + + private var elements: Int = 0 + + protected[parquet] val converter: Converter = CatalystConverter.createConverter( + new CatalystConverter.FieldType( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + elementType, + false), + fieldIndex=0, + parent=this) + + override def getConverter(fieldIndex: Int): Converter = converter + + // arrays have only one (repeated) field, which is its elements + override val size = 1 + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = + throw new UnsupportedOperationException + + // Overriden here to avoid auto-boxing for primitive types + override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = { + checkGrowBuffer() + buffer(elements) = value.getBytes.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = { + checkGrowBuffer() + buffer(elements) = value.toStringUsingUTF8.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def clearBuffer(): Unit = { + elements = 0 + } + + override def start(): Unit = {} + + override def end(): Unit = { + assert(parent != null) + // here we need to make sure to use ArrayScalaType + parent.updateField( + index, + buffer.slice(0, elements).toSeq) + clearBuffer() + } + + private def checkGrowBuffer(): Unit = { + if (elements >= capacity) { + val newCapacity = 2 * capacity + val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity) + Array.copy(buffer, 0, tmp, 0, capacity) + buffer = tmp + capacity = newCapacity + } + } +} + +/** + * This converter is for multi-element groups of primitive or complex types + * that have repetition level optional or required (so struct fields). + * + * @param schema The corresponding Catalyst schema in the form of a list of + * attributes. + * @param index + * @param parent + */ +private[parquet] class CatalystStructConverter( + override protected[parquet] val schema: Array[FieldType], + override protected[parquet] val index: Int, + override protected[parquet] val parent: CatalystConverter) + extends CatalystGroupConverter(schema, index, parent) { + + override protected[parquet] def clearBuffer(): Unit = {} + + // TODO: think about reusing the buffer + override def end(): Unit = { + assert(!isRootConverter) + // here we need to make sure to use StructScalaType + // Note: we need to actually make a copy of the array since we + // may be in a nested field + parent.updateField(index, new GenericRow(current.toArray)) + } +} + +/** + * A `parquet.io.api.GroupConverter` that converts two-element groups that + * match the characteristics of a map (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.MapType]]. + * + * @param schema + * @param index + * @param parent + */ +private[parquet] class CatalystMapConverter( + protected[parquet] val schema: Array[FieldType], + override protected[parquet] val index: Int, + override protected[parquet] val parent: CatalystConverter) + extends CatalystConverter { + + private val map = new HashMap[Any, Any]() + + private val keyValueConverter = new CatalystConverter { + private var currentKey: Any = null + private var currentValue: Any = null + val keyConverter = CatalystConverter.createConverter(schema(0), 0, this) + val valueConverter = CatalystConverter.createConverter(schema(1), 1, this) + + override def getConverter(fieldIndex: Int): Converter = { + if (fieldIndex == 0) keyConverter else valueConverter + } + + override def end(): Unit = CatalystMapConverter.this.map += currentKey -> currentValue + + override def start(): Unit = { + currentKey = null + currentValue = null + } + + override protected[parquet] val size: Int = 2 + override protected[parquet] val index: Int = 0 + override protected[parquet] val parent: CatalystConverter = CatalystMapConverter.this + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + fieldIndex match { + case 0 => + currentKey = value + case 1 => + currentValue = value + case _ => + new RuntimePermission(s"trying to update Map with fieldIndex $fieldIndex") + } + } + + override protected[parquet] def clearBuffer(): Unit = {} + } + + override protected[parquet] val size: Int = 1 + + override protected[parquet] def clearBuffer(): Unit = {} + + override def start(): Unit = { + map.clear() + } + + override def end(): Unit = { + // here we need to make sure to use MapScalaType + parent.updateField(index, map.toMap) + } + + override def getConverter(fieldIndex: Int): Converter = keyValueConverter + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = + throw new UnsupportedOperationException +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 052b0a9196717..cc575bedd8fcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -205,9 +205,9 @@ object ParquetFilters { Some(new AndFilter(leftFilter.get, rightFilter.get)) } } - case p @ Equals(left: Literal, right: NamedExpression) if !right.nullable => + case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => Some(createEqualityFilter(right.name, left, p)) - case p @ Equals(left: NamedExpression, right: Literal) if !left.nullable => + case p @ EqualTo(left: NamedExpression, right: Literal) if !left.nullable => Some(createEqualityFilter(left.name, right, p)) case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable => Some(createLessThanFilter(right.name, left, p)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index f6ae1ddd1e647..cc1371ec5d60c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -20,17 +20,12 @@ package org.apache.spark.sql.parquet import java.io.IOException import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction -import org.apache.hadoop.mapreduce.Job -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetOutputFormat, Footer, ParquetFileWriter, ParquetFileReader} -import parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} -import parquet.io.api.{Binary, RecordConsumer} -import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType, MessageTypeParser} -import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import parquet.schema.Type.Repetition +import parquet.hadoop.ParquetOutputFormat +import parquet.hadoop.metadata.CompressionCodecName +import parquet.schema.MessageType import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} @@ -53,10 +48,13 @@ import scala.collection.JavaConversions._ * * @param path The path to the Parquet file. */ -private[sql] case class ParquetRelation(path: String) - extends LeafNode - with MultiInstanceRelation - with SizeEstimatableRelation[SQLContext] { +private[sql] case class ParquetRelation( + path: String, + @transient conf: Option[Configuration] = None) + extends LeafNode + with MultiInstanceRelation + with SizeEstimatableRelation[SQLContext] { + self: Product => def estimatedSize(context: SQLContext): Long = { @@ -69,14 +67,12 @@ private[sql] case class ParquetRelation(path: String) /** Schema derived from ParquetFile */ def parquetSchema: MessageType = ParquetTypesConverter - .readMetaData(new Path(path)) + .readMetaData(new Path(path), conf) .getFileMetaData .getSchema /** Attributes */ - override val output = - ParquetTypesConverter - .convertToAttributes(parquetSchema) + override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) override def newInstance = ParquetRelation(path).asInstanceOf[this.type] @@ -151,7 +147,9 @@ private[sql] object ParquetRelation { } ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) - new ParquetRelation(path.toString) + new ParquetRelation(path.toString, Some(conf)) { + override val output = attributes + } } private def checkPath(pathStr: String, allowExisting: Boolean, conf: Configuration): Path = { @@ -180,151 +178,3 @@ private[sql] object ParquetRelation { path } } - -private[parquet] object ParquetTypesConverter { - def toDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { - // for now map binary to string type - // TODO: figure out how Parquet uses strings or why we can't use them in a MessageType schema - case ParquetPrimitiveTypeName.BINARY => StringType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - sys.error("Warning: potential loss of precision: converting INT96 to long") - LongType - case _ => sys.error( - s"Unsupported parquet datatype $parquetType") - } - - def fromDataType(ctype: DataType): ParquetPrimitiveTypeName = ctype match { - case StringType => ParquetPrimitiveTypeName.BINARY - case BooleanType => ParquetPrimitiveTypeName.BOOLEAN - case DoubleType => ParquetPrimitiveTypeName.DOUBLE - case ArrayType(ByteType) => ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY - case FloatType => ParquetPrimitiveTypeName.FLOAT - case IntegerType => ParquetPrimitiveTypeName.INT32 - case LongType => ParquetPrimitiveTypeName.INT64 - case _ => sys.error(s"Unsupported datatype $ctype") - } - - def consumeType(consumer: RecordConsumer, ctype: DataType, record: Row, index: Int): Unit = { - ctype match { - case StringType => consumer.addBinary( - Binary.fromByteArray( - record(index).asInstanceOf[String].getBytes("utf-8") - ) - ) - case IntegerType => consumer.addInteger(record.getInt(index)) - case LongType => consumer.addLong(record.getLong(index)) - case DoubleType => consumer.addDouble(record.getDouble(index)) - case FloatType => consumer.addFloat(record.getFloat(index)) - case BooleanType => consumer.addBoolean(record.getBoolean(index)) - case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") - } - } - - def getSchema(schemaString : String) : MessageType = - MessageTypeParser.parseMessageType(schemaString) - - def convertToAttributes(parquetSchema: MessageType) : Seq[Attribute] = { - parquetSchema.getColumns.map { - case (desc) => - val ctype = toDataType(desc.getType) - val name: String = desc.getPath.mkString(".") - new AttributeReference(name, ctype, false)() - } - } - - // TODO: allow nesting? - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { - val fields: Seq[ParquetType] = attributes.map { - a => new ParquetPrimitiveType(Repetition.OPTIONAL, fromDataType(a.dataType), a.name) - } - new MessageType("root", fields) - } - - def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration) { - if (origPath == null) { - throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") - } - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (fs.exists(path) && !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException(s"Expected to write to directory $path but found file") - } - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fs.exists(metadataPath)) { - try { - fs.delete(metadataPath, true) - } catch { - case e: IOException => - throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") - } - } - val extraMetadata = new java.util.HashMap[String, String]() - extraMetadata.put("path", path.toString) - // TODO: add extra data, e.g., table name, date, etc.? - - val parquetSchema: MessageType = - ParquetTypesConverter.convertFromAttributes(attributes) - val metaData: FileMetaData = new FileMetaData( - parquetSchema, - extraMetadata, - "Spark") - - ParquetRelation.enableLogForwarding() - ParquetFileWriter.writeMetadataFile( - conf, - path, - new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) - } - - /** - * Try to read Parquet metadata at the given Path. We first see if there is a summary file - * in the parent directory. If so, this is used. Else we read the actual footer at the given - * location. - * @param origPath The path at which we expect one (or more) Parquet files. - * @return The `ParquetMetadata` containing among other things the schema. - */ - def readMetaData(origPath: Path): ParquetMetadata = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") - } - val job = new Job() - // TODO: since this is called from ParquetRelation (LogicalPlan) we don't have access - // to SparkContext's hadoopConfig; in principle the default FileSystem may be different(?!) - val conf = ContextUtil.getConfiguration(job) - val fs: FileSystem = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") - } - val path = origPath.makeQualified(fs) - if (!fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException( - s"Expected $path for be a directory with Parquet files/metadata") - } - ParquetRelation.enableLogForwarding() - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - // if this is a new table that was just created we will find only the metadata file - if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { - ParquetFileReader.readFooter(conf, metadataPath) - } else { - // there may be one or more Parquet files in the given directory - val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path)) - // TODO: for now we assume that all footers (if there is more than one) have identical - // metadata; we may want to add a check here at some point - if (footers.size() == 0) { - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path") - } - footers(0).getParquetMetadata - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 65ba1246fbf9a..624f2e2fa13f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -36,6 +36,7 @@ import parquet.schema.MessageType import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} /** @@ -64,10 +65,13 @@ case class ParquetTableScan( NewFileInputFormat.addInputPath(job, path) } - // Store Parquet schema in `Configuration` + // Store both requested and original schema in `Configuration` conf.set( - RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertFromAttributes(output).toString) + RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + ParquetTypesConverter.convertToString(output)) + conf.set( + RowWriteSupport.SPARK_ROW_SCHEMA, + ParquetTypesConverter.convertToString(relation.output)) // Store record filtering predicate in `Configuration` // Note 1: the input format ignores all predicates that cannot be expressed @@ -166,13 +170,18 @@ case class InsertIntoParquetTable( val job = new Job(sc.hadoopConfiguration) - ParquetOutputFormat.setWriteSupportClass( - job, - classOf[org.apache.spark.sql.parquet.RowWriteSupport]) + val writeSupport = + if (child.output.map(_.dataType).forall(_.isPrimitive)) { + logger.debug("Initializing MutableRowWriteSupport") + classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] + } else { + classOf[org.apache.spark.sql.parquet.RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupport) - // TODO: move that to function in object val conf = ContextUtil.getConfiguration(job) - conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, relation.parquetSchema.toString) + RowWriteSupport.setSchema(relation.output, conf) val fspath = new Path(relation.path) val fs = fspath.getFileSystem(conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 71ba0fecce47a..bfcbdeb34a92f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -29,21 +29,23 @@ import parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import com.google.common.io.BaseEncoding /** * A `parquet.io.api.RecordMaterializer` for Rows. * *@param root The root group converter for the record. */ -private[parquet] class RowRecordMaterializer(root: CatalystGroupConverter) +private[parquet] class RowRecordMaterializer(root: CatalystConverter) extends RecordMaterializer[Row] { - def this(parquetSchema: MessageType) = - this(new CatalystGroupConverter(ParquetTypesConverter.convertToAttributes(parquetSchema))) + def this(parquetSchema: MessageType, attributes: Seq[Attribute]) = + this(CatalystConverter.createRootConverter(parquetSchema, attributes)) override def getCurrentRecord: Row = root.getCurrentRecord - override def getRootConverter: GroupConverter = root + override def getRootConverter: GroupConverter = root.asInstanceOf[GroupConverter] } /** @@ -56,68 +58,94 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { stringMap: java.util.Map[String, String], fileSchema: MessageType, readContext: ReadContext): RecordMaterializer[Row] = { - log.debug(s"preparing for read with file schema $fileSchema") - new RowRecordMaterializer(readContext.getRequestedSchema) + log.debug(s"preparing for read with Parquet file schema $fileSchema") + // Note: this very much imitates AvroParquet + val parquetSchema = readContext.getRequestedSchema + var schema: Seq[Attribute] = null + + if (readContext.getReadSupportMetadata != null) { + // first try to find the read schema inside the metadata (can result from projections) + if ( + readContext + .getReadSupportMetadata + .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) != null) { + schema = ParquetTypesConverter.convertFromString( + readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) + } else { + // if unavailable, try the schema that was read originally from the file or provided + // during the creation of the Parquet relation + if (readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { + schema = ParquetTypesConverter.convertFromString( + readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) + } + } + } + // if both unavailable, fall back to deducing the schema from the given Parquet schema + if (schema == null) { + log.debug("falling back to Parquet read schema") + schema = ParquetTypesConverter.convertToAttributes(parquetSchema) + } + log.debug(s"list of attributes that will be read: $schema") + new RowRecordMaterializer(parquetSchema, schema) } override def init( configuration: Configuration, keyValueMetaData: java.util.Map[String, String], fileSchema: MessageType): ReadContext = { - val requested_schema_string = - configuration.get(RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, fileSchema.toString) - val requested_schema = - MessageTypeParser.parseMessageType(requested_schema_string) - log.debug(s"read support initialized for requested schema $requested_schema") - ParquetRelation.enableLogForwarding() - new ReadContext(requested_schema, keyValueMetaData) + var parquetSchema: MessageType = fileSchema + var metadata: java.util.Map[String, String] = new java.util.HashMap[String, String]() + val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) + + if (requestedAttributes != null) { + parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) + metadata.put( + RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + ParquetTypesConverter.convertToString(requestedAttributes)) + } + + val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + if (origAttributesStr != null) { + metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) + } + + return new ReadSupport.ReadContext(parquetSchema, metadata) } } private[parquet] object RowReadSupport { - val PARQUET_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + private def getRequestedSchema(configuration: Configuration): Seq[Attribute] = { + val schemaString = configuration.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + if (schemaString == null) null else ParquetTypesConverter.convertFromString(schemaString) + } } /** * A `parquet.hadoop.api.WriteSupport` for Row ojects. */ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { - def setSchema(schema: MessageType, configuration: Configuration) { - // for testing - this.schema = schema - // TODO: could use Attributes themselves instead of Parquet schema? - configuration.set( - RowWriteSupport.PARQUET_ROW_SCHEMA, - schema.toString) - configuration.set( - ParquetOutputFormat.WRITER_VERSION, - ParquetProperties.WriterVersion.PARQUET_1_0.toString) - } - - def getSchema(configuration: Configuration): MessageType = { - MessageTypeParser.parseMessageType(configuration.get(RowWriteSupport.PARQUET_ROW_SCHEMA)) - } - private var schema: MessageType = null - private var writer: RecordConsumer = null - private var attributes: Seq[Attribute] = null + private[parquet] var writer: RecordConsumer = null + private[parquet] var attributes: Seq[Attribute] = null override def init(configuration: Configuration): WriteSupport.WriteContext = { - schema = if (schema == null) getSchema(configuration) else schema - attributes = ParquetTypesConverter.convertToAttributes(schema) - log.debug(s"write support initialized for requested schema $schema") + attributes = if (attributes == null) RowWriteSupport.getSchema(configuration) else attributes + + log.debug(s"write support initialized for requested schema $attributes") ParquetRelation.enableLogForwarding() new WriteSupport.WriteContext( - schema, + ParquetTypesConverter.convertFromAttributes(attributes), new java.util.HashMap[java.lang.String, java.lang.String]()) } override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { writer = recordConsumer - log.debug(s"preparing for write with schema $schema") + log.debug(s"preparing for write with schema $attributes") } - // TODO: add groups (nested fields) override def write(record: Row): Unit = { if (attributes.size > record.size) { throw new IndexOutOfBoundsException( @@ -130,98 +158,176 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { // null values indicate optional fields but we do not check currently if (record(index) != null && record(index) != Nil) { writer.startField(attributes(index).name, index) - ParquetTypesConverter.consumeType(writer, attributes(index).dataType, record, index) + writeValue(attributes(index).dataType, record(index)) writer.endField(attributes(index).name, index) } index = index + 1 } writer.endMessage() } -} -private[parquet] object RowWriteSupport { - val PARQUET_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.schema" -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record to a `Row` object. - * - * @param schema The corresponding Catalyst schema in the form of a list of attributes. - */ -private[parquet] class CatalystGroupConverter( - schema: Seq[Attribute], - protected[parquet] val current: ParquetRelation.RowType) extends GroupConverter { - - def this(schema: Seq[Attribute]) = this(schema, new ParquetRelation.RowType(schema.length)) - - val converters: Array[Converter] = schema.map { - a => a.dataType match { - case ctype: NativeType => - // note: for some reason matching for StringType fails so use this ugly if instead - if (ctype == StringType) { - new CatalystPrimitiveStringConverter(this, schema.indexOf(a)) - } else { - new CatalystPrimitiveConverter(this, schema.indexOf(a)) - } - case _ => throw new RuntimeException( - s"unable to convert datatype ${a.dataType.toString} in CatalystGroupConverter") + private[parquet] def writeValue(schema: DataType, value: Any): Unit = { + if (value != null && value != Nil) { + schema match { + case t @ ArrayType(_) => writeArray( + t, + value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) + case t @ MapType(_, _) => writeMap( + t, + value.asInstanceOf[CatalystConverter.MapScalaType[_, _]]) + case t @ StructType(_) => writeStruct( + t, + value.asInstanceOf[CatalystConverter.StructScalaType[_]]) + case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) + } } - }.toArray + } - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = { + if (value != null && value != Nil) { + schema match { + case StringType => writer.addBinary( + Binary.fromByteArray( + value.asInstanceOf[String].getBytes("utf-8") + ) + ) + case IntegerType => writer.addInteger(value.asInstanceOf[Int]) + case ShortType => writer.addInteger(value.asInstanceOf[Int]) + case LongType => writer.addLong(value.asInstanceOf[Long]) + case ByteType => writer.addInteger(value.asInstanceOf[Int]) + case DoubleType => writer.addDouble(value.asInstanceOf[Double]) + case FloatType => writer.addFloat(value.asInstanceOf[Float]) + case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case _ => sys.error(s"Do not know how to writer $schema to consumer") + } + } + } - private[parquet] def getCurrentRecord: ParquetRelation.RowType = current + private[parquet] def writeStruct( + schema: StructType, + struct: CatalystConverter.StructScalaType[_]): Unit = { + if (struct != null && struct != Nil) { + val fields = schema.fields.toArray + writer.startGroup() + var i = 0 + while(i < fields.size) { + if (struct(i) != null && struct(i) != Nil) { + writer.startField(fields(i).name, i) + writeValue(fields(i).dataType, struct(i)) + writer.endField(fields(i).name, i) + } + i = i + 1 + } + writer.endGroup() + } + } - override def start(): Unit = { - var i = 0 - while (i < schema.length) { - current.setNullAt(i) - i = i + 1 + // TODO: support null values, see + // https://issues.apache.org/jira/browse/SPARK-1649 + private[parquet] def writeArray( + schema: ArrayType, + array: CatalystConverter.ArrayScalaType[_]): Unit = { + val elementType = schema.elementType + writer.startGroup() + if (array.size > 0) { + writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + var i = 0 + while(i < array.size) { + writeValue(elementType, array(i)) + i = i + 1 + } + writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } + writer.endGroup() } - override def end(): Unit = {} + // TODO: support null values, see + // https://issues.apache.org/jira/browse/SPARK-1649 + private[parquet] def writeMap( + schema: MapType, + map: CatalystConverter.MapScalaType[_, _]): Unit = { + writer.startGroup() + if (map.size > 0) { + writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0) + writer.startGroup() + writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) + for(key <- map.keys) { + writeValue(schema.keyType, key) + } + writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) + writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + for(value <- map.values) { + writeValue(schema.valueType, value) + } + writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + writer.endGroup() + writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0) + } + writer.endGroup() + } } -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveConverter( - parent: CatalystGroupConverter, - fieldIndex: Int) extends PrimitiveConverter { - // TODO: consider refactoring these together with ParquetTypesConverter - override def addBinary(value: Binary): Unit = - parent.getCurrentRecord.update(fieldIndex, value.getBytes) +// Optimized for non-nested rows +private[parquet] class MutableRowWriteSupport extends RowWriteSupport { + override def write(record: Row): Unit = { + if (attributes.size > record.size) { + throw new IndexOutOfBoundsException( + s"Trying to write more fields than contained in row (${attributes.size}>${record.size})") + } - override def addBoolean(value: Boolean): Unit = - parent.getCurrentRecord.setBoolean(fieldIndex, value) + var index = 0 + writer.startMessage() + while(index < attributes.size) { + // null values indicate optional fields but we do not check currently + if (record(index) != null && record(index) != Nil) { + writer.startField(attributes(index).name, index) + consumeType(attributes(index).dataType, record, index) + writer.endField(attributes(index).name, index) + } + index = index + 1 + } + writer.endMessage() + } - override def addDouble(value: Double): Unit = - parent.getCurrentRecord.setDouble(fieldIndex, value) + private def consumeType( + ctype: DataType, + record: Row, + index: Int): Unit = { + ctype match { + case StringType => writer.addBinary( + Binary.fromByteArray( + record(index).asInstanceOf[String].getBytes("utf-8") + ) + ) + case IntegerType => writer.addInteger(record.getInt(index)) + case ShortType => writer.addInteger(record.getShort(index)) + case LongType => writer.addLong(record.getLong(index)) + case ByteType => writer.addInteger(record.getByte(index)) + case DoubleType => writer.addDouble(record.getDouble(index)) + case FloatType => writer.addFloat(record.getFloat(index)) + case BooleanType => writer.addBoolean(record.getBoolean(index)) + case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") + } + } +} - override def addFloat(value: Float): Unit = - parent.getCurrentRecord.setFloat(fieldIndex, value) +private[parquet] object RowWriteSupport { + val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" - override def addInt(value: Int): Unit = - parent.getCurrentRecord.setInt(fieldIndex, value) + def getSchema(configuration: Configuration): Seq[Attribute] = { + val schemaString = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + if (schemaString == null) { + throw new RuntimeException("Missing schema!") + } + ParquetTypesConverter.convertFromString(schemaString) + } - override def addLong(value: Long): Unit = - parent.getCurrentRecord.setLong(fieldIndex, value) + def setSchema(schema: Seq[Attribute], configuration: Configuration) { + val encoded = ParquetTypesConverter.convertToString(schema) + configuration.set(SPARK_ROW_SCHEMA, encoded) + configuration.set( + ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_1_0.toString) + } } -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet strings (fixed-length byte arrays) - * into Catalyst Strings. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveStringConverter( - parent: CatalystGroupConverter, - fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.getCurrentRecord.setString(fieldIndex, value.toStringUsingUTF8) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 46c7172985642..1dc58633a2a68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.mapreduce.Job import parquet.example.data.{GroupWriter, Group} import parquet.example.data.simple.SimpleGroup -import parquet.hadoop.ParquetWriter +import parquet.hadoop.{ParquetReader, ParquetFileReader, ParquetWriter} import parquet.hadoop.api.WriteSupport import parquet.hadoop.api.WriteSupport.WriteContext +import parquet.hadoop.example.GroupReadSupport +import parquet.hadoop.util.ContextUtil import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} @@ -51,13 +56,13 @@ private[sql] object ParquetTestData { val testSchema = """message myrecord { - |optional boolean myboolean; - |optional int32 myint; - |optional binary mystring; - |optional int64 mylong; - |optional float myfloat; - |optional double mydouble; - |}""".stripMargin + optional boolean myboolean; + optional int32 myint; + optional binary mystring; + optional int64 mylong; + optional float myfloat; + optional double mydouble; + }""" // field names for test assertion error messages val testSchemaFieldNames = Seq( @@ -71,23 +76,23 @@ private[sql] object ParquetTestData { val subTestSchema = """ - |message myrecord { - |optional boolean myboolean; - |optional int64 mylong; - |} - """.stripMargin + message myrecord { + optional boolean myboolean; + optional int64 mylong; + } + """ val testFilterSchema = """ - |message myrecord { - |required boolean myboolean; - |required int32 myint; - |required binary mystring; - |required int64 mylong; - |required float myfloat; - |required double mydouble; - |} - """.stripMargin + message myrecord { + required boolean myboolean; + required int32 myint; + required binary mystring; + required int64 mylong; + required float myfloat; + required double mydouble; + } + """ // field names for test assertion error messages val subTestSchemaFieldNames = Seq( @@ -100,9 +105,110 @@ private[sql] object ParquetTestData { lazy val testData = new ParquetRelation(testDir.toURI.toString) + val testNestedSchema1 = + // based on blogpost example, source: + // https://blog.twitter.com/2013/dremel-made-simple-with-parquet + // note: instead of string we have to use binary (?) otherwise + // Parquet gives us: + // IllegalArgumentException: expected one of [INT64, INT32, BOOLEAN, + // BINARY, FLOAT, DOUBLE, INT96, FIXED_LEN_BYTE_ARRAY] + // Also repeated primitives seem tricky to convert (AvroParquet + // only uses them in arrays?) so only use at most one in each group + // and nothing else in that group (-> is mapped to array)! + // The "values" inside ownerPhoneNumbers is a keyword currently + // so that array types can be translated correctly. + """ + message AddressBook { + required binary owner; + optional group ownerPhoneNumbers { + repeated binary array; + } + optional group contacts { + repeated group array { + required binary name; + optional binary phoneNumber; + } + } + } + """ + + + val testNestedSchema2 = + """ + message TestNested2 { + required int32 firstInt; + optional int32 secondInt; + optional group longs { + repeated int64 array; + } + required group entries { + repeated group array { + required double value; + optional boolean truth; + } + } + optional group outerouter { + repeated group array { + repeated group array { + repeated int32 array; + } + } + } + } + """ + + val testNestedSchema3 = + """ + message TestNested3 { + required int32 x; + optional group booleanNumberPairs { + repeated group array { + required int32 key; + optional group value { + repeated group array { + required double nestedValue; + optional boolean truth; + } + } + } + } + } + """ + + val testNestedSchema4 = + """ + message TestNested4 { + required int32 x; + optional group data1 { + repeated group map { + required binary key; + required int32 value; + } + } + required group data2 { + repeated group map { + required binary key; + required group value { + required int64 payload1; + optional binary payload2; + } + } + } + } + """ + + val testNestedDir1 = Utils.createTempDir() + val testNestedDir2 = Utils.createTempDir() + val testNestedDir3 = Utils.createTempDir() + val testNestedDir4 = Utils.createTempDir() + + lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString) + lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString) + def writeFile() = { - testDir.delete + testDir.delete() val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet")) + val job = new Job() val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) val writeSupport = new TestGroupWriteSupport(schema) val writer = new ParquetWriter[Group](path, writeSupport) @@ -150,5 +256,149 @@ private[sql] object ParquetTestData { } writer.close() } + + def writeNestedFile1() { + // example data from https://blog.twitter.com/2013/dremel-made-simple-with-parquet + testNestedDir1.delete() + val path: Path = new Path(new Path(testNestedDir1.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema1) + + val r1 = new SimpleGroup(schema) + r1.add(0, "Julien Le Dem") + r1.addGroup(1) + .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 123 4567") + .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 666 1337") + .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "XXX XXX XXXX") + val contacts = r1.addGroup(2) + contacts.addGroup(0) + .append("name", "Dmitriy Ryaboy") + .append("phoneNumber", "555 987 6543") + contacts.addGroup(0) + .append("name", "Chris Aniszczyk") + + val r2 = new SimpleGroup(schema) + r2.add(0, "A. Nonymous") + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.write(r2) + writer.close() + } + + def writeNestedFile2() { + testNestedDir2.delete() + val path: Path = new Path(new Path(testNestedDir2.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema2) + + val r1 = new SimpleGroup(schema) + r1.add(0, 1) + r1.add(1, 7) + val longs = r1.addGroup(2) + longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME , 1.toLong << 32) + longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 33) + longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 34) + val booleanNumberPair = r1.addGroup(3).addGroup(0) + booleanNumberPair.add("value", 2.5) + booleanNumberPair.add("truth", false) + val top_level = r1.addGroup(4) + val second_level_a = top_level.addGroup(0) + val second_level_b = top_level.addGroup(0) + val third_level_aa = second_level_a.addGroup(0) + val third_level_ab = second_level_a.addGroup(0) + val third_level_c = second_level_b.addGroup(0) + third_level_aa.add( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + 7) + third_level_ab.add( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + 8) + third_level_c.add( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + 9) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.close() + } + + def writeNestedFile3() { + testNestedDir3.delete() + val path: Path = new Path(new Path(testNestedDir3.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema3) + + val r1 = new SimpleGroup(schema) + r1.add(0, 1) + val booleanNumberPairs = r1.addGroup(1) + val g1 = booleanNumberPairs.addGroup(0) + g1.add(0, 1) + val nested1 = g1.addGroup(1) + val ng1 = nested1.addGroup(0) + ng1.add(0, 1.5) + ng1.add(1, false) + val ng2 = nested1.addGroup(0) + ng2.add(0, 2.5) + ng2.add(1, true) + val g2 = booleanNumberPairs.addGroup(0) + g2.add(0, 2) + val ng3 = g2.addGroup(1) + .addGroup(0) + ng3.add(0, 3.5) + ng3.add(1, false) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.close() + } + + def writeNestedFile4() { + testNestedDir4.delete() + val path: Path = new Path(new Path(testNestedDir4.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema4) + + val r1 = new SimpleGroup(schema) + r1.add(0, 7) + val map1 = r1.addGroup(1) + val keyValue1 = map1.addGroup(0) + keyValue1.add(0, "key1") + keyValue1.add(1, 1) + val keyValue2 = map1.addGroup(0) + keyValue2.add(0, "key2") + keyValue2.add(1, 2) + val map2 = r1.addGroup(2) + val keyValue3 = map2.addGroup(0) + // TODO: currently only string key type supported + keyValue3.add(0, "seven") + val valueGroup1 = keyValue3.addGroup(1) + valueGroup1.add(0, 42.toLong) + valueGroup1.add(1, "the answer") + val keyValue4 = map2.addGroup(0) + // TODO: currently only string key type supported + keyValue4.add(0, "eight") + val valueGroup2 = keyValue4.addGroup(1) + valueGroup2.add(0, 49.toLong) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.close() + } + + // TODO: this is not actually used anywhere but useful for debugging + /* def readNestedFile(file: File, schemaString: String): Unit = { + val configuration = new Configuration() + val path = new Path(new Path(file.toURI), new Path("part-r-0.parquet")) + val fs: FileSystem = path.getFileSystem(configuration) + val schema: MessageType = MessageTypeParser.parseMessageType(schemaString) + assert(schema != null) + val outputStatus: FileStatus = fs.getFileStatus(new Path(path.toString)) + val footers = ParquetFileReader.readFooter(configuration, outputStatus) + assert(footers != null) + val reader = new ParquetReader(new Path(path.toString), new GroupReadSupport()) + val first = reader.read() + assert(first != null) + } */ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala new file mode 100644 index 0000000000000..f9046368e7ced --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.IOException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job + +import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} +import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} +import parquet.hadoop.util.ContextUtil +import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType} +import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns} +import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} +import parquet.schema.Type.Repetition + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.types._ + +// Implicits +import scala.collection.JavaConversions._ + +private[parquet] object ParquetTypesConverter extends Logging { + def isPrimitiveType(ctype: DataType): Boolean = + classOf[PrimitiveType] isAssignableFrom ctype.getClass + + def toPrimitiveDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { + case ParquetPrimitiveTypeName.BINARY => StringType + case ParquetPrimitiveTypeName.BOOLEAN => BooleanType + case ParquetPrimitiveTypeName.DOUBLE => DoubleType + case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) + case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 => IntegerType + case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 => + // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? + sys.error("Potential loss of precision: cannot convert INT96") + case _ => sys.error( + s"Unsupported parquet datatype $parquetType") + } + + /** + * Converts a given Parquet `Type` into the corresponding + * [[org.apache.spark.sql.catalyst.types.DataType]]. + * + * We apply the following conversion rules: + *
    + *
  • Primitive types are converter to the corresponding primitive type.
  • + *
  • Group types that have a single field that is itself a group, which has repetition + * level `REPEATED`, are treated as follows:
      + *
    • If the nested group has name `values`, the surrounding group is converted + * into an [[ArrayType]] with the corresponding field type (primitive or + * complex) as element type.
    • + *
    • If the nested group has name `map` and two fields (named `key` and `value`), + * the surrounding group is converted into a [[MapType]] + * with the corresponding key and value (value possibly complex) types. + * Note that we currently assume map values are not nullable.
    • + *
    • Other group types are converted into a [[StructType]] with the corresponding + * field types.
  • + *
+ * Note that fields are determined to be `nullable` if and only if their Parquet repetition + * level is not `REQUIRED`. + * + * @param parquetType The type to convert. + * @return The corresponding Catalyst type. + */ + def toDataType(parquetType: ParquetType): DataType = { + def correspondsToMap(groupType: ParquetGroupType): Boolean = { + if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { + false + } else { + // This mostly follows the convention in ``parquet.schema.ConversionPatterns`` + val keyValueGroup = groupType.getFields.apply(0).asGroupType() + keyValueGroup.getRepetition == Repetition.REPEATED && + keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME && + keyValueGroup.getFieldCount == 2 && + keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME && + keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME + } + } + + def correspondsToArray(groupType: ParquetGroupType): Boolean = { + groupType.getFieldCount == 1 && + groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME && + groupType.getFields.apply(0).getRepetition == Repetition.REPEATED + } + + if (parquetType.isPrimitive) { + toPrimitiveDataType(parquetType.asPrimitiveType.getPrimitiveTypeName) + } else { + val groupType = parquetType.asGroupType() + parquetType.getOriginalType match { + // if the schema was constructed programmatically there may be hints how to convert + // it inside the metadata via the OriginalType field + case ParquetOriginalType.LIST => { // TODO: check enums! + assert(groupType.getFieldCount == 1) + val field = groupType.getFields.apply(0) + new ArrayType(toDataType(field)) + } + case ParquetOriginalType.MAP => { + assert( + !groupType.getFields.apply(0).isPrimitive, + "Parquet Map type malformatted: expected nested group for map!") + val keyValueGroup = groupType.getFields.apply(0).asGroupType() + assert( + keyValueGroup.getFieldCount == 2, + "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") + val keyType = toDataType(keyValueGroup.getFields.apply(0)) + assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) + val valueType = toDataType(keyValueGroup.getFields.apply(1)) + assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) + new MapType(keyType, valueType) + } + case _ => { + // Note: the order of these checks is important! + if (correspondsToMap(groupType)) { // MapType + val keyValueGroup = groupType.getFields.apply(0).asGroupType() + val keyType = toDataType(keyValueGroup.getFields.apply(0)) + assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) + val valueType = toDataType(keyValueGroup.getFields.apply(1)) + assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) + new MapType(keyType, valueType) + } else if (correspondsToArray(groupType)) { // ArrayType + val elementType = toDataType(groupType.getFields.apply(0)) + new ArrayType(elementType) + } else { // everything else: StructType + val fields = groupType + .getFields + .map(ptype => new StructField( + ptype.getName, + toDataType(ptype), + ptype.getRepetition != Repetition.REQUIRED)) + new StructType(fields) + } + } + } + } + } + + /** + * For a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] return + * the name of the corresponding Parquet primitive type or None if the given type + * is not primitive. + * + * @param ctype The type to convert + * @return The name of the corresponding Parquet primitive type + */ + def fromPrimitiveDataType(ctype: DataType): + Option[ParquetPrimitiveTypeName] = ctype match { + case StringType => Some(ParquetPrimitiveTypeName.BINARY) + case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN) + case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE) + case ArrayType(ByteType) => + Some(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + case FloatType => Some(ParquetPrimitiveTypeName.FLOAT) + case IntegerType => Some(ParquetPrimitiveTypeName.INT32) + // There is no type for Byte or Short so we promote them to INT32. + case ShortType => Some(ParquetPrimitiveTypeName.INT32) + case ByteType => Some(ParquetPrimitiveTypeName.INT32) + case LongType => Some(ParquetPrimitiveTypeName.INT64) + case _ => None + } + + /** + * Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into + * the corresponding Parquet `Type`. + * + * The conversion follows the rules below: + *
    + *
  • Primitive types are converted into Parquet's primitive types.
  • + *
  • [[org.apache.spark.sql.catalyst.types.StructType]]s are converted + * into Parquet's `GroupType` with the corresponding field types.
  • + *
  • [[org.apache.spark.sql.catalyst.types.ArrayType]]s are converted + * into a 2-level nested group, where the outer group has the inner + * group as sole field. The inner group has name `values` and + * repetition level `REPEATED` and has the element type of + * the array as schema. We use Parquet's `ConversionPatterns` for this + * purpose.
  • + *
  • [[org.apache.spark.sql.catalyst.types.MapType]]s are converted + * into a nested (2-level) Parquet `GroupType` with two fields: a key + * type and a value type. The nested group has repetition level + * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` + * for this purpose
  • + *
+ * Parquet's repetition level is generally set according to the following rule: + *
    + *
  • If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or + * `MapType`, then the repetition level is set to `REPEATED`.
  • + *
  • Otherwise, if the attribute whose type is converted is `nullable`, the Parquet + * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
  • + *
+ * + *@param ctype The type to convert + * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]] + * whose type is converted + * @param nullable When true indicates that the attribute is nullable + * @param inArray When true indicates that this is a nested attribute inside an array. + * @return The corresponding Parquet type. + */ + def fromDataType( + ctype: DataType, + name: String, + nullable: Boolean = true, + inArray: Boolean = false): ParquetType = { + val repetition = + if (inArray) { + Repetition.REPEATED + } else { + if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED + } + val primitiveType = fromPrimitiveDataType(ctype) + if (primitiveType.isDefined) { + new ParquetPrimitiveType(repetition, primitiveType.get, name) + } else { + ctype match { + case ArrayType(elementType) => { + val parquetElementType = fromDataType( + elementType, + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + nullable = false, + inArray = true) + ConversionPatterns.listType(repetition, name, parquetElementType) + } + case StructType(structFields) => { + val fields = structFields.map { + field => fromDataType(field.dataType, field.name, field.nullable, inArray = false) + } + new ParquetGroupType(repetition, name, fields) + } + case MapType(keyType, valueType) => { + val parquetKeyType = + fromDataType( + keyType, + CatalystConverter.MAP_KEY_SCHEMA_NAME, + nullable = false, + inArray = false) + val parquetValueType = + fromDataType( + valueType, + CatalystConverter.MAP_VALUE_SCHEMA_NAME, + nullable = false, + inArray = false) + ConversionPatterns.mapType( + repetition, + name, + parquetKeyType, + parquetValueType) + } + case _ => sys.error(s"Unsupported datatype $ctype") + } + } + } + + def convertToAttributes(parquetSchema: ParquetType): Seq[Attribute] = { + parquetSchema + .asGroupType() + .getFields + .map( + field => + new AttributeReference( + field.getName, + toDataType(field), + field.getRepetition != Repetition.REQUIRED)()) + } + + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val fields = attributes.map( + attribute => + fromDataType(attribute.dataType, attribute.name, attribute.nullable)) + new MessageType("root", fields) + } + + def convertFromString(string: String): Seq[Attribute] = { + DataType(string) match { + case s: StructType => s.toAttributes + case other => sys.error(s"Can convert $string to row") + } + } + + def convertToString(schema: Seq[Attribute]): String = { + StructType.fromAttributes(schema).toString + } + + def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") + } + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") + } + val path = origPath.makeQualified(fs) + if (fs.exists(path) && !fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException(s"Expected to write to directory $path but found file") + } + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fs.exists(metadataPath)) { + try { + fs.delete(metadataPath, true) + } catch { + case e: IOException => + throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") + } + } + val extraMetadata = new java.util.HashMap[String, String]() + extraMetadata.put( + RowReadSupport.SPARK_METADATA_KEY, + ParquetTypesConverter.convertToString(attributes)) + // TODO: add extra data, e.g., table name, date, etc.? + + val parquetSchema: MessageType = + ParquetTypesConverter.convertFromAttributes(attributes) + val metaData: FileMetaData = new FileMetaData( + parquetSchema, + extraMetadata, + "Spark") + + ParquetRelation.enableLogForwarding() + ParquetFileWriter.writeMetadataFile( + conf, + path, + new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) + } + + /** + * Try to read Parquet metadata at the given Path. We first see if there is a summary file + * in the parent directory. If so, this is used. Else we read the actual footer at the given + * location. + * @param origPath The path at which we expect one (or more) Parquet files. + * @param configuration The Hadoop configuration to use. + * @return The `ParquetMetadata` containing among other things the schema. + */ + def readMetaData(origPath: Path, configuration: Option[Configuration]): ParquetMetadata = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") + } + val job = new Job() + val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) + val fs: FileSystem = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") + } + val path = origPath.makeQualified(fs) + if (!fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException( + s"Expected $path for be a directory with Parquet files/metadata") + } + ParquetRelation.enableLogForwarding() + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + // if this is a new table that was just created we will find only the metadata file + if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { + ParquetFileReader.readFooter(conf, metadataPath) + } else { + // there may be one or more Parquet files in the given directory + val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path)) + // TODO: for now we assume that all footers (if there is more than one) have identical + // metadata; we may want to add a check here at some point + if (footers.size() == 0) { + throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path") + } + footers(0).getParquetMetadata + } + } + + /** + * Reads in Parquet Metadata from the given path and tries to extract the schema + * (Catalyst attributes) from the application-specific key-value map. If this + * is empty it falls back to converting from the Parquet file schema which + * may lead to an upcast of types (e.g., {byte, short} to int). + * + * @param origPath The path at which we expect one (or more) Parquet files. + * @param conf The Hadoop configuration to use. + * @return A list of attributes that make up the schema. + */ + def readSchemaFromFile(origPath: Path, conf: Option[Configuration]): Seq[Attribute] = { + val keyValueMetadata: java.util.Map[String, String] = + readMetaData(origPath, conf) + .getFileMetaData + .getKeyValueMetaData + if (keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { + convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) + } else { + val attributes = convertToAttributes( + readMetaData(origPath, conf).getFileMetaData.getSchema) + log.warn(s"Falling back to schema conversion from Parquet types; result: $attributes") + attributes + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 9810520bb9ae6..7714eb1b5628a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -19,26 +19,23 @@ package org.apache.spark.sql.parquet import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} -import org.apache.hadoop.fs.{Path, FileSystem} -import org.apache.hadoop.mapreduce.Job - import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil import parquet.schema.MessageTypeParser +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job +import org.apache.spark.SparkContext import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} +import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.TestData -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.expressions.Equals -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils -// Implicits -import org.apache.spark.sql.test.TestSQLContext._ case class TestRDDEntry(key: Int, value: String) @@ -56,15 +53,36 @@ case class OptionalReflectData( doubleField: Option[Double], booleanField: Option[Boolean]) +case class Nested(i: Int, s: String) + +case class Data(array: Seq[Int], nested: Nested) + +case class AllDataTypes( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { - import TestData._ TestData // Load test data tables. var testRDD: SchemaRDD = null + // TODO: remove this once SqlParser can parse nested select statements + var nestedParserSqlContext: NestedParserSQLContext = null + override def beforeAll() { + nestedParserSqlContext = new NestedParserSQLContext(TestSQLContext.sparkContext) ParquetTestData.writeFile() ParquetTestData.writeFilterFile() + ParquetTestData.writeNestedFile1() + ParquetTestData.writeNestedFile2() + ParquetTestData.writeNestedFile3() + ParquetTestData.writeNestedFile4() testRDD = parquetFile(ParquetTestData.testDir.toString) testRDD.registerAsTable("testsource") parquetFile(ParquetTestData.testFilterDir.toString) @@ -74,9 +92,33 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA override def afterAll() { Utils.deleteRecursively(ParquetTestData.testDir) Utils.deleteRecursively(ParquetTestData.testFilterDir) + Utils.deleteRecursively(ParquetTestData.testNestedDir1) + Utils.deleteRecursively(ParquetTestData.testNestedDir2) + Utils.deleteRecursively(ParquetTestData.testNestedDir3) + Utils.deleteRecursively(ParquetTestData.testNestedDir4) // here we should also unregister the table?? } + test("Read/Write All Types") { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val range = (0 to 255) + TestSQLContext.sparkContext.parallelize(range) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + .saveAsParquetFile(tempDir) + val result = parquetFile(tempDir).collect() + range.foreach { + i => + assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") + assert(result(i).getInt(1) === i) + assert(result(i).getLong(2) === i.toLong) + assert(result(i).getFloat(3) === i.toFloat) + assert(result(i).getDouble(4) === i.toDouble) + assert(result(i).getShort(5) === i.toShort) + assert(result(i).getByte(6) === i.toByte) + assert(result(i).getBoolean(7) === (i % 2 == 0)) + } + } + test("self-join parquet files") { val x = ParquetTestData.testData.as('x) val y = ParquetTestData.testData.as('y) @@ -154,7 +196,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA path, TestSQLContext.sparkContext.hadoopConfiguration) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path) + val metaData = ParquetTypesConverter.readMetaData(path, Some(ContextUtil.getConfiguration(job))) assert(metaData != null) ParquetTestData .testData @@ -197,10 +239,37 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value in line $i") } Utils.deleteRecursively(file) - assert(true) } - test("insert (appending) to same table via Scala API") { + test("Insert (overwrite) via Scala API") { + val dirname = Utils.createTempDir() + val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + source_rdd.registerAsTable("source") + val dest_rdd = createParquetFile[TestRDDEntry](dirname.toString) + dest_rdd.registerAsTable("dest") + sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() + val rdd_copy1 = sql("SELECT * FROM dest").collect() + assert(rdd_copy1.size === 100) + assert(rdd_copy1(0).apply(0) === 1) + assert(rdd_copy1(0).apply(1) === "val_1") + // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is + // executed twice otherwise?! + sql("INSERT INTO dest SELECT * FROM source") + val rdd_copy2 = sql("SELECT * FROM dest").collect() + assert(rdd_copy2.size === 200) + assert(rdd_copy2(0).apply(0) === 1) + assert(rdd_copy2(0).apply(1) === "val_1") + assert(rdd_copy2(99).apply(0) === 100) + assert(rdd_copy2(99).apply(1) === "val_100") + assert(rdd_copy2(100).apply(0) === 1) + assert(rdd_copy2(100).apply(1) === "val_1") + Utils.deleteRecursively(dirname) + } + + test("Insert (appending) to same table via Scala API") { + // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is + // executed twice otherwise?! sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) @@ -245,7 +314,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("create RecordFilter for simple predicates") { val attribute1 = new AttributeReference("first", IntegerType, false)() - val predicate1 = new Equals(attribute1, new Literal(1, IntegerType)) + val predicate1 = new EqualTo(attribute1, new Literal(1, IntegerType)) val filter1 = ParquetFilters.createFilter(predicate1) assert(filter1.isDefined) assert(filter1.get.predicate == predicate1, "predicates do not match") @@ -363,4 +432,272 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") assert(query.collect().size === 10) } + + test("Importing nested Parquet file (Addressbook)") { + val result = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + .collect() + assert(result != null) + assert(result.size === 2) + val first_record = result(0) + val second_record = result(1) + assert(first_record != null) + assert(second_record != null) + assert(first_record.size === 3) + assert(second_record(1) === null) + assert(second_record(2) === null) + assert(second_record(0) === "A. Nonymous") + assert(first_record(0) === "Julien Le Dem") + val first_owner_numbers = first_record(1) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + val first_contacts = first_record(2) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(first_owner_numbers != null) + assert(first_owner_numbers(0) === "555 123 4567") + assert(first_owner_numbers(2) === "XXX XXX XXXX") + assert(first_contacts(0) + .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2) + val first_contacts_entry_one = first_contacts(0) + .asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy") + assert(first_contacts_entry_one(1) === "555 987 6543") + val first_contacts_entry_two = first_contacts(1) + .asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(first_contacts_entry_two(0) === "Chris Aniszczyk") + } + + test("Importing nested Parquet file (nested numbers)") { + val result = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir2.toString) + .toSchemaRDD + .collect() + assert(result.size === 1, "number of top-level rows incorrect") + assert(result(0).size === 5, "number of fields in row incorrect") + assert(result(0)(0) === 1) + assert(result(0)(1) === 7) + val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult1.size === 3) + assert(subresult1(0) === (1.toLong << 32)) + assert(subresult1(1) === (1.toLong << 33)) + assert(subresult1(2) === (1.toLong << 34)) + val subresult2 = result(0)(3) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(subresult2.size === 2) + assert(subresult2(0) === 2.5) + assert(subresult2(1) === false) + val subresult3 = result(0)(4) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult3.size === 2) + assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2) + val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) + assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) + assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1) + assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + } + + test("Simple query on addressbook") { + val data = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect() + assert(tmp.size === 1) + assert(tmp(0)(0) === "Julien Le Dem") + } + + test("Projection in addressbook") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + data.registerAsTable("data") + val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") + val tmp = query.collect() + assert(tmp.size === 2) + assert(tmp(0).size === 2) + assert(tmp(0)(0) === "Julien Le Dem") + assert(tmp(0)(1) === "Chris Aniszczyk") + assert(tmp(1)(0) === "A. Nonymous") + assert(tmp(1)(1) === null) + } + + test("Simple query on nested int data") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir2.toString) + .toSchemaRDD + data.registerAsTable("data") + val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() + assert(result1.size === 1) + assert(result1(0).size === 1) + assert(result1(0)(0) === 2.5) + val result2 = nestedParserSqlContext.sql("SELECT entries[0] FROM data").collect() + assert(result2.size === 1) + val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(subresult1.size === 2) + assert(subresult1(0) === 2.5) + assert(subresult1(1) === false) + val result3 = nestedParserSqlContext.sql("SELECT outerouter FROM data").collect() + val subresult2 = result3(0)(0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) + assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) + assert(result3(0)(0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + } + + test("nested structs") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir3.toString) + .toSchemaRDD + data.registerAsTable("data") + val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() + assert(result1.size === 1) + assert(result1(0).size === 1) + assert(result1(0)(0) === false) + val result2 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() + assert(result2.size === 1) + assert(result2(0).size === 1) + assert(result2(0)(0) === true) + val result3 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() + assert(result3.size === 1) + assert(result3(0).size === 1) + assert(result3(0)(0) === false) + } + + test("simple map") { + val data = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir4.toString) + .toSchemaRDD + data.registerAsTable("mapTable") + val result1 = sql("SELECT data1 FROM mapTable").collect() + assert(result1.size === 1) + assert(result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, _]] + .getOrElse("key1", 0) === 1) + assert(result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, _]] + .getOrElse("key2", 0) === 2) + val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() + assert(result2(0)(0) === 1) + } + + test("map with struct values") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir4.toString) + .toSchemaRDD + data.registerAsTable("mapTable") + val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() + assert(result1.size === 1) + val entry1 = result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("seven", null) + assert(entry1 != null) + assert(entry1(0) === 42) + assert(entry1(1) === "the answer") + val entry2 = result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("eight", null) + assert(entry2 != null) + assert(entry2(0) === 49) + assert(entry2(1) === null) + val result2 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() + assert(result2.size === 1) + assert(result2(0)(0) === 42.toLong) + assert(result2(0)(1) === "the answer") + } + + test("Writing out Addressbook and reading it back in") { + // TODO: find out why CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME + // has no effect in this test case + val tmpdir = Utils.createTempDir() + Utils.deleteRecursively(tmpdir) + val result = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + result.saveAsParquetFile(tmpdir.toString) + nestedParserSqlContext + .parquetFile(tmpdir.toString) + .toSchemaRDD + .registerAsTable("tmpcopy") + val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() + assert(tmpdata.size === 2) + assert(tmpdata(0).size === 2) + assert(tmpdata(0)(0) === "Julien Le Dem") + assert(tmpdata(0)(1) === "Chris Aniszczyk") + assert(tmpdata(1)(0) === "A. Nonymous") + assert(tmpdata(1)(1) === null) + Utils.deleteRecursively(tmpdir) + } + + test("Writing out Map and reading it back in") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir4.toString) + .toSchemaRDD + val tmpdir = Utils.createTempDir() + Utils.deleteRecursively(tmpdir) + data.saveAsParquetFile(tmpdir.toString) + nestedParserSqlContext + .parquetFile(tmpdir.toString) + .toSchemaRDD + .registerAsTable("tmpmapcopy") + val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() + assert(result1.size === 1) + assert(result1(0)(0) === 2) + val result2 = nestedParserSqlContext.sql("SELECT data2 FROM tmpmapcopy").collect() + assert(result2.size === 1) + val entry1 = result2(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("seven", null) + assert(entry1 != null) + assert(entry1(0) === 42) + assert(entry1(1) === "the answer") + val entry2 = result2(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("eight", null) + assert(entry2 != null) + assert(entry2(0) === 49) + assert(entry2(1) === null) + val result3 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() + assert(result3.size === 1) + assert(result3(0)(0) === 42.toLong) + assert(result3(0)(1) === "the answer") + Utils.deleteRecursively(tmpdir) + } +} + +// TODO: the code below is needed temporarily until the standard parser is able to parse +// nested field expressions correctly +class NestedParserSQLContext(@transient override val sparkContext: SparkContext) extends SQLContext(sparkContext) { + override protected[sql] val parser = new NestedSqlParser() +} + +class NestedSqlLexical(override val keywords: Seq[String]) extends SqlLexical(keywords) { + override def identChar = letter | elem('_') + delimiters += (".") +} + +class NestedSqlParser extends SqlParser { + override val lexical = new NestedSqlLexical(reservedWords) + + override protected lazy val baseExpression: PackratParser[Expression] = + expression ~ "[" ~ expression <~ "]" ^^ { + case base ~ _ ~ ordinal => GetItem(base, ordinal) + } | + expression ~ "." ~ ident ^^ { + case base ~ _ ~ fieldName => GetField(base, fieldName) + } | + TRUE ^^^ Literal(true, BooleanType) | + FALSE ^^^ Literal(false, BooleanType) | + cast | + "(" ~> expression <~ ")" | + function | + "-" ~> literal ^^ UnaryMinus | + ident ^^ UnresolvedAttribute | + "*" ^^^ Star(None) | + literal } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index bf084584d41dd..7aedfcd74189b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} +import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand /** * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is @@ -291,6 +292,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * execution is simply passed back to Hive. */ def stringResult(): Seq[String] = executedPlan match { + case describeHiveTableCommand: DescribeHiveTableCommand => + // If it is a describe command for a Hive table, we want to have the output format + // be similar with Hive. + describeHiveTableCommand.hiveString case command: PhysicalCommand => command.sideEffectResult.map(_.toString) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index af513dff189a6..9c483752f9dc0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -207,7 +207,9 @@ object HiveMetastoreTypes extends RegexParsers { } protected lazy val structType: Parser[DataType] = - "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ StructType + "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ { + case fields => new StructType(fields) + } protected lazy val dataType: Parser[DataType] = arrayType | diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 844673f66d103..b073dc3895f05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -52,7 +52,6 @@ private[hive] case class AddFile(filePath: String) extends Command private[hive] object HiveQl { protected val nativeCommands = Seq( "TOK_DESCFUNCTION", - "TOK_DESCTABLE", "TOK_DESCDATABASE", "TOK_SHOW_TABLESTATUS", "TOK_SHOWDATABASES", @@ -120,6 +119,12 @@ private[hive] object HiveQl { "TOK_SWITCHDATABASE" ) + // Commands that we do not need to explain. + protected val noExplainCommands = Seq( + "TOK_CREATETABLE", + "TOK_DESCTABLE" + ) ++ nativeCommands + /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations * similar to [[catalyst.trees.TreeNode]]. @@ -199,6 +204,9 @@ private[hive] object HiveQl { class ParseException(sql: String, cause: Throwable) extends Exception(s"Failed to parse: $sql", cause) + class SemanticException(msg: String) + extends Exception(s"Error in semantic analysis: $msg") + /** * Returns the AST for the given SQL string. */ @@ -362,13 +370,20 @@ private[hive] object HiveQl { } } + protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { + val (db, tableName) = + tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + + (db, tableName) + } + protected def nodeToPlan(node: Node): LogicalPlan = node match { // Just fake explain for any of the native commands. - case Token("TOK_EXPLAIN", explainArgs) if nativeCommands contains explainArgs.head.getText => - ExplainCommand(NoRelation) - // Create tables aren't native commands due to CTAS queries, but we still don't need to - // explain them. - case Token("TOK_EXPLAIN", explainArgs) if explainArgs.head.getText == "TOK_CREATETABLE" => + case Token("TOK_EXPLAIN", explainArgs) + if noExplainCommands.contains(explainArgs.head.getText) => ExplainCommand(NoRelation) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. @@ -377,6 +392,39 @@ private[hive] object HiveQl { // TODO: support EXTENDED? ExplainCommand(nodeToPlan(query)) + case Token("TOK_DESCTABLE", describeArgs) => + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + val Some(tableType) :: formatted :: extended :: pretty :: Nil = + getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) + if (formatted.isDefined || pretty.isDefined) { + // FORMATTED and PRETTY are not supported and this statement will be treated as + // a Hive native command. + NativePlaceholder + } else { + tableType match { + case Token("TOK_TABTYPE", nameParts) if nameParts.size == 1 => { + nameParts.head match { + case Token(".", dbName :: tableName :: Nil) => + // It is describing a table with the format like "describe db.table". + // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. + val (db, tableName) = extractDbNameTableName(nameParts.head) + DescribeCommand( + UnresolvedRelation(db, tableName, None), extended.isDefined) + case Token(".", dbName :: tableName :: colName :: Nil) => + // It is describing a column with the format like "describe db.table column". + NativePlaceholder + case tableName => + // It is describing a table with the format like "describe table". + DescribeCommand( + UnresolvedRelation(None, tableName.getText, None), + extended.isDefined) + } + } + // All other cases. + case _ => NativePlaceholder + } + } + case Token("TOK_CREATETABLE", children) if children.collect { case t@Token("TOK_QUERY", _) => t }.nonEmpty => // TODO: Parse other clauses. @@ -414,11 +462,8 @@ private[hive] object HiveQl { s"Unhandled clauses: ${notImplemented.flatten.map(dumpTree(_)).mkString("\n")}") } - val (db, tableName) = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } + val (db, tableName) = extractDbNameTableName(tableNameParts) + InsertIntoCreatedTable(db, tableName, nodeToPlan(query)) // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command. @@ -438,6 +483,7 @@ private[hive] object HiveQl { whereClause :: groupByClause :: orderByClause :: + havingClause :: sortByClause :: clusterByClause :: distributeByClause :: @@ -452,6 +498,7 @@ private[hive] object HiveQl { "TOK_WHERE", "TOK_GROUPBY", "TOK_ORDERBY", + "TOK_HAVING", "TOK_SORTBY", "TOK_CLUSTERBY", "TOK_DISTRIBUTEBY", @@ -516,7 +563,6 @@ private[hive] object HiveQl { withWhere) }.getOrElse(withWhere) - // The projection of the query can either be a normal projection, an aggregation // (if there is a group by) or a script transformation. val withProject = transformation.getOrElse { @@ -534,21 +580,28 @@ private[hive] object HiveQl { val withDistinct = if (selectDistinctClause.isDefined) Distinct(withProject) else withProject + val withHaving = havingClause.map { h => + val havingExpr = h.getChildren.toSeq match { case Seq(hexpr) => nodeToExpr(hexpr) } + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(havingExpr, BooleanType), withDistinct) + }.getOrElse(withDistinct) + val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), withDistinct) + Sort(totalOrdering.getChildren.map(nodeToSortOrder), withHaving) case (None, Some(perPartitionOrdering), None, None) => - SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withDistinct) + SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withHaving) case (None, None, Some(partitionExprs), None) => - Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), - Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), - Repartition(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) - case (None, None, None, None) => withDistinct + Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) + case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") } @@ -656,7 +709,7 @@ private[hive] object HiveQl { val joinConditions = joinExpressions.sliding(2).map { case Seq(c1, c2) => - val predicates = (c1, c2).zipped.map { case (e1, e2) => Equals(e1, e2): Expression } + val predicates = (c1, c2).zipped.map { case (e1, e2) => EqualTo(e1, e2): Expression } predicates.reduceLeft(And) }.toBuffer @@ -736,11 +789,7 @@ private[hive] object HiveQl { val Some(tableNameParts) :: partitionClause :: Nil = getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - val (db, tableName) = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } + val (db, tableName) = extractDbNameTableName(tableNameParts) val partitionKeys = partitionClause.map(_.getChildren.map { // Parse partitions. We also make keys case insensitive. @@ -886,9 +935,9 @@ private[hive] object HiveQl { case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) /* Comparisons */ - case Token("=", left :: right:: Nil) => Equals(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) + case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) + case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) @@ -928,7 +977,7 @@ private[hive] object HiveQl { // FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval(). // Hence effectful / non-deterministic key expressions are *not* supported at the moment. // We should consider adding new Expressions to get around this. - Seq(Equals(nodeToExpr(branches(0)), nodeToExpr(condVal)), + Seq(EqualTo(nodeToExpr(branches(0)), nodeToExpr(condVal)), nodeToExpr(value)) case Seq(elseVal) => Seq(nodeToExpr(elseVal)) }.toSeq.reduce(_ ++ _) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 0ac0ee9071f36..af7687b40429b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -81,6 +81,16 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil + case describe: logical.DescribeCommand => { + val resolvedTable = context.executePlan(describe.table).analyzed + resolvedTable match { + case t: MetastoreRelation => + Seq(DescribeHiveTableCommand( + t, describe.output, describe.isExtended)(context)) + case o: LogicalPlan => + Seq(DescribeCommand(planLater(o), describe.output)(context)) + } + } case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala index a839231449161..2de2db28a7e04 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive.execution import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive} +import org.apache.hadoop.hive.ql.metadata.formatting.MetaDataFormatUtils import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption @@ -344,12 +346,16 @@ case class InsertIntoHiveTable( writer.commitJob() } + override def execute() = result + /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. + * + * Note: this is run once and then kept to avoid double insertions. */ - def execute() = { + private lazy val result: RDD[Row] = { val childRdd = child.execute() assert(childRdd != null) @@ -367,12 +373,18 @@ case class InsertIntoHiveTable( ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) iter.map { row => - // Casts Strings to HiveVarchars when necessary. - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector) - val mappedRow = row.zip(fieldOIs).map(wrap) + var i = 0 + while (i < row.length) { + // Casts Strings to HiveVarchars when necessary. + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } - serializer.serialize(mappedRow.toArray, standardOI) + serializer.serialize(outputData, standardOI) } } @@ -452,3 +464,61 @@ case class NativeCommand( override def otherCopyArgs = context :: Nil } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class DescribeHiveTableCommand( + table: MetastoreRelation, + output: Seq[Attribute], + isExtended: Boolean)( + @transient context: HiveContext) + extends LeafNode with Command { + + // Strings with the format like Hive. It is used for result comparison in our unit tests. + lazy val hiveString: Seq[String] = { + val alignment = 20 + val delim = "\t" + + sideEffectResult.map { + case (name, dataType, comment) => + String.format("%-" + alignment + "s", name) + delim + + String.format("%-" + alignment + "s", dataType) + delim + + String.format("%-" + alignment + "s", Option(comment).getOrElse("None")) + } + } + + override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + // Trying to mimic the format of Hive's output. But not exactly the same. + var results: Seq[(String, String, String)] = Nil + + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols + results ++= columns.map(field => (field.getName, field.getType, field.getComment)) + if (!partitionColumns.isEmpty) { + val partColumnInfo = + partitionColumns.map(field => (field.getName, field.getType, field.getComment)) + results ++= + partColumnInfo ++ + Seq(("# Partition Information", "", "")) ++ + Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ + partColumnInfo + } + + if (isExtended) { + results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) + } + + results + } + + override def execute(): RDD[Row] = { + val rows = sideEffectResult.map { + case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + } + context.sparkContext.parallelize(rows, 1) + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 24c929ff7430d..08ef4d9b6bb93 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -144,6 +144,12 @@ abstract class HiveComparisonTest case _: SetCommand => Seq("0") case _: LogicalNativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") case _: ExplainCommand => answer + case _: DescribeCommand => + // Filter out non-deterministic lines and lines which do not have actual results but + // can introduce problems because of the way Hive formats these lines. + // Then, remove empty lines. Do not sort the results. + answer.filterNot( + r => nonDeterministicLine(r) || ignoredLine(r)).map(_.trim).filterNot(_ == "") case plan => if (isSorted(plan)) answer else answer.sorted } orderedAnswer.map(cleanPaths) @@ -169,6 +175,16 @@ abstract class HiveComparisonTest protected def nonDeterministicLine(line: String) = nonDeterministicLineIndicators.exists(line contains _) + // This list contains indicators for those lines which do not have actual results and we + // want to ignore. + lazy val ignoredLineIndicators = Seq( + "# Partition Information", + "# col_name" + ) + + protected def ignoredLine(line: String) = + ignoredLineIndicators.exists(line contains _) + /** * Removes non-deterministic paths from `str` so cached answers will compare correctly. */ @@ -329,11 +345,17 @@ abstract class HiveComparisonTest if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) { - val hivePrintOut = s"== HIVE - ${hive.size} row(s) ==" +: preparedHive + val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") + println("hive output") + hive.foreach(println) + + println("catalyst printout") + catalyst.foreach(println) + if (recomputeCache) { logger.warn(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ee194dbcb77b2..cdfc2d0c17384 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -78,7 +78,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_merge", "alter_concatenate_indexed_table", "protectmode2", - "describe_table", + //"describe_table", "describe_comment_nonascii", "udf5", "udf_java_method", @@ -177,7 +177,16 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // After stop taking the `stringOrError` route, exceptions are thrown from these cases. // See SPARK-2129 for details. "join_view", - "mergejoins_mixed" + "mergejoins_mixed", + + // Returning the result of a describe state as a JSON object is not supported. + "describe_table_json", + "describe_database_json", + "describe_formatted_view_partitioned_json", + + // Hive returns the results of describe as plain text. Comments with multiple lines + // introduce extra lines in the Hive results, which make the result comparison fail. + "describe_comment_indent" ) /** @@ -292,11 +301,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "default_partition_name", "delimiter", "desc_non_existent_tbl", - "describe_comment_indent", - "describe_database_json", "describe_formatted_view_partitioned", - "describe_formatted_view_partitioned_json", - "describe_table_json", "diff_part_input_formats", "disable_file_format_check", "drop_function", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fe698f0fc57b8..d855310253bf3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -21,13 +21,21 @@ import scala.util.Try import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SchemaRDD, execution, Row} +import org.apache.spark.sql.{SchemaRDD, Row} + +case class TestData(a: Int, b: String) /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest { + test("CREATE TABLE AS runs once") { + hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() + assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + "Incorrect number of rows in created table") + } + createQueryTest("between", "SELECT * FROM src WHERE key Between 1 and 2") @@ -202,12 +210,9 @@ class HiveQuerySuite extends HiveComparisonTest { } } - private val explainCommandClassName = - classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$") - def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.size > 1 && explanation.head.startsWith(explainCommandClassName) + explanation.size > 1 && explanation.head.startsWith("Physical execution plan") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -219,6 +224,27 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.reset() } + test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { + val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) + .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} + TestHive.sparkContext.parallelize(fixture).registerAsTable("having_test") + val results = + hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + .collect() + .map(x => Pair(x.getString(0), x.getInt(1))) + + assert(results === Array(Pair("foo", 4))) + TestHive.reset() + } + + test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { + hql("select key, count(*) c from src group by key having c").collect() + } + + test("SPARK-2225: turn HAVING without GROUP BY into a simple filter") { + assert(hql("select key from src having key > 490").collect().size < 100) + } + test("Query Hive native command execution result") { val tableName = "test_native_commands" @@ -237,13 +263,6 @@ class HiveQuerySuite extends HiveComparisonTest { .map(_.getString(0)) .contains(tableName)) - assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - hql(s"DESCRIBE $tableName") - .select('result) - .collect() - .map(_.getString(0).split("\t").map(_.trim)) - } - assert(isExplanation(hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() @@ -260,6 +279,97 @@ class HiveQuerySuite extends HiveComparisonTest { assert(Try(q0.count()).isSuccess) } + test("DESCRIBE commands") { + hql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") + + hql( + """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') + |SELECT key, value + """.stripMargin) + + // Describe a table + assertResult( + Array( + Array("key", "int", null), + Array("value", "string", null), + Array("dt", "string", null), + Array("# Partition Information", "", ""), + Array("# col_name", "data_type", "comment"), + Array("dt", "string", null)) + ) { + hql("DESCRIBE test_describe_commands1") + .select('col_name, 'data_type, 'comment) + .collect() + } + + // Describe a table with a fully qualified table name + assertResult( + Array( + Array("key", "int", null), + Array("value", "string", null), + Array("dt", "string", null), + Array("# Partition Information", "", ""), + Array("# col_name", "data_type", "comment"), + Array("dt", "string", null)) + ) { + hql("DESCRIBE default.test_describe_commands1") + .select('col_name, 'data_type, 'comment) + .collect() + } + + // Describe a column is a native command + assertResult(Array(Array("value", "string", "from deserializer"))) { + hql("DESCRIBE test_describe_commands1 value") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + } + + // Describe a column is a native command + assertResult(Array(Array("value", "string", "from deserializer"))) { + hql("DESCRIBE default.test_describe_commands1 value") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + } + + // Describe a partition is a native command + assertResult( + Array( + Array("key", "int", "None"), + Array("value", "string", "None"), + Array("dt", "string", "None"), + Array("", "", ""), + Array("# Partition Information", "", ""), + Array("# col_name", "data_type", "comment"), + Array("", "", ""), + Array("dt", "string", "None")) + ) { + hql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + } + + // Describe a registered temporary table. + val testData: SchemaRDD = + TestHive.sparkContext.parallelize( + TestData(1, "str1") :: + TestData(1, "str2") :: Nil) + testData.registerAsTable("test_describe_commands2") + + assertResult( + Array( + Array("# Registered as a temporary table", null, null), + Array("a", "IntegerType", null), + Array("b", "StringType", null)) + ) { + hql("DESCRIBE test_describe_commands2") + .select('col_name, 'data_type, 'comment) + .collect() + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" @@ -352,3 +462,6 @@ class HiveQuerySuite extends HiveComparisonTest { // since they modify /clear stuff. } + +// for SPARK-2180 test +case class HavingRow(key: Int, value: String, attr: Int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index e030c8ee3dfc8..7436de264a1e1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} +import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.hive.test.TestHive + /** - * A set of tests that validate type promotion rules. + * A set of tests that validate type promotion and coercion rules. */ class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") @@ -28,4 +32,23 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") } } + + test("[SPARK-2210] boolean cast on boolean value should be removed") { + val q = "select cast(cast(key=0 as boolean) as boolean) from src" + val project = TestHive.hql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + + // No cast expression introduced + project.transformAllExpressions { case c: Cast => + fail(s"unexpected cast $c") + c + } + + // Only one equality check + var numEquals = 0 + project.transformAllExpressions { case e: EqualTo => + numEquals += 1 + e + } + assert(numEquals === 1) + } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index fd3ef9e1fa2de..62f9b3cf5ab88 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -21,8 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.SparkConf import org.apache.spark.scheduler.InputFormatInfo -import org.apache.spark.util.IntParam -import org.apache.spark.util.MemoryParam +import org.apache.spark.util.{Utils, IntParam, MemoryParam} // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! @@ -45,6 +44,18 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { parseArgs(args.toList) + // env variable SPARK_YARN_DIST_ARCHIVES/SPARK_YARN_DIST_FILES set in yarn-client then + // it should default to hdfs:// + files = Option(files).getOrElse(sys.env.get("SPARK_YARN_DIST_FILES").orNull) + archives = Option(archives).getOrElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES").orNull) + + // spark.yarn.dist.archives/spark.yarn.dist.files defaults to use file:// if not specified, + // for both yarn-client and yarn-cluster + files = Option(files).getOrElse(sparkConf.getOption("spark.yarn.dist.files"). + map(p => Utils.resolveURIs(p)).orNull) + archives = Option(archives).getOrElse(sparkConf.getOption("spark.yarn.dist.archives"). + map(p => Utils.resolveURIs(p)).orNull) + private def parseArgs(inputArgs: List[String]): Unit = { val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]() diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 858bcaa95b409..8f2267599914c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -162,7 +162,7 @@ trait ClientBase extends Logging { val fs = FileSystem.get(conf) val remoteFs = originalPath.getFileSystem(conf) var newPath = originalPath - if (! compareFs(remoteFs, fs)) { + if (!compareFs(remoteFs, fs)) { newPath = new Path(dstDir, originalPath.getName()) logInfo("Uploading " + originalPath + " to " + newPath) FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) @@ -250,6 +250,7 @@ trait ClientBase extends Logging { } } } + logInfo("Prepared Local resources " + localResources) sparkConf.set(ClientBase.CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) UserGroupInformation.getCurrentUser().addCredentials(credentials) diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 039cf4f276119..412dfe38d55eb 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -70,9 +70,7 @@ private[spark] class YarnClientSchedulerBackend( ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), - ("--name", "SPARK_YARN_APP_NAME", "spark.app.name"), - ("--files", "SPARK_YARN_DIST_FILES", "spark.yarn.dist.files"), - ("--archives", "SPARK_YARN_DIST_ARCHIVES", "spark.yarn.dist.archives")) + ("--name", "SPARK_YARN_APP_NAME", "spark.app.name")) .foreach { case (optName, envVar, sysProp) => addArg(optName, envVar, sysProp, argsArrayBuf) } logDebug("ClientArguments called with: " + argsArrayBuf)