diff --git a/.rat-excludes b/.rat-excludes index 52b2dfac5cf2b..15344dfb292db 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -22,6 +22,7 @@ spark-env.sh.template log4j-defaults.properties sorttable.js .*txt +.*json .*data .*log cloudpickle.py diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 599c3ac9b57c0..a8bc141208a94 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -87,3 +87,24 @@ span.kill-link { span.kill-link a { color: gray; } + +span.expand-details { + font-size: 10pt; + cursor: pointer; + color: grey; + float: right; +} + +.stage-details { + max-height: 100px; + overflow-y: auto; + margin: 0; + transition: max-height 0.5s ease-out, padding 0.5s ease-out; +} + +.stage-details.collapsed { + max-height: 0; + padding-top: 0; + padding-bottom: 0; + border: none; +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 35970c2f50892..f9476ff826a62 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -49,7 +49,7 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} +import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -224,7 +224,6 @@ class SparkContext(config: SparkConf) extends Logging { /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration: Configuration = { - val env = SparkEnv.get val hadoopConf = SparkHadoopUtil.get.newConfiguration() // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && @@ -1036,9 +1035,11 @@ class SparkContext(config: SparkConf) extends Logging { * Capture the current user callsite and return a formatted version for printing. If the user * has overridden the call site, this will return the user's version. */ - private[spark] def getCallSite(): String = { - val defaultCallSite = Utils.getCallSiteInfo - Option(getLocalProperty("externalCallSite")).getOrElse(defaultCallSite.toString) + private[spark] def getCallSite(): CallSite = { + Option(getLocalProperty("externalCallSite")) match { + case Some(callSite) => CallSite(callSite, long = "") + case None => Utils.getCallSite + } } /** @@ -1058,11 +1059,11 @@ class SparkContext(config: SparkConf) extends Logging { } val callSite = getCallSite val cleanedFunc = clean(func) - logInfo("Starting job: " + callSite) + logInfo("Starting job: " + callSite.short) val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) - logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + logInfo("Job finished: " + callSite.short + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() } @@ -1143,11 +1144,11 @@ class SparkContext(config: SparkConf) extends Logging { evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { val callSite = getCallSite - logInfo("Starting job: " + callSite) + logInfo("Starting job: " + callSite.short) val start = System.nanoTime val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.get) - logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + logInfo("Job finished: " + callSite.short + ", took " + (System.nanoTime - start) / 1e9 + " s") result } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 14fa9d8135afe..4f3081433a542 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -543,6 +543,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + partitioner: Partitioner) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, partitioner))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -558,6 +570,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + numPartitions: Int) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, numPartitions))) + /** Alias for cogroup. */ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] = fromRDD(cogroupResultToJava(rdd.groupWith(other))) @@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3))) + /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. @@ -786,6 +828,15 @@ object JavaPairRDD { .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) } + private[spark] + def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( + rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) + : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { + rddToPairRDDFunctions(rdd) + .mapValues(x => + (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + } + def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { new JavaPairRDD[K, V](rdd) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 23d13710794af..86fb374bef1e3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -17,10 +17,13 @@ package org.apache.spark.api.java +import java.util.Comparator + import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -172,6 +175,19 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) rdd.setName(name) this } + + /** + * Return this RDD sorted by the given key function. + */ + def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.call(x) + import com.google.common.collect.Ordering // shadows scala.math.Ordering + implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]] + implicit val ctag: ClassTag[S] = fakeClassTag + wrapRDD(rdd.sortBy(fn, ascending, numPartitions)) + } + } object JavaRDD { diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 5da9615c9e9af..39150deab863c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -21,6 +21,8 @@ import scala.collection.mutable.ListBuffer import org.apache.log4j.Level +import org.apache.spark.util.MemoryParam + /** * Command-line parser for the driver client. */ @@ -51,8 +53,8 @@ private[spark] class ClientArguments(args: Array[String]) { cores = value.toInt parse(tail) - case ("--memory" | "-m") :: value :: tail => - memory = value.toInt + case ("--memory" | "-m") :: MemoryParam(value) :: tail => + memory = value parse(tail) case ("--supervise" | "-s") :: tail => diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index 37dfa7fec0831..9f34d01e6db48 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -19,9 +19,9 @@ package org.apache.spark.deploy private[spark] object ExecutorState extends Enumeration { - val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value + val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST, EXITED = Value type ExecutorState = Value - def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST).contains(state) + def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST, EXITED).contains(state) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 46b9f4dc7d3ba..72d0589689e71 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.master import java.util.Date import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import akka.actor.ActorRef @@ -36,6 +37,7 @@ private[spark] class ApplicationInfo( @transient var state: ApplicationState.Value = _ @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _ + @transient var removedExecutors: ArrayBuffer[ExecutorInfo] = _ @transient var coresGranted: Int = _ @transient var endTime: Long = _ @transient var appSource: ApplicationSource = _ @@ -51,6 +53,7 @@ private[spark] class ApplicationInfo( endTime = -1L appSource = new ApplicationSource(this) nextExecutorId = 0 + removedExecutors = new ArrayBuffer[ExecutorInfo] } private def newExecutorId(useID: Option[Int] = None): Int = { @@ -74,6 +77,7 @@ private[spark] class ApplicationInfo( def removeExecutor(exec: ExecutorInfo) { if (executors.contains(exec.id)) { + removedExecutors += executors(exec.id) executors -= exec.id coresGranted -= exec.cores } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala index 76db61dd619c6..d417070c51016 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala @@ -34,4 +34,19 @@ private[spark] class ExecutorInfo( } def fullId: String = application.id + "/" + id + + override def equals(other: Any): Boolean = { + other match { + case info: ExecutorInfo => + fullId == info.fullId && + worker.id == info.worker.id && + cores == info.cores && + memory == info.memory + case _ => false + } + } + + override def toString: String = fullId + + override def hashCode: Int = toString.hashCode() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c6dec305bffcb..33ffcbd216954 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -303,10 +303,11 @@ private[spark] class Master( appInfo.removeExecutor(exec) exec.worker.removeExecutor(exec) + val normalExit = exitStatus.exists(_ == 0) // Only retry certain number of times so we don't go into an infinite loop. - if (appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) { + if (!normalExit && appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) { schedule() - } else { + } else if (!normalExit) { logError("Application %s with ID %s failed %d times, removing it".format( appInfo.desc.name, appInfo.id, appInfo.retryCount)) removeApplication(appInfo, ApplicationState.FAILED) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index b5cd4d2ea963f..34fa1429c86de 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -25,7 +25,7 @@ import scala.xml.Node import akka.pattern.ask import org.json4s.JValue -import org.apache.spark.deploy.JsonProtocol +import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorInfo import org.apache.spark.ui.{WebUIPage, UIUtils} @@ -57,43 +57,55 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app }) val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") - val executors = app.executors.values.toSeq - val executorTable = UIUtils.listingTable(executorHeaders, executorRow, executors) + val allExecutors = (app.executors.values ++ app.removedExecutors).toSet.toSeq + // This includes executors that are either still running or have exited cleanly + val executors = allExecutors.filter { exec => + !ExecutorState.isFinished(exec.state) || exec.state == ExecutorState.EXITED + } + val removedExecutors = allExecutors.diff(executors) + val executorsTable = UIUtils.listingTable(executorHeaders, executorRow, executors) + val removedExecutorsTable = UIUtils.listingTable(executorHeaders, executorRow, removedExecutors) val content = -
-
- -
+
+
+
    +
  • ID: {app.id}
  • +
  • Name: {app.desc.name}
  • +
  • User: {app.desc.user}
  • +
  • Cores: + { + if (app.desc.maxCores.isEmpty) { + "Unlimited (%s granted)".format(app.coresGranted) + } else { + "%s (%s granted, %s left)".format( + app.desc.maxCores.get, app.coresGranted, app.coresLeft) + } + } +
  • +
  • + Executor Memory: + {Utils.megabytesToString(app.desc.memoryPerSlave)} +
  • +
  • Submit Date: {app.submitDate}
  • +
  • State: {app.state}
  • +
  • Application Detail UI
  • +
+
-
-
-

Executor Summary

- {executorTable} -
-
; +
+
+

Executor Summary

+ {executorsTable} + { + if (removedExecutors.nonEmpty) { +

Removed Executors

++ + removedExecutorsTable + } + } +
+
; UIUtils.basicSparkPage(content, "Application: " + app.desc.name) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d09136de49807..6433aac1c23e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -154,11 +154,10 @@ private[spark] class ExecutorRunner( Files.write(header, stderr, Charsets.UTF_8) stderrAppender = FileAppender(process.getErrorStream, stderr, conf) - // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run - // long-lived processes only. However, in the future, we might restart the executor a few - // times on the same machine. + // Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown) + // or with nonzero exit code val exitCode = process.waitFor() - state = ExecutorState.FAILED + state = ExecutorState.EXITED val message = "Command exited with code " + exitCode worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index d4513118ced05..327b905032800 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -46,74 +46,62 @@ private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] val workerState = Await.result(stateFuture, timeout) - val executorHeaders = Seq("ExecutorID", "Cores", "Memory", "Job Details", "Logs") + val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") + val runningExecutors = workerState.executors val runningExecutorTable = - UIUtils.listingTable(executorHeaders, executorRow, workerState.executors) + UIUtils.listingTable(executorHeaders, executorRow, runningExecutors) + val finishedExecutors = workerState.finishedExecutors val finishedExecutorTable = - UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors) + UIUtils.listingTable(executorHeaders, executorRow, finishedExecutors) val driverHeaders = Seq("DriverID", "Main Class", "State", "Cores", "Memory", "Logs", "Notes") val runningDrivers = workerState.drivers.sortBy(_.driverId).reverse val runningDriverTable = UIUtils.listingTable(driverHeaders, driverRow, runningDrivers) val finishedDrivers = workerState.finishedDrivers.sortBy(_.driverId).reverse - def finishedDriverTable = UIUtils.listingTable(driverHeaders, driverRow, finishedDrivers) + val finishedDriverTable = UIUtils.listingTable(driverHeaders, driverRow, finishedDrivers) // For now we only show driver information if the user has submitted drivers to the cluster. // This is until we integrate the notion of drivers and applications in the UI. - def hasDrivers = runningDrivers.length > 0 || finishedDrivers.length > 0 val content = -
-
-
    -
  • ID: {workerState.workerId}
  • -
  • - Master URL: {workerState.masterUrl} -
  • -
  • Cores: {workerState.cores} ({workerState.coresUsed} Used)
  • -
  • Memory: {Utils.megabytesToString(workerState.memory)} - ({Utils.megabytesToString(workerState.memoryUsed)} Used)
  • -
-

Back to Master

-
+
+
+
    +
  • ID: {workerState.workerId}
  • +
  • + Master URL: {workerState.masterUrl} +
  • +
  • Cores: {workerState.cores} ({workerState.coresUsed} Used)
  • +
  • Memory: {Utils.megabytesToString(workerState.memory)} + ({Utils.megabytesToString(workerState.memoryUsed)} Used)
  • +
+

Back to Master

- -
-
-

Running Executors {workerState.executors.size}

- {runningExecutorTable} -
-
- // scalastyle:off -
- {if (hasDrivers) -
-
-

Running Drivers {workerState.drivers.size}

- {runningDriverTable} -
-
+
+
+
+

Running Executors ({runningExecutors.size})

+ {runningExecutorTable} + { + if (runningDrivers.nonEmpty) { +

Running Drivers ({runningDrivers.size})

++ + runningDriverTable + } } -
- -
-
-

Finished Executors

- {finishedExecutorTable} -
-
- -
- {if (hasDrivers) -
-
-

Finished Drivers

- {finishedDriverTable} -
-
+ { + if (finishedExecutors.nonEmpty) { +

Finished Executors ({finishedExecutors.size})

++ + finishedExecutorTable + } } -
; - // scalastyle:on + { + if (finishedDrivers.nonEmpty) { +

Finished Drivers ({finishedDrivers.size})

++ + finishedDriverTable + } + } +
+
; UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( workerState.host, workerState.port)) } @@ -122,6 +110,7 @@ private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { {executor.execId} {executor.cores} + {executor.state} {Utils.megabytesToString(executor.memory)} diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index b6ad9b6c3e168..443d1c587c3ee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -567,6 +567,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) new FlatMappedValuesRDD(self, cleanF) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + partitioner: Partitioner) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner) + cg.mapValues { case Seq(vs, w1s, w2s, w3s) => + (vs.asInstanceOf[Seq[V]], + w1s.asInstanceOf[Seq[W1]], + w2s.asInstanceOf[Seq[W2]], + w3s.asInstanceOf[Seq[W3]]) + } + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -599,6 +621,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -633,6 +665,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, new HashPartitioner(numPartitions)) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + numPartitions: Int) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, new HashPartitioner(numPartitions)) + } + /** Alias for cogroup. */ def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, defaultPartitioner(self, other)) @@ -644,6 +689,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * Return an RDD with the pairs from `this` whose keys are not in `other`. * @@ -787,8 +838,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance - if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) && - jobFormat.isInstanceOf[NewFileOutputFormat[_, _]]) { + if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { // FileOutputFormat ignores the filesystem parameter jobFormat.checkOutputSpecs(job) } @@ -854,8 +904,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") - if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) && - outputFormatInstance.isInstanceOf[FileOutputFormat[_, _]]) { + if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { // FileOutputFormat ignores the filesystem parameter val ignoredFs = FileSystem.get(conf) conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 2425929fc73c5..66c71bf7e8bb5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -117,6 +117,15 @@ private object ParallelCollectionRDD { if (numSlices < 1) { throw new IllegalArgumentException("Positive number of slices required") } + // Sequences need to be sliced at the same set of index positions for operations + // like RDD.zip() to behave as expected + def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { + (0 until numSlices).iterator.map(i => { + val start = ((i * length) / numSlices).toInt + val end = (((i + 1) * length) / numSlices).toInt + (start, end) + }) + } seq match { case r: Range.Inclusive => { val sign = if (r.step < 0) { @@ -128,18 +137,17 @@ private object ParallelCollectionRDD { r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) } case r: Range => { - (0 until numSlices).map(i => { - val start = ((i * r.length.toLong) / numSlices).toInt - val end = (((i + 1) * r.length.toLong) / numSlices).toInt - new Range(r.start + start * r.step, r.start + end * r.step, r.step) - }).asInstanceOf[Seq[Seq[T]]] + positions(r.length, numSlices).map({ + case (start, end) => + new Range(r.start + start * r.step, r.start + end * r.step, r.step) + }).toSeq.asInstanceOf[Seq[Seq[T]]] } case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) - val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything var r = nr - for (i <- 0 until numSlices) { + for ((start, end) <- positions(nr.length, numSlices)) { + val sliceSize = end - start slices += r.take(sliceSize).asInstanceOf[Seq[T]] r = r.drop(sliceSize) } @@ -147,11 +155,10 @@ private object ParallelCollectionRDD { } case _ => { val array = seq.toArray // To prevent O(n^2) operations for List etc - (0 until numSlices).map(i => { - val start = ((i * array.length.toLong) / numSlices).toInt - val end = (((i + 1) * array.length.toLong) / numSlices).toInt - array.slice(start, end).toSeq - }) + positions(array.length, numSlices).map({ + case (start, end) => + array.slice(start, end).toSeq + }).toSeq } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 446f369c9ea16..cebfd109d825f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -40,7 +40,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -442,6 +442,18 @@ abstract class RDD[T: ClassTag]( */ def ++(other: RDD[T]): RDD[T] = this.union(other) + /** + * Return this RDD sorted by the given key function. + */ + def sortBy[K]( + f: (T) => K, + ascending: Boolean = true, + numPartitions: Int = this.partitions.size) + (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = + this.keyBy[K](f) + .sortByKey(ascending, numPartitions) + .values + /** * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. @@ -1062,11 +1074,11 @@ abstract class RDD[T: ClassTag]( * Returns the top K (largest) elements from this RDD as defined by the specified * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: * {{{ - * sc.parallelize([10, 4, 2, 12, 3]).top(1) - * // returns [12] + * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) + * // returns Array(12) * - * sc.parallelize([2, 3, 4, 5, 6]).top(2) - * // returns [6, 5] + * sc.parallelize(Seq(2, 3, 4, 5, 6)).top(2) + * // returns Array(6, 5) * }}} * * @param num the number of top elements to return @@ -1080,11 +1092,11 @@ abstract class RDD[T: ClassTag]( * implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]]. * For example: * {{{ - * sc.parallelize([10, 4, 2, 12, 3]).takeOrdered(1) - * // returns [12] + * sc.parallelize(Seq(10, 4, 2, 12, 3)).takeOrdered(1) + * // returns Array(2) * - * sc.parallelize([2, 3, 4, 5, 6]).takeOrdered(2) - * // returns [2, 3] + * sc.parallelize(Seq(2, 3, 4, 5, 6)).takeOrdered(2) + * // returns Array(2, 3) * }}} * * @param num the number of top elements to return @@ -1189,8 +1201,8 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ - @transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo - private[spark] def getCreationSite: String = Option(creationSiteInfo).getOrElse("").toString + @transient private[spark] val creationSite = Utils.getCallSite + private[spark] def getCreationSite: String = Option(creationSite).map(_.short).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 9257f48559c9e..b755d8fb15757 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.util.Properties import org.apache.spark.TaskContext +import org.apache.spark.util.CallSite /** * Tracks information about an active job in the DAGScheduler. @@ -29,7 +30,7 @@ private[spark] class ActiveJob( val finalStage: Stage, val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], - val callSite: String, + val callSite: CallSite, val listener: JobListener, val properties: Properties) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3c85b5a2ae776..b3ebaa547de0d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} -import org.apache.spark.util.{SystemClock, Clock, Utils} +import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -195,7 +195,9 @@ class DAGScheduler( case Some(stage) => stage case None => val stage = - newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) + newOrUsedStage( + shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, + shuffleDep.rdd.creationSite) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -212,7 +214,7 @@ class DAGScheduler( numTasks: Int, shuffleDep: Option[ShuffleDependency[_, _, _]], jobId: Int, - callSite: Option[String] = None) + callSite: CallSite) : Stage = { val id = nextStageId.getAndIncrement() @@ -235,7 +237,7 @@ class DAGScheduler( numTasks: Int, shuffleDep: ShuffleDependency[_, _, _], jobId: Int, - callSite: Option[String] = None) + callSite: CallSite) : Stage = { val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) @@ -413,7 +415,7 @@ class DAGScheduler( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - callSite: String, + callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties = null): JobWaiter[U] = @@ -443,7 +445,7 @@ class DAGScheduler( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - callSite: String, + callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties = null) @@ -452,7 +454,7 @@ class DAGScheduler( waiter.awaitResult() match { case JobSucceeded => {} case JobFailed(exception: Exception) => - logInfo("Failed to run " + callSite) + logInfo("Failed to run " + callSite.short) throw exception } } @@ -461,7 +463,7 @@ class DAGScheduler( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], - callSite: String, + callSite: CallSite, timeout: Long, properties: Properties = null) : PartialResult[R] = @@ -666,7 +668,7 @@ class DAGScheduler( func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], allowLocal: Boolean, - callSite: String, + callSite: CallSite, listener: JobListener, properties: Properties = null) { @@ -674,7 +676,7 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newStage(finalRDD, partitions.size, None, jobId, Some(callSite)) + finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) @@ -685,7 +687,7 @@ class DAGScheduler( val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format( - job.jobId, callSite, partitions.length, allowLocal)) + job.jobId, callSite.short, partitions.length, allowLocal)) logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 23f57441b4b11..2b6f7e4205c32 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -25,6 +25,7 @@ import scala.language.existentials import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD +import org.apache.spark.util.CallSite /** * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue @@ -40,7 +41,7 @@ private[scheduler] case class JobSubmitted( func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], allowLocal: Boolean, - callSite: String, + callSite: CallSite, listener: JobListener, properties: Properties = null) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 3bf9713f728c6..9a4be43ee219f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.CallSite /** * A stage is a set of independent tasks all computing the same function that need to run as part @@ -35,6 +36,11 @@ import org.apache.spark.storage.BlockManagerId * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. + * + * The callSite provides a location in user code which relates to the stage. For a shuffle map + * stage, the callSite gives the user code that created the RDD being shuffled. For a result + * stage, the callSite gives the user code that executes the associated action (e.g. count()). + * */ private[spark] class Stage( val id: Int, @@ -43,7 +49,7 @@ private[spark] class Stage( val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, - callSite: Option[String]) + val callSite: CallSite) extends Logging { val isShuffleMap = shuffleDep.isDefined @@ -100,7 +106,8 @@ private[spark] class Stage( id } - val name = callSite.getOrElse(rdd.getCreationSite) + val name = callSite.short + val details = callSite.long override def toString = "Stage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index b42e231e11f91..7644e3f351b3c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -25,7 +25,12 @@ import org.apache.spark.storage.RDDInfo * Stores information about a stage to pass from the scheduler to SparkListeners. */ @DeveloperApi -class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo]) { +class StageInfo( + val stageId: Int, + val name: String, + val numTasks: Int, + val rddInfos: Seq[RDDInfo], + val details: String) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when all tasks in the stage completed or when the stage was cancelled. */ @@ -52,6 +57,6 @@ private[spark] object StageInfo { def fromStage(stage: Stage): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos - new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos) + new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos, stage.details) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index f52bc7075104b..d2f7baf928b62 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -363,6 +363,15 @@ private[spark] class BlockManager( val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { + // Double check to make sure the block is still there. There is a small chance that the + // block has been removed by removeBlock (which also synchronizes on the blockInfo object). + // Note that this only checks metadata tracking. If user intentionally deleted the block + // on disk or from off heap storage without using removeBlock, this conditional check will + // still pass but eventually we will get an exception because we can't find the block. + if (blockInfo.get(blockId).isEmpty) { + logWarning(s"Block $blockId had been removed") + return None + } // If another thread is writing the block, wait for it to become ready. if (!info.waitForReady()) { diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 023fd6e4d8baa..5a72e216872a6 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -26,7 +26,7 @@ class RDDInfo( val id: Int, val name: String, val numPartitions: Int, - val storageLevel: StorageLevel) + var storageLevel: StorageLevel) extends Ordered[RDDInfo] { var numCachedPartitions = 0 @@ -36,8 +36,8 @@ class RDDInfo( override def toString = { import Utils.bytesToString - ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; " + - "TachyonSize: %s; DiskSize: %s").format( + ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + + "MemorySize: %s; TachyonSize: %s; DiskSize: %s").format( name, id, storageLevel.toString, numCachedPartitions, numPartitions, bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize)) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index a6e6627d54e01..c694fc8c347ec 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -37,7 +37,11 @@ class StorageStatusListener extends SparkListener { val filteredStatus = storageStatusList.find(_.blockManagerId.executorId == execId) filteredStatus.foreach { storageStatus => updatedBlocks.foreach { case (blockId, updatedStatus) => - storageStatus.blocks(blockId) = updatedStatus + if (updatedStatus.storageLevel == StorageLevel.NONE) { + storageStatus.blocks.remove(blockId) + } else { + storageStatus.blocks(blockId) = updatedStatus + } } } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 6f3252a2f6d31..f3bde1df45c79 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -89,10 +89,13 @@ private[spark] object StorageUtils { // Add up memory, disk and Tachyon sizes val persistedBlocks = blocks.filter { status => status.memSize + status.diskSize + status.tachyonSize > 0 } + val _storageLevel = + if (persistedBlocks.length > 0) persistedBlocks(0).storageLevel else StorageLevel.NONE val memSize = persistedBlocks.map(_.memSize).reduceOption(_ + _).getOrElse(0L) val diskSize = persistedBlocks.map(_.diskSize).reduceOption(_ + _).getOrElse(0L) val tachyonSize = persistedBlocks.map(_.tachyonSize).reduceOption(_ + _).getOrElse(0L) rddInfoMap.get(rddId).map { rddInfo => + rddInfo.storageLevel = _storageLevel rddInfo.numCachedPartitions = persistedBlocks.length rddInfo.memSize = memSize rddInfo.diskSize = diskSize diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 153434a2032be..a3f824a4e1f57 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -91,9 +91,17 @@ private[ui] class StageTableBase( {s.name} + val details = if (s.details.nonEmpty) ( + + +show details + + + ) + listener.stageIdToDescription.get(s.stageId) .map(d =>
{d}
{nameLink} {killLink}
) - .getOrElse(
{killLink}{nameLink}
) + .getOrElse(
{killLink} {nameLink} {details}
) } protected def stageRow(s: StageInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 09825087bb048..7cecbfe62a382 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -184,6 +184,7 @@ private[spark] object JsonProtocol { ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ + ("Details" -> stageInfo.details) ~ ("Submission Time" -> submissionTime) ~ ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ @@ -469,12 +470,13 @@ private[spark] object JsonProtocol { val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) + val details = (json \ "Details").extractOpt[String].getOrElse("") val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) val emittedTaskSizeWarning = (json \ "Emitted Task Size Warning").extract[Boolean] - val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos) + val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos, details) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 7ebed5105b9fd..2889e171f627e 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -91,8 +91,13 @@ private[spark] object MetadataCleaner { conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } + /** + * Set the default delay time (in seconds). + * @param conf SparkConf instance + * @param delay default delay time to set + * @param resetAll whether to reset all to default + */ def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) { - // override for all ? conf.set("spark.cleaner.ttl", delay.toString) if (resetAll) { for (cleanerType <- MetadataCleanerType.values) { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4ce28bb0cf059..a2454e120a8ab 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -43,6 +43,9 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +/** CallSite represents a place in user code. It can have a short and a long form. */ +private[spark] case class CallSite(val short: String, val long: String) + /** * Various utility methods used by Spark. */ @@ -799,21 +802,12 @@ private[spark] object Utils extends Logging { */ private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r - private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, - val firstUserLine: Int, val firstUserClass: String) { - - /** Returns a printable version of the call site info suitable for logs. */ - override def toString = { - "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) - } - } - /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. */ - def getCallSiteInfo: CallSiteInfo = { + def getCallSite: CallSite = { val trace = Thread.currentThread.getStackTrace() .filterNot(_.getMethodName.contains("getStackTrace")) @@ -824,11 +818,11 @@ private[spark] object Utils extends Logging { var lastSparkMethod = "" var firstUserFile = "" var firstUserLine = 0 - var finished = false - var firstUserClass = "" + var insideSpark = true + var callStack = new ArrayBuffer[String]() :+ "" for (el <- trace) { - if (!finished) { + if (insideSpark) { if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName).isDefined) { lastSparkMethod = if (el.getMethodName == "") { // Spark method is a constructor; get its class name @@ -836,15 +830,21 @@ private[spark] object Utils extends Logging { } else { el.getMethodName } + callStack(0) = el.toString // Put last Spark method on top of the stack trace. } else { firstUserLine = el.getLineNumber firstUserFile = el.getFileName - firstUserClass = el.getClassName - finished = true + callStack += el.toString + insideSpark = false } + } else { + callStack += el.toString } } - new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) + val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt + CallSite( + short = "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine), + long = callStack.take(callStackDepth).mkString("\n")) } /** Return a string containing part of a file from byte 'start' to 'end'. */ diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index ef41bfb88de9d..761f2d6a77d33 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -21,6 +21,9 @@ import java.util.*; import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; + import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; @@ -180,6 +183,39 @@ public void sortByKey() { Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); } + @Test + public void sortBy() { + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(0, 4)); + pairs.add(new Tuple2(3, 2)); + pairs.add(new Tuple2(-1, 1)); + + JavaRDD> rdd = sc.parallelize(pairs); + + // compare on first value + JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() { + public Integer call(Tuple2 t) throws Exception { + return t._1(); + } + }, true, 2); + + Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + List> sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + + // compare on second value + sortedRDD = rdd.sortBy(new Function, Integer>() { + public Integer call(Tuple2 t) throws Exception { + return t._2(); + } + }, true, 2); + Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(2)); + } + @Test public void foreach() { final Accumulator accum = sc.accumulator(0); @@ -271,6 +307,66 @@ public void cogroup() { cogrouped.collect(); } + @SuppressWarnings("unchecked") + @Test + public void cogroup3() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + + JavaPairRDD, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + + + cogrouped.collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void cogroup4() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", "BR"), + new Tuple2("Apples", "US") + )); + + JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities, countries); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); + + cogrouped.collect(); + } + @SuppressWarnings("unchecked") @Test public void leftOuterJoin() { diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index cd3887dcc7371..1fde4badda949 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -70,7 +70,7 @@ package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite - val curCallSite = sc.getCallSite() // note: 2 lines after definition of "rdd" + val curCallSite = sc.getCallSite().short // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 0b9004448a63e..447e38ec9dbd0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { )) } + test("groupWith3") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val joined = rdd1.groupWith(rdd2, rdd3).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'))), + (2, (List(1), List('y', 'z'), List())), + (3, (List(1), List(), List('b'))), + (4, (List(), List('w'), List('c', 'd'))) + )) + } + + test("groupWith4") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val rdd4 = sc.parallelize(Array((2, '@'))) + val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'), List())), + (2, (List(1), List('y', 'z'), List(), List('@'))), + (3, (List(1), List(), List('b'), List())), + (4, (List(), List('w'), List('c', 'd'), List())) + )) + } + test("zero-partition RDD") { val emptyDir = Files.createTempDir() emptyDir.deleteOnExit() diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 4df36558b6d4b..1b112f1a41ca9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -111,6 +111,24 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.forall(_.isInstanceOf[Range])) } + test("identical slice sizes between Range and NumericRange") { + val r = ParallelCollectionRDD.slice(1 to 7, 4) + val nr = ParallelCollectionRDD.slice(1L to 7L, 4) + assert(r.size === 4) + for (i <- 0 until r.size) { + assert(r(i).size === nr(i).size) + } + } + + test("identical slice sizes between List and NumericRange") { + val r = ParallelCollectionRDD.slice(List(1, 2), 4) + val nr = ParallelCollectionRDD.slice(1L to 2L, 4) + assert(r.size === 4) + for (i <- 0 until r.size) { + assert(r(i).size === nr(i).size) + } + } + test("large ranges don't overflow") { val N = 100 * 1000 * 1000 val data = 0 until N diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e94a1e76d410c..0e5625b7645d5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.util.Utils +import org.apache.spark.rdd.RDDSuiteUtils._ + class RDDSuite extends FunSuite with SharedSparkContext { test("basic operations") { @@ -585,14 +587,63 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("sortByKey") { + val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + + val col1 = Array("4|60|C", "5|50|A", "6|40|B") + val col2 = Array("6|40|B", "5|50|A", "4|60|C") + val col3 = Array("5|50|A", "6|40|B", "4|60|C") + + assert(data.sortBy(_.split("\\|")(0)).collect() === col1) + assert(data.sortBy(_.split("\\|")(1)).collect() === col2) + assert(data.sortBy(_.split("\\|")(2)).collect() === col3) + } + + test("sortByKey ascending parameter") { + val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + + val asc = Array("4|60|C", "5|50|A", "6|40|B") + val desc = Array("6|40|B", "5|50|A", "4|60|C") + + assert(data.sortBy(_.split("\\|")(0), true).collect() === asc) + assert(data.sortBy(_.split("\\|")(0), false).collect() === desc) + } + + test("sortByKey with explicit ordering") { + val data = sc.parallelize(Seq("Bob|Smith|50", + "Jane|Smith|40", + "Thomas|Williams|30", + "Karen|Williams|60")) + + val ageOrdered = Array("Thomas|Williams|30", + "Jane|Smith|40", + "Bob|Smith|50", + "Karen|Williams|60") + + // last name, then first name + val nameOrdered = Array("Bob|Smith|50", + "Jane|Smith|40", + "Karen|Williams|60", + "Thomas|Williams|30") + + val parse = (s: String) => { + val split = s.split("\\|") + Person(split(0), split(1), split(2).toInt) + } + + import scala.reflect.classTag + assert(data.sortBy(parse, true, 2)(AgeOrdering, classTag[Person]).collect() === ageOrdered) + assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered) + } + test("intersection") { val all = sc.parallelize(1 to 10) val evens = sc.parallelize(2 to 10 by 2) val intersection = Array(2, 4, 6, 8, 10) // intersection is commutative - assert(all.intersection(evens).collect.sorted === intersection) - assert(evens.intersection(all).collect.sorted === intersection) + assert(all.intersection(evens).collect().sorted === intersection) + assert(evens.intersection(all).collect().sorted === intersection) } test("intersection strips duplicates in an input") { @@ -600,8 +651,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { val b = sc.parallelize(Seq(1,1,2,3)) val intersection = Array(1,2,3) - assert(a.intersection(b).collect.sorted === intersection) - assert(b.intersection(a).collect.sorted === intersection) + assert(a.intersection(b).collect().sorted === intersection) + assert(b.intersection(a).collect().sorted === intersection) } test("zipWithIndex") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala similarity index 51% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala rename to core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala index 9c5d7c81f7c09..4762fc17855ce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala @@ -15,27 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.api.java +package org.apache.spark.rdd -import org.scalatest.FunSuite +object RDDSuiteUtils { + case class Person(first: String, last: String, age: Int) -import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.hive.test.TestHive - -// Implicits -import scala.collection.JavaConversions._ - -class JavaHiveSQLSuite extends FunSuite { - ignore("SELECT * FROM src") { - val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) - // There is a little trickery here to avoid instantiating two HiveContexts in the same JVM - val javaSqlCtx = new JavaHiveContext(javaCtx) { - override val sqlContext = TestHive - } + object AgeOrdering extends Ordering[Person] { + def compare(a:Person, b:Person) = a.age compare b.age + } - assert( - javaSqlCtx.hql("SELECT * FROM src").collect().map(_.getInt(0)) === - TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) + object NameOrdering extends Ordering[Person] { + def compare(a:Person, b:Person) = + implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 7506d56d7e26d..45368328297d3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} +import org.apache.spark.util.CallSite class BuggyDAGEventProcessActor extends Actor { val state = 0 @@ -211,7 +212,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F allowLocal: Boolean = false, listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, CallSite("", ""), listener)) jobId } @@ -251,7 +252,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, jobListener)) + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results === Map(0 -> 42)) assertDataStructuresEmpty } @@ -265,7 +266,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, jobListener)) + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results.size == 0) assertDataStructuresEmpty } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index be506e0287a16..abd7b22310f1a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -239,11 +239,14 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers checkNonZeroAvg( taskInfoMetrics.map(_._2.executorDeserializeTime), stageInfo + " executorDeserializeTime") + + /* Test is disabled (SEE SPARK-2208) if (stageInfo.rddInfos.exists(_.name == d4.name)) { checkNonZeroAvg( taskInfoMetrics.map(_._2.shuffleReadMetrics.get.fetchWaitTime), stageInfo + " fetchWaitTime") } + */ taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => taskMetrics.resultSize should be > (0l) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 91b4c7b0dd962..c3a14f48de38e 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -32,12 +32,12 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val listener = new JobProgressListener(conf) def createStageStartEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, stageId.toString, 0, null) + val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "") SparkListenerStageSubmitted(stageInfo) } def createStageEndEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, stageId.toString, 0, null) + val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "") SparkListenerStageCompleted(stageInfo) } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 53d7f5c6072e6..02e228945bbd9 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -120,7 +120,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // on SparkConf settings. def testAppenderSelection[ExpectedAppender: ClassTag, ExpectedRollingPolicy]( - properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): FileAppender = { + properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): Unit = { // Set spark conf properties val conf = new SparkConf @@ -129,8 +129,9 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { } // Create and test file appender - val inputStream = new PipedInputStream(new PipedOutputStream()) - val appender = FileAppender(inputStream, new File("stdout"), conf) + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) + val appender = FileAppender(testInputStream, testFile, conf) assert(appender.isInstanceOf[ExpectedAppender]) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) @@ -144,7 +145,8 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { } assert(policyParam === expectedRollingPolicyParam) } - appender + testOutputStream.close() + appender.awaitTermination() } import RollingFileAppender._ diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 3031015256ec9..f72389b6b323f 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -117,6 +117,17 @@ class JsonProtocolSuite extends FunSuite { testBlockId(StreamBlockId(1, 2L)) } + test("Backward compatibility") { + // StageInfo.details was added after 1.0.0. + val info = makeStageInfo(1, 2, 3, 4L, 5L) + assert(info.details.nonEmpty) + val newJson = JsonProtocol.stageInfoToJson(info) + val oldJson = newJson.removeField { case (field, _) => field == "Details" } + val newInfo = JsonProtocol.stageInfoFromJson(oldJson) + assert(info.name === newInfo.name) + assert("" === newInfo.details) + } + /** -------------------------- * | Helper test running methods | @@ -235,6 +246,7 @@ class JsonProtocolSuite extends FunSuite { (0 until info1.rddInfos.size).foreach { i => assertEquals(info1.rddInfos(i), info2.rddInfos(i)) } + assert(info1.details === info2.details) } private def assertEquals(info1: RDDInfo, info2: RDDInfo) { @@ -438,7 +450,7 @@ class JsonProtocolSuite extends FunSuite { private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { val rddInfos = (1 to a % 5).map { i => makeRddInfo(a % i, b % i, c % i, d % i, e % i) } - new StageInfo(a, "greetings", b, rddInfos) + new StageInfo(a, "greetings", b, rddInfos, "details") } private def makeTaskInfo(a: Long, b: Int, c: Long) = { diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 97e8f4e9661b6..ae9ede58e8e60 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -147,9 +147,9 @@ are developed, see the linear methods section for example. -The SGD method -[GradientDescent.runMiniBatchSGD](api/scala/index.html#org.apache.spark.mllib.optimization.GradientDescent) -has the following parameters: +The SGD class +[GradientDescent](api/scala/index.html#org.apache.spark.mllib.optimization.GradientDescent) +sets the following parameters: * `Gradient` is a class that computes the stochastic gradient of the function being optimized, i.e., with respect to a single training example, at the @@ -171,7 +171,7 @@ each iteration, to compute the gradient direction. Available algorithms for gradient descent: -* [GradientDescent.runMiniBatchSGD](api/scala/index.html#org.apache.spark.mllib.optimization.GradientDescent) +* [GradientDescent](api/scala/index.html#org.apache.spark.mllib.optimization.GradientDescent) ### L-BFGS L-BFGS is currently only a low-level optimization primitive in `MLlib`. If you want to use L-BFGS in various diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 79784682bfd1b..65d75b85efda6 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -377,13 +377,15 @@ Some notes on reading files with Spark: * The `textFile` method also takes an optional second argument for controlling the number of slices of the file. By default, Spark creates one slice for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of slices by passing a larger value. Note that you cannot have fewer slices than blocks. -Apart from reading files as a collection of lines, -`SparkContext.wholeTextFiles` lets you read a directory containing multiple small text files, and returns each of them as (filename, content) pairs. This is in contrast with `textFile`, which would return one record per line in each file. +Apart from text files, Spark's Python API also supports several other data formats: -### SequenceFile and Hadoop InputFormats +* `SparkContext.wholeTextFiles` lets you read a directory containing multiple small text files, and returns each of them as (filename, content) pairs. This is in contrast with `textFile`, which would return one record per line in each file. + +* `RDD.saveAsPickleFile` and `SparkContext.pickleFile` support saving an RDD in a simple format consisting of pickled Python objects. Batching is used on pickle serialization, with default batch size 10. -In addition to reading text files, PySpark supports reading ```SequenceFile``` -and any arbitrary ```InputFormat```. +* Details on reading `SequenceFile` and arbitrary Hadoop `InputFormat` are given below. + +### SequenceFile and Hadoop InputFormats **Note** this feature is currently marked ```Experimental``` and is intended for advanced users. It may be replaced in future with read/write support based on SparkSQL, in which case SparkSQL is the preferred approach. @@ -760,6 +762,11 @@ val counts = pairs.reduceByKey((a, b) => a + b) We could also use `counts.sortByKey()`, for example, to sort the pairs alphabetically, and finally `counts.collect()` to bring them back to the driver program as an array of objects. +**Note:** when using custom objects as the key in key-value pair operations, you must be sure that a +custom `equals()` method is accompanied with a matching `hashCode()` method. For full details, see +the contract outlined in the [Object.hashCode() +documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#hashCode()). +
@@ -792,6 +799,10 @@ JavaPairRDD counts = pairs.reduceByKey((a, b) -> a + b); We could also use `counts.sortByKey()`, for example, to sort the pairs alphabetically, and finally `counts.collect()` to bring them back to the driver program as an array of objects. +**Note:** when using custom objects as the key in key-value pair operations, you must be sure that a +custom `equals()` method is accompanied with a matching `hashCode()` method. For full details, see +the contract outlined in the [Object.hashCode() +documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#hashCode()).
@@ -888,7 +899,7 @@ for details. reduceByKey(func, [numTasks]) - When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. + When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) @@ -1056,7 +1067,10 @@ storage levels is: Store RDD in serialized format in Tachyon. Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors to be smaller and to share a pool of memory, making it attractive in environments with - large heaps or multiple concurrent applications. + large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, + the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory + in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts + from memory. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index af1788f2aa151..fecd8f2cc2d48 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -67,6 +67,34 @@ Most of the configs are the same for Spark on YARN as for other deployment modes The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + + spark.yarn.dist.archives + (none) + + Comma separated list of archives to be extracted into the working directory of each executor. + + + + spark.yarn.dist.files + (none) + + Comma-separated list of files to be placed in the working directory of each executor. + + + + spark.yarn.executor.memoryOverhead + 384 + + The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. + + + + spark.yarn.driver.memoryOverhead + 384 + + The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. + + By default, Spark on YARN will use a Spark jar installed locally, but the Spark JAR can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a JAR on HDFS, `export SPARK_JAR=hdfs:///some/path`. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4623bb4247d77..522c83884ef42 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -17,20 +17,20 @@ Spark. At the core of this component is a new type of RDD, [Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table in a traditional relational database. A SchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io) -file, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`.
-Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using +Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using Spark. At the core of this component is a new type of RDD, [JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed [Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects along with a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io) -file, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
@@ -41,7 +41,7 @@ Spark. At the core of this component is a new type of RDD, [Row](api/python/pyspark.sql.Row-class.html) objects along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table in a traditional relational database. A SchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io) -file, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell.
@@ -64,8 +64,8 @@ descendants. To create a basic SQLContext, all you need is a SparkContext. val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// Importing the SQL context gives access to all the public SQL functions and implicit conversions. -import sqlContext._ +// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. +import sqlContext.createSchemaRDD {% endhighlight %} @@ -77,8 +77,8 @@ The entry point into all relational functionality in Spark is the of its descendants. To create a basic JavaSQLContext, all you need is a JavaSparkContext. {% highlight java %} -JavaSparkContext ctx = ...; // An existing JavaSparkContext. -JavaSQLContext sqlCtx = new org.apache.spark.sql.api.java.JavaSQLContext(ctx); +JavaSparkContext sc = ...; // An existing JavaSparkContext. +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); {% endhighlight %} @@ -91,14 +91,33 @@ of its decedents. To create a basic SQLContext, all you need is a SparkContext. {% highlight python %} from pyspark.sql import SQLContext -sqlCtx = SQLContext(sc) +sqlContext = SQLContext(sc) {% endhighlight %} -## Running SQL on RDDs +# Data Sources + +
+
+Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. +Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. +
+ +
+Spark SQL supports operating on a variety of data sources through the `JavaSchemaRDD` interface. +Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. +
+ +
+Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. +Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. +
+
+ +## RDDs
@@ -111,8 +130,10 @@ types such as Sequences or Arrays. This RDD can be implicitly converted to a Sch registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} +// sc is an existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -import sqlContext._ +// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. +import sqlContext.createSchemaRDD // Define the schema using a case class. // Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, @@ -124,7 +145,7 @@ val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(" people.registerAsTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. -val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -170,12 +191,11 @@ A schema can be applied to an existing RDD by calling `applySchema` and providin for the JavaBean. {% highlight java %} - -JavaSparkContext ctx = ...; // An existing JavaSparkContext. -JavaSQLContext sqlCtx = new org.apache.spark.sql.api.java.JavaSQLContext(ctx) +// sc is an existing JavaSparkContext. +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc) // Load a text file and convert each line to a JavaBean. -JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map( +JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( new Function() { public Person call(String line) throws Exception { String[] parts = line.split(","); @@ -189,11 +209,11 @@ JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt"). }); // Apply a schema to an RDD of JavaBeans and register it as a table. -JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); +JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); schemaPeople.registerAsTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -215,6 +235,10 @@ row. Any RDD of dictionaries can converted to a SchemaRDD and then registered as can be used in subsequent SQL statements. {% highlight python %} +# sc is an existing SparkContext. +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) + # Load a text file and convert each line to a dictionary. lines = sc.textFile("examples/src/main/resources/people.txt") parts = lines.map(lambda l: l.split(",")) @@ -223,14 +247,16 @@ people = parts.map(lambda p: {"name": p[0], "age": int(p[1])}) # Infer the schema, and register the SchemaRDD as a table. # In future versions of PySpark we would like to add support for registering RDDs with other # datatypes as tables -peopleTable = sqlCtx.inferSchema(people) -peopleTable.registerAsTable("people") +schemaPeople = sqlContext.inferSchema(people) +schemaPeople.registerAsTable("people") # SQL can be run over SchemaRDDs that have been registered as a table. -teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") # The results of SQL queries are RDDs and support all the normal RDD operations. teenNames = teenagers.map(lambda p: "Name: " + p.name) +for teenName in teenNames.collect(): + print teenName {% endhighlight %}
@@ -241,7 +267,7 @@ teenNames = teenagers.map(lambda p: "Name: " + p.name) Users that want a more complete dialect of SQL should look at the HiveQL support provided by `HiveContext`. -## Using Parquet +## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema @@ -252,22 +278,23 @@ of the original data. Using the data from the above example:
{% highlight scala %} -val sqlContext = new org.apache.spark.sql.SQLContext(sc) -import sqlContext._ +// sqlContext from the previous example is used in this example. +// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. +import sqlContext.createSchemaRDD val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. -// The RDD is implicitly converted to a SchemaRDD, allowing it to be stored using Parquet. +// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet. people.saveAsParquetFile("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a Parquet file is also a JavaSchemaRDD. +// The result of loading a Parquet file is also a SchemaRDD. val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerAsTable("parquetFile") -val teenagers = sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") -teenagers.collect().foreach(println) +val teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %}
@@ -275,6 +302,7 @@ teenagers.collect().foreach(println)
{% highlight java %} +// sqlContext from the previous example is used in this example. JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example. @@ -283,13 +311,16 @@ schemaPeople.saveAsParquetFile("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a JavaSchemaRDD. -JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); +JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerAsTable("parquetFile"); -JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); - - +JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); +List teenagerNames = teenagers.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0); + } +}).collect(); {% endhighlight %}
@@ -297,50 +328,149 @@ JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >=
{% highlight python %} +# sqlContext from the previous example is used in this example. -peopleTable # The SchemaRDD from the previous example. +schemaPeople # The SchemaRDD from the previous example. # SchemaRDDs can be saved as Parquet files, maintaining the schema information. -peopleTable.saveAsParquetFile("people.parquet") +schemaPeople.saveAsParquetFile("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a SchemaRDD. -parquetFile = sqlCtx.parquetFile("people.parquet") +parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerAsTable("parquetFile"); -teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") - +teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenNames = teenagers.map(lambda p: "Name: " + p.name) +for teenName in teenNames.collect(): + print teenName {% endhighlight %}
-## Writing Language-Integrated Relational Queries +## JSON Datasets +
-**Language-Integrated queries are currently only supported in Scala.** +
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. +This conversion can be done using one of two methods in a SQLContext: -Spark SQL also supports a domain specific language for writing queries. Once again, -using the data from the above examples: +* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. +* `jsonRdd` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. {% highlight scala %} +// sc is an existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -import sqlContext._ -val people: RDD[Person] = ... // An RDD of case class objects, from the first example. -// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' -val teenagers = people.where('age >= 10).where('age <= 19).select('name) +// A JSON dataset is pointed to by path. +// The path can be either a single text file or a directory storing text files. +val path = "examples/src/main/resources/people.json" +// Create a SchemaRDD from the file(s) pointed to by path +val people = sqlContext.jsonFile(path) + +// The inferred schema can be visualized using the printSchema() method. +people.printSchema() +// root +// |-- age: IntegerType +// |-- name: StringType + +// Register this SchemaRDD as a table. +people.registerAsTable("people") + +// SQL statements can be run by using the sql methods provided by sqlContext. +val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") + +// Alternatively, a SchemaRDD can be created for a JSON dataset represented by +// an RDD[String] storing one JSON object per string. +val anotherPeopleRDD = sc.parallelize( + """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) +val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) {% endhighlight %} -The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers -prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are -evaluated by the SQL execution engine. A full list of the functions supported can be found in the -[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). +
- +
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a JavaSchemaRDD. +This conversion can be done using one of two methods in a JavaSQLContext : -# Hive Support +* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. +* `jsonRdd` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. + +{% highlight java %} +// sc is an existing JavaSparkContext. +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); + +// A JSON dataset is pointed to by path. +// The path can be either a single text file or a directory storing text files. +String path = "examples/src/main/resources/people.json"; +// Create a JavaSchemaRDD from the file(s) pointed to by path +JavaSchemaRDD people = sqlContext.jsonFile(path); + +// The inferred schema can be visualized using the printSchema() method. +people.printSchema(); +// root +// |-- age: IntegerType +// |-- name: StringType + +// Register this JavaSchemaRDD as a table. +people.registerAsTable("people"); + +// SQL statements can be run by using the sql methods provided by sqlContext. +JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + +// Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by +// an RDD[String] storing one JSON object per string. +List jsonData = Arrays.asList( + "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); +JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); +JavaSchemaRDD anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +{% endhighlight %} +
+ +
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. +This conversion can be done using one of two methods in a SQLContext: + +* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. +* `jsonRdd` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. + +{% highlight python %} +# sc is an existing SparkContext. +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) + +# A JSON dataset is pointed to by path. +# The path can be either a single text file or a directory storing text files. +path = "examples/src/main/resources/people.json" +# Create a SchemaRDD from the file(s) pointed to by path +people = sqlContext.jsonFile(path) + +# The inferred schema can be visualized using the printSchema() method. +people.printSchema() +# root +# |-- age: IntegerType +# |-- name: StringType + +# Register this SchemaRDD as a table. +people.registerAsTable("people") + +# SQL statements can be run by using the sql methods provided by sqlContext. +teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") + +# Alternatively, a SchemaRDD can be created for a JSON dataset represented by +# an RDD[String] storing one JSON object per string. +anotherPeopleRDD = sc.parallelize([ + '{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}']) +anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) +{% endhighlight %} +
+ +
+ +## Hive Tables Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. @@ -362,17 +492,14 @@ which is similar to `HiveContext`, but creates a local copy of the `metastore` a automatically. {% highlight scala %} -val sc: SparkContext // An existing SparkContext. +// sc is an existing SparkContext. val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc) -// Importing the SQL context gives access to all the public SQL functions and implicit conversions. -import hiveContext._ - -hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL -hql("FROM src SELECT key, value").collect().foreach(println) +hiveContext.hql("FROM src SELECT key, value").collect().foreach(println) {% endhighlight %} @@ -385,14 +512,14 @@ the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allow expressed in HiveQL. {% highlight java %} -JavaSparkContext ctx = ...; // An existing JavaSparkContext. -JavaHiveContext hiveCtx = new org.apache.spark.sql.hive.api.java.HiveContext(ctx); +// sc is an existing JavaSparkContext. +JavaHiveContext hiveContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); -hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); -hiveCtx.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); +hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); +hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); // Queries are expressed in HiveQL. -Row[] results = hiveCtx.hql("FROM src SELECT key, value").collect(); +Row[] results = hiveContext.hql("FROM src SELECT key, value").collect(); {% endhighlight %} @@ -406,17 +533,44 @@ the `sql` method a `HiveContext` also provides an `hql` methods, which allows qu expressed in HiveQL. {% highlight python %} - +# sc is an existing SparkContext. from pyspark.sql import HiveContext -hiveCtx = HiveContext(sc) +hiveContext = HiveContext(sc) -hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveCtx.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = hiveCtx.hql("FROM src SELECT key, value").collect() +results = hiveContext.hql("FROM src SELECT key, value").collect() {% endhighlight %} + + +# Writing Language-Integrated Relational Queries + +**Language-Integrated queries are currently only supported in Scala.** + +Spark SQL also supports a domain specific language for writing queries. Once again, +using the data from the above examples: + +{% highlight scala %} +// sc is an existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) +// Importing the SQL context gives access to all the public SQL functions and implicit conversions. +import sqlContext._ +val people: RDD[Person] = ... // An RDD of case class objects, from the first example. + +// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' +val teenagers = people.where('age >= 10).where('age <= 19).select('name) +teenagers.map(t => "Name: " + t(0)).collect().foreach(println) +{% endhighlight %} + +The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers +prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are +evaluated by the SQL execution engine. A full list of the functions supported can be found in the +[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). + + \ No newline at end of file diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index bbee67f54c6b8..ce8e58d64a7ed 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -950,7 +950,7 @@ is 200 milliseconds. An alternative to receiving data with multiple input streams / receivers is to explicitly repartition the input data stream (using `inputStream.repartition()`). -This distributes the received batches of data across all the machines in the cluster +This distributes the received batches of data across specified number of machines in the cluster before further processing. ### Level of Parallelism in Data Processing diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 52a89cb2481ca..a40311d9fcf02 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -689,9 +689,23 @@ def ssh(host, opts, command): time.sleep(30) tries = tries + 1 +# Backported from Python 2.7 for compatiblity with 2.6 (See SPARK-1990) +def _check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output + def ssh_read(host, opts, command): - return subprocess.check_output( + return _check_output( ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index ad5ec84b71e69..607df3eddd550 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.sql; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import org.apache.spark.SparkConf; @@ -56,6 +57,7 @@ public static void main(String[] args) throws Exception { JavaSparkContext ctx = new JavaSparkContext(sparkConf); JavaSQLContext sqlCtx = new JavaSQLContext(ctx); + System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map( new Function() { @@ -84,16 +86,88 @@ public String call(Row row) { return "Name: " + row.getString(0); } }).collect(); + for (String name: teenagerNames) { + System.out.println(name); + } + System.out.println("=== Data source: Parquet File ==="); // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); - // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. + // Read in the parquet file created above. + // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a JavaSchemaRDD. JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerAsTable("parquetFile"); - JavaSchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + JavaSchemaRDD teenagers2 = + sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + teenagerNames = teenagers2.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0); + } + }).collect(); + for (String name: teenagerNames) { + System.out.println(name); + } + + System.out.println("=== Data source: JSON Dataset ==="); + // A JSON dataset is pointed by path. + // The path can be either a single text file or a directory storing text files. + String path = "examples/src/main/resources/people.json"; + // Create a JavaSchemaRDD from the file(s) pointed by path + JavaSchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + + // Because the schema of a JSON dataset is automatically inferred, to write queries, + // it is better to take a look at what is the schema. + peopleFromJsonFile.printSchema(); + // The schema of people is ... + // root + // |-- age: IntegerType + // |-- name: StringType + + // Register this JavaSchemaRDD as a table. + peopleFromJsonFile.registerAsTable("people"); + + // SQL statements can be run by using the sql methods provided by sqlCtx. + JavaSchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + + // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. + // The columns of a row in the result can be accessed by ordinal. + teenagerNames = teenagers3.map(new Function() { + public String call(Row row) { return "Name: " + row.getString(0); } + }).collect(); + for (String name: teenagerNames) { + System.out.println(name); + } + + // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by + // a RDD[String] storing one JSON object per string. + List jsonData = Arrays.asList( + "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); + JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); + JavaSchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD); + + // Take a look at the schema of this new JavaSchemaRDD. + peopleFromJsonRDD.printSchema(); + // The schema of anotherPeople is ... + // root + // |-- address: StructType + // | |-- city: StringType + // | |-- state: StringType + // |-- name: StringType + + peopleFromJsonRDD.registerAsTable("people2"); + + JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + List nameAndCity = peopleWithCity.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0) + ", City: " + row.getString(1); + } + }).collect(); + for (String name: nameAndCity) { + System.out.println(name); + } } } diff --git a/examples/src/main/resources/people.json b/examples/src/main/resources/people.json new file mode 100644 index 0000000000000..50a859cbd7ee8 --- /dev/null +++ b/examples/src/main/resources/people.json @@ -0,0 +1,3 @@ +{"name":"Michael"} +{"name":"Andy", "age":30} +{"name":"Justin", "age":19} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 00d0b18c27a8d..1a0073c9d487e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -419,7 +419,7 @@ class RowMatrix( /** Updates or verifies the number of rows. */ private def updateNumRows(m: Long) { if (nRows <= 0) { - nRows == m + nRows = m } else { require(nRows == m, s"The number of rows $m is different from what specified or previously computed: ${nRows}.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 8f187c9df5102..7bbed9c8fdbef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -60,7 +60,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. */ - def setConvergenceTol(tolerance: Int): this.type = { + def setConvergenceTol(tolerance: Double): this.type = { this.convergenceTol = tolerance this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 4b1850659a18e..fe7a9033cd5f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -195,4 +195,38 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { assert(lossLBFGS3.length == 6) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } + + test("Optimize via class LBFGS.") { + val regParam = 0.2 + + // Prepare another non-zero weights to compare the loss in the first iteration. + val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) + val convergenceTol = 1e-12 + val maxNumIterations = 10 + + val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) + .setNumCorrections(numCorrections) + .setConvergenceTol(convergenceTol) + .setMaxNumIterations(maxNumIterations) + .setRegParam(regParam) + + val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) + + val numGDIterations = 50 + val stepSize = 1.0 + val (weightGD, _) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + squaredL2Updater, + stepSize, + numGDIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // for class LBFGS and the optimize method, we only look at the weights + assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && + compareDouble(weightLBFGS(1), weightGD(1), 0.02), + "The weight differences between LBFGS and GD should be within 2%.") + } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2d60a44f04f6f..7bb39dc77120b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -76,7 +76,7 @@ object SparkBuild extends Build { lazy val catalyst = Project("catalyst", file("sql/catalyst"), settings = catalystSettings) dependsOn(core) - lazy val sql = Project("sql", file("sql/core"), settings = sqlCoreSettings) dependsOn(core, catalyst) + lazy val sql = Project("sql", file("sql/core"), settings = sqlCoreSettings) dependsOn(core) dependsOn(catalyst % "compile->compile;test->test") lazy val hive = Project("hive", file("sql/hive"), settings = hiveSettings) dependsOn(sql) @@ -501,9 +501,23 @@ object SparkBuild extends Build { def sqlCoreSettings = sharedSettings ++ Seq( name := "spark-sql", libraryDependencies ++= Seq( - "com.twitter" % "parquet-column" % parquetVersion, - "com.twitter" % "parquet-hadoop" % parquetVersion - ) + "com.twitter" % "parquet-column" % parquetVersion, + "com.twitter" % "parquet-hadoop" % parquetVersion, + "com.fasterxml.jackson.core" % "jackson-databind" % "2.3.0" // json4s-jackson 3.2.6 requires jackson-databind 2.3.0. + ), + initialCommands in console := + """ + |import org.apache.spark.sql.catalyst.analysis._ + |import org.apache.spark.sql.catalyst.dsl._ + |import org.apache.spark.sql.catalyst.errors._ + |import org.apache.spark.sql.catalyst.expressions._ + |import org.apache.spark.sql.catalyst.plans.logical._ + |import org.apache.spark.sql.catalyst.rules._ + |import org.apache.spark.sql.catalyst.types._ + |import org.apache.spark.sql.catalyst.util._ + |import org.apache.spark.sql.execution + |import org.apache.spark.sql.test.TestSQLContext._ + |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin ) // Since we don't include hive in the main assembly this project also acts as an alternative diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 91ae8263f66b8..19235d5f79f85 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -43,12 +43,19 @@ def launch_gateway(): # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE, preexec_fn=preexec_func) else: # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE) - # Determine which ephemeral port the server started on: - gateway_port = int(proc.stdout.readline()) + proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE) + + try: + # Determine which ephemeral port the server started on: + gateway_port = int(proc.stdout.readline()) + except: + error_code = proc.poll() + raise Exception("Launching GatewayServer failed with exit code %d: %s" % + (error_code, "".join(proc.stderr.readlines()))) + # Create a thread to echo output from the GatewayServer, which is required # for Java log output to show up: class EchoOutputThread(Thread): diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 6f94d26ef86a9..5f3a7e71f7866 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -79,15 +79,15 @@ def dispatch(seq): return _do_python_join(rdd, other, numPartitions, dispatch) -def python_cogroup(rdd, other, numPartitions): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) +def python_cogroup(rdds, numPartitions): + def make_mapper(i): + return lambda (k, v): (k, (i, v)) + vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) + rdd_len = len(vrdds) def dispatch(seq): - vbuf, wbuf = [], [] + bufs = [[] for i in range(rdd_len)] for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return (ResultIterable(vbuf), ResultIterable(wbuf)) - return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch) + bufs[n].append(v) + return tuple(map(ResultIterable, bufs)) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ddd22850a819c..1d55c35a8bf48 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -512,7 +512,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x): [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)] """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() bounds = list() @@ -549,6 +549,18 @@ def mapFunc(iterator): .mapPartitions(mapFunc,preservesPartitioning=True) .flatMap(lambda x: x, preservesPartitioning=True)) + def sortBy(self, keyfunc, ascending=True, numPartitions=None): + """ + Sorts this RDD by the given keyfunc + + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortBy(lambda x: x[0]).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] + >>> sc.parallelize(tmp).sortBy(lambda x: x[1]).collect() + [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + """ + return self.keyBy(keyfunc).sortByKey(ascending, numPartitions).values() + def glom(self): """ Return an RDD created by coalescing all elements within each partition @@ -845,7 +857,7 @@ def top(self, num): Note: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] - >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2) + >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2) [6, 5] """ def topIterator(iterator): @@ -1142,7 +1154,7 @@ def partitionBy(self, numPartitions, partitionFunc=None): set([]) """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() if partitionFunc is None: partitionFunc = lambda x: 0 if x is None else hash(x) @@ -1200,7 +1212,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, [('a', '11'), ('b', '1')] """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() def combineLocally(iterator): combiners = {} for x in iterator: @@ -1221,7 +1233,7 @@ def _mergeCombiners(iterator): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) - + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -1233,7 +1245,7 @@ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ def createZero(): return copy.deepcopy(zeroValue) - + return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) def foldByKey(self, zeroValue, func, numPartitions=None): @@ -1311,12 +1323,20 @@ def mapValues(self, f): map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) - # TODO: support varargs cogroup of several RDDs. - def groupWith(self, other): + def groupWith(self, other, *others): """ - Alias for cogroup. + Alias for cogroup but with support for multiple RDDs. + + >>> w = sc.parallelize([("a", 5), ("b", 6)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> z = sc.parallelize([("b", 42)]) + >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \ + sorted(list(w.groupWith(x, y, z).collect()))) + [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))] + """ - return self.cogroup(other) + return python_cogroup((self, other) + others, numPartitions=None) # TODO: add variant with custom parittioner def cogroup(self, other, numPartitions=None): @@ -1330,7 +1350,7 @@ def cogroup(self, other, numPartitions=None): >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) [('a', ([1], [2])), ('b', ([4], []))] """ - return python_cogroup(self, other, numPartitions) + return python_cogroup((self, other), numPartitions) def subtractByKey(self, other, numPartitions=None): """ @@ -1448,9 +1468,12 @@ def toDebugString(self): def getStorageLevel(self): """ Get the RDD's current storage level. + >>> rdd1 = sc.parallelize([1,2]) >>> rdd1.getStorageLevel() StorageLevel(False, False, False, False, 1) + >>> print(rdd1.getStorageLevel()) + Serialized 1x Replicated """ java_storage_level = self._jrdd.getStorageLevel() storage_level = StorageLevel(java_storage_level.useDisk(), @@ -1460,6 +1483,21 @@ def getStorageLevel(self): java_storage_level.replication()) return storage_level + def _defaultReducePartitions(self): + """ + Returns the default number of partitions to use during reduce tasks (e.g., groupBy). + If spark.default.parallelism is set, then we'll use the value from SparkContext + defaultParallelism, otherwise we'll use the number of partitions in this RDD. + + This mirrors the behavior of the Scala Partitioner#defaultPartitioner, intended to reduce + the likelihood of OOMs. Once PySpark adopts Partitioner-based APIs, this behavior will + be inherent. + """ + if self.ctx._conf.contains("spark.default.parallelism"): + return self.ctx.defaultParallelism + else: + return self.getNumPartitions() + # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the # keys in the pairs. This could be an expensive operation, since those diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 960d0a82448aa..5051c82da32a7 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,7 +15,8 @@ # limitations under the License. # -from pyspark.rdd import RDD +from pyspark.rdd import RDD, PipelinedRDD +from pyspark.serializers import BatchedSerializer, PickleSerializer from py4j.protocol import Py4JError @@ -76,12 +77,25 @@ def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to determine the fields names - and types, and then use that to extract all the dictionaries. + and types, and then use that to extract all the dictionaries. Nested + collections are supported, which include array, dict, list, set, and + tuple. >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, ... {"field1" : 3, "field2": "row3"}] True + + >>> from array import array + >>> srdd = sqlCtx.inferSchema(nestedRdd1) + >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, + ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] + True + + >>> srdd = sqlCtx.inferSchema(nestedRdd2) + >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, + ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] + True """ if (rdd.__class__ is SchemaRDD): raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) @@ -123,6 +137,53 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) + + def jsonFile(self, path): + """Loads a text file storing one JSON object per line, + returning the result as a L{SchemaRDD}. + It goes through the entire dataset once to determine the schema. + + >>> import tempfile, shutil + >>> jsonFile = tempfile.mkdtemp() + >>> shutil.rmtree(jsonFile) + >>> ofn = open(jsonFile, 'w') + >>> for json in jsonStrings: + ... print>>ofn, json + >>> ofn.close() + >>> srdd = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1") + >>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}}, + ... {"f1": 2, "f2": "row2", "f3":{"field4":22}}, + ... {"f1": 3, "f2": "row3", "f3":{"field4":33}}] + True + """ + jschema_rdd = self._ssql_ctx.jsonFile(path) + return SchemaRDD(jschema_rdd, self) + + def jsonRDD(self, rdd): + """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. + It goes through the entire dataset once to determine the schema. + + >>> srdd = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1") + >>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}}, + ... {"f1": 2, "f2": "row2", "f3":{"field4":22}}, + ... {"f1": 3, "f2": "row3", "f3":{"field4":33}}] + True + """ + def func(split, iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + yield x.encode("utf-8") + keyed = PipelinedRDD(rdd, func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + return SchemaRDD(jschema_rdd, self) + def sql(self, sqlQuery): """Return a L{SchemaRDD} representing the result of the given query. @@ -251,7 +312,7 @@ class SchemaRDD(RDD): For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the L{SchemaRDD} is not operated on directly, as it's underlying - implementation is a RDD composed of Java objects. Instead it is + implementation is an RDD composed of Java objects. Instead it is converted to a PythonRDD in the JVM, on which Python operations can be done. """ @@ -323,6 +384,14 @@ def saveAsTable(self, tableName): """Creates a new table with the contents of this SchemaRDD.""" self._jschema_rdd.saveAsTable(tableName) + def schemaString(self): + """Returns the output schema in the tree format.""" + return self._jschema_rdd.schemaString() + + def printSchema(self): + """Prints out the schema in the tree format.""" + print self.schemaString() + def count(self): """Return the number of elements in this RDD. @@ -346,7 +415,8 @@ def _toPython(self): # TODO: This is inefficient, we should construct the Python Row object # in Java land in the javaToPython function. May require a custom # pickle serializer in Pyrolite - return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d)) + return RDD(jrdd, self._sc, BatchedSerializer( + PickleSerializer())).map(lambda d: Row(d)) # We override the default cache/persist/checkpoint behavior as we want to cache the underlying # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class @@ -411,6 +481,7 @@ def subtract(self, other, numPartitions=None): def _test(): import doctest + from array import array from pyspark.context import SparkContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, @@ -420,6 +491,17 @@ def _test(): globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) + jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field2": "row2", "field3":{"field4":22}}', + '{"field1" : 3, "field2": "row3", "field3":{"field4":33}}'] + globs['jsonStrings'] = jsonStrings + globs['json'] = sc.parallelize(jsonStrings) + globs['nestedRdd1'] = sc.parallelize([ + {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, + {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) + globs['nestedRdd2'] = sc.parallelize([ + {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, + {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 7b6660eab231b..3a18ea54eae4c 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -36,6 +36,15 @@ def __repr__(self): return "StorageLevel(%s, %s, %s, %s, %s)" % ( self.useDisk, self.useMemory, self.useOffHeap, self.deserialized, self.replication) + def __str__(self): + result = "" + result += "Disk " if self.useDisk else "" + result += "Memory " if self.useMemory else "" + result += "Tachyon " if self.useOffHeap else "" + result += "Deserialized " if self.deserialized else "Serialized " + result += "%sx Replicated" % self.replication + return result + StorageLevel.DISK_ONLY = StorageLevel(True, False, False, False) StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, False, 2) StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, True) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6c78c34486010..01d7b569080ea 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -66,6 +66,34 @@ org.scalatest scalatest-maven-plugin + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + test-jar-on-compile + compile + + test-jar + + + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 46fcfbb9e26ba..0cc4592047b19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -66,43 +66,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected case class Keyword(str: String) protected implicit def asParser(k: Keyword): Parser[String] = - allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - - protected class SqlLexical extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString = chars - } - override lazy val token: Parser[Token] = ( - identChar ~ rep( identChar | digit ) ^^ - { case first ~ rest => processIdent(first :: rest mkString "") } - | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { - case i ~ None => NumericLit(i mkString "") - case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) - } - | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ - { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } - | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ - { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '\"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar = letter | elem('.') | elem('_') - - override def whitespace: Parser[Any] = rep( - whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) - | '#' ~ rep( chrExcept(EofCh, '\n') ) - | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) - | '/' ~ '*' ~ failure("unclosed comment") - ) - } - - override val lexical = new SqlLexical + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") @@ -161,24 +125,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { this.getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword]) - - /** Generate all variations of upper and lower case of a given string */ - private def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s == "") { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) ++ - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } + .map(_.invoke(this).asInstanceOf[Keyword].str) - lexical.reserved ++= reservedWords.flatMap(w => allCaseVersions(w.str)) - - lexical.delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]" - ) + override val lexical = new SqlLexical(reservedWords) protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { @@ -309,13 +258,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) }) protected lazy val comparisonExpression: Parser[Expression] = - termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Equals(e1, e2) } | + termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => EqualTo(e1, e2) } | termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } | termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } | termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } | termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | - termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | - termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | + termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | + termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | @@ -383,7 +332,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { + expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | TRUE ^^^ Literal(true, BooleanType) | @@ -399,3 +348,55 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType } + +class SqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => allCaseVersions(w)) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]" + ) + + override lazy val token: Parser[Token] = ( + identChar ~ rep( identChar | digit ) ^^ + { case first ~ rest => processIdent(first :: rest mkString "") } + | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { + case i ~ None => NumericLit(i mkString "") + case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) + } + | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ + { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } + | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ + { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '\"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') | elem('.') + + override def whitespace: Parser[Any] = rep( + whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) + | '#' ~ rep( chrExcept(EofCh, '\n') ) + | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) + | '/' ~ '*' ~ failure("unclosed comment") + ) + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 326feea6fee91..76ddeba9cb312 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -22,8 +22,18 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types._ +object HiveTypeCoercion { + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. + // The conversion for integral and floating point types have a linear widening hierarchy: + val numericPrecedence = + Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) + // Boolean is only wider than Void + val booleanPrecedence = Seq(NullType, BooleanType) + val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil +} + /** - * A collection of [[catalyst.rules.Rule Rules]] that can be used to coerce differing types that + * A collection of [[Rule Rules]] that can be used to coerce differing types that * participate in operations into compatible ones. Most of these rules are based on Hive semantics, * but they do not introduce any dependencies on the hive codebase. For this reason they remain in * Catalyst until we have a more standard set of coercions. @@ -31,12 +41,20 @@ import org.apache.spark.sql.catalyst.types._ trait HiveTypeCoercion { val typeCoercionRules = - List(PropagateTypes, ConvertNaNs, WidenTypes, PromoteStrings, BooleanComparisons, BooleanCasts, - StringToIntegralCasts, FunctionArgumentConversion) + PropagateTypes :: + ConvertNaNs :: + WidenTypes :: + PromoteStrings :: + BooleanComparisons :: + BooleanCasts :: + StringToIntegralCasts :: + FunctionArgumentConversion :: + CastNulls :: + Nil /** - * Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] data - * types that are made by other rules to instances higher in the query tree. + * Applies any changes to [[AttributeReference]] data types that are made by other rules to + * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -108,19 +126,18 @@ trait HiveTypeCoercion { * * Additionally, all types when UNION-ed with strings will be promoted to strings. * Other string conversions are handled by PromoteStrings. + * + * Widening types might result in loss of precision in the following cases: + * - IntegerType to FloatType + * - LongType to FloatType + * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. - // The conversion for integral and floating point types have a linear widening hierarchy: - val numericPrecedence = - Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) - // Boolean is only wider than Void - val booleanPrecedence = Seq(NullType, BooleanType) - val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { // Try and find a promotion rule that contains both types in question. - val applicableConversion = allPromotions.find(p => p.contains(t1) && p.contains(t2)) + val applicableConversion = + HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) // If found return the widest common type, otherwise None applicableConversion.map(_.filter(t => t == t1 || t == t2).last) @@ -217,8 +234,8 @@ trait HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // No need to change Equals operators as that actually makes sense for boolean types. - case e: Equals => e + // No need to change EqualTo operators as that actually makes sense for boolean types. + case e: EqualTo => e // Otherwise turn them to Byte types so that there exists and ordering. case p: BinaryComparison if p.left.dataType == BooleanType && p.right.dataType == BooleanType => @@ -227,15 +244,20 @@ trait HiveTypeCoercion { } /** - * Casts to/from [[catalyst.types.BooleanType BooleanType]] are transformed into comparisons since + * Casts to/from [[BooleanType]] are transformed into comparisons since * the JVM does not consider Booleans to be numeric types. */ object BooleanCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - - case Cast(e, BooleanType) => Not(Equals(e, Literal(0))) + // Skip if the type is boolean type already. Note that this extra cast should be removed + // by optimizer.SimplifyCasts. + case Cast(e, BooleanType) if e.dataType == BooleanType => e + // If the data type is not boolean and is being cast boolean, turn it into a comparison + // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. + case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) + // Turn true into 1, and false into 0 if casting boolean into other types. case Cast(e, dataType) if e.dataType == BooleanType => Cast(If(e, Literal(1), Literal(0)), dataType) } @@ -282,4 +304,33 @@ trait HiveTypeCoercion { Average(Cast(e, DoubleType)) } } + + /** + * Ensures that NullType gets casted to some other types under certain circumstances. + */ + object CastNulls extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case cw @ CaseWhen(branches) => + val valueTypes = branches.sliding(2, 2).map { + case Seq(_, value) if value.resolved => Some(value.dataType) + case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType) + case _ => None + }.toSeq + if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) { + val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get + val transformedBranches = branches.sliding(2, 2).map { + case Seq(cond, value) if value.resolved && value.dataType == NullType => + Seq(cond, Cast(value, otherType)) + case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType => + Seq(Cast(elseVal, otherType)) + case s => s + }.reduce(_ ++ _) + CaseWhen(transformedBranches) + } else { + // It is possible to have more types due to the possibility of short-circuiting. + cw + } + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d177339d40ae5..26ad4837b0b01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.types._ * * // These unresolved attributes can be used to create more complicated expressions. * scala> 'a === 'b - * res2: org.apache.spark.sql.catalyst.expressions.Equals = ('a = 'b) + * res2: org.apache.spark.sql.catalyst.expressions.EqualTo = ('a = 'b) * * // SQL verbs can be used to construct logical query plans. * scala> import org.apache.spark.sql.catalyst.plans.logical._ @@ -76,8 +76,8 @@ package object dsl { def <= (other: Expression) = LessThanOrEqual(expr, other) def > (other: Expression) = GreaterThan(expr, other) def >= (other: Expression) = GreaterThanOrEqual(expr, other) - def === (other: Expression) = Equals(expr, other) - def !== (other: Expression) = Not(Equals(expr, other)) + def === (other: Expression) = EqualTo(expr, other) + def !== (other: Expression) = Not(EqualTo(expr, other)) def like(other: Expression) = Like(expr, other) def rlike(other: Expression) = RLike(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4ebf6c4584b94..655d4a08fe93b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -68,7 +68,7 @@ class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { } object BindReferences extends Logging { - def bindReference(expression: Expression, input: Seq[Attribute]): Expression = { + def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) @@ -83,6 +83,6 @@ object BindReferences extends Logging { BoundReference(ordinal, a) } } - } + }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0b3a4e728ec54..1f9716e385e9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -24,72 +24,87 @@ import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def foldable = child.foldable - def nullable = (child.dataType, dataType) match { + + override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true case _ => child.nullable } + override def toString = s"CAST($child, $dataType)" type EvaluatedType = Any - def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) { - null - } else { - func(a.asInstanceOf[T]) - } + // [[func]] assumes the input is no longer null because eval already does the null check. + @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) // UDFToString - def castToString: Any => Any = child.dataType match { - case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8")) - case _ => nullOrCast[Any](_, _.toString) + private[this] def castToString: Any => Any = child.dataType match { + case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) + case _ => buildCast[Any](_, _.toString) } // BinaryConverter - def castToBinary: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, _.getBytes("UTF-8")) + private[this] def castToBinary: Any => Any = child.dataType match { + case StringType => buildCast[String](_, _.getBytes("UTF-8")) } // UDFToBoolean - def castToBoolean: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, _.length() != 0) - case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)}) - case LongType => nullOrCast[Long](_, _ != 0) - case IntegerType => nullOrCast[Int](_, _ != 0) - case ShortType => nullOrCast[Short](_, _ != 0) - case ByteType => nullOrCast[Byte](_, _ != 0) - case DecimalType => nullOrCast[BigDecimal](_, _ != 0) - case DoubleType => nullOrCast[Double](_, _ != 0) - case FloatType => nullOrCast[Float](_, _ != 0) + private[this] def castToBoolean: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, _.length() != 0) + case TimestampType => + buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0) + case LongType => + buildCast[Long](_, _ != 0) + case IntegerType => + buildCast[Int](_, _ != 0) + case ShortType => + buildCast[Short](_, _ != 0) + case ByteType => + buildCast[Byte](_, _ != 0) + case DecimalType => + buildCast[BigDecimal](_, _ != 0) + case DoubleType => + buildCast[Double](_, _ != 0) + case FloatType => + buildCast[Float](_, _ != 0) } // TimestampConverter - def castToTimestamp: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => { - // Throw away extra if more than 9 decimal places - val periodIdx = s.indexOf("."); - var n = s - if (periodIdx != -1) { - if (n.length() - periodIdx > 9) { + private[this] def castToTimestamp: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => { + // Throw away extra if more than 9 decimal places + val periodIdx = s.indexOf(".") + var n = s + if (periodIdx != -1 && n.length() - periodIdx > 9) { n = n.substring(0, periodIdx + 10) } - } - try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null} - }) - case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000)) - case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000)) - case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000)) - case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000)) - case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000)) + try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } + }) + case BooleanType => + buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0) * 1000)) + case LongType => + buildCast[Long](_, l => new Timestamp(l * 1000)) + case IntegerType => + buildCast[Int](_, i => new Timestamp(i * 1000)) + case ShortType => + buildCast[Short](_, s => new Timestamp(s * 1000)) + case ByteType => + buildCast[Byte](_, b => new Timestamp(b * 1000)) // TimestampWritable.decimalToTimestamp - case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d)) + case DecimalType => + buildCast[BigDecimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp - case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d)) + case DoubleType => + buildCast[Double](_, d => decimalToTimestamp(d)) // TimestampWritable.floatToTimestamp - case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f)) + case FloatType => + buildCast[Float](_, f => decimalToTimestamp(f)) } - private def decimalToTimestamp(d: BigDecimal) = { + private[this] def decimalToTimestamp(d: BigDecimal) = { val seconds = d.longValue() val bd = (d - seconds) * 1000000000 val nanos = bd.intValue() @@ -104,85 +119,118 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } // Timestamp to long, converting milliseconds to seconds - private def timestampToLong(ts: Timestamp) = ts.getTime / 1000 + private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000 - private def timestampToDouble(ts: Timestamp) = { + private[this] def timestampToDouble(ts: Timestamp) = { // First part is the seconds since the beginning of time, followed by nanosecs. ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000 } - def castToLong: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toLong catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t)) - case DecimalType => nullOrCast[BigDecimal](_, _.toLong) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) - } - - def castToInt: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toInt catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt) - case DecimalType => nullOrCast[BigDecimal](_, _.toInt) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) - } - - def castToShort: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toShort catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort) - case DecimalType => nullOrCast[BigDecimal](_, _.toShort) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort - } - - def castToByte: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toByte catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte) - case DecimalType => nullOrCast[BigDecimal](_, _.toByte) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte - } - - def castToDecimal: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0)) + private[this] def castToLong: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toLong catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1L else 0L) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t)) + case DecimalType => + buildCast[BigDecimal](_, _.toLong) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + } + + private[this] def castToInt: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toInt catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1 else 0) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toInt) + case DecimalType => + buildCast[BigDecimal](_, _.toInt) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + } + + private[this] def castToShort: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toShort catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toShort) + case DecimalType => + buildCast[BigDecimal](_, _.toShort) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort + } + + private[this] def castToByte: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toByte catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toByte) + case DecimalType => + buildCast[BigDecimal](_, _.toByte) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte + } + + private[this] def castToDecimal: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try BigDecimal(s.toDouble) catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) case TimestampType => // Note that we lose precision here. - nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) - case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) - } - - def castToDouble: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toDouble catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t)) - case DecimalType => nullOrCast[BigDecimal](_, _.toDouble) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) - } - - def castToFloat: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toFloat catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat) - case DecimalType => nullOrCast[BigDecimal](_, _.toFloat) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) + buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) + case x: NumericType => + b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + } + + private[this] def castToDouble: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toDouble catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1d else 0d) + case TimestampType => + buildCast[Timestamp](_, t => timestampToDouble(t)) + case DecimalType => + buildCast[BigDecimal](_, _.toDouble) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) + } + + private[this] def castToFloat: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toFloat catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1f else 0f) + case TimestampType => + buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) + case DecimalType => + buildCast[BigDecimal](_, _.toFloat) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } - private lazy val cast: Any => Any = dataType match { + private[this] lazy val cast: Any => Any = dataType match { case StringType => castToString case BinaryType => castToBinary case DecimalType => castToDecimal @@ -198,10 +246,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def eval(input: Row): Any = { val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - cast(evaluated) - } + if (evaluated == null) null else cast(evaluated) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 41398ff956edd..0411ce3aefda1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -28,21 +28,16 @@ abstract class Expression extends TreeNode[Expression] { /** The narrowest possible type that is produced when this expression is evaluated. */ type EvaluatedType <: Any - def dataType: DataType - /** * Returns true when an expression is a candidate for static evaluation before the query is * executed. * * The following conditions are used to determine suitability for constant folding: - * - A [[expressions.Coalesce Coalesce]] is foldable if all of its children are foldable - * - A [[expressions.BinaryExpression BinaryExpression]] is foldable if its both left and right - * child are foldable - * - A [[expressions.Not Not]], [[expressions.IsNull IsNull]], or - * [[expressions.IsNotNull IsNotNull]] is foldable if its child is foldable. - * - A [[expressions.Literal]] is foldable. - * - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its - * child is foldable. + * - A [[Coalesce]] is foldable if all of its children are foldable + * - A [[BinaryExpression]] is foldable if its both left and right child are foldable + * - A [[Not]], [[IsNull]], or [[IsNotNull]] is foldable if its child is foldable + * - A [[Literal]] is foldable + * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable */ def foldable: Boolean = false def nullable: Boolean @@ -53,12 +48,18 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it is still contains any unresolved placeholders. Implementations of expressions + * and `false` if it still contains any unresolved placeholders. Implementations of expressions * should override this if the resolution of this type of expression involves more than just * the resolution of its children. */ lazy val resolved: Boolean = childrenResolved + /** + * Returns the [[DataType]] of the result of evaluating this expression. It is + * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). + */ + def dataType: DataType + /** * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 77b5429bad432..74ae723686cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -208,6 +208,9 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { + def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = + this(ordering.map(BindReferences.bindReference(_, inputSchema))) + def compare(a: Row, b: Row): Int = { var i = 0 while (i < ordering.size) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index b6aeae92f8bec..5d3bb25ad568c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -50,6 +50,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { null } else { if (child.dataType.isInstanceOf[ArrayType]) { + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives val baseValue = value.asInstanceOf[Seq[_]] val o = key.asInstanceOf[Int] if (o >= baseValue.size || o < 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a8145c37c20fa..66ae22e95b60e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -103,7 +103,7 @@ case class Alias(child: Expression, name: String) * A reference to an attribute produced by another operator in the tree. * * @param name The name of this attribute, should only be used during analysis or for debugging. - * @param dataType The [[types.DataType DataType]] of this attribute. + * @param dataType The [[DataType]] of this attribute. * @param nullable True if null is a valid value for this attribute. * @param exprId A globally unique id used to check if different AttributeReferences refer to the * same attribute. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 573ec052f4266..b6f2451b52e1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -24,7 +24,7 @@ package org.apache.spark.sql.catalyst * expression, a [[NamedExpression]] in addition to the standard collection of expressions. * * ==Standard Expressions== - * A library of standard expressions (e.g., [[Add]], [[Equals]]), aggregates (e.g., SUM, COUNT), + * A library of standard expressions (e.g., [[Add]], [[EqualTo]]), aggregates (e.g., SUM, COUNT), * and other computations (e.g. UDFs). Each expression type is capable of determining its output * schema as a function of its children's output schema. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d111578530506..b63406b94a4a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.types.BooleanType @@ -53,7 +52,7 @@ trait PredicateHelper { * * For example consider a join between two relations R(a, b) and S(c, d). * - * `canEvaluate(Equals(a,b), R)` returns `true` where as `canEvaluate(Equals(a,c), R)` returns + * `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns * `false`. */ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = @@ -141,7 +140,7 @@ abstract class BinaryComparison extends BinaryPredicate { self: Product => } -case class Equals(left: Expression, right: Expression) extends BinaryComparison { +case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { def symbol = "=" override def eval(input: Row): Any = { val l = left.eval(input) @@ -202,3 +201,80 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def toString = s"if ($predicate) $trueValue else $falseValue" } + +// scalastyle:off +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + * + * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets + * translated to this form at parsing time. Namely, such a statement gets translated to + * "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END". + * + * Note that `branches` are considered in consecutive pairs (cond, val), and the optional last + * element is the value for the default catch-all case (if provided). Hence, `branches` consists of + * at least two elements, and can have an odd or even length. + */ +// scalastyle:on +case class CaseWhen(branches: Seq[Expression]) extends Expression { + type EvaluatedType = Any + def children = branches + def references = children.flatMap(_.references).toSet + def dataType = { + if (!resolved) { + throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") + } + branches(1).dataType + } + + @transient private[this] lazy val branchesArr = branches.toArray + @transient private[this] lazy val predicates = + branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq + @transient private[this] lazy val values = + branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq + @transient private[this] lazy val elseValue = + if (branches.length % 2 == 0) None else Option(branches.last) + + override def nullable = { + // If no value is nullable and no elseValue is provided, the whole statement defaults to null. + values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + } + + override lazy val resolved = { + if (!childrenResolved) { + false + } else { + val allCondBooleans = predicates.forall(_.dataType == BooleanType) + val dataTypesEqual = values.map(_.dataType).distinct.size <= 1 + allCondBooleans && dataTypesEqual + } + } + + /** Written in imperative fashion for performance considerations. Same for CaseKeyWhen. */ + override def eval(input: Row): Any = { + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + var res: Any = null + while (i < len - 1) { + if (branchesArr(i).eval(input) == true) { + res = branchesArr(i + 1).eval(input) + return res + } + i += 2 + } + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + res + } + + override def toString = { + "CASE" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 25a347bec0e4c..b20b5de8c46eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,13 +95,13 @@ object ColumnPruning extends Rule[LogicalPlan] { Project(substitutedProjection, child) // Eliminate no-op Projects - case Project(projectList, child) if(child.output == projectList) => child + case Project(projectList, child) if child.output == projectList => child } } /** - * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with - * equivalent [[catalyst.expressions.Literal Literal]] values. This rule is more specific with + * Replaces [[Expression Expressions]] that can be statically evaluated with + * equivalent [[Literal]] values. This rule is more specific with * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { @@ -110,8 +110,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) case e @ Sum(Literal(c, _)) if c == 0 => Cast(Literal(0L), e.dataType) case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType) - case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType) - case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType) + case e @ IsNull(c) if !c.nullable => Literal(false, BooleanType) + case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) @@ -154,8 +154,8 @@ object NullPropagation extends Rule[LogicalPlan] { } /** - * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with - * equivalent [[catalyst.expressions.Literal Literal]] values. + * Replaces [[Expression Expressions]] that can be statically evaluated with + * equivalent [[Literal]] values. */ object ConstantFolding extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -197,7 +197,7 @@ object BooleanSimplification extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[catalyst.plans.logical.Filter Filter]] operators into one, merging the + * Combines two adjacent [[Filter]] operators into one, merging the * conditions into one conjunctive predicate. */ object CombineFilters extends Rule[LogicalPlan] { @@ -223,9 +223,8 @@ object SimplifyFilters extends Rule[LogicalPlan] { } /** - * Pushes [[catalyst.plans.logical.Filter Filter]] operators through - * [[catalyst.plans.logical.Project Project]] operators, in-lining any - * [[catalyst.expressions.Alias Aliases]] that were defined in the projection. + * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] + * that were defined in the projection. * * This heuristic is valid assuming the expression evaluation cost is minimal. */ @@ -248,10 +247,10 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { } /** - * Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be + * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other - * [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the - * [[catalyst.plans.logical.Join Join]]. + * [[Filter]] conditions are moved into the `condition` of the [[Join]]. + * * And also Pushes down the join filter, where the `condition` can be evaluated using only the * attributes of the left or right side of sub query when applicable. * @@ -345,8 +344,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } /** - * Removes [[catalyst.expressions.Cast Casts]] that are unnecessary because the input is already - * the correct type. + * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type. */ object SimplifyCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { @@ -355,7 +353,7 @@ object SimplifyCasts extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the + * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. */ object CombineLimits extends Rule[LogicalPlan] { @@ -366,7 +364,7 @@ object CombineLimits extends Rule[LogicalPlan] { } /** - * Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because + * Removes the inner [[CaseConversionExpression]] that are unnecessary because * the inner conversion is overwritten by the outer one. */ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 820ecfb78b52e..a43bef389c4bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -136,14 +136,14 @@ object HashFilteredJoin extends Logging with PredicateHelper { val Join(left, right, joinType, _) = join val (joinPredicates, otherPredicates) = allPredicates.flatMap(splitConjunctivePredicates).partition { - case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || + case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || (canEvaluate(l, right) && canEvaluate(r, left)) => true case _ => false } val joinKeys = joinPredicates.map { - case Equals(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) - case Equals(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) } // Do not consider this strategy if there are no join keys. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 8199a80f5d6bd..7b82e19b2e714 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { self: PlanType with Product => @@ -123,4 +124,53 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case other => Nil }.toSeq } + + protected def generateSchemaString(schema: Seq[Attribute]): String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + schema.foreach { attribute => + val name = attribute.name + val dataType = attribute.dataType + dataType match { + case fields: StructType => + builder.append(s"$prefix-- $name: $StructType\n") + generateSchemaString(fields, s"$prefix |", builder) + case ArrayType(fields: StructType) => + builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") + generateSchemaString(fields, s"$prefix |", builder) + case ArrayType(elementType: DataType) => + builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") + case _ => builder.append(s"$prefix-- $name: $dataType\n") + } + } + + builder.toString() + } + + protected def generateSchemaString( + schema: StructType, + prefix: String, + builder: StringBuilder): StringBuilder = { + schema.fields.foreach { + case StructField(name, fields: StructType, _) => + builder.append(s"$prefix-- $name: $StructType\n") + generateSchemaString(fields, s"$prefix |", builder) + case StructField(name, ArrayType(fields: StructType), _) => + builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") + generateSchemaString(fields, s"$prefix |", builder) + case StructField(name, ArrayType(elementType: DataType), _) => + builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") + case StructField(name, fieldType: DataType, _) => + builder.append(s"$prefix-- $name: $fieldType\n") + } + + builder + } + + /** Returns the output schema in the tree format. */ + def schemaString: String = generateSchemaString(output) + + /** Prints out the schema in the tree format */ + def printSchema(): Unit = println(schemaString) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0933a31c362d8..edc37e3877c0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -41,19 +41,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan - * can override this (e.g. [[catalyst.analysis.UnresolvedRelation UnresolvedRelation]] should - * return `false`). + * can override this (e.g. + * [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]] + * should return `false`). */ lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved /** * Returns true if all its children of this query plan have been resolved. */ - def childrenResolved = !children.exists(!_.resolved) + def childrenResolved: Boolean = !children.exists(!_.resolved) /** - * Optionally resolves the given string to a - * [[catalyst.expressions.NamedExpression NamedExpression]]. The attribute is expressed as + * Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ def resolve(name: String): Option[NamedExpression] = { @@ -93,7 +93,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => // Leaf nodes by definition cannot reference any input attributes. - def references = Set.empty + override def references = Set.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index b777cf4249196..3e0639867b278 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -27,7 +27,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend } /** - * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the + * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. @@ -46,32 +46,32 @@ case class Generate( child: LogicalPlan) extends UnaryNode { - protected def generatorOutput = + protected def generatorOutput: Seq[Attribute] = alias .map(a => generator.output.map(_.withQualifiers(a :: Nil))) .getOrElse(generator.output) - def output = + override def output = if (join) child.output ++ generatorOutput else generatorOutput - def references = + override def references = if (join) child.outputSet else generator.references } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = condition.references + override def output = child.output + override def references = condition.references } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - def output = left.output + override def output = left.output override lazy val resolved = childrenResolved && !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } - def references = Set.empty + override def references = Set.empty } case class Join( @@ -80,8 +80,8 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - def references = condition.map(_.references).getOrElse(Set.empty) - def output = joinType match { + override def references = condition.map(_.references).getOrElse(Set.empty) + override def output = joinType match { case LeftSemi => left.output case _ => @@ -96,9 +96,9 @@ case class InsertIntoTable( overwrite: Boolean) extends LogicalPlan { // The table being inserted into is a child for the purposes of transformations. - def children = table :: child :: Nil - def references = Set.empty - def output = child.output + override def children = table :: child :: Nil + override def references = Set.empty + override def output = child.output override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType @@ -109,20 +109,20 @@ case class InsertIntoCreatedTable( databaseName: Option[String], tableName: String, child: LogicalPlan) extends UnaryNode { - def references = Set.empty - def output = child.output + override def references = Set.empty + override def output = child.output } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - def references = Set.empty - def output = child.output + override def references = Set.empty + override def output = child.output } case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = order.flatMap(_.references).toSet + override def output = child.output + override def references = order.flatMap(_.references).toSet } case class Aggregate( @@ -131,18 +131,19 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - def output = aggregateExpressions.map(_.toAttribute) - def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet + override def output = aggregateExpressions.map(_.toAttribute) + override def references = + (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = limitExpr.references + override def output = child.output + override def references = limitExpr.references } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { - def output = child.output.map(_.withQualifiers(alias :: Nil)) - def references = Set.empty + override def output = child.output.map(_.withQualifiers(alias :: Nil)) + override def references = Set.empty } /** @@ -159,7 +160,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case otherType => otherType } - val output = child.output.map { + override val output = child.output.map { case a: AttributeReference => AttributeReference( a.name.toLowerCase, @@ -170,21 +171,21 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case other => other } - def references = Set.empty + override def references = Set.empty } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = Set.empty + override def output = child.output + override def references = Set.empty } case class Distinct(child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = child.outputSet + override def output = child.output + override def references = child.outputSet } case object NoRelation extends LeafNode { - def output = Nil + override def output = Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 3299e86b85941..1d5f033f0d274 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -60,3 +60,19 @@ case class ExplainCommand(plan: LogicalPlan) extends Command { * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. */ case class CacheCommand(tableName: String, doCache: Boolean) extends Command + +/** + * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. + * @param table The table to be described. + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + * It is effective only when the table is a Hive table. + */ +case class DescribeCommand( + table: LogicalPlan, + isExtended: Boolean) extends Command { + override def output = Seq( + // Column names are based on Hive. + BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()), + BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()), + BoundReference(2, AttributeReference("comment", StringType, nullable = false)())) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ffb3a92f8f340..4bb022cf238af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -46,7 +46,7 @@ case object AllTuples extends Distribution /** * Represents data where tuples that share the same values for the `clustering` - * [[catalyst.expressions.Expression Expressions]] will be co-located. Based on the context, this + * [[Expression Expressions]] will be co-located. Based on the context, this * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ @@ -60,7 +60,7 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi /** * Represents data where tuples have been ordered according to the `ordering` - * [[catalyst.expressions.Expression Expressions]]. This is a strictly stronger guarantee than + * [[Expression Expressions]]. This is a strictly stronger guarantee than * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for * the ordering expressions are contiguous and will never be split across partitions. */ @@ -79,19 +79,17 @@ sealed trait Partitioning { val numPartitions: Int /** - * Returns true iff the guarantees made by this - * [[catalyst.plans.physical.Partitioning Partitioning]] are sufficient to satisfy - * the partitioning scheme mandated by the `required` - * [[catalyst.plans.physical.Distribution Distribution]], i.e. the current dataset does not - * need to be re-partitioned for the `required` Distribution (it is possible that tuples within - * a partition need to be reorganized). + * Returns true iff the guarantees made by this [[Partitioning]] are sufficient + * to satisfy the partitioning scheme mandated by the `required` [[Distribution]], + * i.e. the current dataset does not need to be re-partitioned for the `required` + * Distribution (it is possible that tuples within a partition need to be reorganized). */ def satisfies(required: Distribution): Boolean /** * Returns true iff all distribution guarantees made by this partitioning can also be made * for the `other` specified partitioning. - * For example, two [[catalyst.plans.physical.HashPartitioning HashPartitioning]]s are + * For example, two [[HashPartitioning HashPartitioning]]s are * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0369129393a08..cd04bdf02cf84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -187,6 +187,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -231,6 +239,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformUp(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -273,7 +289,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName?") + this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " + + s"Exception message: ${e.getMessage}.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index da34bd3a21503..bb77bccf86176 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,9 +19,71 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp -import scala.reflect.runtime.universe.{typeTag, TypeTag} +import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.catalyst.expressions.Expression +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.util.Utils + +/** + * + */ +object DataType extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + "StringType" ^^^ StringType | + "FloatType" ^^^ FloatType | + "IntegerType" ^^^ IntegerType | + "ByteType" ^^^ ByteType | + "ShortType" ^^^ ShortType | + "DoubleType" ^^^ DoubleType | + "LongType" ^^^ LongType | + "BinaryType" ^^^ BinaryType | + "BooleanType" ^^^ BooleanType | + "DecimalType" ^^^ DecimalType | + "TimestampType" ^^^ TimestampType + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + "true" ^^^ true | + "false" ^^^ false + + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => new StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") + } +} abstract class DataType { /** Matches any expression that evaluates to this DataType */ @@ -29,25 +91,36 @@ abstract class DataType { case e: Expression if e.dataType == this => true case _ => false } + + def isPrimitive: Boolean = false } case object NullType extends DataType +trait PrimitiveType extends DataType { + override def isPrimitive = true +} + abstract class NativeType extends DataType { type JvmType @transient val tag: TypeTag[JvmType] val ordering: Ordering[JvmType] + + @transient val classTag = { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) + } } -case object StringType extends NativeType { +case object StringType extends NativeType with PrimitiveType { type JvmType = String @transient lazy val tag = typeTag[JvmType] val ordering = implicitly[Ordering[JvmType]] } -case object BinaryType extends DataType { +case object BinaryType extends DataType with PrimitiveType { type JvmType = Array[Byte] } -case object BooleanType extends NativeType { +case object BooleanType extends NativeType with PrimitiveType { type JvmType = Boolean @transient lazy val tag = typeTag[JvmType] val ordering = implicitly[Ordering[JvmType]] @@ -63,7 +136,7 @@ case object TimestampType extends NativeType { } } -abstract class NumericType extends NativeType { +abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets @@ -154,6 +227,17 @@ case object FloatType extends FractionalType { case class ArrayType(elementType: DataType) extends DataType case class StructField(name: String, dataType: DataType, nullable: Boolean) -case class StructType(fields: Seq[StructField]) extends DataType + +object StructType { + def fromAttributes(attributes: Seq[Attribute]): StructType = { + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + } + + // def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq) +} + +case class StructType(fields: Seq[StructField]) extends DataType { + def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) +} case class MapType(keyType: DataType, valueType: DataType) extends DataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 49fc4f70fdfae..d8da45ae70c4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -115,7 +115,7 @@ package object util { } /* FIX ME - implicit class debugLogging(a: AnyRef) { + implicit class debugLogging(a: Any) { def debugLogging() { org.apache.log4j.Logger.getLogger(a.getClass.getName).setLevel(org.apache.log4j.Level.DEBUG) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 1132a30b42767..84d72814778ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -35,7 +35,7 @@ class ExpressionEvaluationSuite extends FunSuite { /** * Checks for three-valued-logic. Based on: * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * + * I.e. in flat cpo "False -> Unknown -> True", OR is lowest upper bound, AND is greatest lower bound. * p q p OR q p AND q p = q * True True True True True * True False True False False @@ -333,6 +333,49 @@ class ExpressionEvaluationSuite extends FunSuite { Literal("^Ba*n", StringType) :: c2 :: Nil), true, row) } + test("case when") { + val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val c1 = 'a.boolean.at(0) + val c2 = 'a.boolean.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row) + + checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) + + assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + + val c4_notNull = 'a.boolean.notNull.at(3) + val c5_notNull = 'a.boolean.notNull.at(4) + val c6_notNull = 'a.boolean.notNull.at(5) + + assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + } + test("complex type") { val row = new GenericRow(Array[Any]( "^Ba*n", // 0 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 714f01843c0f5..4896f1b955f01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class CombiningLimitsSuite extends OptimizerTest { +class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 6efc0e211eb21..0ff82064012a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.types._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class ConstantFoldingSuite extends OptimizerTest { +class ConstantFoldingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -194,8 +195,8 @@ class ConstantFoldingSuite extends OptimizerTest { Add(Literal(null, IntegerType), 1) as 'c9, Add(1, Literal(null, IntegerType)) as 'c10, - Equals(Literal(null, IntegerType), 1) as 'c11, - Equals(1, Literal(null, IntegerType)) as 'c12, + EqualTo(Literal(null, IntegerType), 1) as 'c11, + EqualTo(1, Literal(null, IntegerType)) as 'c12, Like(Literal(null, StringType), "abc") as 'c13, Like("abc", Literal(null, StringType)) as 'c14, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 1f67c80e54906..ebb123c1f909e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class FilterPushdownSuite extends OptimizerTest { +class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala index df1409fe7baee..22992fb6f50d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class SimplifyCaseConversionExpressionsSuite extends OptimizerTest { +class SimplifyCaseConversionExpressionsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala similarity index 88% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 89982d5cd8d74..7e9f47ef21df8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.plans import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ /** - * Provides helper methods for comparing plans produced by optimization rules with the expected - * result + * Provides helper methods for comparing plans. */ -class OptimizerTest extends FunSuite { +class PlanTest extends FunSuite { /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1ddc41a731ff5..6344874538d67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,6 +22,17 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NullType} + +case class Dummy(optKey: Option[Expression]) extends Expression { + def children = optKey.toSeq + def references = Set.empty[Attribute] + def nullable = true + def dataType = NullType + override lazy val resolved = true + type EvaluatedType = Any + def eval(input: Row) = null.asInstanceOf[Any] +} class TreeNodeSuite extends FunSuite { test("top node changed") { @@ -75,4 +86,20 @@ class TreeNodeSuite extends FunSuite { assert(expected === actual) } + + test("transform works on nodes with Option children") { + val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy2 = Dummy(None) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + + var actual = dummy1 transformDown toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy1 transformUp toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy2 transform toZero + assert(actual === Dummy(None)) + } + } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index e65ca6be485e3..8210fd1f210d1 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -43,6 +43,13 @@ spark-catalyst_${scala.binary.version} ${project.version} + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + com.twitter parquet-column @@ -53,6 +60,11 @@ parquet-hadoop ${parquet.version} + + com.fasterxml.jackson.core + jackson-databind + 2.3.0 + org.scalatest scalatest_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 378ff54531118..ab376e5504d35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -22,24 +22,22 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD - import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.{ScalaReflection, dsl} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor - import org.apache.spark.sql.columnar.InMemoryRelation - import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies - +import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.SparkContext /** * :: AlphaComponent :: @@ -53,7 +51,7 @@ import org.apache.spark.sql.parquet.ParquetRelation class SQLContext(@transient val sparkContext: SparkContext) extends Logging with SQLConf - with dsl.ExpressionConversions + with ExpressionConversions with Serializable { self => @@ -96,7 +94,40 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path)) + new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration))) + + /** + * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. + * It goes through the entire dataset once to determine the schema. + * + * @group userf + */ + def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + + /** + * :: Experimental :: + */ + @Experimental + def jsonFile(path: String, samplingRatio: Double): SchemaRDD = { + val json = sparkContext.textFile(path) + jsonRDD(json, samplingRatio) + } + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[SchemaRDD]]. + * It goes through the entire dataset once to determine the schema. + * + * @group userf + */ + def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + + /** + * :: Experimental :: + */ + @Experimental + def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = + new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio)) /** * :: Experimental :: @@ -276,6 +307,8 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... lazy val sparkPlan = planner(optimizedPlan).next() + // executedPlan should not be used to initialize any SparkPlan. It should be + // only used for execution. lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ @@ -298,19 +331,28 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * Peek at the first row of the RDD and infer its schema. - * TODO: We only support primitive types, add support for nested types. + * TODO: consolidate this with the type system developed in SPARK-2060. */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { + import scala.collection.JavaConversions._ + def typeFor(obj: Any): DataType = obj match { + case c: java.lang.String => StringType + case c: java.lang.Integer => IntegerType + case c: java.lang.Long => LongType + case c: java.lang.Double => DoubleType + case c: java.lang.Boolean => BooleanType + case c: java.util.List[_] => ArrayType(typeFor(c.head)) + case c: java.util.Set[_] => ArrayType(typeFor(c.head)) + case c: java.util.Map[_, _] => + val (key, value) = c.head + MapType(typeFor(key), typeFor(value)) + case c if c.getClass.isArray => + val elem = c.asInstanceOf[Array[_]].head + ArrayType(typeFor(elem)) + case c => throw new Exception(s"Object of type $c cannot be used") + } val schema = rdd.first().map { case (fieldName, obj) => - val dataType = obj.getClass match { - case c: Class[_] if c == classOf[java.lang.String] => StringType - case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType - case c: Class[_] if c == classOf[java.lang.Long] => LongType - case c: Class[_] if c == classOf[java.lang.Double] => DoubleType - case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType - case c => throw new Exception(s"Object of type $c cannot be used") - } - AttributeReference(fieldName, dataType, true)() + AttributeReference(fieldName, typeFor(obj), true)() }.toSeq val rowRdd = rdd.mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 821ac850ac3f5..7c0efb4566610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.BooleanType +import org.apache.spark.sql.catalyst.types.{DataType, StructType, BooleanType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD import java.util.{Map => JMap} @@ -41,8 +41,10 @@ import java.util.{Map => JMap} * whose elements are scala case classes into a SchemaRDD. This conversion can also be done * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. * - * A `SchemaRDD` can also be created by loading data in from external sources, for example, - * by using the `parquetFile` method on [[SQLContext]]. + * A `SchemaRDD` can also be created by loading data in from external sources. + * Examples are loading data from Parquet files by using by using the + * `parquetFile` method on [[SQLContext]], and loading JSON datasets + * by using `jsonFile` and `jsonRDD` methods on [[SQLContext]]. * * == SQL Queries == * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once @@ -341,22 +343,41 @@ class SchemaRDD( */ def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan) + /** + * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name) + def rowToMap(row: Row, structType: StructType): JMap[String, Any] = { + val fields = structType.fields.map(field => (field.name, field.dataType)) + val map: JMap[String, Any] = new java.util.HashMap + row.zip(fields).foreach { + case (obj, (name, dataType)) => + dataType match { + case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct)) + case other => map.put(name, obj) + } + } + + map + } + + // TODO: Actually, the schema of a row should be represented by a StructType instead of + // a Seq[Attribute]. Once we have finished that change, we can just use rowToMap to + // construct the Map for python. + val fields: Seq[(String, DataType)] = this.queryExecution.analyzed.output.map( + field => (field.name, field.dataType)) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => val map: JMap[String, Any] = new java.util.HashMap - // TODO: We place the map in an ArrayList so that the object is pickled to a List[Dict]. - // Ideally we should be able to pickle an object directly into a Python collection so we - // don't have to create an ArrayList every time. - val arr: java.util.ArrayList[Any] = new java.util.ArrayList - row.zip(fieldNames).foreach { case (obj, name) => - map.put(name, obj) + row.zip(fields).foreach { case (obj, (name, dataType)) => + dataType match { + case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct)) + case other => map.put(name, obj) + } } - arr.add(map) - pickle.dumps(arr) - } + map + }.grouped(10).map(batched => pickle.dumps(batched.toArray)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 656be965a8fd9..fe81721943202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -122,4 +122,10 @@ private[sql] trait SchemaRDDLike { @Experimental def saveAsTable(tableName: String): Unit = sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd + + /** Returns the output schema in the tree format. */ + def schemaString: String = queryExecution.analyzed.schemaString + + /** Prints out the schema in the tree format. */ + def printSchema(): Unit = println(schemaString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 6f7d431b9a819..ff6deeda2394d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.catalyst.types._ @@ -40,19 +41,13 @@ class JavaSQLContext(val sqlContext: SQLContext) { /** * Executes a query expressed in SQL, returning the result as a JavaSchemaRDD */ - def sql(sqlQuery: String): JavaSchemaRDD = { - val result = new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlQuery)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but do not perform any execution. - result.queryExecution.toRdd - result - } + def sql(sqlQuery: String): JavaSchemaRDD = + new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlQuery)) /** * :: Experimental :: * Creates an empty parquet file with the schema of class `beanClass`, which can be registered as - * a table. This registered table can be used as the target of future insertInto` operations. + * a table. This registered table can be used as the target of future `insertInto` operations. * * {{{ * JavaSQLContext sqlCtx = new JavaSQLContext(...) @@ -62,7 +57,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { * }}} * * @param beanClass A java bean class object that will be used to determine the schema of the - * parquet file. s + * parquet file. * @param path The path where the directory containing parquet metadata should be created. * Data inserted into this table will also be stored at this location. * @param allowExisting When false, an exception will be thrown if this directory already exists. @@ -100,13 +95,32 @@ class JavaSQLContext(val sqlContext: SQLContext) { new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) } - /** * Loads a parquet file, returning the result as a [[JavaSchemaRDD]]. */ def parquetFile(path: String): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, ParquetRelation(path)) + new JavaSchemaRDD( + sqlContext, + ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration))) + /** + * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. + * It goes through the entire dataset once to determine the schema. + * + * @group userf + */ + def jsonFile(path: String): JavaSchemaRDD = + jsonRDD(sqlContext.sparkContext.textFile(path)) + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[JavaSchemaRDD]]. + * It goes through the entire dataset once to determine the schema. + * + * @group userf + */ + def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = + new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0)) /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index cef294167f146..f46fa0516566f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.sql.{SQLConf, SQLContext, Row} +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{MutableProjection, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind { override def outputPartitioning = newPartitioning @@ -42,7 +42,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new MutableProjection(expressions) + val hashExpressions = new MutableProjection(expressions, child.output) val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } @@ -53,7 +53,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case RangePartitioning(sortingExpressions, numPartitions) => // TODO: RangePartitioner should take an Ordering. - implicit val ordering = new RowOrdering(sortingExpressions) + implicit val ordering = new RowOrdering(sortingExpressions, child.output) val rdd = child.execute().mapPartitions { iter => val mutablePair = new MutablePair[Row, Null](null, null) @@ -82,9 +82,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } /** - * Ensures that the [[catalyst.plans.physical.Partitioning Partitioning]] of input data meets the - * [[catalyst.plans.physical.Distribution Distribution]] requirements for each operator by inserting - * [[Exchange]] Operators where required. + * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] + * of input data meets the + * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for + * each operator by inserting [[Exchange]] Operators where required. */ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2233216a6ec52..4694f25d6d630 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -154,7 +154,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.WriteToFile(path, child) => val relation = ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) - InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil + // Note: overwrite=false because otherwise the metadata we just created will be deleted + InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => @@ -250,9 +251,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.SetCommand(key, value) => Seq(execution.SetCommand(key, value, plan.output)(context)) - case logical.ExplainCommand(child) => - val executedPlan = context.executePlan(child).executedPlan - Seq(execution.ExplainCommand(executedPlan, plan.output)(context)) + case logical.ExplainCommand(logicalPlan) => + Seq(execution.ExplainCommand(logicalPlan, plan.output)(context)) case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 0377290af5926..acb1b0f4dc229 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan trait Command { /** @@ -71,20 +72,27 @@ case class SetCommand( } /** + * An explain command for users to see how a command will be executed. + * + * Note that this command takes in a logical plan, runs the optimizer on the logical plan + * (but do NOT actually execute it). + * * :: DeveloperApi :: */ @DeveloperApi case class ExplainCommand( - child: SparkPlan, output: Seq[Attribute])( + logicalPlan: LogicalPlan, output: Seq[Attribute])( @transient context: SQLContext) - extends UnaryNode with Command { + extends LeafNode with Command { - // Actually "EXPLAIN" command doesn't cause any side effect. - override protected[sql] lazy val sideEffectResult: Seq[String] = this.toString.split("\n") + // Run through the optimizer to generate the physical plan. + override protected[sql] lazy val sideEffectResult: Seq[String] = { + "Physical execution plan:" +: context.executePlan(logicalPlan).executedPlan.toString.split("\n") + } def execute(): RDD[Row] = { - val explanation = sideEffectResult.mkString("\n") - context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](explanation))), 1) + val explanation = sideEffectResult.map(row => new GenericRow(Array[Any](row))) + context.sparkContext.parallelize(explanation, 1) } override def otherCopyArgs = context :: Nil @@ -113,3 +121,24 @@ case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: override def output: Seq[Attribute] = Seq.empty } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( + @transient context: SQLContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + Seq(("# Registered as a temporary table", null, null)) ++ + child.output.map(field => (field.name, field.dataType.toString, null)) + } + + override def execute(): RDD[Row] = { + val rows = sideEffectResult.map { + case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + } + context.sparkContext.parallelize(rows, 1) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala new file mode 100644 index 0000000000000..edf86775579d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import scala.collection.JavaConversions._ +import scala.math.BigDecimal + +import com.fasterxml.jackson.databind.ObjectMapper + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.Logging + +private[sql] object JsonRDD extends Logging { + + private[sql] def inferSchema( + json: RDD[String], + samplingRatio: Double = 1.0): LogicalPlan = { + require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") + val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) + val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) + val baseSchema = createSchema(allKeys) + + createLogicalPlan(json, baseSchema) + } + + private def createLogicalPlan( + json: RDD[String], + baseSchema: StructType): LogicalPlan = { + val schema = nullTypeToStringType(baseSchema) + + SparkLogicalPlan(ExistingRdd(asAttributes(schema), parseJson(json).map(asRow(_, schema)))) + } + + private def createSchema(allKeys: Set[(String, DataType)]): StructType = { + // Resolve type conflicts + val resolved = allKeys.groupBy { + case (key, dataType) => key + }.map { + // Now, keys and types are organized in the format of + // key -> Set(type1, type2, ...). + case (key, typeSet) => { + val fieldName = key.substring(1, key.length - 1).split("`.`").toSeq + val dataType = typeSet.map { + case (_, dataType) => dataType + }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) + + (fieldName, dataType) + } + } + + def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { + val (topLevel, structLike) = values.partition(_.size == 1) + val topLevelFields = topLevel.filter { + name => resolved.get(prefix ++ name).get match { + case ArrayType(StructType(Nil)) => false + case ArrayType(_) => true + case struct: StructType => false + case _ => true + } + }.map { + a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) + } + + val structFields: Seq[StructField] = structLike.groupBy(_(0)).map { + case (name, fields) => { + val nestedFields = fields.map(_.tail) + val structType = makeStruct(nestedFields, prefix :+ name) + val dataType = resolved.get(prefix :+ name).get + dataType match { + case array: ArrayType => Some(StructField(name, ArrayType(structType), nullable = true)) + case struct: StructType => Some(StructField(name, structType, nullable = true)) + // dataType is StringType means that we have resolved type conflicts involving + // primitive types and complex types. So, the type of name has been relaxed to + // StringType. Also, this field should have already been put in topLevelFields. + case StringType => None + } + } + }.flatMap(field => field).toSeq + + StructType( + (topLevelFields ++ structFields).sortBy { + case StructField(name, _, _) => name + }) + } + + makeStruct(resolved.keySet.toSeq, Nil) + } + + /** + * Returns the most general data type for two given data types. + */ + private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p + .contains(t2)) + + // If found return the widest common type, otherwise None + val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + + if (returnType.isDefined) { + returnType.get + } else { + // t1 or t2 is a StructType, ArrayType, or an unexpected type. + (t1, t2) match { + case (other: DataType, NullType) => other + case (NullType, other: DataType) => other + case (StructType(fields1), StructType(fields2)) => { + val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { + case (name, fieldTypes) => { + val dataType = fieldTypes.map(field => field.dataType).reduce( + (type1: DataType, type2: DataType) => compatibleType(type1, type2)) + StructField(name, dataType, true) + } + } + StructType(newFields.toSeq.sortBy { + case StructField(name, _, _) => name + }) + } + case (ArrayType(elementType1), ArrayType(elementType2)) => + ArrayType(compatibleType(elementType1, elementType2)) + // TODO: We should use JsonObjectStringType to mark that values of field will be + // strings and every string is a Json object. + case (_, _) => StringType + } + } + } + + private def typeOfPrimitiveValue(value: Any): DataType = { + value match { + case value: java.lang.String => StringType + case value: java.lang.Integer => IntegerType + case value: java.lang.Long => LongType + // Since we do not have a data type backed by BigInteger, + // when we see a Java BigInteger, we use DecimalType. + case value: java.math.BigInteger => DecimalType + case value: java.lang.Double => DoubleType + case value: java.math.BigDecimal => DecimalType + case value: java.lang.Boolean => BooleanType + case null => NullType + // Unexpected data type. + case _ => StringType + } + } + + /** + * Returns the element type of an JSON array. We go through all elements of this array + * to detect any possible type conflict. We use [[compatibleType]] to resolve + * type conflicts. Right now, when the element of an array is another array, we + * treat the element as String. + */ + private def typeOfArray(l: Seq[Any]): ArrayType = { + val elements = l.flatMap(v => Option(v)) + if (elements.isEmpty) { + // If this JSON array is empty, we use NullType as a placeholder. + // If this array is not empty in other JSON objects, we can resolve + // the type after we have passed through all JSON objects. + ArrayType(NullType) + } else { + val elementType = elements.map { + e => e match { + case map: Map[_, _] => StructType(Nil) + // We have an array of arrays. If those element arrays do not have the same + // element types, we will return ArrayType[StringType]. + case seq: Seq[_] => typeOfArray(seq) + case value => typeOfPrimitiveValue(value) + } + }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) + + ArrayType(elementType) + } + } + + /** + * Figures out all key names and data types of values from a parsed JSON object + * (in the format of Map[Stirng, Any]). When the value of a key is an JSON object, we + * only use a placeholder (StructType(Nil)) to mark that it should be a struct + * instead of getting all fields of this struct because a field does not appear + * in this JSON object can appear in other JSON objects. + */ + private def allKeysWithValueTypes(m: Map[String, Any]): Set[(String, DataType)] = { + m.map{ + // Quote the key with backticks to handle cases which have dots + // in the field name. + case (key, dataType) => (s"`$key`", dataType) + }.flatMap { + case (key: String, struct: Map[String, Any]) => { + // The value associted with the key is an JSON object. + allKeysWithValueTypes(struct).map { + case (k, dataType) => (s"$key.$k", dataType) + } ++ Set((key, StructType(Nil))) + } + case (key: String, array: List[Any]) => { + // The value associted with the key is an array. + typeOfArray(array) match { + case ArrayType(StructType(Nil)) => { + // The elements of this arrays are structs. + array.asInstanceOf[List[Map[String, Any]]].flatMap { + element => allKeysWithValueTypes(element) + }.map { + case (k, dataType) => (s"$key.$k", dataType) + } :+ (key, ArrayType(StructType(Nil))) + } + case ArrayType(elementType) => (key, ArrayType(elementType)) :: Nil + } + } + case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil + }.toSet + } + + /** + * Converts a Java Map/List to a Scala Map/List. + * We do not use Jackson's scala module at here because + * DefaultScalaModule in jackson-module-scala will make + * the parsing very slow. + */ + private def scalafy(obj: Any): Any = obj match { + case map: java.util.Map[String, Object] => + // .map(identity) is used as a workaround of non-serializable Map + // generated by .mapValues. + // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 + map.toMap.mapValues(scalafy).map(identity) + case list: java.util.List[Object] => + list.toList.map(scalafy) + case atom => atom + } + + private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = { + // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], + // ObjectMapper will not return BigDecimal when + // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled + // (see NumberDeserializer.deserialize for the logic). + // But, we do not want to enable this feature because it will use BigDecimal + // for every float number, which will be slow. + // So, right now, we will have Infinity for those BigDecimal number. + // TODO: Support BigDecimal. + json.mapPartitions(iter => { + // When there is a key appearing multiple times (a duplicate key), + // the ObjectMapper will take the last value associated with this duplicate key. + // For example: for {"key": 1, "key":2}, we will get "key"->2. + val mapper = new ObjectMapper() + iter.map(record => mapper.readValue(record, classOf[java.util.Map[String, Any]])) + }).map(scalafy).map(_.asInstanceOf[Map[String, Any]]) + } + + private def toLong(value: Any): Long = { + value match { + case value: java.lang.Integer => value.asInstanceOf[Int].toLong + case value: java.lang.Long => value.asInstanceOf[Long] + } + } + + private def toDouble(value: Any): Double = { + value match { + case value: java.lang.Integer => value.asInstanceOf[Int].toDouble + case value: java.lang.Long => value.asInstanceOf[Long].toDouble + case value: java.lang.Double => value.asInstanceOf[Double] + } + } + + private def toDecimal(value: Any): BigDecimal = { + value match { + case value: java.lang.Integer => BigDecimal(value) + case value: java.lang.Long => BigDecimal(value) + case value: java.math.BigInteger => BigDecimal(value) + case value: java.lang.Double => BigDecimal(value) + case value: java.math.BigDecimal => BigDecimal(value) + } + } + + private def toJsonArrayString(seq: Seq[Any]): String = { + val builder = new StringBuilder + builder.append("[") + var count = 0 + seq.foreach { + element => + if (count > 0) builder.append(",") + count += 1 + builder.append(toString(element)) + } + builder.append("]") + + builder.toString() + } + + private def toJsonObjectString(map: Map[String, Any]): String = { + val builder = new StringBuilder + builder.append("{") + var count = 0 + map.foreach { + case (key, value) => + if (count > 0) builder.append(",") + count += 1 + builder.append(s"""\"${key}\":${toString(value)}""") + } + builder.append("}") + + builder.toString() + } + + private def toString(value: Any): String = { + value match { + case value: Map[String, Any] => toJsonObjectString(value) + case value: Seq[Any] => toJsonArrayString(value) + case value => Option(value).map(_.toString).orNull + } + } + + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ + if (value == null) { + null + } else { + desiredType match { + case ArrayType(elementType) => + value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) + case StringType => toString(value) + case IntegerType => value.asInstanceOf[IntegerType.JvmType] + case LongType => toLong(value) + case DoubleType => toDouble(value) + case DecimalType => toDecimal(value) + case BooleanType => value.asInstanceOf[BooleanType.JvmType] + case NullType => null + } + } + } + + private def asRow(json: Map[String,Any], schema: StructType): Row = { + val row = new GenericMutableRow(schema.fields.length) + schema.fields.zipWithIndex.foreach { + // StructType + case (StructField(name, fields: StructType, _), i) => + row.update(i, json.get(name).flatMap(v => Option(v)).map( + v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull) + + // ArrayType(StructType) + case (StructField(name, ArrayType(structType: StructType), _), i) => + row.update(i, + json.get(name).flatMap(v => Option(v)).map( + v => v.asInstanceOf[Seq[Any]].map( + e => asRow(e.asInstanceOf[Map[String, Any]], structType))).orNull) + + // Other cases + case (StructField(name, dataType, _), i) => + row.update(i, json.get(name).flatMap(v => Option(v)).map( + enforceCorrectType(_, dataType)).getOrElse(null)) + } + + row + } + + private def nullTypeToStringType(struct: StructType): StructType = { + val fields = struct.fields.map { + case StructField(fieldName, dataType, nullable) => { + val newType = dataType match { + case NullType => StringType + case ArrayType(NullType) => ArrayType(StringType) + case struct: StructType => nullTypeToStringType(struct) + case other: DataType => other + } + StructField(fieldName, newType, nullable) + } + } + + StructType(fields) + } + + private def asAttributes(struct: StructType): Seq[AttributeReference] = { + struct.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) + } + + private def asStruct(attributes: Seq[AttributeReference]): StructType = { + val fields = attributes.map { + case AttributeReference(name, dataType, nullable) => StructField(name, dataType, nullable) + } + + StructType(fields) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala new file mode 100644 index 0000000000000..889a408e3c393 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} + +import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} +import parquet.schema.MessageType + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute} +import org.apache.spark.sql.parquet.CatalystConverter.FieldType + +/** + * Collection of converters of Parquet types (group and primitive types) that + * model arrays and maps. The conversions are partly based on the AvroParquet + * converters that are part of Parquet in order to be able to process these + * types. + * + * There are several types of converters: + *
    + *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveConverter]] for primitive + * (numeric, boolean and String) types
  • + *
  • [[org.apache.spark.sql.parquet.CatalystNativeArrayConverter]] for arrays + * of native JVM element types; note: currently null values are not supported!
  • + *
  • [[org.apache.spark.sql.parquet.CatalystArrayConverter]] for arrays of + * arbitrary element types (including nested element types); note: currently + * null values are not supported!
  • + *
  • [[org.apache.spark.sql.parquet.CatalystStructConverter]] for structs
  • + *
  • [[org.apache.spark.sql.parquet.CatalystMapConverter]] for maps; note: + * currently null values are not supported!
  • + *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveRowConverter]] for rows + * of only primitive element types
  • + *
  • [[org.apache.spark.sql.parquet.CatalystGroupConverter]] for other nested + * records, including the top-level row record
  • + *
+ */ + +private[sql] object CatalystConverter { + // The type internally used for fields + type FieldType = StructField + + // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). + // Note that "array" for the array elements is chosen by ParquetAvro. + // Using a different value will result in Parquet silently dropping columns. + val ARRAY_ELEMENTS_SCHEMA_NAME = "array" + val MAP_KEY_SCHEMA_NAME = "key" + val MAP_VALUE_SCHEMA_NAME = "value" + val MAP_SCHEMA_NAME = "map" + + // TODO: consider using Array[T] for arrays to avoid boxing of primitive types + type ArrayScalaType[T] = Seq[T] + type StructScalaType[T] = Seq[T] + type MapScalaType[K, V] = Map[K, V] + + protected[parquet] def createConverter( + field: FieldType, + fieldIndex: Int, + parent: CatalystConverter): Converter = { + val fieldType: DataType = field.dataType + fieldType match { + // For native JVM types we use a converter with native arrays + case ArrayType(elementType: NativeType) => { + new CatalystNativeArrayConverter(elementType, fieldIndex, parent) + } + // This is for other types of arrays, including those with nested fields + case ArrayType(elementType: DataType) => { + new CatalystArrayConverter(elementType, fieldIndex, parent) + } + case StructType(fields: Seq[StructField]) => { + new CatalystStructConverter(fields.toArray, fieldIndex, parent) + } + case MapType(keyType: DataType, valueType: DataType) => { + new CatalystMapConverter( + Array( + new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), + new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, true)), + fieldIndex, + parent) + } + // Strings, Shorts and Bytes do not have a corresponding type in Parquet + // so we need to treat them separately + case StringType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateString(fieldIndex, value) + } + } + case ShortType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addInt(value: Int): Unit = + parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.JvmType]) + } + } + case ByteType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addInt(value: Int): Unit = + parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) + } + } + // All other primitive types use the default converter + case ctype: NativeType => { // note: need the type tag here! + new CatalystPrimitiveConverter(parent, fieldIndex) + } + case _ => throw new RuntimeException( + s"unable to convert datatype ${field.dataType.toString} in CatalystConverter") + } + } + + protected[parquet] def createRootConverter( + parquetSchema: MessageType, + attributes: Seq[Attribute]): CatalystConverter = { + // For non-nested types we use the optimized Row converter + if (attributes.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))) { + new CatalystPrimitiveRowConverter(attributes.toArray) + } else { + new CatalystGroupConverter(attributes.toArray) + } + } +} + +private[parquet] abstract class CatalystConverter extends GroupConverter { + /** + * The number of fields this group has + */ + protected[parquet] val size: Int + + /** + * The index of this converter in the parent + */ + protected[parquet] val index: Int + + /** + * The parent converter + */ + protected[parquet] val parent: CatalystConverter + + /** + * Called by child converters to update their value in its parent (this). + * Note that if possible the more specific update methods below should be used + * to avoid auto-boxing of native JVM types. + * + * @param fieldIndex + * @param value + */ + protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit + + protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = + updateField(fieldIndex, value.getBytes) + + protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = + updateField(fieldIndex, value.toStringUsingUTF8) + + protected[parquet] def isRootConverter: Boolean = parent == null + + protected[parquet] def clearBuffer(): Unit + + /** + * Should only be called in the root (group) converter! + * + * @return + */ + def getCurrentRecord: Row = throw new UnsupportedOperationException +} + +/** + * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record + * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. + * + * @param schema The corresponding Catalyst schema in the form of a list of attributes. + */ +private[parquet] class CatalystGroupConverter( + protected[parquet] val schema: Array[FieldType], + protected[parquet] val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var current: ArrayBuffer[Any], + protected[parquet] var buffer: ArrayBuffer[Row]) + extends CatalystConverter { + + def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = + this( + schema, + index, + parent, + current=null, + buffer=new ArrayBuffer[Row]( + CatalystArrayConverter.INITIAL_ARRAY_SIZE)) + + /** + * This constructor is used for the root converter only! + */ + def this(attributes: Array[Attribute]) = + this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null) + + protected [parquet] val converters: Array[Converter] = + schema.map(field => + CatalystConverter.createConverter(field, schema.indexOf(field), this)) + .toArray + + override val size = schema.size + + override def getCurrentRecord: Row = { + assert(isRootConverter, "getCurrentRecord should only be called in root group converter!") + // TODO: use iterators if possible + // Note: this will ever only be called in the root converter when the record has been + // fully processed. Therefore it will be difficult to use mutable rows instead, since + // any non-root converter never would be sure when it would be safe to re-use the buffer. + new GenericRow(current.toArray) + } + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + // for child converters to update upstream values + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + current.update(fieldIndex, value) + } + + override protected[parquet] def clearBuffer(): Unit = buffer.clear() + + override def start(): Unit = { + current = ArrayBuffer.fill(size)(null) + converters.foreach { + converter => if (!converter.isPrimitive) { + converter.asInstanceOf[CatalystConverter].clearBuffer + } + } + } + + override def end(): Unit = { + if (!isRootConverter) { + assert(current!=null) // there should be no empty groups + buffer.append(new GenericRow(current.toArray)) + parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) + } + } +} + +/** + * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record + * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. Note that his + * converter is optimized for rows of primitive types (non-nested records). + */ +private[parquet] class CatalystPrimitiveRowConverter( + protected[parquet] val schema: Array[FieldType], + protected[parquet] var current: ParquetRelation.RowType) + extends CatalystConverter { + + // This constructor is used for the root converter only + def this(attributes: Array[Attribute]) = + this( + attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), + new ParquetRelation.RowType(attributes.length)) + + protected [parquet] val converters: Array[Converter] = + schema.map(field => + CatalystConverter.createConverter(field, schema.indexOf(field), this)) + .toArray + + override val size = schema.size + + override val index = 0 + + override val parent = null + + // Should be only called in root group converter! + override def getCurrentRecord: ParquetRelation.RowType = current + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + // for child converters to update upstream values + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + throw new UnsupportedOperationException // child converters should use the + // specific update methods below + } + + override protected[parquet] def clearBuffer(): Unit = {} + + override def start(): Unit = { + var i = 0 + while (i < size) { + current.setNullAt(i) + i = i + 1 + } + } + + override def end(): Unit = {} + + // Overriden here to avoid auto-boxing for primitive types + override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = + current.setBoolean(fieldIndex, value) + + override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = + current.setInt(fieldIndex, value) + + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = + current.setLong(fieldIndex, value) + + override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = + current.setShort(fieldIndex, value) + + override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = + current.setByte(fieldIndex, value) + + override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = + current.setDouble(fieldIndex, value) + + override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = + current.setFloat(fieldIndex, value) + + override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = + current.update(fieldIndex, value.getBytes) + + override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = + current.setString(fieldIndex, value.toStringUsingUTF8) +} + +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +private[parquet] class CatalystPrimitiveConverter( + parent: CatalystConverter, + fieldIndex: Int) extends PrimitiveConverter { + override def addBinary(value: Binary): Unit = + parent.updateBinary(fieldIndex, value) + + override def addBoolean(value: Boolean): Unit = + parent.updateBoolean(fieldIndex, value) + + override def addDouble(value: Double): Unit = + parent.updateDouble(fieldIndex, value) + + override def addFloat(value: Float): Unit = + parent.updateFloat(fieldIndex, value) + + override def addInt(value: Int): Unit = + parent.updateInt(fieldIndex, value) + + override def addLong(value: Long): Unit = + parent.updateLong(fieldIndex, value) +} + +object CatalystArrayConverter { + val INITIAL_ARRAY_SIZE = 20 +} + +/** + * A `parquet.io.api.GroupConverter` that converts a single-element groups that + * match the characteristics of an array (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * + * @param elementType The type of the array elements (complex or primitive) + * @param index The position of this (array) field inside its parent converter + * @param parent The parent converter + * @param buffer A data buffer + */ +private[parquet] class CatalystArrayConverter( + val elementType: DataType, + val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var buffer: Buffer[Any]) + extends CatalystConverter { + + def this(elementType: DataType, index: Int, parent: CatalystConverter) = + this( + elementType, + index, + parent, + new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) + + protected[parquet] val converter: Converter = CatalystConverter.createConverter( + new CatalystConverter.FieldType( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + elementType, + false), + fieldIndex=0, + parent=this) + + override def getConverter(fieldIndex: Int): Converter = converter + + // arrays have only one (repeated) field, which is its elements + override val size = 1 + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + // fieldIndex is ignored (assumed to be zero but not checked) + if(value == null) { + throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") + } + buffer += value + } + + override protected[parquet] def clearBuffer(): Unit = { + buffer.clear() + } + + override def start(): Unit = { + if (!converter.isPrimitive) { + converter.asInstanceOf[CatalystConverter].clearBuffer + } + } + + override def end(): Unit = { + assert(parent != null) + // here we need to make sure to use ArrayScalaType + parent.updateField(index, buffer.toArray.toSeq) + clearBuffer() + } +} + +/** + * A `parquet.io.api.GroupConverter` that converts a single-element groups that + * match the characteristics of an array (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * + * @param elementType The type of the array elements (native) + * @param index The position of this (array) field inside its parent converter + * @param parent The parent converter + * @param capacity The (initial) capacity of the buffer + */ +private[parquet] class CatalystNativeArrayConverter( + val elementType: NativeType, + val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) + extends CatalystConverter { + + type NativeType = elementType.JvmType + + private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) + + private var elements: Int = 0 + + protected[parquet] val converter: Converter = CatalystConverter.createConverter( + new CatalystConverter.FieldType( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + elementType, + false), + fieldIndex=0, + parent=this) + + override def getConverter(fieldIndex: Int): Converter = converter + + // arrays have only one (repeated) field, which is its elements + override val size = 1 + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = + throw new UnsupportedOperationException + + // Overriden here to avoid auto-boxing for primitive types + override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = { + checkGrowBuffer() + buffer(elements) = value.getBytes.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = { + checkGrowBuffer() + buffer(elements) = value.toStringUsingUTF8.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def clearBuffer(): Unit = { + elements = 0 + } + + override def start(): Unit = {} + + override def end(): Unit = { + assert(parent != null) + // here we need to make sure to use ArrayScalaType + parent.updateField( + index, + buffer.slice(0, elements).toSeq) + clearBuffer() + } + + private def checkGrowBuffer(): Unit = { + if (elements >= capacity) { + val newCapacity = 2 * capacity + val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity) + Array.copy(buffer, 0, tmp, 0, capacity) + buffer = tmp + capacity = newCapacity + } + } +} + +/** + * This converter is for multi-element groups of primitive or complex types + * that have repetition level optional or required (so struct fields). + * + * @param schema The corresponding Catalyst schema in the form of a list of + * attributes. + * @param index + * @param parent + */ +private[parquet] class CatalystStructConverter( + override protected[parquet] val schema: Array[FieldType], + override protected[parquet] val index: Int, + override protected[parquet] val parent: CatalystConverter) + extends CatalystGroupConverter(schema, index, parent) { + + override protected[parquet] def clearBuffer(): Unit = {} + + // TODO: think about reusing the buffer + override def end(): Unit = { + assert(!isRootConverter) + // here we need to make sure to use StructScalaType + // Note: we need to actually make a copy of the array since we + // may be in a nested field + parent.updateField(index, new GenericRow(current.toArray)) + } +} + +/** + * A `parquet.io.api.GroupConverter` that converts two-element groups that + * match the characteristics of a map (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.MapType]]. + * + * @param schema + * @param index + * @param parent + */ +private[parquet] class CatalystMapConverter( + protected[parquet] val schema: Array[FieldType], + override protected[parquet] val index: Int, + override protected[parquet] val parent: CatalystConverter) + extends CatalystConverter { + + private val map = new HashMap[Any, Any]() + + private val keyValueConverter = new CatalystConverter { + private var currentKey: Any = null + private var currentValue: Any = null + val keyConverter = CatalystConverter.createConverter(schema(0), 0, this) + val valueConverter = CatalystConverter.createConverter(schema(1), 1, this) + + override def getConverter(fieldIndex: Int): Converter = { + if (fieldIndex == 0) keyConverter else valueConverter + } + + override def end(): Unit = CatalystMapConverter.this.map += currentKey -> currentValue + + override def start(): Unit = { + currentKey = null + currentValue = null + } + + override protected[parquet] val size: Int = 2 + override protected[parquet] val index: Int = 0 + override protected[parquet] val parent: CatalystConverter = CatalystMapConverter.this + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + fieldIndex match { + case 0 => + currentKey = value + case 1 => + currentValue = value + case _ => + new RuntimePermission(s"trying to update Map with fieldIndex $fieldIndex") + } + } + + override protected[parquet] def clearBuffer(): Unit = {} + } + + override protected[parquet] val size: Int = 1 + + override protected[parquet] def clearBuffer(): Unit = {} + + override def start(): Unit = { + map.clear() + } + + override def end(): Unit = { + // here we need to make sure to use MapScalaType + parent.updateField(index, map.toMap) + } + + override def getConverter(fieldIndex: Int): Converter = keyValueConverter + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = + throw new UnsupportedOperationException +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 052b0a9196717..cc575bedd8fcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -205,9 +205,9 @@ object ParquetFilters { Some(new AndFilter(leftFilter.get, rightFilter.get)) } } - case p @ Equals(left: Literal, right: NamedExpression) if !right.nullable => + case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => Some(createEqualityFilter(right.name, left, p)) - case p @ Equals(left: NamedExpression, right: Literal) if !left.nullable => + case p @ EqualTo(left: NamedExpression, right: Literal) if !left.nullable => Some(createEqualityFilter(left.name, right, p)) case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable => Some(createLessThanFilter(right.name, left, p)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 32813a66de3c3..96c131a7f8af1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -20,25 +20,16 @@ package org.apache.spark.sql.parquet import java.io.IOException import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction -import org.apache.hadoop.mapreduce.Job -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetOutputFormat, Footer, ParquetFileWriter, ParquetFileReader} -import parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} -import parquet.io.api.{Binary, RecordConsumer} -import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType, MessageTypeParser} -import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import parquet.schema.Type.Repetition +import parquet.hadoop.ParquetOutputFormat +import parquet.hadoop.metadata.CompressionCodecName +import parquet.schema.MessageType import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} -import org.apache.spark.sql.catalyst.types._ - -// Implicits -import scala.collection.JavaConversions._ /** * Relation that consists of data stored in a Parquet columnar format. @@ -52,21 +43,20 @@ import scala.collection.JavaConversions._ * * @param path The path to the Parquet file. */ -private[sql] case class ParquetRelation(val path: String) - extends LeafNode with MultiInstanceRelation { +private[sql] case class ParquetRelation( + val path: String, + @transient val conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation { self: Product => /** Schema derived from ParquetFile */ def parquetSchema: MessageType = ParquetTypesConverter - .readMetaData(new Path(path)) + .readMetaData(new Path(path), conf) .getFileMetaData .getSchema /** Attributes */ - override val output = - ParquetTypesConverter - .convertToAttributes(parquetSchema) + override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) override def newInstance = ParquetRelation(path).asInstanceOf[this.type] @@ -141,7 +131,9 @@ private[sql] object ParquetRelation { } ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) - new ParquetRelation(path.toString) + new ParquetRelation(path.toString, Some(conf)) { + override val output = attributes + } } private def checkPath(pathStr: String, allowExisting: Boolean, conf: Configuration): Path = { @@ -170,151 +162,3 @@ private[sql] object ParquetRelation { path } } - -private[parquet] object ParquetTypesConverter { - def toDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { - // for now map binary to string type - // TODO: figure out how Parquet uses strings or why we can't use them in a MessageType schema - case ParquetPrimitiveTypeName.BINARY => StringType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - sys.error("Warning: potential loss of precision: converting INT96 to long") - LongType - case _ => sys.error( - s"Unsupported parquet datatype $parquetType") - } - - def fromDataType(ctype: DataType): ParquetPrimitiveTypeName = ctype match { - case StringType => ParquetPrimitiveTypeName.BINARY - case BooleanType => ParquetPrimitiveTypeName.BOOLEAN - case DoubleType => ParquetPrimitiveTypeName.DOUBLE - case ArrayType(ByteType) => ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY - case FloatType => ParquetPrimitiveTypeName.FLOAT - case IntegerType => ParquetPrimitiveTypeName.INT32 - case LongType => ParquetPrimitiveTypeName.INT64 - case _ => sys.error(s"Unsupported datatype $ctype") - } - - def consumeType(consumer: RecordConsumer, ctype: DataType, record: Row, index: Int): Unit = { - ctype match { - case StringType => consumer.addBinary( - Binary.fromByteArray( - record(index).asInstanceOf[String].getBytes("utf-8") - ) - ) - case IntegerType => consumer.addInteger(record.getInt(index)) - case LongType => consumer.addLong(record.getLong(index)) - case DoubleType => consumer.addDouble(record.getDouble(index)) - case FloatType => consumer.addFloat(record.getFloat(index)) - case BooleanType => consumer.addBoolean(record.getBoolean(index)) - case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") - } - } - - def getSchema(schemaString : String) : MessageType = - MessageTypeParser.parseMessageType(schemaString) - - def convertToAttributes(parquetSchema: MessageType) : Seq[Attribute] = { - parquetSchema.getColumns.map { - case (desc) => - val ctype = toDataType(desc.getType) - val name: String = desc.getPath.mkString(".") - new AttributeReference(name, ctype, false)() - } - } - - // TODO: allow nesting? - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { - val fields: Seq[ParquetType] = attributes.map { - a => new ParquetPrimitiveType(Repetition.OPTIONAL, fromDataType(a.dataType), a.name) - } - new MessageType("root", fields) - } - - def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration) { - if (origPath == null) { - throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") - } - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (fs.exists(path) && !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException(s"Expected to write to directory $path but found file") - } - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fs.exists(metadataPath)) { - try { - fs.delete(metadataPath, true) - } catch { - case e: IOException => - throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") - } - } - val extraMetadata = new java.util.HashMap[String, String]() - extraMetadata.put("path", path.toString) - // TODO: add extra data, e.g., table name, date, etc.? - - val parquetSchema: MessageType = - ParquetTypesConverter.convertFromAttributes(attributes) - val metaData: FileMetaData = new FileMetaData( - parquetSchema, - extraMetadata, - "Spark") - - ParquetRelation.enableLogForwarding() - ParquetFileWriter.writeMetadataFile( - conf, - path, - new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) - } - - /** - * Try to read Parquet metadata at the given Path. We first see if there is a summary file - * in the parent directory. If so, this is used. Else we read the actual footer at the given - * location. - * @param origPath The path at which we expect one (or more) Parquet files. - * @return The `ParquetMetadata` containing among other things the schema. - */ - def readMetaData(origPath: Path): ParquetMetadata = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") - } - val job = new Job() - // TODO: since this is called from ParquetRelation (LogicalPlan) we don't have access - // to SparkContext's hadoopConfig; in principle the default FileSystem may be different(?!) - val conf = ContextUtil.getConfiguration(job) - val fs: FileSystem = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") - } - val path = origPath.makeQualified(fs) - if (!fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException( - s"Expected $path for be a directory with Parquet files/metadata") - } - ParquetRelation.enableLogForwarding() - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - // if this is a new table that was just created we will find only the metadata file - if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { - ParquetFileReader.readFooter(conf, metadataPath) - } else { - // there may be one or more Parquet files in the given directory - val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path)) - // TODO: for now we assume that all footers (if there is more than one) have identical - // metadata; we may want to add a check here at some point - if (footers.size() == 0) { - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path") - } - footers(0).getParquetMetadata - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 65ba1246fbf9a..624f2e2fa13f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -36,6 +36,7 @@ import parquet.schema.MessageType import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} /** @@ -64,10 +65,13 @@ case class ParquetTableScan( NewFileInputFormat.addInputPath(job, path) } - // Store Parquet schema in `Configuration` + // Store both requested and original schema in `Configuration` conf.set( - RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertFromAttributes(output).toString) + RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + ParquetTypesConverter.convertToString(output)) + conf.set( + RowWriteSupport.SPARK_ROW_SCHEMA, + ParquetTypesConverter.convertToString(relation.output)) // Store record filtering predicate in `Configuration` // Note 1: the input format ignores all predicates that cannot be expressed @@ -166,13 +170,18 @@ case class InsertIntoParquetTable( val job = new Job(sc.hadoopConfiguration) - ParquetOutputFormat.setWriteSupportClass( - job, - classOf[org.apache.spark.sql.parquet.RowWriteSupport]) + val writeSupport = + if (child.output.map(_.dataType).forall(_.isPrimitive)) { + logger.debug("Initializing MutableRowWriteSupport") + classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] + } else { + classOf[org.apache.spark.sql.parquet.RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupport) - // TODO: move that to function in object val conf = ContextUtil.getConfiguration(job) - conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, relation.parquetSchema.toString) + RowWriteSupport.setSchema(relation.output, conf) val fspath = new Path(relation.path) val fs = fspath.getFileSystem(conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 71ba0fecce47a..bfcbdeb34a92f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -29,21 +29,23 @@ import parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import com.google.common.io.BaseEncoding /** * A `parquet.io.api.RecordMaterializer` for Rows. * *@param root The root group converter for the record. */ -private[parquet] class RowRecordMaterializer(root: CatalystGroupConverter) +private[parquet] class RowRecordMaterializer(root: CatalystConverter) extends RecordMaterializer[Row] { - def this(parquetSchema: MessageType) = - this(new CatalystGroupConverter(ParquetTypesConverter.convertToAttributes(parquetSchema))) + def this(parquetSchema: MessageType, attributes: Seq[Attribute]) = + this(CatalystConverter.createRootConverter(parquetSchema, attributes)) override def getCurrentRecord: Row = root.getCurrentRecord - override def getRootConverter: GroupConverter = root + override def getRootConverter: GroupConverter = root.asInstanceOf[GroupConverter] } /** @@ -56,68 +58,94 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { stringMap: java.util.Map[String, String], fileSchema: MessageType, readContext: ReadContext): RecordMaterializer[Row] = { - log.debug(s"preparing for read with file schema $fileSchema") - new RowRecordMaterializer(readContext.getRequestedSchema) + log.debug(s"preparing for read with Parquet file schema $fileSchema") + // Note: this very much imitates AvroParquet + val parquetSchema = readContext.getRequestedSchema + var schema: Seq[Attribute] = null + + if (readContext.getReadSupportMetadata != null) { + // first try to find the read schema inside the metadata (can result from projections) + if ( + readContext + .getReadSupportMetadata + .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) != null) { + schema = ParquetTypesConverter.convertFromString( + readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) + } else { + // if unavailable, try the schema that was read originally from the file or provided + // during the creation of the Parquet relation + if (readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { + schema = ParquetTypesConverter.convertFromString( + readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) + } + } + } + // if both unavailable, fall back to deducing the schema from the given Parquet schema + if (schema == null) { + log.debug("falling back to Parquet read schema") + schema = ParquetTypesConverter.convertToAttributes(parquetSchema) + } + log.debug(s"list of attributes that will be read: $schema") + new RowRecordMaterializer(parquetSchema, schema) } override def init( configuration: Configuration, keyValueMetaData: java.util.Map[String, String], fileSchema: MessageType): ReadContext = { - val requested_schema_string = - configuration.get(RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, fileSchema.toString) - val requested_schema = - MessageTypeParser.parseMessageType(requested_schema_string) - log.debug(s"read support initialized for requested schema $requested_schema") - ParquetRelation.enableLogForwarding() - new ReadContext(requested_schema, keyValueMetaData) + var parquetSchema: MessageType = fileSchema + var metadata: java.util.Map[String, String] = new java.util.HashMap[String, String]() + val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) + + if (requestedAttributes != null) { + parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) + metadata.put( + RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + ParquetTypesConverter.convertToString(requestedAttributes)) + } + + val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + if (origAttributesStr != null) { + metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) + } + + return new ReadSupport.ReadContext(parquetSchema, metadata) } } private[parquet] object RowReadSupport { - val PARQUET_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + private def getRequestedSchema(configuration: Configuration): Seq[Attribute] = { + val schemaString = configuration.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + if (schemaString == null) null else ParquetTypesConverter.convertFromString(schemaString) + } } /** * A `parquet.hadoop.api.WriteSupport` for Row ojects. */ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { - def setSchema(schema: MessageType, configuration: Configuration) { - // for testing - this.schema = schema - // TODO: could use Attributes themselves instead of Parquet schema? - configuration.set( - RowWriteSupport.PARQUET_ROW_SCHEMA, - schema.toString) - configuration.set( - ParquetOutputFormat.WRITER_VERSION, - ParquetProperties.WriterVersion.PARQUET_1_0.toString) - } - - def getSchema(configuration: Configuration): MessageType = { - MessageTypeParser.parseMessageType(configuration.get(RowWriteSupport.PARQUET_ROW_SCHEMA)) - } - private var schema: MessageType = null - private var writer: RecordConsumer = null - private var attributes: Seq[Attribute] = null + private[parquet] var writer: RecordConsumer = null + private[parquet] var attributes: Seq[Attribute] = null override def init(configuration: Configuration): WriteSupport.WriteContext = { - schema = if (schema == null) getSchema(configuration) else schema - attributes = ParquetTypesConverter.convertToAttributes(schema) - log.debug(s"write support initialized for requested schema $schema") + attributes = if (attributes == null) RowWriteSupport.getSchema(configuration) else attributes + + log.debug(s"write support initialized for requested schema $attributes") ParquetRelation.enableLogForwarding() new WriteSupport.WriteContext( - schema, + ParquetTypesConverter.convertFromAttributes(attributes), new java.util.HashMap[java.lang.String, java.lang.String]()) } override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { writer = recordConsumer - log.debug(s"preparing for write with schema $schema") + log.debug(s"preparing for write with schema $attributes") } - // TODO: add groups (nested fields) override def write(record: Row): Unit = { if (attributes.size > record.size) { throw new IndexOutOfBoundsException( @@ -130,98 +158,176 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { // null values indicate optional fields but we do not check currently if (record(index) != null && record(index) != Nil) { writer.startField(attributes(index).name, index) - ParquetTypesConverter.consumeType(writer, attributes(index).dataType, record, index) + writeValue(attributes(index).dataType, record(index)) writer.endField(attributes(index).name, index) } index = index + 1 } writer.endMessage() } -} -private[parquet] object RowWriteSupport { - val PARQUET_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.schema" -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record to a `Row` object. - * - * @param schema The corresponding Catalyst schema in the form of a list of attributes. - */ -private[parquet] class CatalystGroupConverter( - schema: Seq[Attribute], - protected[parquet] val current: ParquetRelation.RowType) extends GroupConverter { - - def this(schema: Seq[Attribute]) = this(schema, new ParquetRelation.RowType(schema.length)) - - val converters: Array[Converter] = schema.map { - a => a.dataType match { - case ctype: NativeType => - // note: for some reason matching for StringType fails so use this ugly if instead - if (ctype == StringType) { - new CatalystPrimitiveStringConverter(this, schema.indexOf(a)) - } else { - new CatalystPrimitiveConverter(this, schema.indexOf(a)) - } - case _ => throw new RuntimeException( - s"unable to convert datatype ${a.dataType.toString} in CatalystGroupConverter") + private[parquet] def writeValue(schema: DataType, value: Any): Unit = { + if (value != null && value != Nil) { + schema match { + case t @ ArrayType(_) => writeArray( + t, + value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) + case t @ MapType(_, _) => writeMap( + t, + value.asInstanceOf[CatalystConverter.MapScalaType[_, _]]) + case t @ StructType(_) => writeStruct( + t, + value.asInstanceOf[CatalystConverter.StructScalaType[_]]) + case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) + } } - }.toArray + } - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = { + if (value != null && value != Nil) { + schema match { + case StringType => writer.addBinary( + Binary.fromByteArray( + value.asInstanceOf[String].getBytes("utf-8") + ) + ) + case IntegerType => writer.addInteger(value.asInstanceOf[Int]) + case ShortType => writer.addInteger(value.asInstanceOf[Int]) + case LongType => writer.addLong(value.asInstanceOf[Long]) + case ByteType => writer.addInteger(value.asInstanceOf[Int]) + case DoubleType => writer.addDouble(value.asInstanceOf[Double]) + case FloatType => writer.addFloat(value.asInstanceOf[Float]) + case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case _ => sys.error(s"Do not know how to writer $schema to consumer") + } + } + } - private[parquet] def getCurrentRecord: ParquetRelation.RowType = current + private[parquet] def writeStruct( + schema: StructType, + struct: CatalystConverter.StructScalaType[_]): Unit = { + if (struct != null && struct != Nil) { + val fields = schema.fields.toArray + writer.startGroup() + var i = 0 + while(i < fields.size) { + if (struct(i) != null && struct(i) != Nil) { + writer.startField(fields(i).name, i) + writeValue(fields(i).dataType, struct(i)) + writer.endField(fields(i).name, i) + } + i = i + 1 + } + writer.endGroup() + } + } - override def start(): Unit = { - var i = 0 - while (i < schema.length) { - current.setNullAt(i) - i = i + 1 + // TODO: support null values, see + // https://issues.apache.org/jira/browse/SPARK-1649 + private[parquet] def writeArray( + schema: ArrayType, + array: CatalystConverter.ArrayScalaType[_]): Unit = { + val elementType = schema.elementType + writer.startGroup() + if (array.size > 0) { + writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + var i = 0 + while(i < array.size) { + writeValue(elementType, array(i)) + i = i + 1 + } + writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } + writer.endGroup() } - override def end(): Unit = {} + // TODO: support null values, see + // https://issues.apache.org/jira/browse/SPARK-1649 + private[parquet] def writeMap( + schema: MapType, + map: CatalystConverter.MapScalaType[_, _]): Unit = { + writer.startGroup() + if (map.size > 0) { + writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0) + writer.startGroup() + writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) + for(key <- map.keys) { + writeValue(schema.keyType, key) + } + writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) + writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + for(value <- map.values) { + writeValue(schema.valueType, value) + } + writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + writer.endGroup() + writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0) + } + writer.endGroup() + } } -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveConverter( - parent: CatalystGroupConverter, - fieldIndex: Int) extends PrimitiveConverter { - // TODO: consider refactoring these together with ParquetTypesConverter - override def addBinary(value: Binary): Unit = - parent.getCurrentRecord.update(fieldIndex, value.getBytes) +// Optimized for non-nested rows +private[parquet] class MutableRowWriteSupport extends RowWriteSupport { + override def write(record: Row): Unit = { + if (attributes.size > record.size) { + throw new IndexOutOfBoundsException( + s"Trying to write more fields than contained in row (${attributes.size}>${record.size})") + } - override def addBoolean(value: Boolean): Unit = - parent.getCurrentRecord.setBoolean(fieldIndex, value) + var index = 0 + writer.startMessage() + while(index < attributes.size) { + // null values indicate optional fields but we do not check currently + if (record(index) != null && record(index) != Nil) { + writer.startField(attributes(index).name, index) + consumeType(attributes(index).dataType, record, index) + writer.endField(attributes(index).name, index) + } + index = index + 1 + } + writer.endMessage() + } - override def addDouble(value: Double): Unit = - parent.getCurrentRecord.setDouble(fieldIndex, value) + private def consumeType( + ctype: DataType, + record: Row, + index: Int): Unit = { + ctype match { + case StringType => writer.addBinary( + Binary.fromByteArray( + record(index).asInstanceOf[String].getBytes("utf-8") + ) + ) + case IntegerType => writer.addInteger(record.getInt(index)) + case ShortType => writer.addInteger(record.getShort(index)) + case LongType => writer.addLong(record.getLong(index)) + case ByteType => writer.addInteger(record.getByte(index)) + case DoubleType => writer.addDouble(record.getDouble(index)) + case FloatType => writer.addFloat(record.getFloat(index)) + case BooleanType => writer.addBoolean(record.getBoolean(index)) + case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") + } + } +} - override def addFloat(value: Float): Unit = - parent.getCurrentRecord.setFloat(fieldIndex, value) +private[parquet] object RowWriteSupport { + val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" - override def addInt(value: Int): Unit = - parent.getCurrentRecord.setInt(fieldIndex, value) + def getSchema(configuration: Configuration): Seq[Attribute] = { + val schemaString = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + if (schemaString == null) { + throw new RuntimeException("Missing schema!") + } + ParquetTypesConverter.convertFromString(schemaString) + } - override def addLong(value: Long): Unit = - parent.getCurrentRecord.setLong(fieldIndex, value) + def setSchema(schema: Seq[Attribute], configuration: Configuration) { + val encoded = ParquetTypesConverter.convertToString(schema) + configuration.set(SPARK_ROW_SCHEMA, encoded) + configuration.set( + ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_1_0.toString) + } } -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet strings (fixed-length byte arrays) - * into Catalyst Strings. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveStringConverter( - parent: CatalystGroupConverter, - fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.getCurrentRecord.setString(fieldIndex, value.toStringUsingUTF8) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 46c7172985642..1dc58633a2a68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.mapreduce.Job import parquet.example.data.{GroupWriter, Group} import parquet.example.data.simple.SimpleGroup -import parquet.hadoop.ParquetWriter +import parquet.hadoop.{ParquetReader, ParquetFileReader, ParquetWriter} import parquet.hadoop.api.WriteSupport import parquet.hadoop.api.WriteSupport.WriteContext +import parquet.hadoop.example.GroupReadSupport +import parquet.hadoop.util.ContextUtil import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} @@ -51,13 +56,13 @@ private[sql] object ParquetTestData { val testSchema = """message myrecord { - |optional boolean myboolean; - |optional int32 myint; - |optional binary mystring; - |optional int64 mylong; - |optional float myfloat; - |optional double mydouble; - |}""".stripMargin + optional boolean myboolean; + optional int32 myint; + optional binary mystring; + optional int64 mylong; + optional float myfloat; + optional double mydouble; + }""" // field names for test assertion error messages val testSchemaFieldNames = Seq( @@ -71,23 +76,23 @@ private[sql] object ParquetTestData { val subTestSchema = """ - |message myrecord { - |optional boolean myboolean; - |optional int64 mylong; - |} - """.stripMargin + message myrecord { + optional boolean myboolean; + optional int64 mylong; + } + """ val testFilterSchema = """ - |message myrecord { - |required boolean myboolean; - |required int32 myint; - |required binary mystring; - |required int64 mylong; - |required float myfloat; - |required double mydouble; - |} - """.stripMargin + message myrecord { + required boolean myboolean; + required int32 myint; + required binary mystring; + required int64 mylong; + required float myfloat; + required double mydouble; + } + """ // field names for test assertion error messages val subTestSchemaFieldNames = Seq( @@ -100,9 +105,110 @@ private[sql] object ParquetTestData { lazy val testData = new ParquetRelation(testDir.toURI.toString) + val testNestedSchema1 = + // based on blogpost example, source: + // https://blog.twitter.com/2013/dremel-made-simple-with-parquet + // note: instead of string we have to use binary (?) otherwise + // Parquet gives us: + // IllegalArgumentException: expected one of [INT64, INT32, BOOLEAN, + // BINARY, FLOAT, DOUBLE, INT96, FIXED_LEN_BYTE_ARRAY] + // Also repeated primitives seem tricky to convert (AvroParquet + // only uses them in arrays?) so only use at most one in each group + // and nothing else in that group (-> is mapped to array)! + // The "values" inside ownerPhoneNumbers is a keyword currently + // so that array types can be translated correctly. + """ + message AddressBook { + required binary owner; + optional group ownerPhoneNumbers { + repeated binary array; + } + optional group contacts { + repeated group array { + required binary name; + optional binary phoneNumber; + } + } + } + """ + + + val testNestedSchema2 = + """ + message TestNested2 { + required int32 firstInt; + optional int32 secondInt; + optional group longs { + repeated int64 array; + } + required group entries { + repeated group array { + required double value; + optional boolean truth; + } + } + optional group outerouter { + repeated group array { + repeated group array { + repeated int32 array; + } + } + } + } + """ + + val testNestedSchema3 = + """ + message TestNested3 { + required int32 x; + optional group booleanNumberPairs { + repeated group array { + required int32 key; + optional group value { + repeated group array { + required double nestedValue; + optional boolean truth; + } + } + } + } + } + """ + + val testNestedSchema4 = + """ + message TestNested4 { + required int32 x; + optional group data1 { + repeated group map { + required binary key; + required int32 value; + } + } + required group data2 { + repeated group map { + required binary key; + required group value { + required int64 payload1; + optional binary payload2; + } + } + } + } + """ + + val testNestedDir1 = Utils.createTempDir() + val testNestedDir2 = Utils.createTempDir() + val testNestedDir3 = Utils.createTempDir() + val testNestedDir4 = Utils.createTempDir() + + lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString) + lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString) + def writeFile() = { - testDir.delete + testDir.delete() val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet")) + val job = new Job() val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) val writeSupport = new TestGroupWriteSupport(schema) val writer = new ParquetWriter[Group](path, writeSupport) @@ -150,5 +256,149 @@ private[sql] object ParquetTestData { } writer.close() } + + def writeNestedFile1() { + // example data from https://blog.twitter.com/2013/dremel-made-simple-with-parquet + testNestedDir1.delete() + val path: Path = new Path(new Path(testNestedDir1.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema1) + + val r1 = new SimpleGroup(schema) + r1.add(0, "Julien Le Dem") + r1.addGroup(1) + .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 123 4567") + .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 666 1337") + .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "XXX XXX XXXX") + val contacts = r1.addGroup(2) + contacts.addGroup(0) + .append("name", "Dmitriy Ryaboy") + .append("phoneNumber", "555 987 6543") + contacts.addGroup(0) + .append("name", "Chris Aniszczyk") + + val r2 = new SimpleGroup(schema) + r2.add(0, "A. Nonymous") + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.write(r2) + writer.close() + } + + def writeNestedFile2() { + testNestedDir2.delete() + val path: Path = new Path(new Path(testNestedDir2.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema2) + + val r1 = new SimpleGroup(schema) + r1.add(0, 1) + r1.add(1, 7) + val longs = r1.addGroup(2) + longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME , 1.toLong << 32) + longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 33) + longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 34) + val booleanNumberPair = r1.addGroup(3).addGroup(0) + booleanNumberPair.add("value", 2.5) + booleanNumberPair.add("truth", false) + val top_level = r1.addGroup(4) + val second_level_a = top_level.addGroup(0) + val second_level_b = top_level.addGroup(0) + val third_level_aa = second_level_a.addGroup(0) + val third_level_ab = second_level_a.addGroup(0) + val third_level_c = second_level_b.addGroup(0) + third_level_aa.add( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + 7) + third_level_ab.add( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + 8) + third_level_c.add( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + 9) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.close() + } + + def writeNestedFile3() { + testNestedDir3.delete() + val path: Path = new Path(new Path(testNestedDir3.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema3) + + val r1 = new SimpleGroup(schema) + r1.add(0, 1) + val booleanNumberPairs = r1.addGroup(1) + val g1 = booleanNumberPairs.addGroup(0) + g1.add(0, 1) + val nested1 = g1.addGroup(1) + val ng1 = nested1.addGroup(0) + ng1.add(0, 1.5) + ng1.add(1, false) + val ng2 = nested1.addGroup(0) + ng2.add(0, 2.5) + ng2.add(1, true) + val g2 = booleanNumberPairs.addGroup(0) + g2.add(0, 2) + val ng3 = g2.addGroup(1) + .addGroup(0) + ng3.add(0, 3.5) + ng3.add(1, false) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.close() + } + + def writeNestedFile4() { + testNestedDir4.delete() + val path: Path = new Path(new Path(testNestedDir4.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema4) + + val r1 = new SimpleGroup(schema) + r1.add(0, 7) + val map1 = r1.addGroup(1) + val keyValue1 = map1.addGroup(0) + keyValue1.add(0, "key1") + keyValue1.add(1, 1) + val keyValue2 = map1.addGroup(0) + keyValue2.add(0, "key2") + keyValue2.add(1, 2) + val map2 = r1.addGroup(2) + val keyValue3 = map2.addGroup(0) + // TODO: currently only string key type supported + keyValue3.add(0, "seven") + val valueGroup1 = keyValue3.addGroup(1) + valueGroup1.add(0, 42.toLong) + valueGroup1.add(1, "the answer") + val keyValue4 = map2.addGroup(0) + // TODO: currently only string key type supported + keyValue4.add(0, "eight") + val valueGroup2 = keyValue4.addGroup(1) + valueGroup2.add(0, 49.toLong) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.close() + } + + // TODO: this is not actually used anywhere but useful for debugging + /* def readNestedFile(file: File, schemaString: String): Unit = { + val configuration = new Configuration() + val path = new Path(new Path(file.toURI), new Path("part-r-0.parquet")) + val fs: FileSystem = path.getFileSystem(configuration) + val schema: MessageType = MessageTypeParser.parseMessageType(schemaString) + assert(schema != null) + val outputStatus: FileStatus = fs.getFileStatus(new Path(path.toString)) + val footers = ParquetFileReader.readFooter(configuration, outputStatus) + assert(footers != null) + val reader = new ParquetReader(new Path(path.toString), new GroupReadSupport()) + val first = reader.read() + assert(first != null) + } */ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala new file mode 100644 index 0000000000000..f9046368e7ced --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.IOException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job + +import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} +import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} +import parquet.hadoop.util.ContextUtil +import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType} +import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns} +import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} +import parquet.schema.Type.Repetition + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.types._ + +// Implicits +import scala.collection.JavaConversions._ + +private[parquet] object ParquetTypesConverter extends Logging { + def isPrimitiveType(ctype: DataType): Boolean = + classOf[PrimitiveType] isAssignableFrom ctype.getClass + + def toPrimitiveDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { + case ParquetPrimitiveTypeName.BINARY => StringType + case ParquetPrimitiveTypeName.BOOLEAN => BooleanType + case ParquetPrimitiveTypeName.DOUBLE => DoubleType + case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) + case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 => IntegerType + case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 => + // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? + sys.error("Potential loss of precision: cannot convert INT96") + case _ => sys.error( + s"Unsupported parquet datatype $parquetType") + } + + /** + * Converts a given Parquet `Type` into the corresponding + * [[org.apache.spark.sql.catalyst.types.DataType]]. + * + * We apply the following conversion rules: + *
    + *
  • Primitive types are converter to the corresponding primitive type.
  • + *
  • Group types that have a single field that is itself a group, which has repetition + * level `REPEATED`, are treated as follows:
      + *
    • If the nested group has name `values`, the surrounding group is converted + * into an [[ArrayType]] with the corresponding field type (primitive or + * complex) as element type.
    • + *
    • If the nested group has name `map` and two fields (named `key` and `value`), + * the surrounding group is converted into a [[MapType]] + * with the corresponding key and value (value possibly complex) types. + * Note that we currently assume map values are not nullable.
    • + *
    • Other group types are converted into a [[StructType]] with the corresponding + * field types.
  • + *
+ * Note that fields are determined to be `nullable` if and only if their Parquet repetition + * level is not `REQUIRED`. + * + * @param parquetType The type to convert. + * @return The corresponding Catalyst type. + */ + def toDataType(parquetType: ParquetType): DataType = { + def correspondsToMap(groupType: ParquetGroupType): Boolean = { + if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { + false + } else { + // This mostly follows the convention in ``parquet.schema.ConversionPatterns`` + val keyValueGroup = groupType.getFields.apply(0).asGroupType() + keyValueGroup.getRepetition == Repetition.REPEATED && + keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME && + keyValueGroup.getFieldCount == 2 && + keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME && + keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME + } + } + + def correspondsToArray(groupType: ParquetGroupType): Boolean = { + groupType.getFieldCount == 1 && + groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME && + groupType.getFields.apply(0).getRepetition == Repetition.REPEATED + } + + if (parquetType.isPrimitive) { + toPrimitiveDataType(parquetType.asPrimitiveType.getPrimitiveTypeName) + } else { + val groupType = parquetType.asGroupType() + parquetType.getOriginalType match { + // if the schema was constructed programmatically there may be hints how to convert + // it inside the metadata via the OriginalType field + case ParquetOriginalType.LIST => { // TODO: check enums! + assert(groupType.getFieldCount == 1) + val field = groupType.getFields.apply(0) + new ArrayType(toDataType(field)) + } + case ParquetOriginalType.MAP => { + assert( + !groupType.getFields.apply(0).isPrimitive, + "Parquet Map type malformatted: expected nested group for map!") + val keyValueGroup = groupType.getFields.apply(0).asGroupType() + assert( + keyValueGroup.getFieldCount == 2, + "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") + val keyType = toDataType(keyValueGroup.getFields.apply(0)) + assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) + val valueType = toDataType(keyValueGroup.getFields.apply(1)) + assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) + new MapType(keyType, valueType) + } + case _ => { + // Note: the order of these checks is important! + if (correspondsToMap(groupType)) { // MapType + val keyValueGroup = groupType.getFields.apply(0).asGroupType() + val keyType = toDataType(keyValueGroup.getFields.apply(0)) + assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) + val valueType = toDataType(keyValueGroup.getFields.apply(1)) + assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) + new MapType(keyType, valueType) + } else if (correspondsToArray(groupType)) { // ArrayType + val elementType = toDataType(groupType.getFields.apply(0)) + new ArrayType(elementType) + } else { // everything else: StructType + val fields = groupType + .getFields + .map(ptype => new StructField( + ptype.getName, + toDataType(ptype), + ptype.getRepetition != Repetition.REQUIRED)) + new StructType(fields) + } + } + } + } + } + + /** + * For a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] return + * the name of the corresponding Parquet primitive type or None if the given type + * is not primitive. + * + * @param ctype The type to convert + * @return The name of the corresponding Parquet primitive type + */ + def fromPrimitiveDataType(ctype: DataType): + Option[ParquetPrimitiveTypeName] = ctype match { + case StringType => Some(ParquetPrimitiveTypeName.BINARY) + case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN) + case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE) + case ArrayType(ByteType) => + Some(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + case FloatType => Some(ParquetPrimitiveTypeName.FLOAT) + case IntegerType => Some(ParquetPrimitiveTypeName.INT32) + // There is no type for Byte or Short so we promote them to INT32. + case ShortType => Some(ParquetPrimitiveTypeName.INT32) + case ByteType => Some(ParquetPrimitiveTypeName.INT32) + case LongType => Some(ParquetPrimitiveTypeName.INT64) + case _ => None + } + + /** + * Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into + * the corresponding Parquet `Type`. + * + * The conversion follows the rules below: + *
    + *
  • Primitive types are converted into Parquet's primitive types.
  • + *
  • [[org.apache.spark.sql.catalyst.types.StructType]]s are converted + * into Parquet's `GroupType` with the corresponding field types.
  • + *
  • [[org.apache.spark.sql.catalyst.types.ArrayType]]s are converted + * into a 2-level nested group, where the outer group has the inner + * group as sole field. The inner group has name `values` and + * repetition level `REPEATED` and has the element type of + * the array as schema. We use Parquet's `ConversionPatterns` for this + * purpose.
  • + *
  • [[org.apache.spark.sql.catalyst.types.MapType]]s are converted + * into a nested (2-level) Parquet `GroupType` with two fields: a key + * type and a value type. The nested group has repetition level + * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` + * for this purpose
  • + *
+ * Parquet's repetition level is generally set according to the following rule: + *
    + *
  • If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or + * `MapType`, then the repetition level is set to `REPEATED`.
  • + *
  • Otherwise, if the attribute whose type is converted is `nullable`, the Parquet + * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
  • + *
+ * + *@param ctype The type to convert + * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]] + * whose type is converted + * @param nullable When true indicates that the attribute is nullable + * @param inArray When true indicates that this is a nested attribute inside an array. + * @return The corresponding Parquet type. + */ + def fromDataType( + ctype: DataType, + name: String, + nullable: Boolean = true, + inArray: Boolean = false): ParquetType = { + val repetition = + if (inArray) { + Repetition.REPEATED + } else { + if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED + } + val primitiveType = fromPrimitiveDataType(ctype) + if (primitiveType.isDefined) { + new ParquetPrimitiveType(repetition, primitiveType.get, name) + } else { + ctype match { + case ArrayType(elementType) => { + val parquetElementType = fromDataType( + elementType, + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + nullable = false, + inArray = true) + ConversionPatterns.listType(repetition, name, parquetElementType) + } + case StructType(structFields) => { + val fields = structFields.map { + field => fromDataType(field.dataType, field.name, field.nullable, inArray = false) + } + new ParquetGroupType(repetition, name, fields) + } + case MapType(keyType, valueType) => { + val parquetKeyType = + fromDataType( + keyType, + CatalystConverter.MAP_KEY_SCHEMA_NAME, + nullable = false, + inArray = false) + val parquetValueType = + fromDataType( + valueType, + CatalystConverter.MAP_VALUE_SCHEMA_NAME, + nullable = false, + inArray = false) + ConversionPatterns.mapType( + repetition, + name, + parquetKeyType, + parquetValueType) + } + case _ => sys.error(s"Unsupported datatype $ctype") + } + } + } + + def convertToAttributes(parquetSchema: ParquetType): Seq[Attribute] = { + parquetSchema + .asGroupType() + .getFields + .map( + field => + new AttributeReference( + field.getName, + toDataType(field), + field.getRepetition != Repetition.REQUIRED)()) + } + + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val fields = attributes.map( + attribute => + fromDataType(attribute.dataType, attribute.name, attribute.nullable)) + new MessageType("root", fields) + } + + def convertFromString(string: String): Seq[Attribute] = { + DataType(string) match { + case s: StructType => s.toAttributes + case other => sys.error(s"Can convert $string to row") + } + } + + def convertToString(schema: Seq[Attribute]): String = { + StructType.fromAttributes(schema).toString + } + + def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") + } + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") + } + val path = origPath.makeQualified(fs) + if (fs.exists(path) && !fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException(s"Expected to write to directory $path but found file") + } + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fs.exists(metadataPath)) { + try { + fs.delete(metadataPath, true) + } catch { + case e: IOException => + throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") + } + } + val extraMetadata = new java.util.HashMap[String, String]() + extraMetadata.put( + RowReadSupport.SPARK_METADATA_KEY, + ParquetTypesConverter.convertToString(attributes)) + // TODO: add extra data, e.g., table name, date, etc.? + + val parquetSchema: MessageType = + ParquetTypesConverter.convertFromAttributes(attributes) + val metaData: FileMetaData = new FileMetaData( + parquetSchema, + extraMetadata, + "Spark") + + ParquetRelation.enableLogForwarding() + ParquetFileWriter.writeMetadataFile( + conf, + path, + new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) + } + + /** + * Try to read Parquet metadata at the given Path. We first see if there is a summary file + * in the parent directory. If so, this is used. Else we read the actual footer at the given + * location. + * @param origPath The path at which we expect one (or more) Parquet files. + * @param configuration The Hadoop configuration to use. + * @return The `ParquetMetadata` containing among other things the schema. + */ + def readMetaData(origPath: Path, configuration: Option[Configuration]): ParquetMetadata = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") + } + val job = new Job() + val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) + val fs: FileSystem = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") + } + val path = origPath.makeQualified(fs) + if (!fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException( + s"Expected $path for be a directory with Parquet files/metadata") + } + ParquetRelation.enableLogForwarding() + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + // if this is a new table that was just created we will find only the metadata file + if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { + ParquetFileReader.readFooter(conf, metadataPath) + } else { + // there may be one or more Parquet files in the given directory + val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path)) + // TODO: for now we assume that all footers (if there is more than one) have identical + // metadata; we may want to add a check here at some point + if (footers.size() == 0) { + throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path") + } + footers(0).getParquetMetadata + } + } + + /** + * Reads in Parquet Metadata from the given path and tries to extract the schema + * (Catalyst attributes) from the application-specific key-value map. If this + * is empty it falls back to converting from the Parquet file schema which + * may lead to an upcast of types (e.g., {byte, short} to int). + * + * @param origPath The path at which we expect one (or more) Parquet files. + * @param conf The Hadoop configuration to use. + * @return A list of attributes that make up the schema. + */ + def readSchemaFromFile(origPath: Path, conf: Option[Configuration]): Seq[Attribute] = { + val keyValueMetadata: java.util.Map[String, String] = + readMetaData(origPath, conf) + .getFileMetaData + .getKeyValueMetaData + if (keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { + convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) + } else { + val attributes = convertToAttributes( + readMetaData(origPath, conf).getFileMetaData.getSchema) + log.warn(s"Falling back to schema conversion from Parquet types; result: $attributes") + attributes + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d7f6abaf5d381..ef84ead2e6e8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql -import org.scalatest.FunSuite - import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -class QueryTest extends FunSuite { +class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. * @param rdd the [[SchemaRDD]] to be executed diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 9fff7222fe840..020baf0c7ec6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -22,6 +22,7 @@ import scala.beans.BeanProperty import org.scalatest.FunSuite import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.TestSQLContext // Implicits @@ -111,4 +112,48 @@ class JavaSQLSuite extends FunSuite { """.stripMargin).collect.head.row === Seq.fill(8)(null)) } + + test("loads JSON datasets") { + val jsonString = + """{"string":"this is a simple string.", + "integer":10, + "long":21474836470, + "bigInteger":92233720368547758070, + "double":1.7976931348623157E308, + "boolean":true, + "null":null + }""".replaceAll("\n", " ") + val rdd = javaCtx.parallelize(jsonString :: Nil) + + var schemaRDD = javaSqlCtx.jsonRDD(rdd) + + schemaRDD.registerAsTable("jsonTable1") + + assert( + javaSqlCtx.sql("select * from jsonTable1").collect.head.row === + Seq(BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")) + + val file = getTempFilePath("json") + val path = file.toString + rdd.saveAsTextFile(path) + schemaRDD = javaSqlCtx.jsonFile(path) + + schemaRDD.registerAsTable("jsonTable2") + + assert( + javaSqlCtx.sql("select * from jsonTable2").collect.head.row === + Seq(BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala new file mode 100644 index 0000000000000..10bd9f08f0238 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -0,0 +1,519 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.TestSQLContext._ + +protected case class Schema(output: Seq[Attribute]) extends LeafNode + +class JsonSuite extends QueryTest { + import TestJsonData._ + TestJsonData + + test("Type promotion") { + def checkTypePromotion(expected: Any, actual: Any) { + assert(expected.getClass == actual.getClass, + s"Failed to promote ${actual.getClass} to ${expected.getClass}.") + assert(expected == actual, + s"Promoted value ${actual}(${actual.getClass}) does not equal the expected value " + + s"${expected}(${expected.getClass}).") + } + + val intNumber: Int = 2147483647 + checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) + checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) + checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) + checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType)) + + val longNumber: Long = 9223372036854775807L + checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) + checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) + checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType)) + + val doubleNumber: Double = 1.7976931348623157E308d + checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) + checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) + } + + test("Get compatible type") { + def checkDataType(t1: DataType, t2: DataType, expected: DataType) { + var actual = compatibleType(t1, t2) + assert(actual == expected, + s"Expected $expected as the most general data type for $t1 and $t2, found $actual") + actual = compatibleType(t2, t1) + assert(actual == expected, + s"Expected $expected as the most general data type for $t1 and $t2, found $actual") + } + + // NullType + checkDataType(NullType, BooleanType, BooleanType) + checkDataType(NullType, IntegerType, IntegerType) + checkDataType(NullType, LongType, LongType) + checkDataType(NullType, DoubleType, DoubleType) + checkDataType(NullType, DecimalType, DecimalType) + checkDataType(NullType, StringType, StringType) + checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) + checkDataType(NullType, StructType(Nil), StructType(Nil)) + checkDataType(NullType, NullType, NullType) + + // BooleanType + checkDataType(BooleanType, BooleanType, BooleanType) + checkDataType(BooleanType, IntegerType, StringType) + checkDataType(BooleanType, LongType, StringType) + checkDataType(BooleanType, DoubleType, StringType) + checkDataType(BooleanType, DecimalType, StringType) + checkDataType(BooleanType, StringType, StringType) + checkDataType(BooleanType, ArrayType(IntegerType), StringType) + checkDataType(BooleanType, StructType(Nil), StringType) + + // IntegerType + checkDataType(IntegerType, IntegerType, IntegerType) + checkDataType(IntegerType, LongType, LongType) + checkDataType(IntegerType, DoubleType, DoubleType) + checkDataType(IntegerType, DecimalType, DecimalType) + checkDataType(IntegerType, StringType, StringType) + checkDataType(IntegerType, ArrayType(IntegerType), StringType) + checkDataType(IntegerType, StructType(Nil), StringType) + + // LongType + checkDataType(LongType, LongType, LongType) + checkDataType(LongType, DoubleType, DoubleType) + checkDataType(LongType, DecimalType, DecimalType) + checkDataType(LongType, StringType, StringType) + checkDataType(LongType, ArrayType(IntegerType), StringType) + checkDataType(LongType, StructType(Nil), StringType) + + // DoubleType + checkDataType(DoubleType, DoubleType, DoubleType) + checkDataType(DoubleType, DecimalType, DecimalType) + checkDataType(DoubleType, StringType, StringType) + checkDataType(DoubleType, ArrayType(IntegerType), StringType) + checkDataType(DoubleType, StructType(Nil), StringType) + + // DoubleType + checkDataType(DecimalType, DecimalType, DecimalType) + checkDataType(DecimalType, StringType, StringType) + checkDataType(DecimalType, ArrayType(IntegerType), StringType) + checkDataType(DecimalType, StructType(Nil), StringType) + + // StringType + checkDataType(StringType, StringType, StringType) + checkDataType(StringType, ArrayType(IntegerType), StringType) + checkDataType(StringType, StructType(Nil), StringType) + + // ArrayType + checkDataType(ArrayType(IntegerType), ArrayType(IntegerType), ArrayType(IntegerType)) + checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType)) + checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType)) + checkDataType(ArrayType(IntegerType), StructType(Nil), StringType) + + // StructType + checkDataType(StructType(Nil), StructType(Nil), StructType(Nil)) + checkDataType( + StructType(StructField("f1", IntegerType, true) :: Nil), + StructType(StructField("f1", IntegerType, true) :: Nil), + StructType(StructField("f1", IntegerType, true) :: Nil)) + checkDataType( + StructType(StructField("f1", IntegerType, true) :: Nil), + StructType(Nil), + StructType(StructField("f1", IntegerType, true) :: Nil)) + checkDataType( + StructType( + StructField("f1", IntegerType, true) :: + StructField("f2", IntegerType, true) :: Nil), + StructType(StructField("f1", LongType, true) :: Nil) , + StructType( + StructField("f1", LongType, true) :: + StructField("f2", IntegerType, true) :: Nil)) + checkDataType( + StructType( + StructField("f1", IntegerType, true) :: Nil), + StructType( + StructField("f2", IntegerType, true) :: Nil), + StructType( + StructField("f1", IntegerType, true) :: + StructField("f2", IntegerType, true) :: Nil)) + checkDataType( + StructType( + StructField("f1", IntegerType, true) :: Nil), + DecimalType, + StringType) + } + + test("Primitive field and type inferring") { + val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) + + val expectedSchema = + AttributeReference("bigInteger", DecimalType, true)() :: + AttributeReference("boolean", BooleanType, true)() :: + AttributeReference("double", DoubleType, true)() :: + AttributeReference("integer", IntegerType, true)() :: + AttributeReference("long", LongType, true)() :: + AttributeReference("null", StringType, true)() :: + AttributeReference("string", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } + + test("Complex field and type inferring") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType) + + val expectedSchema = + AttributeReference("arrayOfArray1", ArrayType(ArrayType(StringType)), true)() :: + AttributeReference("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true)() :: + AttributeReference("arrayOfBigInteger", ArrayType(DecimalType), true)() :: + AttributeReference("arrayOfBoolean", ArrayType(BooleanType), true)() :: + AttributeReference("arrayOfDouble", ArrayType(DoubleType), true)() :: + AttributeReference("arrayOfInteger", ArrayType(IntegerType), true)() :: + AttributeReference("arrayOfLong", ArrayType(LongType), true)() :: + AttributeReference("arrayOfNull", ArrayType(StringType), true)() :: + AttributeReference("arrayOfString", ArrayType(StringType), true)() :: + AttributeReference("arrayOfStruct", ArrayType( + StructType(StructField("field1", BooleanType, true) :: + StructField("field2", StringType, true) :: Nil)), true)() :: + AttributeReference("struct", StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType, true) :: Nil), true)() :: + AttributeReference("structWithArrayFields", StructType( + StructField("field1", ArrayType(IntegerType), true) :: + StructField("field2", ArrayType(StringType), true) :: Nil), true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + // Access elements of a primitive array. + checkAnswer( + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), + ("str1", "str2", null) :: Nil + ) + + // Access an array of null values. + checkAnswer( + sql("select arrayOfNull from jsonTable"), + Seq(Seq(null, null, null, null)) :: Nil + ) + + // Access elements of a BigInteger array (we use DecimalType internally). + checkAnswer( + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), + (BigDecimal("922337203685477580700"), BigDecimal("-922337203685477580800"), null) :: Nil + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), + (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), + (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + ) + + // Access elements of an array inside a filed with the type of ArrayType(ArrayType). + checkAnswer( + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), + ("str2", 2.1) :: Nil + ) + + // Access elements of an array of structs. + checkAnswer( + sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2] from jsonTable"), + (true :: "str1" :: Nil, false :: null :: Nil, null) :: Nil + ) + + // Access a struct and fields inside of it. + checkAnswer( + sql("select struct, struct.field1, struct.field2 from jsonTable"), + ( + Seq(true, BigDecimal("92233720368547758070")), + true, + BigDecimal("92233720368547758070")) :: Nil + ) + + // Access an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), + (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + ) + + // Access elements of an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), + (5, null) :: Nil + ) + } + + ignore("Complex field and type inferring (Ignored)") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType) + jsonSchemaRDD.registerAsTable("jsonTable") + + // Right now, "field1" and "field2" are treated as aliases. We should fix it. + checkAnswer( + sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + (true, "str1") :: Nil + ) + + // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. + // Getting all values of a specific field from an array of structs. + checkAnswer( + sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), + (Seq(true, false), Seq("str1", null)) :: Nil + ) + } + + test("Type conflict in primitive field values") { + val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) + + val expectedSchema = + AttributeReference("num_bool", StringType, true)() :: + AttributeReference("num_num_1", LongType, true)() :: + AttributeReference("num_num_2", DecimalType, true)() :: + AttributeReference("num_num_3", DoubleType, true)() :: + AttributeReference("num_str", StringType, true)() :: + AttributeReference("str_bool", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + ("true", 11L, null, 1.1, "13.1", "str1") :: + ("12", null, BigDecimal("21474836470.9"), null, null, "true") :: + ("false", 21474836470L, BigDecimal("92233720368547758070"), 100, "str1", "false") :: + (null, 21474836570L, BigDecimal(1.1), 21474836470L, "92233720368547758070", null) :: Nil + ) + + // Number and Boolean conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_bool - 10 from jsonTable where num_bool > 11"), + 2 + ) + + // Widening to LongType + checkAnswer( + sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), + Seq(21474836370L) :: Seq(21474836470L) :: Nil + ) + + checkAnswer( + sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), + Seq(-89) :: Seq(21474836370L) :: Seq(21474836470L) :: Nil + ) + + // Widening to DecimalType + checkAnswer( + sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), + Seq(BigDecimal("21474836472.1")) :: Seq(BigDecimal("92233720368547758071.2")) :: Nil + ) + + // Widening to DoubleType + checkAnswer( + sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), + Seq(101.2) :: Seq(21474836471.2) :: Nil + ) + + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 14"), + 92233720368547758071.2 + ) + + // String and Boolean conflict: resolve the type as string. + checkAnswer( + sql("select * from jsonTable where str_bool = 'str1'"), + ("true", 11L, null, 1.1, "13.1", "str1") :: Nil + ) + } + + ignore("Type conflict in primitive field values (Ignored)") { + val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) + jsonSchemaRDD.registerAsTable("jsonTable") + + // Right now, the analyzer does not promote strings in a boolean expreesion. + // Number and Boolean conflict: resolve the type as boolean in this query. + checkAnswer( + sql("select num_bool from jsonTable where NOT num_bool"), + false + ) + + checkAnswer( + sql("select str_bool from jsonTable where NOT str_bool"), + false + ) + + // Right now, the analyzer does not know that num_bool should be treated as a boolean. + // Number and Boolean conflict: resolve the type as boolean in this query. + checkAnswer( + sql("select num_bool from jsonTable where num_bool"), + true + ) + + checkAnswer( + sql("select str_bool from jsonTable where str_bool"), + false + ) + + // Right now, we have a parsing error. + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), + BigDecimal("92233720368547758061.2") + ) + + // The plan of the following DSL is + // Project [(CAST(num_str#65:4, DoubleType) + 1.2) AS num#78] + // Filter (CAST(CAST(num_str#65:4, DoubleType), DecimalType) > 92233720368547758060) + // ExistingRdd [num_bool#61,num_num_1#62L,num_num_2#63,num_num_3#64,num_str#65,str_bool#66] + // We should directly cast num_str to DecimalType and also need to do the right type promotion + // in the Project. + checkAnswer( + jsonSchemaRDD. + where('num_str > BigDecimal("92233720368547758060")). + select('num_str + 1.2 as Symbol("num")), + BigDecimal("92233720368547758061.2") + ) + + // The following test will fail. The type of num_str is StringType. + // So, to evaluate num_str + 1.2, we first need to use Cast to convert the type. + // In our test data, one value of num_str is 13.1. + // The result of (CAST(num_str#65:4, DoubleType) + 1.2) for this value is 14.299999999999999, + // which is not 14.3. + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 13"), + Seq(14.3) :: Seq(92233720368547758071.2) :: Nil + ) + } + + test("Type conflict in complex field values") { + val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) + + val expectedSchema = + AttributeReference("array", ArrayType(IntegerType), true)() :: + AttributeReference("num_struct", StringType, true)() :: + AttributeReference("str_array", StringType, true)() :: + AttributeReference("struct", StructType( + StructField("field", StringType, true) :: Nil), true)() :: + AttributeReference("struct_array", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + (Seq(), "11", "[1,2,3]", Seq(null), "[]") :: + (null, """{"field":false}""", null, null, "{}") :: + (Seq(4, 5, 6), null, "str", Seq(null), "[7,8,9]") :: + (Seq(7), "{}","[str1,str2,33]", Seq("str"), """{"field":true}""") :: Nil + ) + } + + test("Type conflict in array elements") { + val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) + + val expectedSchema = + AttributeReference("array", ArrayType(StringType), true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", + """{"field":str}""")) :: Nil + ) + + // Treat an element as a number. + checkAnswer( + sql("select array[0] + 1 from jsonTable"), + 2 + ) + } + + test("Handling missing fields") { + val jsonSchemaRDD = jsonRDD(missingFields) + + val expectedSchema = + AttributeReference("a", BooleanType, true)() :: + AttributeReference("b", LongType, true)() :: + AttributeReference("c", ArrayType(IntegerType), true)() :: + AttributeReference("d", StructType( + StructField("field", BooleanType, true) :: Nil), true)() :: + AttributeReference("e", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + } + + test("Loading a JSON dataset from a text file") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + val jsonSchemaRDD = jsonFile(path) + + val expectedSchema = + AttributeReference("bigInteger", DecimalType, true)() :: + AttributeReference("boolean", BooleanType, true)() :: + AttributeReference("double", DoubleType, true)() :: + AttributeReference("integer", IntegerType, true)() :: + AttributeReference("long", LongType, true)() :: + AttributeReference("null", StringType, true)() :: + AttributeReference("string", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala new file mode 100644 index 0000000000000..065e04046e8a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import org.apache.spark.sql.test.TestSQLContext + +object TestJsonData { + + val primitiveFieldAndType = + TestSQLContext.sparkContext.parallelize( + """{"string":"this is a simple string.", + "integer":10, + "long":21474836470, + "bigInteger":92233720368547758070, + "double":1.7976931348623157E308, + "boolean":true, + "null":null + }""" :: Nil) + + val complexFieldAndType = + TestSQLContext.sparkContext.parallelize( + """{"struct":{"field1": true, "field2": 92233720368547758070}, + "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, + "arrayOfString":["str1", "str2"], + "arrayOfInteger":[1, 2147483647, -2147483648], + "arrayOfLong":[21474836470, 9223372036854775807, -9223372036854775808], + "arrayOfBigInteger":[922337203685477580700, -922337203685477580800], + "arrayOfDouble":[1.2, 1.7976931348623157E308, 4.9E-324, 2.2250738585072014E-308], + "arrayOfBoolean":[true, false, true], + "arrayOfNull":[null, null, null, null], + "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}], + "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], + "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] + }""" :: Nil) + + val primitiveFieldValueTypeConflict = + TestSQLContext.sparkContext.parallelize( + """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, + "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: + """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, + "num_bool":12, "num_str":null, "str_bool":true}""" :: + """{"num_num_1":21474836470, "num_num_2":92233720368547758070, "num_num_3": 100, + "num_bool":false, "num_str":"str1", "str_bool":false}""" :: + """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, + "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) + + val complexFieldValueTypeConflict = + TestSQLContext.sparkContext.parallelize( + """{"num_struct":11, "str_array":[1, 2, 3], + "array":[], "struct_array":[], "struct": {}}""" :: + """{"num_struct":{"field":false}, "str_array":null, + "array":null, "struct_array":{}, "struct": null}""" :: + """{"num_struct":null, "str_array":"str", + "array":[4, 5, 6], "struct_array":[7, 8, 9], "struct": {"field":null}}""" :: + """{"num_struct":{}, "str_array":["str1", "str2", 33], + "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) + + val arrayElementTypeConflict = + TestSQLContext.sparkContext.parallelize( + """{"array": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}]}""" :: Nil) + + val missingFields = + TestSQLContext.sparkContext.parallelize( + """{"a":true}""" :: + """{"b":21474836470}""" :: + """{"c":[33, 44]}""" :: + """{"d":{"field":true}}""" :: + """{"e":"str"}""" :: Nil) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 9810520bb9ae6..7714eb1b5628a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -19,26 +19,23 @@ package org.apache.spark.sql.parquet import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} -import org.apache.hadoop.fs.{Path, FileSystem} -import org.apache.hadoop.mapreduce.Job - import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil import parquet.schema.MessageTypeParser +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job +import org.apache.spark.SparkContext import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} +import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.TestData -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.expressions.Equals -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils -// Implicits -import org.apache.spark.sql.test.TestSQLContext._ case class TestRDDEntry(key: Int, value: String) @@ -56,15 +53,36 @@ case class OptionalReflectData( doubleField: Option[Double], booleanField: Option[Boolean]) +case class Nested(i: Int, s: String) + +case class Data(array: Seq[Int], nested: Nested) + +case class AllDataTypes( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { - import TestData._ TestData // Load test data tables. var testRDD: SchemaRDD = null + // TODO: remove this once SqlParser can parse nested select statements + var nestedParserSqlContext: NestedParserSQLContext = null + override def beforeAll() { + nestedParserSqlContext = new NestedParserSQLContext(TestSQLContext.sparkContext) ParquetTestData.writeFile() ParquetTestData.writeFilterFile() + ParquetTestData.writeNestedFile1() + ParquetTestData.writeNestedFile2() + ParquetTestData.writeNestedFile3() + ParquetTestData.writeNestedFile4() testRDD = parquetFile(ParquetTestData.testDir.toString) testRDD.registerAsTable("testsource") parquetFile(ParquetTestData.testFilterDir.toString) @@ -74,9 +92,33 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA override def afterAll() { Utils.deleteRecursively(ParquetTestData.testDir) Utils.deleteRecursively(ParquetTestData.testFilterDir) + Utils.deleteRecursively(ParquetTestData.testNestedDir1) + Utils.deleteRecursively(ParquetTestData.testNestedDir2) + Utils.deleteRecursively(ParquetTestData.testNestedDir3) + Utils.deleteRecursively(ParquetTestData.testNestedDir4) // here we should also unregister the table?? } + test("Read/Write All Types") { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val range = (0 to 255) + TestSQLContext.sparkContext.parallelize(range) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + .saveAsParquetFile(tempDir) + val result = parquetFile(tempDir).collect() + range.foreach { + i => + assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") + assert(result(i).getInt(1) === i) + assert(result(i).getLong(2) === i.toLong) + assert(result(i).getFloat(3) === i.toFloat) + assert(result(i).getDouble(4) === i.toDouble) + assert(result(i).getShort(5) === i.toShort) + assert(result(i).getByte(6) === i.toByte) + assert(result(i).getBoolean(7) === (i % 2 == 0)) + } + } + test("self-join parquet files") { val x = ParquetTestData.testData.as('x) val y = ParquetTestData.testData.as('y) @@ -154,7 +196,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA path, TestSQLContext.sparkContext.hadoopConfiguration) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path) + val metaData = ParquetTypesConverter.readMetaData(path, Some(ContextUtil.getConfiguration(job))) assert(metaData != null) ParquetTestData .testData @@ -197,10 +239,37 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value in line $i") } Utils.deleteRecursively(file) - assert(true) } - test("insert (appending) to same table via Scala API") { + test("Insert (overwrite) via Scala API") { + val dirname = Utils.createTempDir() + val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + source_rdd.registerAsTable("source") + val dest_rdd = createParquetFile[TestRDDEntry](dirname.toString) + dest_rdd.registerAsTable("dest") + sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() + val rdd_copy1 = sql("SELECT * FROM dest").collect() + assert(rdd_copy1.size === 100) + assert(rdd_copy1(0).apply(0) === 1) + assert(rdd_copy1(0).apply(1) === "val_1") + // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is + // executed twice otherwise?! + sql("INSERT INTO dest SELECT * FROM source") + val rdd_copy2 = sql("SELECT * FROM dest").collect() + assert(rdd_copy2.size === 200) + assert(rdd_copy2(0).apply(0) === 1) + assert(rdd_copy2(0).apply(1) === "val_1") + assert(rdd_copy2(99).apply(0) === 100) + assert(rdd_copy2(99).apply(1) === "val_100") + assert(rdd_copy2(100).apply(0) === 1) + assert(rdd_copy2(100).apply(1) === "val_1") + Utils.deleteRecursively(dirname) + } + + test("Insert (appending) to same table via Scala API") { + // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is + // executed twice otherwise?! sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) @@ -245,7 +314,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("create RecordFilter for simple predicates") { val attribute1 = new AttributeReference("first", IntegerType, false)() - val predicate1 = new Equals(attribute1, new Literal(1, IntegerType)) + val predicate1 = new EqualTo(attribute1, new Literal(1, IntegerType)) val filter1 = ParquetFilters.createFilter(predicate1) assert(filter1.isDefined) assert(filter1.get.predicate == predicate1, "predicates do not match") @@ -363,4 +432,272 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") assert(query.collect().size === 10) } + + test("Importing nested Parquet file (Addressbook)") { + val result = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + .collect() + assert(result != null) + assert(result.size === 2) + val first_record = result(0) + val second_record = result(1) + assert(first_record != null) + assert(second_record != null) + assert(first_record.size === 3) + assert(second_record(1) === null) + assert(second_record(2) === null) + assert(second_record(0) === "A. Nonymous") + assert(first_record(0) === "Julien Le Dem") + val first_owner_numbers = first_record(1) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + val first_contacts = first_record(2) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(first_owner_numbers != null) + assert(first_owner_numbers(0) === "555 123 4567") + assert(first_owner_numbers(2) === "XXX XXX XXXX") + assert(first_contacts(0) + .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2) + val first_contacts_entry_one = first_contacts(0) + .asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy") + assert(first_contacts_entry_one(1) === "555 987 6543") + val first_contacts_entry_two = first_contacts(1) + .asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(first_contacts_entry_two(0) === "Chris Aniszczyk") + } + + test("Importing nested Parquet file (nested numbers)") { + val result = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir2.toString) + .toSchemaRDD + .collect() + assert(result.size === 1, "number of top-level rows incorrect") + assert(result(0).size === 5, "number of fields in row incorrect") + assert(result(0)(0) === 1) + assert(result(0)(1) === 7) + val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult1.size === 3) + assert(subresult1(0) === (1.toLong << 32)) + assert(subresult1(1) === (1.toLong << 33)) + assert(subresult1(2) === (1.toLong << 34)) + val subresult2 = result(0)(3) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(subresult2.size === 2) + assert(subresult2(0) === 2.5) + assert(subresult2(1) === false) + val subresult3 = result(0)(4) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult3.size === 2) + assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2) + val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) + assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) + assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1) + assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + } + + test("Simple query on addressbook") { + val data = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect() + assert(tmp.size === 1) + assert(tmp(0)(0) === "Julien Le Dem") + } + + test("Projection in addressbook") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + data.registerAsTable("data") + val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") + val tmp = query.collect() + assert(tmp.size === 2) + assert(tmp(0).size === 2) + assert(tmp(0)(0) === "Julien Le Dem") + assert(tmp(0)(1) === "Chris Aniszczyk") + assert(tmp(1)(0) === "A. Nonymous") + assert(tmp(1)(1) === null) + } + + test("Simple query on nested int data") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir2.toString) + .toSchemaRDD + data.registerAsTable("data") + val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() + assert(result1.size === 1) + assert(result1(0).size === 1) + assert(result1(0)(0) === 2.5) + val result2 = nestedParserSqlContext.sql("SELECT entries[0] FROM data").collect() + assert(result2.size === 1) + val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(subresult1.size === 2) + assert(subresult1(0) === 2.5) + assert(subresult1(1) === false) + val result3 = nestedParserSqlContext.sql("SELECT outerouter FROM data").collect() + val subresult2 = result3(0)(0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) + assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) + assert(result3(0)(0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + } + + test("nested structs") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir3.toString) + .toSchemaRDD + data.registerAsTable("data") + val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() + assert(result1.size === 1) + assert(result1(0).size === 1) + assert(result1(0)(0) === false) + val result2 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() + assert(result2.size === 1) + assert(result2(0).size === 1) + assert(result2(0)(0) === true) + val result3 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() + assert(result3.size === 1) + assert(result3(0).size === 1) + assert(result3(0)(0) === false) + } + + test("simple map") { + val data = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir4.toString) + .toSchemaRDD + data.registerAsTable("mapTable") + val result1 = sql("SELECT data1 FROM mapTable").collect() + assert(result1.size === 1) + assert(result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, _]] + .getOrElse("key1", 0) === 1) + assert(result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, _]] + .getOrElse("key2", 0) === 2) + val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() + assert(result2(0)(0) === 1) + } + + test("map with struct values") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir4.toString) + .toSchemaRDD + data.registerAsTable("mapTable") + val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() + assert(result1.size === 1) + val entry1 = result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("seven", null) + assert(entry1 != null) + assert(entry1(0) === 42) + assert(entry1(1) === "the answer") + val entry2 = result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("eight", null) + assert(entry2 != null) + assert(entry2(0) === 49) + assert(entry2(1) === null) + val result2 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() + assert(result2.size === 1) + assert(result2(0)(0) === 42.toLong) + assert(result2(0)(1) === "the answer") + } + + test("Writing out Addressbook and reading it back in") { + // TODO: find out why CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME + // has no effect in this test case + val tmpdir = Utils.createTempDir() + Utils.deleteRecursively(tmpdir) + val result = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + result.saveAsParquetFile(tmpdir.toString) + nestedParserSqlContext + .parquetFile(tmpdir.toString) + .toSchemaRDD + .registerAsTable("tmpcopy") + val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() + assert(tmpdata.size === 2) + assert(tmpdata(0).size === 2) + assert(tmpdata(0)(0) === "Julien Le Dem") + assert(tmpdata(0)(1) === "Chris Aniszczyk") + assert(tmpdata(1)(0) === "A. Nonymous") + assert(tmpdata(1)(1) === null) + Utils.deleteRecursively(tmpdir) + } + + test("Writing out Map and reading it back in") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir4.toString) + .toSchemaRDD + val tmpdir = Utils.createTempDir() + Utils.deleteRecursively(tmpdir) + data.saveAsParquetFile(tmpdir.toString) + nestedParserSqlContext + .parquetFile(tmpdir.toString) + .toSchemaRDD + .registerAsTable("tmpmapcopy") + val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() + assert(result1.size === 1) + assert(result1(0)(0) === 2) + val result2 = nestedParserSqlContext.sql("SELECT data2 FROM tmpmapcopy").collect() + assert(result2.size === 1) + val entry1 = result2(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("seven", null) + assert(entry1 != null) + assert(entry1(0) === 42) + assert(entry1(1) === "the answer") + val entry2 = result2(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("eight", null) + assert(entry2 != null) + assert(entry2(0) === 49) + assert(entry2(1) === null) + val result3 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() + assert(result3.size === 1) + assert(result3(0)(0) === 42.toLong) + assert(result3(0)(1) === "the answer") + Utils.deleteRecursively(tmpdir) + } +} + +// TODO: the code below is needed temporarily until the standard parser is able to parse +// nested field expressions correctly +class NestedParserSQLContext(@transient override val sparkContext: SparkContext) extends SQLContext(sparkContext) { + override protected[sql] val parser = new NestedSqlParser() +} + +class NestedSqlLexical(override val keywords: Seq[String]) extends SqlLexical(keywords) { + override def identChar = letter | elem('_') + delimiters += (".") +} + +class NestedSqlParser extends SqlParser { + override val lexical = new NestedSqlLexical(reservedWords) + + override protected lazy val baseExpression: PackratParser[Expression] = + expression ~ "[" ~ expression <~ "]" ^^ { + case base ~ _ ~ ordinal => GetItem(base, ordinal) + } | + expression ~ "." ~ ident ^^ { + case base ~ _ ~ fieldName => GetField(base, fieldName) + } | + TRUE ^^^ Literal(true, BooleanType) | + FALSE ^^^ Literal(false, BooleanType) | + cast | + "(" ~> expression <~ ")" | + function | + "-" ~> literal ^^ UnaryMinus | + ident ^^ UnresolvedAttribute | + "*" ^^^ Star(None) | + literal } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 96e0ec5136331..7695242a81601 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} +import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand /** * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is @@ -250,7 +251,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType) + ShortType, DecimalType, TimestampType) protected def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => @@ -291,6 +292,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * execution is simply passed back to Hive. */ def stringResult(): Seq[String] = executedPlan match { + case describeHiveTableCommand: DescribeHiveTableCommand => + // If it is a describe command for a Hive table, we want to have the output format + // be similar with Hive. + describeHiveTableCommand.hiveString case command: PhysicalCommand => command.sideEffectResult.map(_.toString) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 68284344afd55..f923d68932f83 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -208,7 +208,9 @@ object HiveMetastoreTypes extends RegexParsers { } protected lazy val structType: Parser[DataType] = - "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ StructType + "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ { + case fields => new StructType(fields) + } protected lazy val dataType: Parser[DataType] = arrayType | diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b745d8ffd8f17..c69e3dba6b467 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -52,7 +52,6 @@ private[hive] case class AddFile(filePath: String) extends Command private[hive] object HiveQl { protected val nativeCommands = Seq( "TOK_DESCFUNCTION", - "TOK_DESCTABLE", "TOK_DESCDATABASE", "TOK_SHOW_TABLESTATUS", "TOK_SHOWDATABASES", @@ -120,6 +119,12 @@ private[hive] object HiveQl { "TOK_SWITCHDATABASE" ) + // Commands that we do not need to explain. + protected val noExplainCommands = Seq( + "TOK_CREATETABLE", + "TOK_DESCTABLE" + ) ++ nativeCommands + /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations * similar to [[catalyst.trees.TreeNode]]. @@ -199,6 +204,9 @@ private[hive] object HiveQl { class ParseException(sql: String, cause: Throwable) extends Exception(s"Failed to parse: $sql", cause) + class SemanticException(msg: String) + extends Exception(s"Error in semantic analysis: $msg") + /** * Returns the AST for the given SQL string. */ @@ -362,13 +370,20 @@ private[hive] object HiveQl { } } + protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { + val (db, tableName) = + tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + + (db, tableName) + } + protected def nodeToPlan(node: Node): LogicalPlan = node match { // Just fake explain for any of the native commands. - case Token("TOK_EXPLAIN", explainArgs) if nativeCommands contains explainArgs.head.getText => - ExplainCommand(NoRelation) - // Create tables aren't native commands due to CTAS queries, but we still don't need to - // explain them. - case Token("TOK_EXPLAIN", explainArgs) if explainArgs.head.getText == "TOK_CREATETABLE" => + case Token("TOK_EXPLAIN", explainArgs) + if noExplainCommands.contains(explainArgs.head.getText) => ExplainCommand(NoRelation) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. @@ -377,6 +392,39 @@ private[hive] object HiveQl { // TODO: support EXTENDED? ExplainCommand(nodeToPlan(query)) + case Token("TOK_DESCTABLE", describeArgs) => + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + val Some(tableType) :: formatted :: extended :: pretty :: Nil = + getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) + if (formatted.isDefined || pretty.isDefined) { + // FORMATTED and PRETTY are not supported and this statement will be treated as + // a Hive native command. + NativePlaceholder + } else { + tableType match { + case Token("TOK_TABTYPE", nameParts) if nameParts.size == 1 => { + nameParts.head match { + case Token(".", dbName :: tableName :: Nil) => + // It is describing a table with the format like "describe db.table". + // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. + val (db, tableName) = extractDbNameTableName(nameParts.head) + DescribeCommand( + UnresolvedRelation(db, tableName, None), extended.isDefined) + case Token(".", dbName :: tableName :: colName :: Nil) => + // It is describing a column with the format like "describe db.table column". + NativePlaceholder + case tableName => + // It is describing a table with the format like "describe table". + DescribeCommand( + UnresolvedRelation(None, tableName.getText, None), + extended.isDefined) + } + } + // All other cases. + case _ => NativePlaceholder + } + } + case Token("TOK_CREATETABLE", children) if children.collect { case t@Token("TOK_QUERY", _) => t }.nonEmpty => // TODO: Parse other clauses. @@ -414,11 +462,8 @@ private[hive] object HiveQl { s"Unhandled clauses: ${notImplemented.flatten.map(dumpTree(_)).mkString("\n")}") } - val (db, tableName) = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } + val (db, tableName) = extractDbNameTableName(tableNameParts) + InsertIntoCreatedTable(db, tableName, nodeToPlan(query)) // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command. @@ -438,6 +483,7 @@ private[hive] object HiveQl { whereClause :: groupByClause :: orderByClause :: + havingClause :: sortByClause :: clusterByClause :: distributeByClause :: @@ -452,6 +498,7 @@ private[hive] object HiveQl { "TOK_WHERE", "TOK_GROUPBY", "TOK_ORDERBY", + "TOK_HAVING", "TOK_SORTBY", "TOK_CLUSTERBY", "TOK_DISTRIBUTEBY", @@ -534,21 +581,34 @@ private[hive] object HiveQl { val withDistinct = if (selectDistinctClause.isDefined) Distinct(withProject) else withProject + val withHaving = havingClause.map { h => + + if (groupByClause == None) { + throw new SemanticException("HAVING specified without GROUP BY") + } + + val havingExpr = h.getChildren.toSeq match { + case Seq(hexpr) => nodeToExpr(hexpr) + } + + Filter(Cast(havingExpr, BooleanType), withDistinct) + }.getOrElse(withDistinct) + val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), withDistinct) + Sort(totalOrdering.getChildren.map(nodeToSortOrder), withHaving) case (None, Some(perPartitionOrdering), None, None) => - SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withDistinct) + SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withHaving) case (None, None, Some(partitionExprs), None) => - Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), - Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), - Repartition(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) - case (None, None, None, None) => withDistinct + Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) + case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") } @@ -656,7 +716,7 @@ private[hive] object HiveQl { val joinConditions = joinExpressions.sliding(2).map { case Seq(c1, c2) => - val predicates = (c1, c2).zipped.map { case (e1, e2) => Equals(e1, e2): Expression } + val predicates = (c1, c2).zipped.map { case (e1, e2) => EqualTo(e1, e2): Expression } predicates.reduceLeft(And) }.toBuffer @@ -736,11 +796,7 @@ private[hive] object HiveQl { val Some(tableNameParts) :: partitionClause :: Nil = getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - val (db, tableName) = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } + val (db, tableName) = extractDbNameTableName(tableNameParts) val partitionKeys = partitionClause.map(_.getChildren.map { // Parse partitions. We also make keys case insensitive. @@ -811,6 +867,8 @@ private[hive] object HiveQl { val IN = "(?i)IN".r val DIV = "(?i)DIV".r val BETWEEN = "(?i)BETWEEN".r + val WHEN = "(?i)WHEN".r + val CASE = "(?i)CASE".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -884,9 +942,9 @@ private[hive] object HiveQl { case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) /* Comparisons */ - case Token("=", left :: right:: Nil) => Equals(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) + case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) + case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) @@ -917,6 +975,21 @@ private[hive] object HiveQl { case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) + /* Case statements */ + case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => + CaseWhen(branches.map(nodeToExpr)) + case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => + val transformed = branches.drop(1).sliding(2, 2).map { + case Seq(condVal, value) => + // FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval(). + // Hence effectful / non-deterministic key expressions are *not* supported at the moment. + // We should consider adding new Expressions to get around this. + Seq(EqualTo(nodeToExpr(branches(0)), nodeToExpr(condVal)), + nodeToExpr(value)) + case Seq(elseVal) => Seq(nodeToExpr(elseVal)) + }.toSeq.reduce(_ ++ _) + CaseWhen(transformed) + /* Complex datatype manipulation */ case Token("[", child :: ordinal :: Nil) => GetItem(nodeToExpr(child), nodeToExpr(ordinal)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 0ac0ee9071f36..af7687b40429b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -81,6 +81,16 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil + case describe: logical.DescribeCommand => { + val resolvedTable = context.executePlan(describe.table).analyzed + resolvedTable match { + case t: MetastoreRelation => + Seq(DescribeHiveTableCommand( + t, describe.output, describe.isExtended)(context)) + case o: LogicalPlan => + Seq(DescribeCommand(planLater(o), describe.output)(context)) + } + } case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala index 6df76fa825101..c9ee162191c96 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala @@ -31,12 +31,6 @@ class JavaHiveContext(sparkContext: JavaSparkContext) extends JavaSQLContext(spa /** * Executes a query expressed in HiveQL, returning the result as a JavaSchemaRDD. */ - def hql(hqlQuery: String): JavaSchemaRDD = { - val result = new JavaSchemaRDD(sqlContext, HiveQl.parseSql(hqlQuery)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but do not perform any execution. - result.queryExecution.toRdd - result - } + def hql(hqlQuery: String): JavaSchemaRDD = + new JavaSchemaRDD(sqlContext, HiveQl.parseSql(hqlQuery)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala index a839231449161..2de2db28a7e04 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive.execution import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive} +import org.apache.hadoop.hive.ql.metadata.formatting.MetaDataFormatUtils import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption @@ -344,12 +346,16 @@ case class InsertIntoHiveTable( writer.commitJob() } + override def execute() = result + /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. + * + * Note: this is run once and then kept to avoid double insertions. */ - def execute() = { + private lazy val result: RDD[Row] = { val childRdd = child.execute() assert(childRdd != null) @@ -367,12 +373,18 @@ case class InsertIntoHiveTable( ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) iter.map { row => - // Casts Strings to HiveVarchars when necessary. - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector) - val mappedRow = row.zip(fieldOIs).map(wrap) + var i = 0 + while (i < row.length) { + // Casts Strings to HiveVarchars when necessary. + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } - serializer.serialize(mappedRow.toArray, standardOI) + serializer.serialize(outputData, standardOI) } } @@ -452,3 +464,61 @@ case class NativeCommand( override def otherCopyArgs = context :: Nil } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class DescribeHiveTableCommand( + table: MetastoreRelation, + output: Seq[Attribute], + isExtended: Boolean)( + @transient context: HiveContext) + extends LeafNode with Command { + + // Strings with the format like Hive. It is used for result comparison in our unit tests. + lazy val hiveString: Seq[String] = { + val alignment = 20 + val delim = "\t" + + sideEffectResult.map { + case (name, dataType, comment) => + String.format("%-" + alignment + "s", name) + delim + + String.format("%-" + alignment + "s", dataType) + delim + + String.format("%-" + alignment + "s", Option(comment).getOrElse("None")) + } + } + + override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + // Trying to mimic the format of Hive's output. But not exactly the same. + var results: Seq[(String, String, String)] = Nil + + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols + results ++= columns.map(field => (field.getName, field.getType, field.getComment)) + if (!partitionColumns.isEmpty) { + val partColumnInfo = + partitionColumns.map(field => (field.getName, field.getType, field.getComment)) + results ++= + partColumnInfo ++ + Seq(("# Partition Information", "", "")) ++ + Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ + partColumnInfo + } + + if (isExtended) { + results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) + } + + results + } + + override def execute(): RDD[Row] = { + val rows = sideEffectResult.map { + case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + } + context.sparkContext.parallelize(rows, 1) + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 572902042337f..ad5e24c62c621 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -187,7 +187,8 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) val primitiveClasses = Seq( Integer.TYPE, classOf[java.lang.Integer], classOf[java.lang.String], java.lang.Double.TYPE, classOf[java.lang.Double], java.lang.Long.TYPE, classOf[java.lang.Long], - classOf[HiveDecimal], java.lang.Byte.TYPE, classOf[java.lang.Byte] + classOf[HiveDecimal], java.lang.Byte.TYPE, classOf[java.lang.Byte], + classOf[java.sql.Timestamp] ) val matchingConstructor = argClass.getConstructors.find { c => c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head) @@ -334,6 +335,9 @@ private[hive] trait HiveInspectors { case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case StructType(fields) => + ObjectInspectorFactory.getStandardStructObjectInspector( + fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) } def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUdf/part-00000 new file mode 100755 index 0000000000000..240a5c1a63c5c Binary files /dev/null and b/sql/hive/src/test/resources/data/files/testUdf/part-00000 differ diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 b/sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 new file mode 100644 index 0000000000000..816fe57d162dc --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 @@ -0,0 +1,14 @@ +NULL +3 +3 +3 +NULL +NULL +3 +3 +3 +3 +NULL +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 b/sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 new file mode 100644 index 0000000000000..4cca081e6e294 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 @@ -0,0 +1,14 @@ +4 +3 +3 +3 +4 +4 +3 +3 +3 +3 +4 +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 b/sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 new file mode 100644 index 0000000000000..8d0416a8f8d9c --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 @@ -0,0 +1,14 @@ +2 +3 +3 +3 +2 +2 +3 +3 +3 +3 +NULL +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e b/sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e new file mode 100644 index 0000000000000..6ed452bcd870d --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e @@ -0,0 +1,14 @@ +2 +3 +3 +3 +2 +2 +3 +3 +3 +3 +0 +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 b/sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 new file mode 100644 index 0000000000000..3f5a2fbbe99fd --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 @@ -0,0 +1,14 @@ +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL diff --git a/sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b b/sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b new file mode 100644 index 0000000000000..e1ca6e76d1f8f --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b @@ -0,0 +1,14 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +3 +0 +0 +0 diff --git a/sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 b/sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 new file mode 100644 index 0000000000000..896207fdbcf3d --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 @@ -0,0 +1,14 @@ +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +3 +NULL +NULL +NULL diff --git a/sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 b/sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 new file mode 100644 index 0000000000000..e1ca6e76d1f8f --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 @@ -0,0 +1,14 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +3 +0 +0 +0 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-10-dbc23736a61d9482d13cacada02a7a09 b/sql/hive/src/test/resources/golden/timestamp_udf-10-dbc23736a61d9482d13cacada02a7a09 new file mode 100644 index 0000000000000..b3c4eec4c2209 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-10-dbc23736a61d9482d13cacada02a7a09 @@ -0,0 +1 @@ +2011-05-06 07:08:09.1234567 2011-05-06 02:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-11-442cf850a0cc1f1dcfdeaeffbffb2c35 b/sql/hive/src/test/resources/golden/timestamp_udf-11-442cf850a0cc1f1dcfdeaeffbffb2c35 new file mode 100644 index 0000000000000..f69f13ed1fb94 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-11-442cf850a0cc1f1dcfdeaeffbffb2c35 @@ -0,0 +1 @@ +2011-05-06 07:08:09.1234567 2011-05-06 02:08:09.1234567 2011-05-06 07:08:09.1234567 2011-05-06 02:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-12-51959036fd4ac4f1e24f4e06eb9b0b6 b/sql/hive/src/test/resources/golden/timestamp_udf-12-51959036fd4ac4f1e24f4e06eb9b0b6 new file mode 100644 index 0000000000000..f14f17e692822 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-12-51959036fd4ac4f1e24f4e06eb9b0b6 @@ -0,0 +1 @@ +2011-05-06 07:08:09.1234567 2011-05-06 12:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-13-6ab3f356deaf807e8accc37e1f4849a b/sql/hive/src/test/resources/golden/timestamp_udf-13-6ab3f356deaf807e8accc37e1f4849a new file mode 100644 index 0000000000000..7881bff731be1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-13-6ab3f356deaf807e8accc37e1f4849a @@ -0,0 +1 @@ +2011-05-06 07:08:09.1234567 2011-05-06 12:08:09.1234567 2011-05-06 07:08:09.1234567 2011-05-06 12:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-14-c745a1016461403526d44928a269c1de b/sql/hive/src/test/resources/golden/timestamp_udf-14-c745a1016461403526d44928a269c1de new file mode 100644 index 0000000000000..2c5e9e9656202 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-14-c745a1016461403526d44928a269c1de @@ -0,0 +1 @@ +1304690889 2011 5 6 6 18 7 8 9 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-15-7ab76c4458c7f78038c8b1df0fdeafbe b/sql/hive/src/test/resources/golden/timestamp_udf-15-7ab76c4458c7f78038c8b1df0fdeafbe new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-15-7ab76c4458c7f78038c8b1df0fdeafbe @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-16-b36e87e17ca24d82072220bff559c718 b/sql/hive/src/test/resources/golden/timestamp_udf-16-b36e87e17ca24d82072220bff559c718 new file mode 100644 index 0000000000000..816f56e43eaba --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-16-b36e87e17ca24d82072220bff559c718 @@ -0,0 +1 @@ +0 3333 -3333 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-17-dad44d2d4a421286e9da080271bd2639 b/sql/hive/src/test/resources/golden/timestamp_udf-17-dad44d2d4a421286e9da080271bd2639 new file mode 100644 index 0000000000000..a4182d1e39db9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-17-dad44d2d4a421286e9da080271bd2639 @@ -0,0 +1 @@ +2011-05-06 02:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-18-cb033ecad964a2623bc633ac1d3f752a b/sql/hive/src/test/resources/golden/timestamp_udf-18-cb033ecad964a2623bc633ac1d3f752a new file mode 100644 index 0000000000000..02ccd3a2e97ce --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-18-cb033ecad964a2623bc633ac1d3f752a @@ -0,0 +1 @@ +2011-05-06 12:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-19-79914c5347620c6e62a8e0b9a95984af b/sql/hive/src/test/resources/golden/timestamp_udf-19-79914c5347620c6e62a8e0b9a95984af new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-20-59fc1842a23369235d42ed040d45fb3d b/sql/hive/src/test/resources/golden/timestamp_udf-20-59fc1842a23369235d42ed040d45fb3d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-4-80ce02ec84ee8abcb046367ca37279cc b/sql/hive/src/test/resources/golden/timestamp_udf-4-80ce02ec84ee8abcb046367ca37279cc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-5-1124399033bcadf3874fb48f593392d b/sql/hive/src/test/resources/golden/timestamp_udf-5-1124399033bcadf3874fb48f593392d new file mode 100644 index 0000000000000..2c5e9e9656202 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-5-1124399033bcadf3874fb48f593392d @@ -0,0 +1 @@ +1304690889 2011 5 6 6 18 7 8 9 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-6-5810193ce35d38c23f4fc4b4979d60a4 b/sql/hive/src/test/resources/golden/timestamp_udf-6-5810193ce35d38c23f4fc4b4979d60a4 new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-6-5810193ce35d38c23f4fc4b4979d60a4 @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-7-250e640a6a818f989f3f3280b00f64f9 b/sql/hive/src/test/resources/golden/timestamp_udf-7-250e640a6a818f989f3f3280b00f64f9 new file mode 100644 index 0000000000000..816f56e43eaba --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-7-250e640a6a818f989f3f3280b00f64f9 @@ -0,0 +1 @@ +0 3333 -3333 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-8-975df43df015d86422965af456f87a94 b/sql/hive/src/test/resources/golden/timestamp_udf-8-975df43df015d86422965af456f87a94 new file mode 100644 index 0000000000000..a4182d1e39db9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-8-975df43df015d86422965af456f87a94 @@ -0,0 +1 @@ +2011-05-06 02:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-9-287614364eaa3fb82aad08c6b62cc938 b/sql/hive/src/test/resources/golden/timestamp_udf-9-287614364eaa3fb82aad08c6b62cc938 new file mode 100644 index 0000000000000..02ccd3a2e97ce --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-9-287614364eaa3fb82aad08c6b62cc938 @@ -0,0 +1 @@ +2011-05-06 12:08:09.1234567 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala new file mode 100644 index 0000000000000..10c8069a624e6 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.api.java + +import scala.util.Try + +import org.scalatest.FunSuite + +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.execution.ExplainCommand +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.TestSQLContext + +// Implicits +import scala.collection.JavaConversions._ + +class JavaHiveQLSuite extends FunSuite { + lazy val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) + + // There is a little trickery here to avoid instantiating two HiveContexts in the same JVM + lazy val javaHiveCtx = new JavaHiveContext(javaCtx) { + override val sqlContext = TestHive + } + + ignore("SELECT * FROM src") { + assert( + javaHiveCtx.hql("SELECT * FROM src").collect().map(_.getInt(0)) === + TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) + } + + private val explainCommandClassName = + classOf[ExplainCommand].getSimpleName.stripSuffix("$") + + def isExplanation(result: JavaSchemaRDD) = { + val explanation = result.collect().map(_.getString(0)) + explanation.size > 1 && explanation.head.startsWith(explainCommandClassName) + } + + ignore("Query Hive native command execution result") { + val tableName = "test_native_commands" + + assertResult(0) { + javaHiveCtx.hql(s"DROP TABLE IF EXISTS $tableName").count() + } + + assertResult(0) { + javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + } + + javaHiveCtx.hql("SHOW TABLES").registerAsTable("show_tables") + + assert( + javaHiveCtx + .hql("SELECT result FROM show_tables") + .collect() + .map(_.getString(0)) + .contains(tableName)) + + assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { + javaHiveCtx.hql(s"DESCRIBE $tableName").registerAsTable("describe_table") + + javaHiveCtx + .hql("SELECT result FROM describe_table") + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + .toArray + } + + assert(isExplanation(javaHiveCtx.hql( + s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) + + TestHive.reset() + } + + ignore("Exactly once semantics for DDL and command statements") { + val tableName = "test_exactly_once" + val q0 = javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)") + + // If the table was not created, the following assertion would fail + assert(Try(TestHive.table(tableName)).isSuccess) + + // If the CREATE TABLE command got executed again, the following assertion would fail + assert(Try(q0.count()).isSuccess) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 24c929ff7430d..08ef4d9b6bb93 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -144,6 +144,12 @@ abstract class HiveComparisonTest case _: SetCommand => Seq("0") case _: LogicalNativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") case _: ExplainCommand => answer + case _: DescribeCommand => + // Filter out non-deterministic lines and lines which do not have actual results but + // can introduce problems because of the way Hive formats these lines. + // Then, remove empty lines. Do not sort the results. + answer.filterNot( + r => nonDeterministicLine(r) || ignoredLine(r)).map(_.trim).filterNot(_ == "") case plan => if (isSorted(plan)) answer else answer.sorted } orderedAnswer.map(cleanPaths) @@ -169,6 +175,16 @@ abstract class HiveComparisonTest protected def nonDeterministicLine(line: String) = nonDeterministicLineIndicators.exists(line contains _) + // This list contains indicators for those lines which do not have actual results and we + // want to ignore. + lazy val ignoredLineIndicators = Seq( + "# Partition Information", + "# col_name" + ) + + protected def ignoredLine(line: String) = + ignoredLineIndicators.exists(line contains _) + /** * Removes non-deterministic paths from `str` so cached answers will compare correctly. */ @@ -329,11 +345,17 @@ abstract class HiveComparisonTest if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) { - val hivePrintOut = s"== HIVE - ${hive.size} row(s) ==" +: preparedHive + val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") + println("hive output") + hive.foreach(println) + + println("catalyst printout") + catalyst.foreach(println) + if (recomputeCache) { logger.warn(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ee194dbcb77b2..cdfc2d0c17384 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -78,7 +78,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_merge", "alter_concatenate_indexed_table", "protectmode2", - "describe_table", + //"describe_table", "describe_comment_nonascii", "udf5", "udf_java_method", @@ -177,7 +177,16 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // After stop taking the `stringOrError` route, exceptions are thrown from these cases. // See SPARK-2129 for details. "join_view", - "mergejoins_mixed" + "mergejoins_mixed", + + // Returning the result of a describe state as a JSON object is not supported. + "describe_table_json", + "describe_database_json", + "describe_formatted_view_partitioned_json", + + // Hive returns the results of describe as plain text. Comments with multiple lines + // introduce extra lines in the Hive results, which make the result comparison fail. + "describe_comment_indent" ) /** @@ -292,11 +301,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "default_partition_name", "delimiter", "desc_non_existent_tbl", - "describe_comment_indent", - "describe_database_json", "describe_formatted_view_partitioned", - "describe_formatted_view_partitioned_json", - "describe_table_json", "diff_part_input_formats", "disable_file_format_check", "drop_function", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 0d656c556965d..80185098bf24f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -21,13 +21,21 @@ import scala.util.Try import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SchemaRDD, execution, Row} +import org.apache.spark.sql.{SchemaRDD, Row} + +case class TestData(a: Int, b: String) /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest { + test("CREATE TABLE AS runs once") { + hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() + assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + "Incorrect number of rows in created table") + } + createQueryTest("between", "SELECT * FROM src WHERE key Between 1 and 2") @@ -164,12 +172,47 @@ class HiveQuerySuite extends HiveComparisonTest { hql("SELECT * FROM src").toString } - private val explainCommandClassName = - classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$") + createQueryTest("case statements with key #1", + "SELECT (CASE 1 WHEN 2 THEN 3 END) FROM src where key < 15") + + createQueryTest("case statements with key #2", + "SELECT (CASE key WHEN 2 THEN 3 ELSE 0 END) FROM src WHERE key < 15") + + createQueryTest("case statements with key #3", + "SELECT (CASE key WHEN 2 THEN 3 WHEN NULL THEN 4 END) FROM src WHERE key < 15") + + createQueryTest("case statements with key #4", + "SELECT (CASE key WHEN 2 THEN 3 WHEN NULL THEN 4 ELSE 0 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #1", + "SELECT (CASE WHEN key > 2 THEN 3 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #2", + "SELECT (CASE WHEN key > 2 THEN 3 ELSE 4 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #3", + "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #4", + "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") + + test("implement identity function using case statement") { + val actual = hql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet + val expected = hql("SELECT key FROM src").collect().toSet + assert(actual === expected) + } + + // TODO: adopt this test when Spark SQL has the functionality / framework to report errors. + // See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion. + ignore("non-boolean conditions in a CaseWhen are illegal") { + intercept[Exception] { + hql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() + } + } def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.size == 1 && explanation.head.startsWith(explainCommandClassName) + explanation.size > 1 && explanation.head.startsWith("Physical execution plan") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -181,28 +224,51 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.reset() } - test("Query Hive native command execution result") { - val tableName = "test_native_commands" + test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { + val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) + .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} + + TestHive.sparkContext.parallelize(fixture).registerAsTable("having_test") + + val results = + hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + .collect() + .map(x => Pair(x.getString(0), x.getInt(1))) + + assert(results === Array(Pair("foo", 4))) + + TestHive.reset() + } - val q0 = hql(s"DROP TABLE IF EXISTS $tableName") - assert(q0.count() == 0) + test("SPARK-2180: HAVING without GROUP BY raises exception") { + intercept[Exception] { + hql("SELECT value, attr FROM having_test HAVING attr > 3") + } + } + + test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { + val results = hql("select key, count(*) c from src group by key having c").collect() + } - val q1 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") - assert(q1.count() == 0) + test("Query Hive native command execution result") { + val tableName = "test_native_commands" - val q2 = hql("SHOW TABLES") - val tables = q2.select('result).collect().map { case Row(table: String) => table } - assert(tables.contains(tableName)) + assertResult(0) { + hql(s"DROP TABLE IF EXISTS $tableName").count() + } - val q3 = hql(s"DESCRIBE $tableName") - assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - q3.select('result).collect().map { case Row(fieldDesc: String) => - fieldDesc.split("\t").map(_.trim) - } + assertResult(0) { + hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } - val q4 = hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key") - assert(isExplanation(q4)) + assert( + hql("SHOW TABLES") + .select('result) + .collect() + .map(_.getString(0)) + .contains(tableName)) + + assert(isExplanation(hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() } @@ -218,6 +284,97 @@ class HiveQuerySuite extends HiveComparisonTest { assert(Try(q0.count()).isSuccess) } + test("DESCRIBE commands") { + hql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") + + hql( + """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') + |SELECT key, value + """.stripMargin) + + // Describe a table + assertResult( + Array( + Array("key", "int", null), + Array("value", "string", null), + Array("dt", "string", null), + Array("# Partition Information", "", ""), + Array("# col_name", "data_type", "comment"), + Array("dt", "string", null)) + ) { + hql("DESCRIBE test_describe_commands1") + .select('col_name, 'data_type, 'comment) + .collect() + } + + // Describe a table with a fully qualified table name + assertResult( + Array( + Array("key", "int", null), + Array("value", "string", null), + Array("dt", "string", null), + Array("# Partition Information", "", ""), + Array("# col_name", "data_type", "comment"), + Array("dt", "string", null)) + ) { + hql("DESCRIBE default.test_describe_commands1") + .select('col_name, 'data_type, 'comment) + .collect() + } + + // Describe a column is a native command + assertResult(Array(Array("value", "string", "from deserializer"))) { + hql("DESCRIBE test_describe_commands1 value") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + } + + // Describe a column is a native command + assertResult(Array(Array("value", "string", "from deserializer"))) { + hql("DESCRIBE default.test_describe_commands1 value") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + } + + // Describe a partition is a native command + assertResult( + Array( + Array("key", "int", "None"), + Array("value", "string", "None"), + Array("dt", "string", "None"), + Array("", "", ""), + Array("# Partition Information", "", ""), + Array("# col_name", "data_type", "comment"), + Array("", "", ""), + Array("dt", "string", "None")) + ) { + hql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + } + + // Describe a registered temporary table. + val testData: SchemaRDD = + TestHive.sparkContext.parallelize( + TestData(1, "str1") :: + TestData(1, "str2") :: Nil) + testData.registerAsTable("test_describe_commands2") + + assertResult( + Array( + Array("# Registered as a temporary table", null, null), + Array("a", "IntegerType", null), + Array("b", "StringType", null)) + ) { + hql("DESCRIBE test_describe_commands2") + .select('col_name, 'data_type, 'comment) + .collect() + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" @@ -310,3 +467,6 @@ class HiveQuerySuite extends HiveComparisonTest { // since they modify /clear stuff. } + +// for SPARK-2180 test +case class HavingRow(key: Int, value: String, attr: Int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index e030c8ee3dfc8..7436de264a1e1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} +import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.hive.test.TestHive + /** - * A set of tests that validate type promotion rules. + * A set of tests that validate type promotion and coercion rules. */ class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") @@ -28,4 +32,23 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") } } + + test("[SPARK-2210] boolean cast on boolean value should be removed") { + val q = "select cast(cast(key=0 as boolean) as boolean) from src" + val project = TestHive.hql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + + // No cast expression introduced + project.transformAllExpressions { case c: Cast => + fail(s"unexpected cast $c") + c + } + + // Only one equality check + var numEquals = 0 + project.transformAllExpressions { case e: EqualTo => + numEquals += 1 + e + } + assert(numEquals === 1) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala new file mode 100644 index 0000000000000..a9e3f42a3adfc --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.hadoop.conf.Configuration +import org.apache.spark.SparkContext._ +import java.util +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector} +import java.util.Properties +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import scala.collection.JavaConversions._ +import java.io.{DataOutput, DataInput} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject + +/** + * A test suite for Hive custom UDFs. + */ +class HiveUdfSuite extends HiveComparisonTest { + + TestHive.hql( + """ + |CREATE EXTERNAL TABLE hiveUdfTestTable ( + | pair STRUCT + |) + |PARTITIONED BY (partition STRING) + |ROW FORMAT SERDE '%s' + |STORED AS SEQUENCEFILE + """.stripMargin.format(classOf[PairSerDe].getName) + ) + + TestHive.hql( + "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'" + .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile) + ) + + TestHive.hql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) + + TestHive.hql("SELECT testUdf(pair) FROM hiveUdfTestTable") + + TestHive.hql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") +} + +class TestPair(x: Int, y: Int) extends Writable with Serializable { + def this() = this(0, 0) + var entry: (Int, Int) = (x, y) + + override def write(output: DataOutput): Unit = { + output.writeInt(entry._1) + output.writeInt(entry._2) + } + + override def readFields(input: DataInput): Unit = { + val x = input.readInt() + val y = input.readInt() + entry = (x, y) + } +} + +class PairSerDe extends AbstractSerDe { + override def initialize(p1: Configuration, p2: Properties): Unit = {} + + override def getObjectInspector: ObjectInspector = { + ObjectInspectorFactory + .getStandardStructObjectInspector( + Seq("pair"), + Seq(ObjectInspectorFactory.getStandardStructObjectInspector( + Seq("id", "value"), + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector)) + )) + } + + override def getSerializedClass: Class[_ <: Writable] = classOf[TestPair] + + override def getSerDeStats: SerDeStats = null + + override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = null + + override def deserialize(value: Writable): AnyRef = { + val pair = value.asInstanceOf[TestPair] + + val row = new util.ArrayList[util.ArrayList[AnyRef]] + row.add(new util.ArrayList[AnyRef](2)) + row(0).add(Integer.valueOf(pair.entry._1)) + row(0).add(Integer.valueOf(pair.entry._2)) + + row + } +} + +class PairUdf extends GenericUDF { + override def initialize(p1: Array[ObjectInspector]): ObjectInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + Seq("id", "value"), + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector) + ) + + override def evaluate(args: Array[DeferredObject]): AnyRef = { + println("Type = %s".format(args(0).getClass.getName)) + Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) + } + + override def getDisplayString(p1: Array[String]): String = "" +} + + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 75cabdbf8da26..391e40924f38a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -74,7 +74,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont /** Get information on received blocks. */ private[streaming] def getReceivedBlockInfo(time: Time) = { - receivedBlockInfo(time) + receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo]) } /** diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4ccddc214c8ad..82f79d88a3009 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -71,7 +71,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] // Memory for the ApplicationMaster. - capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + capability.setMemory(args.amMemory + memoryOverhead) amContainer.setResource(capability) appContext.setQueue(args.amQueue) @@ -115,7 +115,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa val minResMemory = newApp.getMinimumResourceCapability().getMemory() val amMemory = ((args.amMemory / minResMemory) * minResMemory) + ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - - YarnAllocationHandler.MEMORY_OVERHEAD) + memoryOverhead) amMemory } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index b6ecae1e652fe..bfdb6232f5113 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -92,13 +92,15 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp appAttemptId = getApplicationAttemptId() resourceManager = registerWithResourceManager() + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() // Compute number of threads for akka val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() if (minimumMemory > 0) { - val mem = args.executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) if (numCore > 0) { diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 856391e52b2df..80e0162e9f277 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -88,6 +88,10 @@ private[yarn] class YarnAllocationHandler( // Containers to be released in next request to RM private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + // Additional memory overhead - in mb. + private def memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) + private val numExecutorsRunning = new AtomicInteger() // Used to generate a unique id per executor private val executorIdCounter = new AtomicInteger() @@ -99,7 +103,7 @@ private[yarn] class YarnAllocationHandler( def getNumExecutorsFailed: Int = numExecutorsFailed.intValue def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + container.getResource.getMemory >= (executorMemory + memoryOverhead) } def allocateContainers(executorsToRequest: Int) { @@ -229,7 +233,7 @@ private[yarn] class YarnAllocationHandler( val containerId = container.getId assert( container.getResource.getMemory >= - (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + (executorMemory + memoryOverhead)) if (numExecutorsRunningNow > maxExecutors) { logInfo("""Ignoring container %s at host %s, since we already have the required number of @@ -450,7 +454,7 @@ private[yarn] class YarnAllocationHandler( if (numExecutors > 0) { logInfo("Allocating %d executor containers with %d of memory each.".format(numExecutors, - executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + executorMemory + memoryOverhead)) } else { logDebug("Empty allocation req .. release : " + releasedContainerList) } @@ -505,7 +509,7 @@ private[yarn] class YarnAllocationHandler( val rsrcRequest = Records.newRecord(classOf[ResourceRequest]) val memCapability = Records.newRecord(classOf[Resource]) // There probably is some overhead here, let's reserve a bit more memory. - memCapability.setMemory(executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + memCapability.setMemory(executorMemory + memoryOverhead) rsrcRequest.setCapability(memCapability) val pri = Records.newRecord(classOf[Priority]) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index fd3ef9e1fa2de..62f9b3cf5ab88 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -21,8 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.SparkConf import org.apache.spark.scheduler.InputFormatInfo -import org.apache.spark.util.IntParam -import org.apache.spark.util.MemoryParam +import org.apache.spark.util.{Utils, IntParam, MemoryParam} // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! @@ -45,6 +44,18 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { parseArgs(args.toList) + // env variable SPARK_YARN_DIST_ARCHIVES/SPARK_YARN_DIST_FILES set in yarn-client then + // it should default to hdfs:// + files = Option(files).getOrElse(sys.env.get("SPARK_YARN_DIST_FILES").orNull) + archives = Option(archives).getOrElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES").orNull) + + // spark.yarn.dist.archives/spark.yarn.dist.files defaults to use file:// if not specified, + // for both yarn-client and yarn-cluster + files = Option(files).getOrElse(sparkConf.getOption("spark.yarn.dist.files"). + map(p => Utils.resolveURIs(p)).orNull) + archives = Option(archives).getOrElse(sparkConf.getOption("spark.yarn.dist.archives"). + map(p => Utils.resolveURIs(p)).orNull) + private def parseArgs(inputArgs: List[String]): Unit = { val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]() diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 6861b503000ca..8f2267599914c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -65,6 +65,10 @@ trait ClientBase extends Logging { val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(Integer.parseInt("644", 8).toShort) + // Additional memory overhead - in mb. + protected def memoryOverhead: Int = sparkConf.getInt("spark.yarn.driver.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) + // TODO(harvey): This could just go in ClientArguments. def validateArgs() = { Map( @@ -72,10 +76,10 @@ trait ClientBase extends Logging { "Error: You must specify a user jar when running in standalone mode!"), (args.userClass == null) -> "Error: You must specify a user class!", (args.numExecutors <= 0) -> "Error: You must specify at least 1 executor!", - (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be" + - "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD), - (args.executorMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Executor memory size" + - "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString) + (args.amMemory <= memoryOverhead) -> ("Error: AM memory size must be" + + "greater than: " + memoryOverhead), + (args.executorMemory <= memoryOverhead) -> ("Error: Executor memory size" + + "must be greater than: " + memoryOverhead.toString) ).foreach { case(cond, errStr) => if (cond) { logError(errStr) @@ -101,7 +105,7 @@ trait ClientBase extends Logging { logError(errorMessage) throw new IllegalArgumentException(errorMessage) } - val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val amMem = args.amMemory + memoryOverhead if (amMem > maxMem) { val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster." @@ -158,7 +162,7 @@ trait ClientBase extends Logging { val fs = FileSystem.get(conf) val remoteFs = originalPath.getFileSystem(conf) var newPath = originalPath - if (! compareFs(remoteFs, fs)) { + if (!compareFs(remoteFs, fs)) { newPath = new Path(dstDir, originalPath.getName()) logInfo("Uploading " + originalPath + " to " + newPath) FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) @@ -246,6 +250,7 @@ trait ClientBase extends Logging { } } } + logInfo("Prepared Local resources " + localResources) sparkConf.set(ClientBase.CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) UserGroupInformation.getCurrentUser().addCredentials(credentials) diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 039cf4f276119..412dfe38d55eb 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -70,9 +70,7 @@ private[spark] class YarnClientSchedulerBackend( ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), - ("--name", "SPARK_YARN_APP_NAME", "spark.app.name"), - ("--files", "SPARK_YARN_DIST_FILES", "spark.yarn.dist.files"), - ("--archives", "SPARK_YARN_DIST_ARCHIVES", "spark.yarn.dist.archives")) + ("--name", "SPARK_YARN_APP_NAME", "spark.app.name")) .foreach { case (optName, envVar, sysProp) => addArg(optName, envVar, sysProp, argsArrayBuf) } logDebug("ClientArguments called with: " + argsArrayBuf) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 80a8bceb17269..15f3c4f180ea3 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -84,7 +84,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa // Memory for the ApplicationMaster. val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] - memoryResource.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + memoryResource.setMemory(args.amMemory + memoryOverhead) appContext.setResource(memoryResource) // Finally, submit and monitor the application. @@ -117,7 +117,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() // var amMemory = ((args.amMemory / minResMemory) * minResMemory) + // ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - - // YarnAllocationHandler.MEMORY_OVERHEAD) + // memoryOverhead ) args.amMemory } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index a979fe4d62630..29ccec2adcac3 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -90,6 +90,10 @@ private[yarn] class YarnAllocationHandler( // Containers to be released in next request to RM private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + // Additional memory overhead - in mb. + private def memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) + // Number of container requests that have been sent to, but not yet allocated by the // ApplicationMaster. private val numPendingAllocate = new AtomicInteger() @@ -106,7 +110,7 @@ private[yarn] class YarnAllocationHandler( def getNumExecutorsFailed: Int = numExecutorsFailed.intValue def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + container.getResource.getMemory >= (executorMemory + memoryOverhead) } def releaseContainer(container: Container) { @@ -248,7 +252,7 @@ private[yarn] class YarnAllocationHandler( val executorHostname = container.getNodeId.getHost val containerId = container.getId - val executorMemoryOverhead = (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + val executorMemoryOverhead = (executorMemory + memoryOverhead) assert(container.getResource.getMemory >= executorMemoryOverhead) if (numExecutorsRunningNow > maxExecutors) { @@ -477,7 +481,7 @@ private[yarn] class YarnAllocationHandler( numPendingAllocate.addAndGet(numExecutors) logInfo("Will Allocate %d executor containers, each with %d memory".format( numExecutors, - (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD))) + (executorMemory + memoryOverhead))) } else { logDebug("Empty allocation request ...") } @@ -537,7 +541,7 @@ private[yarn] class YarnAllocationHandler( priority: Int ): ArrayBuffer[ContainerRequest] = { - val memoryRequest = executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val memoryRequest = executorMemory + memoryOverhead val resource = Resource.newInstance(memoryRequest, executorCores) val prioritySetting = Records.newRecord(classOf[Priority])