diff --git a/.gitignore b/.gitignore index b54a3058de659..4f177c82ae5e0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ sbt/*.jar .settings .cache -.generated-mima-excludes +.generated-mima* /build/ work/ out/ diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 7df43a555d562..2cf4e381c1c88 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -38,8 +38,10 @@ else JAR_CMD="jar" fi -# First check if we have a dependencies jar. If so, include binary classes with the deps jar -if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then +# A developer option to prepend more recently compiled Spark classes +if [ -n "$SPARK_PREPEND_CLASSES" ]; then + echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ + "classes ahead of assembly." >&2 CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" @@ -51,17 +53,31 @@ if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" +fi - ASSEMBLY_JAR=$(ls "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar 2>/dev/null) +# Use spark-assembly jar from either RELEASE or assembly directory +if [ -f "$FWDIR/RELEASE" ]; then + assembly_folder="$FWDIR"/lib else - # Else use spark-assembly jar from either RELEASE or assembly directory - if [ -f "$FWDIR/RELEASE" ]; then - ASSEMBLY_JAR=$(ls "$FWDIR"/lib/spark-assembly*hadoop*.jar 2>/dev/null) - else - ASSEMBLY_JAR=$(ls "$ASSEMBLY_DIR"/spark-assembly*hadoop*.jar 2>/dev/null) - fi + assembly_folder="$ASSEMBLY_DIR" fi +num_jars=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l) +if [ "$num_jars" -eq "0" ]; then + echo "Failed to find Spark assembly in $assembly_folder" + echo "You need to build Spark before running this program." + exit 1 +fi +if [ "$num_jars" -gt "1" ]; then + jars_list=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*.jar") + echo "Found multiple Spark assembly jars in $assembly_folder:" + echo "$jars_list" + echo "Please remove all but one jar." + exit 1 +fi + +ASSEMBLY_JAR=$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null) + # Verify that versions of java used to build the jars and run Spark are compatible jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1) if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then diff --git a/bin/pyspark b/bin/pyspark index 114cbbc3a8a8e..0b5ed40e2157d 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -45,7 +45,7 @@ fi . $FWDIR/bin/load-spark-env.sh # Figure out which Python executable to use -if [ -z "$PYSPARK_PYTHON" ] ; then +if [[ -z "$PYSPARK_PYTHON" ]]; then PYSPARK_PYTHON="python" fi export PYSPARK_PYTHON @@ -59,7 +59,7 @@ export OLD_PYTHONSTARTUP=$PYTHONSTARTUP export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py # If IPython options are specified, assume user wants to run IPython -if [ -n "$IPYTHON_OPTS" ]; then +if [[ -n "$IPYTHON_OPTS" ]]; then IPYTHON=1 fi @@ -76,6 +76,16 @@ for i in "$@"; do done export PYSPARK_SUBMIT_ARGS +# For pyspark tests +if [[ -n "$SPARK_TESTING" ]]; then + if [[ -n "$PYSPARK_DOC_TEST" ]]; then + exec "$PYSPARK_PYTHON" -m doctest $1 + else + exec "$PYSPARK_PYTHON" $1 + fi + exit +fi + # If a python file is provided, directly run spark-submit. if [[ "$1" =~ \.py$ ]]; then echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 @@ -86,10 +96,6 @@ else if [[ "$IPYTHON" = "1" ]]; then exec ipython $IPYTHON_OPTS else - if [[ -n $SPARK_TESTING ]]; then - exec "$PYSPARK_PYTHON" -m doctest - else - exec "$PYSPARK_PYTHON" - fi + exec "$PYSPARK_PYTHON" fi fi diff --git a/bin/spark-class b/bin/spark-class index e884511010c6c..cfe363a71da31 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -108,23 +108,6 @@ fi export JAVA_OPTS # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! -if [ ! -f "$FWDIR/RELEASE" ]; then - # Exit if the user hasn't compiled Spark - num_jars=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar" | wc -l) - jars_list=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar") - if [ "$num_jars" -eq "0" ]; then - echo "Failed to find Spark assembly in $FWDIR/assembly/target/scala-$SCALA_VERSION/" >&2 - echo "You need to build Spark before running this program." >&2 - exit 1 - fi - if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $FWDIR/assembly/target/scala-$SCALA_VERSION:" >&2 - echo "$jars_list" - echo "Please remove all but one jar." - exit 1 - fi -fi - TOOLS_DIR="$FWDIR"/tools SPARK_TOOLS_JAR="" if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar ]; then diff --git a/core/pom.xml b/core/pom.xml index c3d6b00a443f1..bd6767e03bb9d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -67,6 +67,12 @@ org.apache.commons commons-lang3 + + org.apache.commons + commons-math3 + 3.3 + test + com.google.code.findbugs jsr305 diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 811610c657b62..315ed91f81df3 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -32,10 +32,14 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { private val loading = new HashSet[RDDBlockId]() /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ - def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, + def getOrCompute[T]( + rdd: RDD[T], + split: Partition, + context: TaskContext, storageLevel: StorageLevel): Iterator[T] = { + val key = RDDBlockId(rdd.id, split.index) - logDebug("Looking for partition " + key) + logDebug(s"Looking for partition $key") blockManager.get(key) match { case Some(values) => // Partition is already materialized, so just return its values @@ -45,7 +49,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // Mark the split as loading (unless someone else marks it first) loading.synchronized { if (loading.contains(key)) { - logInfo("Another thread is loading %s, waiting for it to finish...".format(key)) + logInfo(s"Another thread is loading $key, waiting for it to finish...") while (loading.contains(key)) { try { loading.wait() @@ -54,7 +58,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { logWarning(s"Got an exception while waiting for another thread to load $key", e) } } - logInfo("Finished waiting for %s".format(key)) + logInfo(s"Finished waiting for $key") /* See whether someone else has successfully loaded it. The main way this would fail * is for the RDD-level cache eviction policy if someone else has loaded the same RDD * partition but we didn't want to make space for it. However, that case is unlikely @@ -64,7 +68,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Some(values) => return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => - logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key)) + logInfo(s"Whoever was loading $key failed; we'll try it ourselves") loading.add(key) } } else { @@ -73,7 +77,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { } try { // If we got here, we have to load the split - logInfo("Partition %s not found, computing it".format(key)) + logInfo(s"Partition $key not found, computing it") val computedValues = rdd.computeOrReadCheckpoint(split, context) // Persist the result, so long as the task is not running locally @@ -97,8 +101,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Some(values) => values.asInstanceOf[Iterator[T]] case None => - logInfo("Failure to store %s".format(key)) - throw new Exception("Block manager failed to return persisted valued") + logInfo(s"Failure to store $key") + throw new SparkException("Block manager failed to return persisted value") } } else { // In this case the RDD is cached to an array buffer. This will save the results diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index e2d2250982daa..bf3c3a6ceb5ef 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Register a ShuffleDependency for cleanup when it is garbage collected. */ - def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 2c31cc20211ff..c8c194a111aac 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -20,6 +20,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleHandle /** * :: DeveloperApi :: @@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * Represents a dependency on the output of a shuffle stage. * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output - * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null, + * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will * be used. */ @DeveloperApi -class ShuffleDependency[K, V]( +class ShuffleDependency[K, V, C]( @transient rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, - val serializer: Serializer = null) + val serializer: Option[Serializer] = None, + val keyOrdering: Option[Ordering[K]] = None, + val aggregator: Option[Aggregator[K, V, C]] = None) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() + val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( + shuffleId, rdd.partitions.size, this) + rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d721aba709600..35970c2f50892 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -290,6 +290,9 @@ class SparkContext(config: SparkConf) extends Logging { value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { executorEnvs(envKey) = value } + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + executorEnvs("SPARK_PREPEND_CLASSES") = v + } // The Mesos scheduler backend relies on this environment variable to set executor memory. // TODO: Set this only in the Mesos scheduler. executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" @@ -297,7 +300,7 @@ class SparkContext(config: SparkConf) extends Logging { // Set SPARK_USER for user who is running SparkContext. val sparkUser = Option { - Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER")) + Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name")) }.getOrElse { SparkContext.SPARK_UNKNOWN_USER } @@ -431,12 +434,21 @@ class SparkContext(config: SparkConf) extends Logging { // Methods for creating RDDs - /** Distribute a local Scala collection to form an RDD. */ + /** Distribute a local Scala collection to form an RDD. + * + * @note Parallelize acts lazily. If `seq` is a mutable collection and is + * altered after the call to parallelize and before the first action on the + * RDD, the resultant RDD will reflect the modified collection. Pass a copy of + * the argument to avoid this. + */ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } - /** Distribute a local Scala collection to form an RDD. */ + /** Distribute a local Scala collection to form an RDD. + * + * This method is identical to `parallelize`. + */ def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { parallelize(seq, numSlices) } @@ -823,9 +835,11 @@ class SparkContext(config: SparkConf) extends Logging { } /** + * :: DeveloperApi :: * Return information about what RDDs are cached, if they are in mem or on disk, how much space * they take, etc. */ + @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) } @@ -837,8 +851,10 @@ class SparkContext(config: SparkConf) extends Logging { def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap /** + * :: DeveloperApi :: * Return information about blocks stored in all of the slaves */ + @DeveloperApi def getExecutorStorageStatus: Array[StorageStatus] = { env.blockManager.master.getStorageStatus } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 720151a6b0f84..8dfa8cc4b5b3f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.ConnectionManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -56,7 +57,7 @@ class SparkEnv ( val closureSerializer: Serializer, val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, - val shuffleFetcher: ShuffleFetcher, + val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, val blockManager: BlockManager, val connectionManager: ConnectionManager, @@ -80,7 +81,7 @@ class SparkEnv ( pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() - shuffleFetcher.stop() + shuffleManager.stop() broadcastManager.stop() blockManager.stop() blockManager.master.stop() @@ -163,13 +164,20 @@ object SparkEnv extends Logging { def instantiateClass[T](propertyName: String, defaultClassName: String): T = { val name = conf.get(propertyName, defaultClassName) val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader) - // First try with the constructor that takes SparkConf. If we can't find one, - // use a no-arg constructor instead. + // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just + // SparkConf, then one taking no arguments try { - cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) + .newInstance(conf, new java.lang.Boolean(isDriver)) + .asInstanceOf[T] } catch { case _: NoSuchMethodException => - cls.getConstructor().newInstance().asInstanceOf[T] + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } } } @@ -219,9 +227,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val shuffleFetcher = instantiateClass[ShuffleFetcher]( - "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") - val httpFileServer = new HttpFileServer(securityManager) httpFileServer.initialize() conf.set("spark.fileserver.uri", httpFileServer.serverUri) @@ -242,6 +247,9 @@ object SparkEnv extends Logging { "." } + val shuffleManager = instantiateClass[ShuffleManager]( + "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -255,7 +263,7 @@ object SparkEnv extends Logging { closureSerializer, cacheManager, mapOutputTracker, - shuffleFetcher, + shuffleManager, broadcastManager, blockManager, connectionManager, diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 7dcfbf741c4f1..14fa9d8135afe 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -228,6 +228,50 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U](zeroValue: U, partitioner: Partitioner, seqFunc: JFunction2[U, V, U], + combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = { + implicit val ctag: ClassTag[U] = fakeClassTag + fromRDD(rdd.aggregateByKey(zeroValue, partitioner)(seqFunc, combFunc)) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U](zeroValue: U, numPartitions: Int, seqFunc: JFunction2[U, V, U], + combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = { + implicit val ctag: ClassTag[U] = fakeClassTag + fromRDD(rdd.aggregateByKey(zeroValue, numPartitions)(seqFunc, combFunc)) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's. + * The former operation is used for merging values within a partition, and the latter is used for + * merging values between partitions. To avoid memory allocation, both of these functions are + * allowed to modify and return their first argument instead of creating a new U. + */ + def aggregateByKey[U](zeroValue: U, seqFunc: JFunction2[U, V, U], combFunc: JFunction2[U, U, U]): + JavaPairRDD[K, U] = { + implicit val ctag: ClassTag[U] = fakeClassTag + fromRDD(rdd.aggregateByKey(zeroValue)(seqFunc, combFunc)) + } + /** * Merge the values for each key using an associative function and a neutral "zero value" which * may be added to the result an arbitrary number of times, and must not change the result diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index aeb159adc31d9..c371dc3a51c73 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -81,7 +81,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends case "kill" => val driverId = driverArgs.driverId - val killFuture = masterActor ! RequestKillDriver(driverId) + masterActor ! RequestKillDriver(driverId) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d27e0e1f15c65..d09136de49807 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -23,9 +23,10 @@ import akka.actor.ActorRef import com.google.common.base.Charsets import com.google.common.io.Files -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.util.logging.FileAppender /** * Manages the execution of one executor process. @@ -42,12 +43,15 @@ private[spark] class ExecutorRunner( val sparkHome: File, val workDir: File, val workerUrl: String, + val conf: SparkConf, var state: ExecutorState.Value) extends Logging { val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null + var stdoutAppender: FileAppender = null + var stderrAppender: FileAppender = null // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might // make sense to remove this in the future. @@ -76,6 +80,13 @@ private[spark] class ExecutorRunner( if (process != null) { logInfo("Killing process!") process.destroy() + process.waitFor() + if (stdoutAppender != null) { + stdoutAppender.stop() + } + if (stderrAppender != null) { + stderrAppender.stop() + } val exitCode = process.waitFor() worker ! ExecutorStateChanged(appId, execId, state, message, Some(exitCode)) } @@ -137,11 +148,11 @@ private[spark] class ExecutorRunner( // Redirect its stdout and stderr to files val stdout = new File(executorDir, "stdout") - CommandUtils.redirectStream(process.getInputStream, stdout) + stdoutAppender = FileAppender(process.getInputStream, stdout, conf) val stderr = new File(executorDir, "stderr") Files.write(header, stderr, Charsets.UTF_8) - CommandUtils.redirectStream(process.getErrorStream, stderr) + stderrAppender = FileAppender(process.getErrorStream, stderr, conf) // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run // long-lived processes only. However, in the future, we might restart the executor a few diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 100de26170a50..a0ecaf709f8e2 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -235,7 +235,7 @@ private[spark] class Worker( val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, self, workerId, host, appDesc.sparkHome.map(userSparkHome => new File(userSparkHome)).getOrElse(sparkHome), - workDir, akkaUrl, ExecutorState.RUNNING) + workDir, akkaUrl, conf, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 8381f59672ea3..6a5ffb1b71bfb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -24,8 +24,10 @@ import scala.xml.Node import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils +import org.apache.spark.Logging +import org.apache.spark.util.logging.{FileAppender, RollingFileAppender} -private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") { +private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir @@ -39,21 +41,18 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") { val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = (appId, executorId, driverId) match { + val logDir = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => - s"${workDir.getPath}/$appId/$executorId/$logType" + s"${workDir.getPath}/$appId/$executorId/" case (None, None, Some(d)) => - s"${workDir.getPath}/$driverId/$logType" + s"${workDir.getPath}/$driverId/" case _ => throw new Exception("Request must specify either application or driver identifiers") } - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - - val pre = s"==== Bytes $startByte-$endByte of $logLength of $path ====\n" - pre + Utils.offsetBytes(path, startByte, endByte) + val (logText, startByte, endByte, logLength) = getLog(logDir, logType, offset, byteLength) + val pre = s"==== Bytes $startByte-$endByte of $logLength of $logDir$logType ====\n" + pre + logText } def render(request: HttpServletRequest): Seq[Node] = { @@ -65,19 +64,16 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") { val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val (path, params) = (appId, executorId, driverId) match { + val (logDir, params) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => - (s"${workDir.getPath}/$a/$e/$logType", s"appId=$a&executorId=$e") + (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e") case (None, None, Some(d)) => - (s"${workDir.getPath}/$d/$logType", s"driverId=$d") + (s"${workDir.getPath}/$d/", s"driverId=$d") case _ => throw new Exception("Request must specify either application or driver identifiers") } - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - val logText = {Utils.offsetBytes(path, startByte, endByte)} + val (logText, startByte, endByte, logLength) = getLog(logDir, logType, offset, byteLength) val linkToMaster =

Back to Master

val range = Bytes {startByte.toString} - {endByte.toString} of {logLength} @@ -127,23 +123,37 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") { UIUtils.basicSparkPage(content, logType + " log page for " + appId) } - /** Determine the byte range for a log or log page. */ - private def getByteRange(path: String, offset: Option[Long], byteLength: Int): (Long, Long) = { - val defaultBytes = 100 * 1024 - val maxBytes = 1024 * 1024 - val file = new File(path) - val logLength = file.length() - val getOffset = offset.getOrElse(logLength - defaultBytes) - val startByte = - if (getOffset < 0) { - 0L - } else if (getOffset > logLength) { - logLength - } else { - getOffset + /** Get the part of the log files given the offset and desired length of bytes */ + private def getLog( + logDirectory: String, + logType: String, + offsetOption: Option[Long], + byteLength: Int + ): (String, Long, Long, Long) = { + try { + val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) + logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") + + val totalLength = files.map { _.length }.sum + val offset = offsetOption.getOrElse(totalLength - byteLength) + val startIndex = { + if (offset < 0) { + 0L + } else if (offset > totalLength) { + totalLength + } else { + offset + } } - val logPageLength = math.min(byteLength, maxBytes) - val endByte = math.min(startByte + logPageLength, logLength) - (startByte, endByte) + val endIndex = math.min(startIndex + totalLength, totalLength) + logDebug(s"Getting log from $startIndex to $endIndex") + val logText = Utils.offsetBytes(files, startIndex, endIndex) + logDebug(s"Got log of length ${logText.length} bytes") + (logText, startIndex, endIndex, totalLength) + } catch { + case e: Exception => + logError(s"Error getting $logType logs from directory $logDirectory", e) + ("Error getting logs due to exception: " + e.getMessage, 0, 0, 0) + } } } diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 3ffaaab23d0f5..3b6298a26d7c5 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -210,7 +210,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, var nextMessageToBeUsed = 0 def addMessage(message: Message) { - messages.synchronized{ + messages.synchronized { /* messages += message */ messages.enqueue(message) logDebug("Added [" + message + "] to outbox for sending to " + @@ -223,7 +223,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, while (!messages.isEmpty) { /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ /* val message = messages(nextMessageToBeUsed) */ - val message = messages.dequeue + val message = messages.dequeue() val chunk = message.getChunkForSending(defaultChunkSize) if (chunk.isDefined) { messages.enqueue(message) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 5dd5fd0047c0d..cf1c985c2fff9 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -250,14 +250,14 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, try { while(!selectorThread.isInterrupted) { while (! registerRequests.isEmpty) { - val conn: SendingConnection = registerRequests.dequeue + val conn: SendingConnection = registerRequests.dequeue() addListeners(conn) conn.connect() addConnection(conn) } while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue + val (key, ops) = keyInterestChangeRequests.dequeue() try { if (key.isValid) { @@ -532,9 +532,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } return } - var securityMsgResp = SecurityMessage.fromResponse(replyToken, + val securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId.toString()) - var message = securityMsgResp.toBufferMessage + val message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) } catch { @@ -568,9 +568,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, logDebug("Server sasl not completed: " + connection.connectionId) } if (replyToken != null) { - var securityMsgResp = SecurityMessage.fromResponse(replyToken, + val securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId) - var message = securityMsgResp.toBufferMessage + val message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security Message") sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) } @@ -618,7 +618,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, return true } } - return false + false } private def handleMessage( @@ -694,9 +694,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() - var securityMsg = SecurityMessage.fromResponse(firstResponse, + val securityMsg = SecurityMessage.fromResponse(firstResponse, conn.connectionId.toString()) - var message = securityMsg.toBufferMessage + val message = securityMsg.toBufferMessage if (message == null) throw new Exception("Error creating security message") connectionsAwaitingSasl += ((conn.connectionId, conn)) sendSecurityMessage(connManagerId, message) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 9ff76892aed32..5951865e56c9d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleHandle private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep( } } -private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep +private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) extends Partition with Serializable { @@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: private type CoGroupValue = (Any, Int) // Int is dependency number private type CoGroupCombiner = Seq[CoGroup] - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): CoGroupedRDD[K] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[Any, Any](rdd, part, serializer) + new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer) } } } @@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) + case s: ShuffleDependency[_, _, _] => + new ShuffleCoGroupSplitDep(s.shuffleHandle) case _ => new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } @@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] rddIterators += ((it, depNum)) - case ShuffleCoGroupSplitDep(shuffleId) => + case ShuffleCoGroupSplitDep(handle) => // Read map outputs of shuffle - val fetcher = SparkEnv.get.shuffleFetcher - val ser = Serializer.getSerializer(serializer) - val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) + val it = SparkEnv.get.shuffleManager + .getReader(handle, split.index, split.index + 1, context) + .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 6547755764dcf..2aa111d600e9b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -139,10 +139,13 @@ class HadoopRDD[K, V]( // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - val newJobConf = new JobConf(broadcastedConf.value.value) - initLocalJobConfFuncOpt.map(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - newJobConf + // synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456) + broadcastedConf.synchronized { + val newJobConf = new JobConf(broadcastedConf.value.value) + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 8909980957058..b6ad9b6c3e168 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -118,6 +118,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) } + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U, + combOp: (U, U) => U): RDD[(K, U)] = { + // Serialize the zero value to a byte array so that we can get a new clone of it on each key + val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) + val zeroArray = new Array[Byte](zeroBuffer.limit) + zeroBuffer.get(zeroArray) + + lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() + def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) + + combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U, + combOp: (U, U) => U): RDD[(K, U)] = { + aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U, + combOp: (U, U) => U): RDD[(K, U)] = { + aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp) + } + /** * Merge the values for each key using an associative function and a neutral "zero value" which * may be added to the result an arbitrary number of times, and must not change the result diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 54bdc3e7cbc7a..446f369c9ea16 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -42,7 +42,7 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler} +import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -378,46 +378,56 @@ abstract class RDD[T: ClassTag]( }.toArray } - def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = - { - var fraction = 0.0 - var total = 0 - val multiplier = 3.0 - val initialCount = this.count() - var maxSelected = 0 + /** + * Return a fixed-size sampled subset of this RDD in an array + * + * @param withReplacement whether sampling is done with replacement + * @param num size of the returned sample + * @param seed seed for the random number generator + * @return sample of specified size in an array + */ + def takeSample(withReplacement: Boolean, + num: Int, + seed: Long = Utils.random.nextLong): Array[T] = { + val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") + } else if (num == 0) { + return new Array[T](0) } + val initialCount = this.count() if (initialCount == 0) { return new Array[T](0) } - if (initialCount > Integer.MAX_VALUE - 1) { - maxSelected = Integer.MAX_VALUE - 1 - } else { - maxSelected = initialCount.toInt + val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt + if (num > maxSampleSize) { + throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") } - if (num > initialCount && !withReplacement) { - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - } else { - fraction = multiplier * (num + 1) / initialCount - total = num + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + return Utils.randomizeInPlace(this.collect(), rand) } - val rand = new Random(seed) + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() // If the first sample didn't turn out large enough, keep trying to take samples; // this shouldn't happen often because we use a big multiplier for the initial size - while (samples.length < total) { + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 } - Utils.randomizeInPlace(samples, rand).take(total) + Utils.randomizeInPlace(samples, rand).take(num) } /** @@ -1180,7 +1190,7 @@ abstract class RDD[T: ClassTag]( /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ @transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo - private[spark] def getCreationSite: String = creationSiteInfo.toString + private[spark] def getCreationSite: String = Option(creationSiteInfo).getOrElse("").toString private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 802b0bdfb2d59..bb108ef163c56 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( part: Partitioner) extends RDD[P](prev.context, Nil) { - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[P] = { - val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - val ser = Serializer.getSerializer(serializer) - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser) + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[P]] } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 9a09c05bbc959..ed24ea22a661c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) + case s: ShuffleDependency[_, _, _] => + new ShuffleCoGroupSplitDep(s.shuffleHandle) case _ => new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } @@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] - val ser = Serializer.getSerializer(serializer) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) @@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( case NarrowCoGroupSplitDep(rdd, _, itsSplit) => rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - case ShuffleCoGroupSplitDep(shuffleId) => - val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context, ser) + case ShuffleCoGroupSplitDep(handle) => + val iter = SparkEnv.get.shuffleManager + .getReader(handle, partition.index, partition.index + 1, context) + .read() iter.foreach(op) } // the first dep is rdd1; add all values to the map diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e09a4221e8315..3c85b5a2ae776 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -190,7 +190,7 @@ class DAGScheduler( * The jobId value passed in will be used if the stage doesn't already exist with * a lower jobId (jobId always increases across jobs.) */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => @@ -210,7 +210,7 @@ class DAGScheduler( private def newStage( rdd: RDD[_], numTasks: Int, - shuffleDep: Option[ShuffleDependency[_,_]], + shuffleDep: Option[ShuffleDependency[_, _, _]], jobId: Int, callSite: Option[String] = None) : Stage = @@ -233,7 +233,7 @@ class DAGScheduler( private def newOrUsedStage( rdd: RDD[_], numTasks: Int, - shuffleDep: ShuffleDependency[_,_], + shuffleDep: ShuffleDependency[_, _, _], jobId: Int, callSite: Option[String] = None) : Stage = @@ -269,7 +269,7 @@ class DAGScheduler( // we can't do it in its constructor because # of partitions is unknown for (dep <- r.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => parents += getShuffleMapStage(shufDep, jobId) case _ => visit(dep.rdd) @@ -290,7 +290,7 @@ class DAGScheduler( if (getCacheLocs(rdd).contains(Nil)) { for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { missing += mapStage @@ -1088,7 +1088,7 @@ class DAGScheduler( visitedRdds += rdd for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { visitedStages += mapStage diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index ed0f56f1abdf5..0098b5a59d1a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ +import org.apache.spark.shuffle.ShuffleWriter private[spark] object ShuffleMapTask { @@ -37,7 +38,7 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. private val serializedInfoCache = new HashMap[Int, Array[Byte]] - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { + def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { synchronized { val old = serializedInfoCache.get(stageId).orNull if (old != null) { @@ -56,12 +57,12 @@ private[spark] object ShuffleMapTask { } } - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] (rdd, dep) } @@ -99,7 +100,7 @@ private[spark] object ShuffleMapTask { private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], - var dep: ShuffleDependency[_,_], + var dep: ShuffleDependency[_, _, _], _partitionId: Int, @transient private var locs: Seq[TaskLocation]) extends Task[MapStatus](stageId, _partitionId) @@ -141,66 +142,22 @@ private[spark] class ShuffleMapTask( } override def runTask(context: TaskContext): MapStatus = { - val numOutputSplits = dep.partitioner.numPartitions metrics = Some(context.taskMetrics) - - val blockManager = SparkEnv.get.blockManager - val shuffleBlockManager = blockManager.shuffleBlockManager - var shuffle: ShuffleWriterGroup = null - var success = false - + var writer: ShuffleWriter[Any, Any] = null try { - // Obtain all the block writers for shuffle blocks. - val ser = Serializer.getSerializer(dep.serializer) - shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) - - // Write the map output to its associated buckets. + val manager = SparkEnv.get.shuffleManager + writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) for (elem <- rdd.iterator(split, context)) { - val pair = elem.asInstanceOf[Product2[Any, Any]] - val bucketId = dep.partitioner.getPartition(pair._1) - shuffle.writers(bucketId).write(pair) - } - - // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - var totalTime = 0L - val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() - val size = writer.fileSegment().length - totalBytes += size - totalTime += writer.timeWriting() - MapOutputTracker.compressSize(size) + writer.write(elem.asInstanceOf[Product2[Any, Any]]) } - - // Update shuffle metrics. - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - - success = true - new MapStatus(blockManager.blockManagerId, compressedSizes) - } catch { case e: Exception => - // If there is an exception from running the task, revert the partial writes - // and throw the exception upstream to Spark. - if (shuffle != null && shuffle.writers != null) { - for (writer <- shuffle.writers) { - writer.revertPartialWrites() - writer.close() + return writer.stop(success = true).get + } catch { + case e: Exception => + if (writer != null) { + writer.stop(success = false) } - } - throw e + throw e } finally { - // Release the writers back to the shuffle block manager. - if (shuffle != null && shuffle.writers != null) { - try { - shuffle.releaseWriters(success) - } catch { - case e: Exception => logError("Failed to release shuffle writers", e) - } - } - // Execute the callbacks on task completion. context.executeOnCompleteCallbacks() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5c1fc30e4a557..3bf9713f728c6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -40,7 +40,7 @@ private[spark] class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, - val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage + val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, callSite: Option[String]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 99d305b36a959..df59f444b7a0e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -71,7 +71,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val loader = Thread.currentThread.getContextClassLoader taskSetManager.abort("ClassNotFound with classloader: " + loader) case ex: Exception => - taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex)) + logError("Exception while getting task result", ex) + taskSetManager.abort("Exception while getting task result: %s".format(ex)) } } }) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index ee26970a3d874..f2f5cea469c61 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -52,6 +52,10 @@ object Serializer { def getSerializer(serializer: Serializer): Serializer = { if (serializer == null) SparkEnv.get.serializer else serializer } + + def getSerializer(serializer: Option[Serializer]): Serializer = { + serializer.getOrElse(SparkEnv.get.serializer) + } } diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala similarity index 66% rename from core/src/main/scala/org/apache/spark/ShuffleFetcher.scala rename to core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala index a4f69b6b22b2c..b36c457d6d514 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala @@ -15,22 +15,16 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle +import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner} import org.apache.spark.serializer.Serializer -private[spark] abstract class ShuffleFetcher { - - /** - * Fetch the shuffle outputs for a given ShuffleDependency. - * @return An iterator over the elements of the fetched shuffle outputs. - */ - def fetch[T]( - shuffleId: Int, - reduceId: Int, - context: TaskContext, - serializer: Serializer = SparkEnv.get.serializer): Iterator[T] - - /** Stop the fetcher */ - def stop() {} -} +/** + * A basic ShuffleHandle implementation that just captures registerShuffle's parameters. + */ +private[spark] class BaseShuffleHandle[K, V, C]( + shuffleId: Int, + val numMaps: Int, + val dependency: ShuffleDependency[K, V, C]) + extends ShuffleHandle(shuffleId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala new file mode 100644 index 0000000000000..13c7115f88afa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks. + * + * @param shuffleId ID of the shuffle + */ +private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala new file mode 100644 index 0000000000000..9c859b8b4a118 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.{TaskContext, ShuffleDependency} + +/** + * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the + * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles + * with it, and executors (or tasks running locally in the driver) can ask to read and write data. + * + * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and + * boolean isDriver as parameters. + */ +private[spark] trait ShuffleManager { + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle + + /** Get a writer for a given partition. Called on executors by map tasks. */ + def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] + + /** Remove a shuffle's metadata from the ShuffleManager. */ + def unregisterShuffle(shuffleId: Int) + + /** Shut down this ShuffleManager. */ + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala new file mode 100644 index 0000000000000..b30e366d06006 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * Obtained inside a reduce task to read combined records from the mappers. + */ +private[spark] trait ShuffleReader[K, C] { + /** Read the combined key-values for this reduce task */ + def read(): Iterator[Product2[K, C]] + + /** Close this reader */ + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala new file mode 100644 index 0000000000000..ead3ebd652ca5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.scheduler.MapStatus + +/** + * Obtained inside a map task to write out records to the shuffle system. + */ +private[spark] trait ShuffleWriter[K, V] { + /** Write a record to this task's output */ + def write(record: Product2[K, V]): Unit + + /** Close this writer, passing along whether the map completed */ + def stop(success: Boolean): Option[MapStatus] +} diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala rename to core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index a67392441ed29..b05b6ea345df3 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -24,17 +24,16 @@ import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator +import org.apache.spark._ -private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - - override def fetch[T]( +private[hash] object BlockStoreShuffleFetcher extends Logging { + def fetch[T]( shuffleId: Int, reduceId: Int, context: TaskContext, serializer: Serializer) : Iterator[T] = { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala new file mode 100644 index 0000000000000..5b0940ecce29d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark._ +import org.apache.spark.shuffle._ + +/** + * A ShuffleManager using hashing, that creates one output file per reduce partition on each + * mapper (possibly reusing these across waves of tasks). + */ +class HashShuffleManager(conf: SparkConf) extends ShuffleManager { + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + new HashShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) + : ShuffleWriter[K, V] = { + new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Unit = {} + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = {} +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala new file mode 100644 index 0000000000000..f6a790309a587 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.TaskContext + +class HashShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext) + extends ShuffleReader[K, C] +{ + require(endPartition == startPartition + 1, + "Hash shuffle currently only supports fetching one partition") + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, + Serializer.getSerializer(handle.dependency.serializer)) + } + + /** Close this reader */ + override def stop(): Unit = ??? +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala new file mode 100644 index 0000000000000..4c6749098c110 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} +import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.storage.{BlockObjectWriter} +import org.apache.spark.serializer.Serializer +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.scheduler.MapStatus + +class HashShuffleWriter[K, V]( + handle: BaseShuffleHandle[K, V, _], + mapId: Int, + context: TaskContext) + extends ShuffleWriter[K, V] with Logging { + + private val dep = handle.dependency + private val numOutputSplits = dep.partitioner.numPartitions + private val metrics = context.taskMetrics + private var stopping = false + + private val blockManager = SparkEnv.get.blockManager + private val shuffleBlockManager = blockManager.shuffleBlockManager + private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) + private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser) + + /** Write a record to this task's output */ + override def write(record: Product2[K, V]): Unit = { + val pair = record.asInstanceOf[Product2[Any, Any]] + val bucketId = dep.partitioner.getPartition(pair._1) + shuffle.writers(bucketId).write(pair) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + try { + return Some(commitWritesAndBuildStatus()) + } catch { + case e: Exception => + revertWrites() + throw e + } + } else { + revertWrites() + return None + } + } finally { + // Release the writers back to the shuffle block manager. + if (shuffle != null && shuffle.writers != null) { + try { + shuffle.releaseWriters(success) + } catch { + case e: Exception => logError("Failed to release shuffle writers", e) + } + } + } + } + + private def commitWritesAndBuildStatus(): MapStatus = { + // Commit the writes. Get the size of each bucket block (total block size). + var totalBytes = 0L + var totalTime = 0L + val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => + writer.commit() + writer.close() + val size = writer.fileSegment().length + totalBytes += size + totalTime += writer.timeWriting() + MapOutputTracker.compressSize(size) + } + + // Update shuffle metrics. + val shuffleMetrics = new ShuffleWriteMetrics + shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime + metrics.shuffleWriteMetrics = Some(shuffleMetrics) + + new MapStatus(blockManager.blockManagerId, compressedSizes) + } + + private def revertWrites(): Unit = { + if (shuffle != null && shuffle.writers != null) { + for (writer <- shuffle.writers) { + writer.revertPartialWrites() + writer.close() + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala index c8f397609a0b4..22fdf73e9d1f4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala @@ -29,9 +29,9 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea setInitThread() private def setInitThread() { - // Set current thread as init thread - waitForReady will not block this thread - // (in case there is non trivial initialization which ends up calling waitForReady as part of - // initialization itself) + /* Set current thread as init thread - waitForReady will not block this thread + * (in case there is non trivial initialization which ends up calling waitForReady + * as part of initialization itself) */ BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread()) } @@ -42,7 +42,9 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea def waitForReady(): Boolean = { if (pending && initThread != Thread.currentThread()) { synchronized { - while (pending) this.wait() + while (pending) { + this.wait() + } } } !failed @@ -50,8 +52,8 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea /** Mark this BlockInfo as ready (i.e. block is finished writing) */ def markReady(sizeInBytes: Long) { - require (sizeInBytes >= 0, "sizeInBytes was negative: " + sizeInBytes) - assert (pending) + require(sizeInBytes >= 0, s"sizeInBytes was negative: $sizeInBytes") + assert(pending) size = sizeInBytes BlockInfo.blockInfoInitThreads.remove(this) synchronized { @@ -61,7 +63,7 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea /** Mark this BlockInfo as ready but failed */ def markFailure() { - assert (pending) + assert(pending) size = BlockInfo.BLOCK_FAILED BlockInfo.blockInfoInitThreads.remove(this) synchronized { @@ -71,9 +73,9 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea } private object BlockInfo { - // initThread is logically a BlockInfo field, but we store it here because - // it's only needed while this block is in the 'pending' state and we want - // to minimize BlockInfo's memory footprint. + /* initThread is logically a BlockInfo field, but we store it here because + * it's only needed while this block is in the 'pending' state and we want + * to minimize BlockInfo's memory footprint. */ private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread] private val BLOCK_PENDING: Long = -1L diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 9cd79d262ea53..f52bc7075104b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -28,46 +28,48 @@ import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, MapOutputTracker, SecurityManager, SparkConf, SparkEnv, SparkException} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer import org.apache.spark.util._ -private[spark] sealed trait Values - -private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends Values -private[spark] case class IteratorValues(iterator: Iterator[Any]) extends Values -private[spark] case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends Values +private[spark] sealed trait BlockValues +private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues +private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues +private[spark] case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends BlockValues private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, - val defaultSerializer: Serializer, + defaultSerializer: Serializer, maxMemory: Long, - val _conf: SparkConf, + val conf: SparkConf, securityManager: SecurityManager, mapOutputTracker: MapOutputTracker) extends Logging { - def conf = _conf val shuffleBlockManager = new ShuffleBlockManager(this) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, - conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + val connectionManager = new ConnectionManager(0, conf, securityManager) + + implicit val futureExecContext = connectionManager.futureExecContext private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + // Actual storage of where blocks are kept + private var tachyonInitialized = false private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) - var tachyonInitialized = false private[storage] lazy val tachyonStore: TachyonStore = { val storeDir = conf.get("spark.tachyonStore.baseDir", "/tmp_spark_tachyon") val appFolderName = conf.get("spark.tachyonStore.folderName") - val tachyonStorePath = s"${storeDir}/${appFolderName}/${this.executorId}" + val tachyonStorePath = s"$storeDir/$appFolderName/${this.executorId}" val tachyonMaster = conf.get("spark.tachyonStore.url", "tachyon://localhost:19998") - val tachyonBlockManager = new TachyonBlockManager( - shuffleBlockManager, tachyonStorePath, tachyonMaster) + val tachyonBlockManager = + new TachyonBlockManager(shuffleBlockManager, tachyonStorePath, tachyonMaster) tachyonInitialized = true new TachyonStore(this, tachyonBlockManager) } @@ -79,43 +81,39 @@ private[spark] class BlockManager( if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } - val connectionManager = new ConnectionManager(0, conf, securityManager) - implicit val futureExecContext = connectionManager.futureExecContext - val blockManagerId = BlockManagerId( executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) - val maxBytesInFlight = - conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024 + val maxBytesInFlight = conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024 // Whether to compress broadcast variables that are stored - val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) + private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) // Whether to compress shuffle output that are stored - val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) + private val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) // Whether to compress RDD partitions that are stored serialized - val compressRdds = conf.getBoolean("spark.rdd.compress", false) + private val compressRdds = conf.getBoolean("spark.rdd.compress", false) // Whether to compress shuffle output temporarily spilled to disk - val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) + private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - - val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)), + private val slaveActor = actorSystem.actorOf( + Props(new BlockManagerSlaveActor(this, mapOutputTracker)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) - // Pending re-registration action being executed asynchronously or null if none - // is pending. Accesses should synchronize on asyncReregisterLock. - var asyncReregisterTask: Future[Unit] = null - val asyncReregisterLock = new Object + // Pending re-registration action being executed asynchronously or null if none is pending. + // Accesses should synchronize on asyncReregisterLock. + private var asyncReregisterTask: Future[Unit] = null + private val asyncReregisterLock = new Object - private def heartBeat() { + private def heartBeat(): Unit = { if (!master.sendHeartBeat(blockManagerId)) { reregister() } } - var heartBeatTask: Cancellable = null + private val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) + private var heartBeatTask: Cancellable = null private val metadataCleaner = new MetadataCleaner( MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) @@ -124,11 +122,11 @@ private[spark] class BlockManager( initialize() - // The compression codec to use. Note that the "lazy" val is necessary because we want to delay - // the initialization of the compression codec until it is first used. The reason is that a Spark - // program could be using a user-defined codec in a third party jar, which is loaded in - // Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been - // loaded yet. + /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay + * the initialization of the compression codec until it is first used. The reason is that a Spark + * program could be using a user-defined codec in a third party jar, which is loaded in + * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been + * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) /** @@ -150,7 +148,7 @@ private[spark] class BlockManager( * Initialize the BlockManager. Register to the BlockManagerMaster, and start the * BlockManagerWorker actor. */ - private def initialize() { + private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { @@ -170,12 +168,12 @@ private[spark] class BlockManager( * heart beat attempt or new block registration and another try to re-register all blocks * will be made then. */ - private def reportAllBlocks() { - logInfo("Reporting " + blockInfo.size + " blocks to the master.") + private def reportAllBlocks(): Unit = { + logInfo(s"Reporting ${blockInfo.size} blocks to the master.") for ((blockId, info) <- blockInfo) { val status = getCurrentBlockStatus(blockId, info) if (!tryToReportBlockStatus(blockId, info, status)) { - logError("Failed to report " + blockId + " to master; giving up.") + logError(s"Failed to report $blockId to master; giving up.") return } } @@ -187,7 +185,7 @@ private[spark] class BlockManager( * * Note that this method must be called without any BlockInfo locks held. */ - def reregister() { + private def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") master.registerBlockManager(blockManagerId, maxMemory, slaveActor) @@ -197,7 +195,7 @@ private[spark] class BlockManager( /** * Re-register with the master sometime soon. */ - def asyncReregister() { + private def asyncReregister(): Unit = { asyncReregisterLock.synchronized { if (asyncReregisterTask == null) { asyncReregisterTask = Future[Unit] { @@ -213,7 +211,7 @@ private[spark] class BlockManager( /** * For testing. Wait for any pending asynchronous re-registration; otherwise, do nothing. */ - def waitForAsyncReregister() { + def waitForAsyncReregister(): Unit = { val task = asyncReregisterTask if (task != null) { Await.ready(task, Duration.Inf) @@ -251,18 +249,18 @@ private[spark] class BlockManager( * it is still valid). This ensures that update in master will compensate for the increase in * memory on slave. */ - def reportBlockStatus( + private def reportBlockStatus( blockId: BlockId, info: BlockInfo, status: BlockStatus, - droppedMemorySize: Long = 0L) { + droppedMemorySize: Long = 0L): Unit = { val needReregister = !tryToReportBlockStatus(blockId, info, status, droppedMemorySize) if (needReregister) { - logInfo("Got told to re-register updating block " + blockId) + logInfo(s"Got told to re-register updating block $blockId") // Re-registering will report our new block for free. asyncReregister() } - logDebug("Told master about block " + blockId) + logDebug(s"Told master about block $blockId") } /** @@ -293,10 +291,10 @@ private[spark] class BlockManager( * and the updated in-memory and on-disk sizes. */ private def getCurrentBlockStatus(blockId: BlockId, info: BlockInfo): BlockStatus = { - val (newLevel, inMemSize, onDiskSize, inTachyonSize) = info.synchronized { + info.synchronized { info.level match { case null => - (StorageLevel.NONE, 0L, 0L, 0L) + BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) case level => val inMem = level.useMemory && memoryStore.contains(blockId) val inTachyon = level.useOffHeap && tachyonStore.contains(blockId) @@ -307,19 +305,18 @@ private[spark] class BlockManager( val memSize = if (inMem) memoryStore.getSize(blockId) else 0L val tachyonSize = if (inTachyon) tachyonStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L - (storageLevel, memSize, diskSize, tachyonSize) + BlockStatus(storageLevel, memSize, diskSize, tachyonSize) } } - BlockStatus(newLevel, inMemSize, onDiskSize, inTachyonSize) } /** * Get locations of an array of blocks. */ - def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = { + private def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = { val startTimeMs = System.currentTimeMillis val locations = master.getLocations(blockIds).toArray - logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got multiple block location in %s".format(Utils.getUsedTimeMs(startTimeMs))) locations } @@ -329,15 +326,16 @@ private[spark] class BlockManager( * never deletes (recent) items. */ def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - diskStore.getValues(blockId, serializer).orElse( - sys.error("Block " + blockId + " not found on disk, though it should be")) + diskStore.getValues(blockId, serializer).orElse { + throw new BlockException(blockId, s"Block $blockId not found on disk, though it should be") + } } /** * Get block from local block manager. */ def getLocal(blockId: BlockId): Option[Iterator[Any]] = { - logDebug("Getting local block " + blockId) + logDebug(s"Getting local block $blockId") doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] } @@ -345,7 +343,7 @@ private[spark] class BlockManager( * Get block from the local block manager as serialized bytes. */ def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = { - logDebug("Getting local block " + blockId + " as bytes") + logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { @@ -353,7 +351,8 @@ private[spark] class BlockManager( case Some(bytes) => Some(bytes) case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") + throw new BlockException( + blockId, s"Block $blockId not found on disk, though it should be") } } else { doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] @@ -368,16 +367,16 @@ private[spark] class BlockManager( // If another thread is writing the block, wait for it to become ready. if (!info.waitForReady()) { // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure.") + logWarning(s"Block $blockId was marked as failure.") return None } val level = info.level - logDebug("Level for block " + blockId + " is " + level) + logDebug(s"Level for block $blockId is $level") // Look for the block in memory if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") + logDebug(s"Getting block $blockId from memory") val result = if (asValues) { memoryStore.getValues(blockId) } else { @@ -387,51 +386,51 @@ private[spark] class BlockManager( case Some(values) => return Some(values) case None => - logDebug("Block " + blockId + " not found in memory") + logDebug(s"Block $blockId not found in memory") } } // Look for the block in Tachyon if (level.useOffHeap) { - logDebug("Getting block " + blockId + " from tachyon") + logDebug(s"Getting block $blockId from tachyon") if (tachyonStore.contains(blockId)) { tachyonStore.getBytes(blockId) match { - case Some(bytes) => { + case Some(bytes) => if (!asValues) { return Some(bytes) } else { return Some(dataDeserialize(blockId, bytes)) } - } case None => - logDebug("Block " + blockId + " not found in tachyon") + logDebug(s"Block $blockId not found in tachyon") } } } - // Look for block on disk, potentially storing it back into memory if required: + // Look for block on disk, potentially storing it back in memory if required if (level.useDisk) { - logDebug("Getting block " + blockId + " from disk") + logDebug(s"Getting block $blockId from disk") val bytes: ByteBuffer = diskStore.getBytes(blockId) match { - case Some(bytes) => bytes + case Some(b) => b case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") + throw new BlockException( + blockId, s"Block $blockId not found on disk, though it should be") } - assert (0 == bytes.position()) + assert(0 == bytes.position()) if (!level.useMemory) { - // If the block shouldn't be stored in memory, we can just return it: + // If the block shouldn't be stored in memory, we can just return it if (asValues) { return Some(dataDeserialize(blockId, bytes)) } else { return Some(bytes) } } else { - // Otherwise, we also have to store something in the memory store: + // Otherwise, we also have to store something in the memory store if (!level.deserialized || !asValues) { - // We'll store the bytes in memory if the block's storage level includes - // "memory serialized", or if it should be cached as objects in memory - // but we only requested its serialized bytes: + /* We'll store the bytes in memory if the block's storage level includes + * "memory serialized", or if it should be cached as objects in memory + * but we only requested its serialized bytes. */ val copyForMemory = ByteBuffer.allocate(bytes.limit) copyForMemory.put(bytes) memoryStore.putBytes(blockId, copyForMemory, level) @@ -442,16 +441,17 @@ private[spark] class BlockManager( } else { val values = dataDeserialize(blockId, bytes) if (level.deserialized) { - // Cache the values before returning them: + // Cache the values before returning them // TODO: Consider creating a putValues that also takes in a iterator? val valuesBuffer = new ArrayBuffer[Any] valuesBuffer ++= values - memoryStore.putValues(blockId, valuesBuffer, level, true).data match { - case Left(values2) => - return Some(values2) - case _ => - throw new Exception("Memory store did not return back an iterator") - } + memoryStore.putValues(blockId, valuesBuffer, level, returnValues = true).data + match { + case Left(values2) => + return Some(values2) + case _ => + throw new SparkException("Memory store did not return an iterator") + } } else { return Some(values) } @@ -460,7 +460,7 @@ private[spark] class BlockManager( } } } else { - logDebug("Block " + blockId + " not registered locally") + logDebug(s"Block $blockId not registered locally") } None } @@ -469,7 +469,7 @@ private[spark] class BlockManager( * Get block from remote block managers. */ def getRemote(blockId: BlockId): Option[Iterator[Any]] = { - logDebug("Getting remote block " + blockId) + logDebug(s"Getting remote block $blockId") doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] } @@ -477,7 +477,7 @@ private[spark] class BlockManager( * Get block from remote block managers as serialized bytes. */ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { - logDebug("Getting remote block " + blockId + " as bytes") + logDebug(s"Getting remote block $blockId as bytes") doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] } @@ -485,7 +485,7 @@ private[spark] class BlockManager( require(blockId != null, "BlockId is null") val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { - logDebug("Getting remote block " + blockId + " from " + loc) + logDebug(s"Getting remote block $blockId from $loc") val data = BlockManagerWorker.syncGetBlock( GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) if (data != null) { @@ -495,9 +495,9 @@ private[spark] class BlockManager( return Some(data) } } - logDebug("The value of block " + blockId + " is null") + logDebug(s"The value of block $blockId is null") } - logDebug("Block " + blockId + " not found") + logDebug(s"Block $blockId not found") None } @@ -507,12 +507,12 @@ private[spark] class BlockManager( def get(blockId: BlockId): Option[Iterator[Any]] = { val local = getLocal(blockId) if (local.isDefined) { - logInfo("Found block %s locally".format(blockId)) + logInfo(s"Found block $blockId locally") return local } val remote = getRemote(blockId) if (remote.isDefined) { - logInfo("Found block %s remotely".format(blockId)) + logInfo(s"Found block $blockId remotely") return remote } None @@ -533,7 +533,6 @@ private[spark] class BlockManager( } else { new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) } - iter.initialize() iter } @@ -543,6 +542,7 @@ private[spark] class BlockManager( values: Iterator[Any], level: StorageLevel, tellMaster: Boolean): Seq[(BlockId, BlockStatus)] = { + require(values != null, "Values is null") doPut(blockId, IteratorValues(values), level, tellMaster) } @@ -562,8 +562,8 @@ private[spark] class BlockManager( } /** - * Put a new block of values to the block manager. Return a list of blocks updated as a - * result of this put. + * Put a new block of values to the block manager. + * Return a list of blocks updated as a result of this put. */ def put( blockId: BlockId, @@ -575,8 +575,8 @@ private[spark] class BlockManager( } /** - * Put a new block of serialized bytes to the block manager. Return a list of blocks updated - * as a result of this put. + * Put a new block of serialized bytes to the block manager. + * Return a list of blocks updated as a result of this put. */ def putBytes( blockId: BlockId, @@ -589,7 +589,7 @@ private[spark] class BlockManager( private def doPut( blockId: BlockId, - data: Values, + data: BlockValues, level: StorageLevel, tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = { @@ -599,20 +599,18 @@ private[spark] class BlockManager( // Return value val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - // Remember the block's storage level so that we can correctly drop it to disk if it needs - // to be dropped right after it got put into memory. Note, however, that other threads will - // not be able to get() this block until we call markReady on its BlockInfo. + /* Remember the block's storage level so that we can correctly drop it to disk if it needs + * to be dropped right after it got put into memory. Note, however, that other threads will + * not be able to get() this block until we call markReady on its BlockInfo. */ val putBlockInfo = { val tinfo = new BlockInfo(level, tellMaster) // Do atomically ! val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - if (oldBlockOpt.isDefined) { if (oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + logWarning(s"Block $blockId already exists on this machine; not re-adding it") return updatedBlocks } - // TODO: So the block info exists - but previous attempt to load it (?) failed. // What do we do now ? Retry on it ? oldBlockOpt.get @@ -623,10 +621,10 @@ private[spark] class BlockManager( val startTimeMs = System.currentTimeMillis - // If we're storing values and we need to replicate the data, we'll want access to the values, - // but because our put will read the whole iterator, there will be no values left. For the - // case where the put serializes data, we'll remember the bytes, above; but for the case where - // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator. + /* If we're storing values and we need to replicate the data, we'll want access to the values, + * but because our put will read the whole iterator, there will be no values left. For the + * case where the put serializes data, we'll remember the bytes, above; but for the case where + * it doesn't, such as deserialized storage, let's rely on the put returning an Iterator. */ var valuesAfterPut: Iterator[Any] = null // Ditto for the bytes after the put @@ -637,78 +635,62 @@ private[spark] class BlockManager( // If we're storing bytes, then initiate the replication before storing them locally. // This is faster as data is already serialized and ready to send. - val replicationFuture = if (data.isInstanceOf[ByteBufferValues] && level.replication > 1) { - // Duplicate doesn't copy the bytes, just creates a wrapper - val bufferView = data.asInstanceOf[ByteBufferValues].buffer.duplicate() - Future { - replicate(blockId, bufferView, level) - } - } else { - null + val replicationFuture = data match { + case b: ByteBufferValues if level.replication > 1 => + // Duplicate doesn't copy the bytes, but just creates a wrapper + val bufferView = b.buffer.duplicate() + Future { replicate(blockId, bufferView, level) } + case _ => null } putBlockInfo.synchronized { - logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") + logTrace("Put for block %s took %s to get into synchronized block" + .format(blockId, Utils.getUsedTimeMs(startTimeMs))) var marked = false try { - if (level.useMemory) { - // Save it just to memory first, even if it also has useDisk set to true; we will - // drop it to disk later if the memory store can't hold it. - val res = data match { - case IteratorValues(iterator) => - memoryStore.putValues(blockId, iterator, level, true) - case ArrayBufferValues(array) => - memoryStore.putValues(blockId, array, level, true) - case ByteBufferValues(bytes) => - bytes.rewind() - memoryStore.putBytes(blockId, bytes, level) - } - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case Left(newIterator) => valuesAfterPut = newIterator - } - // Keep track of which blocks are dropped from memory - res.droppedBlocks.foreach { block => updatedBlocks += block } - } else if (level.useOffHeap) { - // Save to Tachyon. - val res = data match { - case IteratorValues(iterator) => - tachyonStore.putValues(blockId, iterator, level, false) - case ArrayBufferValues(array) => - tachyonStore.putValues(blockId, array, level, false) - case ByteBufferValues(bytes) => - bytes.rewind() - tachyonStore.putBytes(blockId, bytes, level) - } - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => - } - } else { - // Save directly to disk. - // Don't get back the bytes unless we replicate them. - val askForBytes = level.replication > 1 - - val res = data match { - case IteratorValues(iterator) => - diskStore.putValues(blockId, iterator, level, askForBytes) - case ArrayBufferValues(array) => - diskStore.putValues(blockId, array, level, askForBytes) - case ByteBufferValues(bytes) => - bytes.rewind() - diskStore.putBytes(blockId, bytes, level) - } - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => + // returnValues - Whether to return the values put + // blockStore - The type of storage to put these values into + val (returnValues, blockStore: BlockStore) = { + if (level.useMemory) { + // Put it in memory first, even if it also has useDisk set to true; + // We will drop it to disk later if the memory store can't hold it. + (true, memoryStore) + } else if (level.useOffHeap) { + // Use tachyon for off-heap storage + (false, tachyonStore) + } else if (level.useDisk) { + // Don't get back the bytes from put unless we replicate them + (level.replication > 1, diskStore) + } else { + assert(level == StorageLevel.NONE) + throw new BlockException( + blockId, s"Attempted to put block $blockId without specifying storage level!") } } + // Actually put the values + val result = data match { + case IteratorValues(iterator) => + blockStore.putValues(blockId, iterator, level, returnValues) + case ArrayBufferValues(array) => + blockStore.putValues(blockId, array, level, returnValues) + case ByteBufferValues(bytes) => + bytes.rewind() + blockStore.putBytes(blockId, bytes, level) + } + size = result.size + result.data match { + case Left (newIterator) if level.useMemory => valuesAfterPut = newIterator + case Right (newBytes) => bytesAfterPut = newBytes + case _ => + } + + // Keep track of which blocks are dropped from memory + if (level.useMemory) { + result.droppedBlocks.foreach { updatedBlocks += _ } + } + val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) if (putBlockStatus.storageLevel != StorageLevel.NONE) { // Now that the block is in either the memory, tachyon, or disk store, @@ -728,18 +710,21 @@ private[spark] class BlockManager( // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) putBlockInfo.markFailure() - logWarning("Putting block " + blockId + " failed") + logWarning(s"Putting block $blockId failed") } } } - logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) // Either we're storing bytes and we asynchronously started replication, or we're storing // values and need to serialize and replicate them now: if (level.replication > 1) { data match { - case ByteBufferValues(bytes) => Await.ready(replicationFuture, Duration.Inf) - case _ => { + case ByteBufferValues(bytes) => + if (replicationFuture != null) { + Await.ready(replicationFuture, Duration.Inf) + } + case _ => val remoteStartTime = System.currentTimeMillis // Serialize the block if not already done if (bytesAfterPut == null) { @@ -750,20 +735,19 @@ private[spark] class BlockManager( bytesAfterPut = dataSerialize(blockId, valuesAfterPut) } replicate(blockId, bytesAfterPut, level) - logDebug("Put block " + blockId + " remotely took " + - Utils.getUsedTimeMs(remoteStartTime)) - } + logDebug("Put block %s remotely took %s" + .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) } } BlockManager.dispose(bytesAfterPut) if (level.replication > 1) { - logDebug("Put for block " + blockId + " with replication took " + - Utils.getUsedTimeMs(startTimeMs)) + logDebug("Putting block %s with replication took %s" + .format(blockId, Utils.getUsedTimeMs(startTimeMs))) } else { - logDebug("Put for block " + blockId + " without replication took " + - Utils.getUsedTimeMs(startTimeMs)) + logDebug("Putting block %s without replication took %s" + .format(blockId, Utils.getUsedTimeMs(startTimeMs))) } updatedBlocks @@ -773,7 +757,7 @@ private[spark] class BlockManager( * Replicate block to another node. */ @volatile var cachedPeers: Seq[BlockManagerId] = null - private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) { + private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel): Unit = { val tLevel = StorageLevel( level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1) if (cachedPeers == null) { @@ -782,15 +766,16 @@ private[spark] class BlockManager( for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime data.rewind() - logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " - + data.limit() + " Bytes. To node: " + peer) - if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), - new ConnectionManagerId(peer.host, peer.port))) { - logError("Failed to call syncPutBlock to " + peer) + logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " + + s"To node: $peer") + val putBlock = PutBlock(blockId, data, tLevel) + val cmId = new ConnectionManagerId(peer.host, peer.port) + val syncPutBlockSuccess = BlockManagerWorker.syncPutBlock(putBlock, cmId) + if (!syncPutBlockSuccess) { + logError(s"Failed to call syncPutBlock to $peer") } - logDebug("Replicated BlockId " + blockId + " once used " + - (System.nanoTime - start) / 1e6 + " s; The size of the data is " + - data.limit() + " bytes.") + logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes." + .format(blockId, (System.nanoTime - start) / 1e6, data.limit())) } } @@ -822,17 +807,17 @@ private[spark] class BlockManager( blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]): Option[BlockStatus] = { - logInfo("Dropping block " + blockId + " from memory") + logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId).orNull // If the block has not already been dropped - if (info != null) { + if (info != null) { info.synchronized { // required ? As of now, this will be invoked only for blocks which are ready // But in case this changes in future, adding for consistency sake. if (!info.waitForReady()) { // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure. Nothing to drop") + logWarning(s"Block $blockId was marked as failure. Nothing to drop") return None } @@ -841,10 +826,10 @@ private[spark] class BlockManager( // Drop to disk, if storage level requires if (level.useDisk && !diskStore.contains(blockId)) { - logInfo("Writing block " + blockId + " to disk") + logInfo(s"Writing block $blockId to disk") data match { case Left(elements) => - diskStore.putValues(blockId, elements, level, false) + diskStore.putValues(blockId, elements, level, returnValues = false) case Right(bytes) => diskStore.putBytes(blockId, bytes, level) } @@ -858,7 +843,7 @@ private[spark] class BlockManager( if (blockIsRemoved) { blockIsUpdated = true } else { - logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") + logWarning(s"Block $blockId could not be dropped from memory as it does not exist") } val status = getCurrentBlockStatus(blockId, info) @@ -883,7 +868,7 @@ private[spark] class BlockManager( */ def removeRdd(rddId: Int): Int = { // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. - logInfo("Removing RDD " + rddId) + logInfo(s"Removing RDD $rddId") val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size @@ -893,7 +878,7 @@ private[spark] class BlockManager( * Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { - logInfo("Removing broadcast " + broadcastId) + logInfo(s"Removing broadcast $broadcastId") val blocksToRemove = blockInfo.keys.collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } @@ -904,40 +889,42 @@ private[spark] class BlockManager( /** * Remove a block from both memory and disk. */ - def removeBlock(blockId: BlockId, tellMaster: Boolean = true) { - logInfo("Removing block " + blockId) + def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { + logInfo(s"Removing block $blockId") val info = blockInfo.get(blockId).orNull - if (info != null) info.synchronized { - // Removals are idempotent in disk store and memory store. At worst, we get a warning. - val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) - val removedFromTachyon = if (tachyonInitialized) tachyonStore.remove(blockId) else false - if (!removedFromMemory && !removedFromDisk && !removedFromTachyon) { - logWarning("Block " + blockId + " could not be removed as it was not found in either " + - "the disk, memory, or tachyon store") - } - blockInfo.remove(blockId) - if (tellMaster && info.tellMaster) { - val status = getCurrentBlockStatus(blockId, info) - reportBlockStatus(blockId, info, status) + if (info != null) { + info.synchronized { + // Removals are idempotent in disk store and memory store. At worst, we get a warning. + val removedFromMemory = memoryStore.remove(blockId) + val removedFromDisk = diskStore.remove(blockId) + val removedFromTachyon = if (tachyonInitialized) tachyonStore.remove(blockId) else false + if (!removedFromMemory && !removedFromDisk && !removedFromTachyon) { + logWarning(s"Block $blockId could not be removed as it was not found in either " + + "the disk, memory, or tachyon store") + } + blockInfo.remove(blockId) + if (tellMaster && info.tellMaster) { + val status = getCurrentBlockStatus(blockId, info) + reportBlockStatus(blockId, info, status) + } } } else { // The block has already been removed; do nothing. - logWarning("Asked to remove block " + blockId + ", which does not exist") + logWarning(s"Asked to remove block $blockId, which does not exist") } } - private def dropOldNonBroadcastBlocks(cleanupTime: Long) { - logInfo("Dropping non broadcast blocks older than " + cleanupTime) + private def dropOldNonBroadcastBlocks(cleanupTime: Long): Unit = { + logInfo(s"Dropping non broadcast blocks older than $cleanupTime") dropOldBlocks(cleanupTime, !_.isBroadcast) } - private def dropOldBroadcastBlocks(cleanupTime: Long) { - logInfo("Dropping broadcast blocks older than " + cleanupTime) + private def dropOldBroadcastBlocks(cleanupTime: Long): Unit = { + logInfo(s"Dropping broadcast blocks older than $cleanupTime") dropOldBlocks(cleanupTime, _.isBroadcast) } - private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { + private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)): Unit = { val iterator = blockInfo.getEntrySet.iterator while (iterator.hasNext) { val entry = iterator.next() @@ -945,17 +932,11 @@ private[spark] class BlockManager( if (time < cleanupTime && shouldDrop(id)) { info.synchronized { val level = info.level - if (level.useMemory) { - memoryStore.remove(id) - } - if (level.useDisk) { - diskStore.remove(id) - } - if (level.useOffHeap) { - tachyonStore.remove(id) - } + if (level.useMemory) { memoryStore.remove(id) } + if (level.useDisk) { diskStore.remove(id) } + if (level.useOffHeap) { tachyonStore.remove(id) } iterator.remove() - logInfo("Dropped block " + id) + logInfo(s"Dropped block $id") } val status = getCurrentBlockStatus(id, info) reportBlockStatus(id, info, status) @@ -963,12 +944,14 @@ private[spark] class BlockManager( } } - def shouldCompress(blockId: BlockId): Boolean = blockId match { - case ShuffleBlockId(_, _, _) => compressShuffle - case BroadcastBlockId(_, _) => compressBroadcast - case RDDBlockId(_, _) => compressRdds - case TempBlockId(_) => compressShuffleSpill - case _ => false + private def shouldCompress(blockId: BlockId): Boolean = { + blockId match { + case _: ShuffleBlockId => compressShuffle + case _: BroadcastBlockId => compressBroadcast + case _: RDDBlockId => compressRdds + case _: TempBlockId => compressShuffleSpill + case _ => false + } } /** @@ -990,7 +973,7 @@ private[spark] class BlockManager( blockId: BlockId, outputStream: OutputStream, values: Iterator[Any], - serializer: Serializer = defaultSerializer) { + serializer: Serializer = defaultSerializer): Unit = { val byteStream = new BufferedOutputStream(outputStream) val ser = serializer.newInstance() ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() @@ -1016,16 +999,16 @@ private[spark] class BlockManager( serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - def getIterator = { + def getIterator: Iterator[Any] = { val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) serializer.newInstance().deserializeStream(stream).asIterator } if (blockId.isShuffle) { - // Reducer may need to read many local shuffle blocks and will wrap them into Iterators - // at the beginning. The wrapping will cost some memory (compression instance - // initialization, etc.). Reducer reads shuffle blocks one by one so we could do the - // wrapping lazily to save memory. + /* Reducer may need to read many local shuffle blocks and will wrap them into Iterators + * at the beginning. The wrapping will cost some memory (compression instance + * initialization, etc.). Reducer reads shuffle blocks one by one so we could do the + * wrapping lazily to save memory. */ class LazyProxyIterator(f: => Iterator[Any]) extends Iterator[Any] { lazy val proxy = f override def hasNext: Boolean = proxy.hasNext @@ -1037,7 +1020,7 @@ private[spark] class BlockManager( } } - def stop() { + def stop(): Unit = { if (heartBeatTask != null) { heartBeatTask.cancel() } @@ -1059,9 +1042,9 @@ private[spark] class BlockManager( private[spark] object BlockManager extends Logging { - val ID_GENERATOR = new IdGenerator + private val ID_GENERATOR = new IdGenerator - def getMaxMemory(conf: SparkConf): Long = { + private def getMaxMemory(conf: SparkConf): Long = { val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) (Runtime.getRuntime.maxMemory * memoryFraction).toLong } @@ -1078,9 +1061,9 @@ private[spark] object BlockManager extends Logging { * waiting for the GC to find it because that could lead to huge numbers of open files. There's * unfortunately no standard API to do this. */ - def dispose(buffer: ByteBuffer) { + def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace("Unmapping " + buffer) + logTrace(s"Unmapping $buffer") if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { buffer.asInstanceOf[DirectBuffer].cleaner().clean() } @@ -1093,7 +1076,7 @@ private[spark] object BlockManager extends Logging { blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[BlockManagerId]] = { // blockManagerMaster != null is used in tests - assert (env != null || blockManagerMaster != null) + assert(env != null || blockManagerMaster != null) val blockLocations: Seq[Seq[BlockManagerId]] = if (blockManagerMaster == null) { env.blockManager.getLocationBlockIds(blockIds) } else { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3a7243a1ba19c..2ec46d416f37d 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -40,9 +40,9 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64) - // Create one local directory for each path mentioned in spark.local.dir; then, inside this - // directory, create multiple subdirectories that we will hash files into, in order to avoid - // having really large inodes at the top level. + /* Create one local directory for each path mentioned in spark.local.dir; then, inside this + * directory, create multiple subdirectories that we will hash files into, in order to avoid + * having really large inodes at the top level. */ private val localDirs: Array[File] = createLocalDirs() private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private var shuffleSender : ShuffleSender = null @@ -114,7 +114,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD } private def createLocalDirs(): Array[File] = { - logDebug("Creating local directories at root dirs '" + rootDirs + "'") + logDebug(s"Creating local directories at root dirs '$rootDirs'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") rootDirs.split(",").map { rootDir => var foundLocalDir = false @@ -126,21 +126,20 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD tries += 1 try { localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - localDir = new File(rootDir, "spark-local-" + localDirId) + localDir = new File(rootDir, s"spark-local-$localDirId") if (!localDir.exists) { foundLocalDir = localDir.mkdirs() } } catch { case e: Exception => - logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) + logWarning(s"Attempt $tries to create local dir $localDir failed", e) } } if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + - " attempts to create local dir in " + rootDir) + logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create local dir in $rootDir") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } - logInfo("Created local directory at " + localDir) + logInfo(s"Created local directory at $localDir") localDir } } @@ -163,7 +162,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) } catch { case e: Exception => - logError("Exception while deleting local spark dir: " + localDir, e) + logError(s"Exception while deleting local spark dir: $localDir", e) } } } @@ -175,7 +174,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD private[storage] def startShuffleBlockSender(port: Int): Int = { shuffleSender = new ShuffleSender(port, this) - logInfo("Created ShuffleSender binding to port : " + shuffleSender.port) + logInfo(s"Created ShuffleSender binding to port: ${shuffleSender.port}") shuffleSender.port } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 0ab9fad422717..ebff0cb5ba153 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -39,41 +39,39 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage diskManager.getBlockLocation(blockId).length } - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = { // So that we do not modify the input offsets ! // duplicate does not copy buffer, so inexpensive val bytes = _bytes.duplicate() - logDebug("Attempting to put block " + blockId) + logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) - val channel = new FileOutputStream(file).getChannel() + val channel = new FileOutputStream(file).getChannel while (bytes.remaining > 0) { channel.write(bytes) } channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( - file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime))) - return PutResult(bytes.limit(), Right(bytes.duplicate())) + file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) + PutResult(bytes.limit(), Right(bytes.duplicate())) } override def putValues( blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, - returnValues: Boolean) - : PutResult = { - return putValues(blockId, values.toIterator, level, returnValues) + returnValues: Boolean): PutResult = { + putValues(blockId, values.toIterator, level, returnValues) } override def putValues( blockId: BlockId, values: Iterator[Any], level: StorageLevel, - returnValues: Boolean) - : PutResult = { + returnValues: Boolean): PutResult = { - logDebug("Attempting to write values for block " + blockId) + logDebug(s"Attempting to write values for block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) @@ -95,7 +93,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val segment = diskManager.getBlockLocation(blockId) - val channel = new RandomAccessFile(segment.file, "r").getChannel() + val channel = new RandomAccessFile(segment.file, "r").getChannel try { // For small files, directly read rather than memory map @@ -131,7 +129,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage file.delete() } else { if (fileSegment.length < file.length()) { - logWarning("Could not delete block associated with only a part of a file: " + blockId) + logWarning(s"Could not delete block associated with only a part of a file: $blockId") } false } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 488f1ea9628f5..084a566c48560 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -24,6 +24,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.util.{SizeEstimator, Utils} +private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean) + /** * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as * serialized ByteBuffers. @@ -31,15 +33,13 @@ import org.apache.spark.util.{SizeEstimator, Utils} private class MemoryStore(blockManager: BlockManager, maxMemory: Long) extends BlockStore(blockManager) { - case class Entry(value: Any, size: Long, deserialized: Boolean) - - private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true) + private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true) @volatile private var currentMemory = 0L // Object used to ensure that only one thread is putting blocks and if necessary, dropping // blocks from the memory store. private val putLock = new Object() - logInfo("MemoryStore started with capacity %s.".format(Utils.bytesToString(maxMemory))) + logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) def freeMemory: Long = maxMemory - currentMemory @@ -101,7 +101,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else if (entry.deserialized) { Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)) } else { - Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data + Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data } } @@ -124,8 +124,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) val entry = entries.remove(blockId) if (entry != null) { currentMemory -= entry.size - logInfo("Block %s of size %d dropped from memory (free %d)".format( - blockId, entry.size, freeMemory)) + logInfo(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)") true } else { false @@ -181,18 +180,14 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlocks ++= freeSpaceResult.droppedBlocks if (enoughFreeSpace) { - val entry = new Entry(value, size, deserialized) + val entry = new MemoryEntry(value, size, deserialized) entries.synchronized { entries.put(blockId, entry) currentMemory += size } - if (deserialized) { - logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - } else { - logInfo("Block %s stored as bytes to memory (size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - } + val valuesOrBytes = if (deserialized) "values" else "bytes" + logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( + blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) putSuccess = true } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to @@ -221,13 +216,12 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Return whether there is enough free space, along with the blocks dropped in the process. */ private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): ResultWithDroppedBlocks = { - logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( - space, currentMemory, maxMemory)) + logInfo(s"ensureFreeSpace($space) called with curMem=$currentMemory, maxMem=$maxMemory") val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] if (space > maxMemory) { - logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit") + logInfo(s"Will not store $blockIdToAdd as it is larger than our memory limit") return ResultWithDroppedBlocks(success = false, droppedBlocks) } @@ -252,7 +246,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } if (maxMemory - (currentMemory - selectedMemory) >= space) { - logInfo(selectedBlocks.size + " blocks selected for dropping") + logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } // This should never be null as only one thread should be dropping diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 2d8ff1194a5dc..1e35abaab5353 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -34,11 +34,11 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi class StorageLevel private( - private var useDisk_ : Boolean, - private var useMemory_ : Boolean, - private var useOffHeap_ : Boolean, - private var deserialized_ : Boolean, - private var replication_ : Int = 1) + private var _useDisk: Boolean, + private var _useMemory: Boolean, + private var _useOffHeap: Boolean, + private var _deserialized: Boolean, + private var _replication: Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. @@ -48,13 +48,13 @@ class StorageLevel private( def this() = this(false, true, false, false) // For deserialization - def useDisk = useDisk_ - def useMemory = useMemory_ - def useOffHeap = useOffHeap_ - def deserialized = deserialized_ - def replication = replication_ + def useDisk = _useDisk + def useMemory = _useMemory + def useOffHeap = _useOffHeap + def deserialized = _deserialized + def replication = _replication - assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes") if (useOffHeap) { require(!useDisk, "Off-heap storage level does not support using disk") @@ -63,8 +63,9 @@ class StorageLevel private( require(replication == 1, "Off-heap storage level does not support multiple replication") } - override def clone(): StorageLevel = new StorageLevel( - this.useDisk, this.useMemory, this.useOffHeap, this.deserialized, this.replication) + override def clone(): StorageLevel = { + new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication) + } override def equals(other: Any): Boolean = other match { case s: StorageLevel => @@ -77,20 +78,20 @@ class StorageLevel private( false } - def isValid = ((useMemory || useDisk || useOffHeap) && (replication > 0)) + def isValid = (useMemory || useDisk || useOffHeap) && (replication > 0) def toInt: Int = { var ret = 0 - if (useDisk_) { + if (_useDisk) { ret |= 8 } - if (useMemory_) { + if (_useMemory) { ret |= 4 } - if (useOffHeap_) { + if (_useOffHeap) { ret |= 2 } - if (deserialized_) { + if (_deserialized) { ret |= 1 } ret @@ -98,32 +99,34 @@ class StorageLevel private( override def writeExternal(out: ObjectOutput) { out.writeByte(toInt) - out.writeByte(replication_) + out.writeByte(_replication) } override def readExternal(in: ObjectInput) { val flags = in.readByte() - useDisk_ = (flags & 8) != 0 - useMemory_ = (flags & 4) != 0 - useOffHeap_ = (flags & 2) != 0 - deserialized_ = (flags & 1) != 0 - replication_ = in.readByte() + _useDisk = (flags & 8) != 0 + _useMemory = (flags & 4) != 0 + _useOffHeap = (flags & 2) != 0 + _deserialized = (flags & 1) != 0 + _replication = in.readByte() } @throws(classOf[IOException]) private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) - override def toString: String = "StorageLevel(%b, %b, %b, %b, %d)".format( - useDisk, useMemory, useOffHeap, deserialized, replication) + override def toString: String = { + s"StorageLevel($useDisk, $useMemory, $useOffHeap, $deserialized, $replication)" + } override def hashCode(): Int = toInt * 41 + replication - def description : String = { + + def description: String = { var result = "" result += (if (useDisk) "Disk " else "") result += (if (useMemory) "Memory " else "") result += (if (useOffHeap) "Tachyon " else "") result += (if (deserialized) "Deserialized " else "Serialized ") - result += "%sx Replicated".format(replication) + result += s"${replication}x Replicated" result } } @@ -165,7 +168,7 @@ object StorageLevel { case "MEMORY_AND_DISK_SER" => MEMORY_AND_DISK_SER case "MEMORY_AND_DISK_SER_2" => MEMORY_AND_DISK_SER_2 case "OFF_HEAP" => OFF_HEAP - case _ => throw new IllegalArgumentException("Invalid StorageLevel: " + s) + case _ => throw new IllegalArgumentException(s"Invalid StorageLevel: $s") } /** @@ -173,26 +176,37 @@ object StorageLevel { * Create a new StorageLevel object without setting useOffHeap. */ @DeveloperApi - def apply(useDisk: Boolean, useMemory: Boolean, useOffHeap: Boolean, - deserialized: Boolean, replication: Int) = getCachedStorageLevel( + def apply( + useDisk: Boolean, + useMemory: Boolean, + useOffHeap: Boolean, + deserialized: Boolean, + replication: Int) = { + getCachedStorageLevel( new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication)) + } /** * :: DeveloperApi :: * Create a new StorageLevel object. */ @DeveloperApi - def apply(useDisk: Boolean, useMemory: Boolean, - deserialized: Boolean, replication: Int = 1) = getCachedStorageLevel( - new StorageLevel(useDisk, useMemory, false, deserialized, replication)) + def apply( + useDisk: Boolean, + useMemory: Boolean, + deserialized: Boolean, + replication: Int = 1) = { + getCachedStorageLevel(new StorageLevel(useDisk, useMemory, false, deserialized, replication)) + } /** * :: DeveloperApi :: * Create a new StorageLevel object from its integer representation. */ @DeveloperApi - def apply(flags: Int, replication: Int): StorageLevel = + def apply(flags: Int, replication: Int): StorageLevel = { getCachedStorageLevel(new StorageLevel(flags, replication)) + } /** * :: DeveloperApi :: @@ -205,8 +219,8 @@ object StorageLevel { getCachedStorageLevel(obj) } - private[spark] - val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + private[spark] val storageLevelCache = + new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { storageLevelCache.putIfAbsent(level, level) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala index c37e76f893605..d8ff4ff6bd42c 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala @@ -22,15 +22,10 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import tachyon.client.{WriteType, ReadType} +import tachyon.client.{ReadType, WriteType} import org.apache.spark.Logging import org.apache.spark.util.Utils -import org.apache.spark.serializer.Serializer - - -private class Entry(val size: Long) - /** * Stores BlockManager blocks on Tachyon. @@ -46,8 +41,8 @@ private class TachyonStore( tachyonManager.getFile(blockId.name).length } - override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { - putToTachyonStore(blockId, bytes, true) + override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { + putIntoTachyonStore(blockId, bytes, returnValues = true) } override def putValues( @@ -55,7 +50,7 @@ private class TachyonStore( values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - return putValues(blockId, values.toIterator, level, returnValues) + putValues(blockId, values.toIterator, level, returnValues) } override def putValues( @@ -63,12 +58,12 @@ private class TachyonStore( values: Iterator[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - logDebug("Attempting to write values for block " + blockId) - val _bytes = blockManager.dataSerialize(blockId, values) - putToTachyonStore(blockId, _bytes, returnValues) + logDebug(s"Attempting to write values for block $blockId") + val bytes = blockManager.dataSerialize(blockId, values) + putIntoTachyonStore(blockId, bytes, returnValues) } - private def putToTachyonStore( + private def putIntoTachyonStore( blockId: BlockId, bytes: ByteBuffer, returnValues: Boolean): PutResult = { @@ -76,7 +71,7 @@ private class TachyonStore( // duplicate does not copy buffer, so inexpensive val byteBuffer = bytes.duplicate() byteBuffer.rewind() - logDebug("Attempting to put block " + blockId + " into Tachyon") + logDebug(s"Attempting to put block $blockId into Tachyon") val startTime = System.currentTimeMillis val file = tachyonManager.getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) @@ -84,7 +79,7 @@ private class TachyonStore( os.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file in Tachyon in %d ms".format( - blockId, Utils.bytesToString(byteBuffer.limit), (finishTime - startTime))) + blockId, Utils.bytesToString(byteBuffer.limit), finishTime - startTime)) if (returnValues) { PutResult(bytes.limit(), Right(bytes.duplicate())) @@ -106,10 +101,9 @@ private class TachyonStore( getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) } - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val file = tachyonManager.getFile(blockId) - if (file == null || file.getLocationHosts().size == 0) { + if (file == null || file.getLocationHosts.size == 0) { return None } val is = file.getInStream(ReadType.CACHE) @@ -121,16 +115,15 @@ private class TachyonStore( val fetchSize = is.read(bs, 0, size.asInstanceOf[Int]) buffer = ByteBuffer.wrap(bs) if (fetchSize != size) { - logWarning("Failed to fetch the block " + blockId + " from Tachyon : Size " + size + - " is not equal to fetched size " + fetchSize) + logWarning(s"Failed to fetch the block $blockId from Tachyon: Size $size " + + s"is not equal to fetched size $fetchSize") return None } } } catch { - case ioe: IOException => { - logWarning("Failed to fetch the block " + blockId + " from Tachyon", ioe) - return None - } + case ioe: IOException => + logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) + return None } Some(buffer) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3b1b6df089b8e..4ce28bb0cf059 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -862,6 +862,59 @@ private[spark] object Utils extends Logging { Source.fromBytes(buff).mkString } + /** + * Return a string containing data across a set of files. The `startIndex` + * and `endIndex` is based on the cumulative size of all the files take in + * the given order. See figure below for more details. + */ + def offsetBytes(files: Seq[File], start: Long, end: Long): String = { + val fileLengths = files.map { _.length } + val startIndex = math.max(start, 0) + val endIndex = math.min(end, fileLengths.sum) + val fileToLength = files.zip(fileLengths).toMap + logDebug("Log files: \n" + fileToLength.mkString("\n")) + + val stringBuffer = new StringBuffer((endIndex - startIndex).toInt) + var sum = 0L + for (file <- files) { + val startIndexOfFile = sum + val endIndexOfFile = sum + fileToLength(file) + logDebug(s"Processing file $file, " + + s"with start index = $startIndexOfFile, end index = $endIndex") + + /* + ____________ + range 1: | | + | case A | + + files: |==== file 1 ====|====== file 2 ======|===== file 3 =====| + + | case B . case C . case D | + range 2: |___________.____________________.______________| + */ + + if (startIndex <= startIndexOfFile && endIndex >= endIndexOfFile) { + // Case C: read the whole file + stringBuffer.append(offsetBytes(file.getAbsolutePath, 0, fileToLength(file))) + } else if (startIndex > startIndexOfFile && startIndex < endIndexOfFile) { + // Case A and B: read from [start of required range] to [end of file / end of range] + val effectiveStartIndex = startIndex - startIndexOfFile + val effectiveEndIndex = math.min(endIndex - startIndexOfFile, fileToLength(file)) + stringBuffer.append(Utils.offsetBytes( + file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) + } else if (endIndex > startIndexOfFile && endIndex < endIndexOfFile) { + // Case D: read from [start of file] to [end of require range] + val effectiveStartIndex = math.max(startIndex - startIndexOfFile, 0) + val effectiveEndIndex = endIndex - startIndexOfFile + stringBuffer.append(Utils.offsetBytes( + file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) + } + sum += fileToLength(file) + logDebug(s"After processing file $file, string built is ${stringBuffer.toString}}") + } + stringBuffer.toString + } + /** * Clone an object using a Spark serializer. */ diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala new file mode 100644 index 0000000000000..8e9c3036d09c2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.logging + +import java.io.{File, FileOutputStream, InputStream} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.{IntParam, Utils} + +/** + * Continuously appends the data from an input stream into the given file. + */ +private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSize: Int = 8192) + extends Logging { + @volatile private var outputStream: FileOutputStream = null + @volatile private var markedForStop = false // has the appender been asked to stopped + @volatile private var stopped = false // has the appender stopped + + // Thread that reads the input stream and writes to file + private val writingThread = new Thread("File appending thread for " + file) { + setDaemon(true) + override def run() { + Utils.logUncaughtExceptions { + appendStreamToFile() + } + } + } + writingThread.start() + + /** + * Wait for the appender to stop appending, either because input stream is closed + * or because of any error in appending + */ + def awaitTermination() { + synchronized { + if (!stopped) { + wait() + } + } + } + + /** Stop the appender */ + def stop() { + markedForStop = true + } + + /** Continuously read chunks from the input stream and append to the file */ + protected def appendStreamToFile() { + try { + logDebug("Started appending thread") + openFile() + val buf = new Array[Byte](bufferSize) + var n = 0 + while (!markedForStop && n != -1) { + n = inputStream.read(buf) + if (n != -1) { + appendToFile(buf, n) + } + } + } catch { + case e: Exception => + logError(s"Error writing stream to file $file", e) + } finally { + closeFile() + synchronized { + stopped = true + notifyAll() + } + } + } + + /** Append bytes to the file output stream */ + protected def appendToFile(bytes: Array[Byte], len: Int) { + if (outputStream == null) { + openFile() + } + outputStream.write(bytes, 0, len) + } + + /** Open the file output stream */ + protected def openFile() { + outputStream = new FileOutputStream(file, false) + logDebug(s"Opened file $file") + } + + /** Close the file output stream */ + protected def closeFile() { + outputStream.flush() + outputStream.close() + logDebug(s"Closed file $file") + } +} + +/** + * Companion object to [[org.apache.spark.util.logging.FileAppender]] which has helper + * functions to choose the correct type of FileAppender based on SparkConf configuration. + */ +private[spark] object FileAppender extends Logging { + + /** Create the right appender based on Spark configuration */ + def apply(inputStream: InputStream, file: File, conf: SparkConf): FileAppender = { + + import RollingFileAppender._ + + val rollingStrategy = conf.get(STRATEGY_PROPERTY, STRATEGY_DEFAULT) + val rollingSizeBytes = conf.get(SIZE_PROPERTY, STRATEGY_DEFAULT) + val rollingInterval = conf.get(INTERVAL_PROPERTY, INTERVAL_DEFAULT) + + def createTimeBasedAppender() = { + val validatedParams: Option[(Long, String)] = rollingInterval match { + case "daily" => + logInfo(s"Rolling executor logs enabled for $file with daily rolling") + Some(24 * 60 * 60 * 1000L, "--YYYY-MM-dd") + case "hourly" => + logInfo(s"Rolling executor logs enabled for $file with hourly rolling") + Some(60 * 60 * 1000L, "--YYYY-MM-dd--HH") + case "minutely" => + logInfo(s"Rolling executor logs enabled for $file with rolling every minute") + Some(60 * 1000L, "--YYYY-MM-dd--HH-mm") + case IntParam(seconds) => + logInfo(s"Rolling executor logs enabled for $file with rolling $seconds seconds") + Some(seconds * 1000L, "--YYYY-MM-dd--HH-mm-ss") + case _ => + logWarning(s"Illegal interval for rolling executor logs [$rollingInterval], " + + s"rolling logs not enabled") + None + } + validatedParams.map { + case (interval, pattern) => + new RollingFileAppender( + inputStream, file, new TimeBasedRollingPolicy(interval, pattern), conf) + }.getOrElse { + new FileAppender(inputStream, file) + } + } + + def createSizeBasedAppender() = { + rollingSizeBytes match { + case IntParam(bytes) => + logInfo(s"Rolling executor logs enabled for $file with rolling every $bytes bytes") + new RollingFileAppender(inputStream, file, new SizeBasedRollingPolicy(bytes), conf) + case _ => + logWarning( + s"Illegal size [$rollingSizeBytes] for rolling executor logs, rolling logs not enabled") + new FileAppender(inputStream, file) + } + } + + rollingStrategy match { + case "" => + new FileAppender(inputStream, file) + case "time" => + createTimeBasedAppender() + case "size" => + createSizeBasedAppender() + case _ => + logWarning( + s"Illegal strategy [$rollingStrategy] for rolling executor logs, " + + s"rolling logs not enabled") + new FileAppender(inputStream, file) + } + } +} + + diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala new file mode 100644 index 0000000000000..1bbbd20cf076f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.logging + +import java.io.{File, FileFilter, InputStream} + +import org.apache.commons.io.FileUtils +import org.apache.spark.SparkConf +import RollingFileAppender._ + +/** + * Continuously appends data from input stream into the given file, and rolls + * over the file after the given interval. The rolled over files are named + * based on the given pattern. + * + * @param inputStream Input stream to read data from + * @param activeFile File to write data to + * @param rollingPolicy Policy based on which files will be rolled over. + * @param conf SparkConf that is used to pass on extra configurations + * @param bufferSize Optional buffer size. Used mainly for testing. + */ +private[spark] class RollingFileAppender( + inputStream: InputStream, + activeFile: File, + val rollingPolicy: RollingPolicy, + conf: SparkConf, + bufferSize: Int = DEFAULT_BUFFER_SIZE + ) extends FileAppender(inputStream, activeFile, bufferSize) { + + private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1) + + /** Stop the appender */ + override def stop() { + super.stop() + } + + /** Append bytes to file after rolling over is necessary */ + override protected def appendToFile(bytes: Array[Byte], len: Int) { + if (rollingPolicy.shouldRollover(len)) { + rollover() + rollingPolicy.rolledOver() + } + super.appendToFile(bytes, len) + rollingPolicy.bytesWritten(len) + } + + /** Rollover the file, by closing the output stream and moving it over */ + private def rollover() { + try { + closeFile() + moveFile() + openFile() + if (maxRetainedFiles > 0) { + deleteOldFiles() + } + } catch { + case e: Exception => + logError(s"Error rolling over $activeFile", e) + } + } + + /** Move the active log file to a new rollover file */ + private def moveFile() { + val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix() + val rolloverFile = new File( + activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile + try { + logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") + if (activeFile.exists) { + if (!rolloverFile.exists) { + FileUtils.moveFile(activeFile, rolloverFile) + logInfo(s"Rolled over $activeFile to $rolloverFile") + } else { + // In case the rollover file name clashes, make a unique file name. + // The resultant file names are long and ugly, so this is used only + // if there is a name collision. This can be avoided by the using + // the right pattern such that name collisions do not occur. + var i = 0 + var altRolloverFile: File = null + do { + altRolloverFile = new File(activeFile.getParent, + s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile + i += 1 + } while (i < 10000 && altRolloverFile.exists) + + logWarning(s"Rollover file $rolloverFile already exists, " + + s"rolled over $activeFile to file $altRolloverFile") + FileUtils.moveFile(activeFile, altRolloverFile) + } + } else { + logWarning(s"File $activeFile does not exist") + } + } + } + + /** Retain only last few files */ + private[util] def deleteOldFiles() { + try { + val rolledoverFiles = activeFile.getParentFile.listFiles(new FileFilter { + def accept(f: File): Boolean = { + f.getName.startsWith(activeFile.getName) && f != activeFile + } + }).sorted + val filesToBeDeleted = rolledoverFiles.take( + math.max(0, rolledoverFiles.size - maxRetainedFiles)) + filesToBeDeleted.foreach { file => + logInfo(s"Deleting file executor log file ${file.getAbsolutePath}") + file.delete() + } + } catch { + case e: Exception => + logError("Error cleaning logs in directory " + activeFile.getParentFile.getAbsolutePath, e) + } + } +} + +/** + * Companion object to [[org.apache.spark.util.logging.RollingFileAppender]]. Defines + * names of configurations that configure rolling file appenders. + */ +private[spark] object RollingFileAppender { + val STRATEGY_PROPERTY = "spark.executor.logs.rolling.strategy" + val STRATEGY_DEFAULT = "" + val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval" + val INTERVAL_DEFAULT = "daily" + val SIZE_PROPERTY = "spark.executor.logs.rolling.size.maxBytes" + val SIZE_DEFAULT = (1024 * 1024).toString + val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" + val DEFAULT_BUFFER_SIZE = 8192 + + /** + * Get the sorted list of rolled over files. This assumes that the all the rolled + * over file names are prefixed with the `activeFileName`, and the active file + * name has the latest logs. So it sorts all the rolled over logs (that are + * prefixed with `activeFileName`) and appends the active file + */ + def getSortedRolledOverFiles(directory: String, activeFileName: String): Seq[File] = { + val rolledOverFiles = new File(directory).getAbsoluteFile.listFiles.filter { file => + val fileName = file.getName + fileName.startsWith(activeFileName) && fileName != activeFileName + }.sorted + val activeFile = { + val file = new File(directory, activeFileName).getAbsoluteFile + if (file.exists) Some(file) else None + } + rolledOverFiles ++ activeFile + } +} diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala new file mode 100644 index 0000000000000..84e5c3c917dcb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.logging + +import java.text.SimpleDateFormat +import java.util.Calendar + +import org.apache.spark.Logging + +/** + * Defines the policy based on which [[org.apache.spark.util.logging.RollingFileAppender]] will + * generate rolling files. + */ +private[spark] trait RollingPolicy { + + /** Whether rollover should be initiated at this moment */ + def shouldRollover(bytesToBeWritten: Long): Boolean + + /** Notify that rollover has occurred */ + def rolledOver() + + /** Notify that bytes have been written */ + def bytesWritten(bytes: Long) + + /** Get the desired name of the rollover file */ + def generateRolledOverFileSuffix(): String +} + +/** + * Defines a [[org.apache.spark.util.logging.RollingPolicy]] by which files will be rolled + * over at a fixed interval. + */ +private[spark] class TimeBasedRollingPolicy( + var rolloverIntervalMillis: Long, + rollingFileSuffixPattern: String, + checkIntervalConstraint: Boolean = true // set to false while testing + ) extends RollingPolicy with Logging { + + import TimeBasedRollingPolicy._ + if (checkIntervalConstraint && rolloverIntervalMillis < MINIMUM_INTERVAL_SECONDS * 1000L) { + logWarning(s"Rolling interval [${rolloverIntervalMillis/1000L} seconds] is too small. " + + s"Setting the interval to the acceptable minimum of $MINIMUM_INTERVAL_SECONDS seconds.") + rolloverIntervalMillis = MINIMUM_INTERVAL_SECONDS * 1000L + } + + @volatile private var nextRolloverTime = calculateNextRolloverTime() + private val formatter = new SimpleDateFormat(rollingFileSuffixPattern) + + /** Should rollover if current time has exceeded next rollover time */ + def shouldRollover(bytesToBeWritten: Long): Boolean = { + System.currentTimeMillis > nextRolloverTime + } + + /** Rollover has occurred, so find the next time to rollover */ + def rolledOver() { + nextRolloverTime = calculateNextRolloverTime() + logDebug(s"Current time: ${System.currentTimeMillis}, next rollover time: " + nextRolloverTime) + } + + def bytesWritten(bytes: Long) { } // nothing to do + + private def calculateNextRolloverTime(): Long = { + val now = System.currentTimeMillis() + val targetTime = ( + math.ceil(now.toDouble / rolloverIntervalMillis) * rolloverIntervalMillis + ).toLong + logDebug(s"Next rollover time is $targetTime") + targetTime + } + + def generateRolledOverFileSuffix(): String = { + formatter.format(Calendar.getInstance.getTime) + } +} + +private[spark] object TimeBasedRollingPolicy { + val MINIMUM_INTERVAL_SECONDS = 60L // 1 minute +} + +/** + * Defines a [[org.apache.spark.util.logging.RollingPolicy]] by which files will be rolled + * over after reaching a particular size. + */ +private[spark] class SizeBasedRollingPolicy( + var rolloverSizeBytes: Long, + checkSizeConstraint: Boolean = true // set to false while testing + ) extends RollingPolicy with Logging { + + import SizeBasedRollingPolicy._ + if (checkSizeConstraint && rolloverSizeBytes < MINIMUM_SIZE_BYTES) { + logWarning(s"Rolling size [$rolloverSizeBytes bytes] is too small. " + + s"Setting the size to the acceptable minimum of $MINIMUM_SIZE_BYTES bytes.") + rolloverSizeBytes = MINIMUM_SIZE_BYTES + } + + @volatile private var bytesWrittenSinceRollover = 0L + val formatter = new SimpleDateFormat("--YYYY-MM-dd--HH-mm-ss--SSSS") + + /** Should rollover if the next set of bytes is going to exceed the size limit */ + def shouldRollover(bytesToBeWritten: Long): Boolean = { + logInfo(s"$bytesToBeWritten + $bytesWrittenSinceRollover > $rolloverSizeBytes") + bytesToBeWritten + bytesWrittenSinceRollover > rolloverSizeBytes + } + + /** Rollover has occurred, so reset the counter */ + def rolledOver() { + bytesWrittenSinceRollover = 0 + } + + /** Increment the bytes that have been written in the current file */ + def bytesWritten(bytes: Long) { + bytesWrittenSinceRollover += bytes + } + + /** Get the desired name of the rollover file */ + def generateRolledOverFileSuffix(): String = { + formatter.format(Calendar.getInstance.getTime) + } +} + +private[spark] object SizeBasedRollingPolicy { + val MINIMUM_SIZE_BYTES = RollingFileAppender.DEFAULT_BUFFER_SIZE * 10 +} + diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 4dc8ada00a3e8..247f10173f1e9 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } /** - * Return a sampler with is the complement of the range specified of the current sampler. + * Return a sampler that is the complement of the range specified of the current sampler. */ def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala new file mode 100644 index 0000000000000..a79e3ee756fc6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +private[spark] object SamplingUtils { + + /** + * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of + * the time. + * + * How the sampling rate is determined: + * Let p = num / total, where num is the sample size and total is the total number of + * datapoints in the RDD. We're trying to compute q > p such that + * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), + * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total), + * i.e. the failure rate of not having a sufficiently large sample < 0.0001. + * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for + * num > 12, but we need a slightly larger q (9 empirically determined). + * - when sampling without replacement, we're drawing each datapoint with prob_i + * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success + * rate, where success rate is defined the same as in sampling with replacement. + * + * @param sampleSizeLowerBound sample size + * @param total size of RDD + * @param withReplacement whether sampling with replacement + * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate + */ + def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, + withReplacement: Boolean): Double = { + val fraction = sampleSizeLowerBound.toDouble / total + if (withReplacement) { + val numStDev = if (sampleSizeLowerBound < 12) 9 else 5 + fraction + numStDev * math.sqrt(fraction / total) + } else { + val delta = 1e-4 + val gamma = - math.log(delta) / total + math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) + } + } +} diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 50a62129116f1..ef41bfb88de9d 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -317,6 +317,37 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(33, sum); } + @Test + public void aggregateByKey() { + JavaPairRDD pairs = sc.parallelizePairs( + Arrays.asList( + new Tuple2(1, 1), + new Tuple2(1, 1), + new Tuple2(3, 2), + new Tuple2(5, 1), + new Tuple2(5, 3)), 2); + + Map> sets = pairs.aggregateByKey(new HashSet(), + new Function2, Integer, Set>() { + @Override + public Set call(Set a, Integer b) { + a.add(b); + return a; + } + }, + new Function2, Set, Set>() { + @Override + public Set call(Set a, Set b) { + a.addAll(b); + return a; + } + }).collectAsMap(); + Assert.assertEquals(3, sets.size()); + Assert.assertEquals(new HashSet(Arrays.asList(1)), sets.get(1)); + Assert.assertEquals(new HashSet(Arrays.asList(2)), sets.get(3)); + Assert.assertEquals(new HashSet(Arrays.asList(1, 3)), sets.get(5)); + } + @SuppressWarnings("unchecked") @Test public void foldByKey() { diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 4e7c34e6d1ada..3aab88e9e9196 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark import scala.collection.mutable import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.SparkContext._ -class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { +class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] { diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index dc2db66df60e0..13b415cccb647 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo def newPairRDD = newRDD.map(_ -> 1) def newShuffleRDD = newPairRDD.reduceByKey(_ + _) def newBroadcast = sc.broadcast(1 to 100) - def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { + def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => getAllDependencies(dep.rdd) @@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Get all the shuffle dependencies val shuffleDeps = getAllDependencies(rdd) - .filter(_.isInstanceOf[ShuffleDependency[_, _]]) - .map(_.asInstanceOf[ShuffleDependency[_, _]]) + .filter(_.isInstanceOf[ShuffleDependency[_, _, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _, _]]) (rdd, shuffleDeps) } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 14ddd6f1ec08f..41c294f727b3c 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} import org.apache.spark.SparkContext._ @@ -31,7 +31,7 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter +class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 2c8ef405c944c..a57430e829ced 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -25,7 +25,7 @@ import scala.concurrent.duration._ import scala.concurrent.future import org.scalatest.{BeforeAndAfter, FunSuite} -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.SparkContext._ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} @@ -35,7 +35,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers * in both FIFO and fair scheduling modes. */ -class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter +class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter with LocalSparkContext { override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index be6508a40ea61..47112ce66d695 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.SparkContext._ import org.apache.spark.ShuffleSuite.NonJavaSerializableClass @@ -26,7 +26,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.MutablePair -class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { +class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val conf = new SparkConf(loadDefaults = false) @@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 10) @@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) .setSerializer(new KryoSerializer(conf)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => @@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // NOTE: The default Java serializer should create zero-sized blocks val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index d6b93f5fedd3b..4161aede1d1d0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.deploy import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers -class ClientSuite extends FunSuite with ShouldMatchers { +class ClientSuite extends FunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) ClientArguments.isValidJarUrl("file://some/path/to/a/jarFile.jar") should be (true) diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index bfae32dae0dc5..01ab2d549325c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.SparkConf class JsonProtocolSuite extends FunSuite { @@ -116,7 +117,8 @@ class JsonProtocolSuite extends FunSuite { } def createExecutorRunner(): ExecutorRunner = { new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", - new File("sparkHome"), new File("workDir"), "akka://worker", ExecutorState.RUNNING) + new File("sparkHome"), new File("workDir"), "akka://worker", + new SparkConf, ExecutorState.RUNNING) } def createDriverRunner(): DriverRunner = { new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(), diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 02427a4a83506..565c53e9529ff 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkException, Test import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers -class SparkSubmitSuite extends FunSuite with ShouldMatchers { +class SparkSubmitSuite extends FunSuite with Matchers { def beforeAll() { System.setProperty("spark.testing", "true") } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 8ae387fa0be6f..e5f748d55500d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -22,6 +22,7 @@ import java.io.File import org.scalatest.FunSuite import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} +import org.apache.spark.SparkConf class ExecutorRunnerTest extends FunSuite { test("command includes appId") { @@ -32,7 +33,7 @@ class ExecutorRunnerTest extends FunSuite { sparkHome, "appUiUrl") val appId = "12345-worker321-9876" val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome.getOrElse(".")), - f("ooga"), "blah", ExecutorState.RUNNING) + f("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) assert(er.getCommandSeq.last === appId) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 9ddafc451878d..0b9004448a63e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -30,6 +30,19 @@ import org.apache.spark.SparkContext._ import org.apache.spark.{Partitioner, SharedSparkContext} class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { + test("aggregateByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2) + + val sets = pairs.aggregateByKey(new HashSet[Int]())(_ += _, _ ++= _).collect() + assert(sets.size === 3) + val valuesFor1 = sets.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1)) + val valuesFor3 = sets.find(_._1 == 3).get._2 + assert(valuesFor3.toList.sorted === List(2)) + val valuesFor5 = sets.find(_._1 == 5).get._2 + assert(valuesFor5.toList.sorted === List(1, 3)) + } + test("groupByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) val groups = pairs.groupByKey().collect() diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 55af1666df662..e94a1e76d410c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.SparkContext._ -import org.apache.spark.rdd._ +import org.apache.spark.util.Utils class RDDSuite extends FunSuite with SharedSparkContext { @@ -66,6 +66,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("serialization") { + val empty = new EmptyRDD[Int](sc) + val serial = Utils.serialize(empty) + val deserial: EmptyRDD[Int] = Utils.deserialize(serial) + assert(!deserial.toString().isEmpty()) + } + test("countApproxDistinct") { def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble @@ -498,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("takeSample") { - val data = sc.parallelize(1 to 100, 2) + val n = 1000000 + val data = sc.parallelize(1 to n, 2) for (num <- List(5, 20, 100)) { val sample = data.takeSample(withReplacement=false, num=num) assert(sample.size === num) // Got exactly num elements assert(sample.toSet.size === num) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 200, seed) + val sample = data.takeSample(withReplacement=false, 100, seed) assert(sample.size === 100) // Got only 100 elements assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, 20, seed) assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement=true, num=20) assert(sample.size === 20) // Got exactly 100 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { - val sample = data.takeSample(withReplacement=true, num=100) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, num=n) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 100, seed) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, n, seed) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 200, seed) - assert(sample.size === 200) // Got exactly 200 elements + val sample = data.takeSample(withReplacement=true, 2 * n, seed) + assert(sample.size === 2 * n) // Got exactly 200 elements // Chance of getting all distinct elements is still quite low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } } diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index d0619559bb457..656917628f7a8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.rdd import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.{Logging, SharedSparkContext} import org.apache.spark.SparkContext._ -class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { +class SortingSuite extends FunSuite with SharedSparkContext with Matchers with Logging { test("sortByKey") { val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 5426e578a9ddd..be506e0287a16 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -22,13 +22,13 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.executor.TaskMetrics -class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers +class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 81bd8257bc155..d7dbe5164b7f6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,7 +28,7 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ -import org.scalatest.matchers.ShouldMatchers._ +import org.scalatest.Matchers import org.scalatest.time.SpanSugar._ import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} @@ -39,7 +39,8 @@ import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, U import scala.language.implicitConversions import scala.language.postfixOps -class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { +class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter + with PrivateMethodTester { private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 8c06a2d9aa4ab..91b4c7b0dd962 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -18,14 +18,14 @@ package org.apache.spark.ui.jobs import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.{LocalSparkContext, SparkConf, Success} import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils -class JobProgressListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { +class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { test("test LRU eviction of stages") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala index 63642461e4465..090d48ec921a1 100644 --- a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.util import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers /** * */ -class DistributionSuite extends FunSuite with ShouldMatchers { +class DistributionSuite extends FunSuite with Matchers { test("summary") { val d = new Distribution((1 to 100).toArray.map{_.toDouble}) val stats = d.statCounter diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala new file mode 100644 index 0000000000000..53d7f5c6072e6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io._ + +import scala.collection.mutable.HashSet +import scala.reflect._ + +import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.spark.{Logging, SparkConf} +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} + +class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { + + val testFile = new File("FileAppenderSuite-test-" + System.currentTimeMillis).getAbsoluteFile + + before { + cleanup() + } + + after { + cleanup() + } + + test("basic file appender") { + val testString = (1 to 1000).mkString(", ") + val inputStream = IOUtils.toInputStream(testString) + val appender = new FileAppender(inputStream, testFile) + inputStream.close() + appender.awaitTermination() + assert(FileUtils.readFileToString(testFile) === testString) + } + + test("rolling file appender - time-based rolling") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val rolloverIntervalMillis = 100 + val durationMillis = 1000 + val numRollovers = durationMillis / rolloverIntervalMillis + val textToAppend = (1 to numRollovers).map( _.toString * 10 ) + + val appender = new RollingFileAppender(testInputStream, testFile, + new TimeBasedRollingPolicy(rolloverIntervalMillis, s"--HH-mm-ss-SSSS", false), + new SparkConf(), 10) + + testRolling(appender, testOutputStream, textToAppend, rolloverIntervalMillis) + } + + test("rolling file appender - size-based rolling") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val rolloverSize = 1000 + val textToAppend = (1 to 3).map( _.toString * 1000 ) + + val appender = new RollingFileAppender(testInputStream, testFile, + new SizeBasedRollingPolicy(rolloverSize, false), new SparkConf(), 99) + + val files = testRolling(appender, testOutputStream, textToAppend, 0) + files.foreach { file => + logInfo(file.toString + ": " + file.length + " bytes") + assert(file.length <= rolloverSize) + } + } + + test("rolling file appender - cleaning") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val conf = new SparkConf().set(RollingFileAppender.RETAINED_FILES_PROPERTY, "10") + val appender = new RollingFileAppender(testInputStream, testFile, + new SizeBasedRollingPolicy(1000, false), conf, 10) + + // send data to appender through the input stream, and wait for the data to be written + val allGeneratedFiles = new HashSet[String]() + val items = (1 to 10).map { _.toString * 10000 } + for (i <- 0 until items.size) { + testOutputStream.write(items(i).getBytes("UTF8")) + testOutputStream.flush() + allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles( + testFile.getParentFile.toString, testFile.getName).map(_.toString) + + Thread.sleep(10) + } + testOutputStream.close() + appender.awaitTermination() + logInfo("Appender closed") + + // verify whether the earliest file has been deleted + val rolledOverFiles = allGeneratedFiles.filter { _ != testFile.toString }.toArray.sorted + logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + rolledOverFiles.mkString("\n")) + assert(rolledOverFiles.size > 2) + val earliestRolledOverFile = rolledOverFiles.head + val existingRolledOverFiles = RollingFileAppender.getSortedRolledOverFiles( + testFile.getParentFile.toString, testFile.getName).map(_.toString) + logInfo("Existing rolled over files:\n" + existingRolledOverFiles.mkString("\n")) + assert(!existingRolledOverFiles.toSet.contains(earliestRolledOverFile)) + } + + test("file appender selection") { + // Test whether FileAppender.apply() returns the right type of the FileAppender based + // on SparkConf settings. + + def testAppenderSelection[ExpectedAppender: ClassTag, ExpectedRollingPolicy]( + properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): FileAppender = { + + // Set spark conf properties + val conf = new SparkConf + properties.foreach { p => + conf.set(p._1, p._2) + } + + // Create and test file appender + val inputStream = new PipedInputStream(new PipedOutputStream()) + val appender = FileAppender(inputStream, new File("stdout"), conf) + assert(appender.isInstanceOf[ExpectedAppender]) + assert(appender.getClass.getSimpleName === + classTag[ExpectedAppender].runtimeClass.getSimpleName) + if (appender.isInstanceOf[RollingFileAppender]) { + val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy + rollingPolicy.isInstanceOf[ExpectedRollingPolicy] + val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) { + rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis + } else { + rollingPolicy.asInstanceOf[SizeBasedRollingPolicy].rolloverSizeBytes + } + assert(policyParam === expectedRollingPolicyParam) + } + appender + } + + import RollingFileAppender._ + + def rollingStrategy(strategy: String) = Seq(STRATEGY_PROPERTY -> strategy) + def rollingSize(size: String) = Seq(SIZE_PROPERTY -> size) + def rollingInterval(interval: String) = Seq(INTERVAL_PROPERTY -> interval) + + val msInDay = 24 * 60 * 60 * 1000L + val msInHour = 60 * 60 * 1000L + val msInMinute = 60 * 1000L + + // test no strategy -> no rolling + testAppenderSelection[FileAppender, Any](Seq.empty) + + // test time based rolling strategy + testAppenderSelection[RollingFileAppender, Any](rollingStrategy("time"), msInDay) + testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy]( + rollingStrategy("time") ++ rollingInterval("daily"), msInDay) + testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy]( + rollingStrategy("time") ++ rollingInterval("hourly"), msInHour) + testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy]( + rollingStrategy("time") ++ rollingInterval("minutely"), msInMinute) + testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy]( + rollingStrategy("time") ++ rollingInterval("123456789"), 123456789 * 1000L) + testAppenderSelection[FileAppender, Any]( + rollingStrategy("time") ++ rollingInterval("xyz")) + + // test size based rolling strategy + testAppenderSelection[RollingFileAppender, SizeBasedRollingPolicy]( + rollingStrategy("size") ++ rollingSize("123456789"), 123456789) + testAppenderSelection[FileAppender, Any](rollingSize("xyz")) + + // test illegal strategy + testAppenderSelection[FileAppender, Any](rollingStrategy("xyz")) + } + + /** + * Run the rolling file appender with data and see whether all the data was written correctly + * across rolled over files. + */ + def testRolling( + appender: FileAppender, + outputStream: OutputStream, + textToAppend: Seq[String], + sleepTimeBetweenTexts: Long + ): Seq[File] = { + // send data to appender through the input stream, and wait for the data to be written + val expectedText = textToAppend.mkString("") + for (i <- 0 until textToAppend.size) { + outputStream.write(textToAppend(i).getBytes("UTF8")) + outputStream.flush() + Thread.sleep(sleepTimeBetweenTexts) + } + logInfo("Data sent to appender") + outputStream.close() + appender.awaitTermination() + logInfo("Appender closed") + + // verify whether all the data written to rolled over files is same as expected + val generatedFiles = RollingFileAppender.getSortedRolledOverFiles( + testFile.getParentFile.toString, testFile.getName) + logInfo("Filtered files: \n" + generatedFiles.mkString("\n")) + assert(generatedFiles.size > 1) + val allText = generatedFiles.map { file => + FileUtils.readFileToString(file) + }.mkString("") + assert(allText === expectedText) + generatedFiles + } + + /** Delete all the generated rolledover files */ + def cleanup() { + testFile.getParentFile.listFiles.filter { file => + file.getName.startsWith(testFile.getName) + }.foreach { _.delete() } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index 32d74d0500b72..cf438a3d72a06 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -22,9 +22,9 @@ import java.util.NoSuchElementException import scala.collection.mutable.Buffer import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers -class NextIteratorSuite extends FunSuite with ShouldMatchers { +class NextIteratorSuite extends FunSuite with Matchers { test("one iteration") { val i = new StubIterator(Buffer(1)) i.hasNext should be === true diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 0aad882ed76a8..1ee936bc78f49 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -140,6 +140,38 @@ class UtilsSuite extends FunSuite { Utils.deleteRecursively(tmpDir2) } + test("reading offset bytes across multiple files") { + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val files = (1 to 3).map(i => new File(tmpDir, i.toString)) + Files.write("0123456789", files(0), Charsets.UTF_8) + Files.write("abcdefghij", files(1), Charsets.UTF_8) + Files.write("ABCDEFGHIJ", files(2), Charsets.UTF_8) + + // Read first few bytes in the 1st file + assert(Utils.offsetBytes(files, 0, 5) === "01234") + + // Read bytes within the 1st file + assert(Utils.offsetBytes(files, 5, 8) === "567") + + // Read bytes across 1st and 2nd file + assert(Utils.offsetBytes(files, 8, 18) === "89abcdefgh") + + // Read bytes across 1st, 2nd and 3rd file + assert(Utils.offsetBytes(files, 5, 24) === "56789abcdefghijABCD") + + // Read some nonexistent bytes in the beginning + assert(Utils.offsetBytes(files, -5, 18) === "0123456789abcdefgh") + + // Read some nonexistent bytes at the end + assert(Utils.offsetBytes(files, 18, 35) === "ijABCDEFGHIJ") + + // Read some nonexistent bytes on both ends + assert(Utils.offsetBytes(files, -5, 35) === "0123456789abcdefghijABCDEFGHIJ") + + Utils.deleteRecursively(tmpDir) + } + test("deserialize long value") { val testval : Long = 9730889947L val bbuf = ByteBuffer.allocate(8) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index b024c89d94d33..6a70877356409 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.util.SizeEstimator -class OpenHashMapSuite extends FunSuite with ShouldMatchers { +class OpenHashMapSuite extends FunSuite with Matchers { test("size for specialized, primitive value (int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index ff4a98f5dcd4a..68a03e3a0970f 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.util.collection import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.util.SizeEstimator -class OpenHashSetSuite extends FunSuite with ShouldMatchers { +class OpenHashSetSuite extends FunSuite with Matchers { test("size for specialized, primitive int") { val loadFactor = 0.7 diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index e3fca173908e9..8c7df7d73dcd3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.util.SizeEstimator -class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers { +class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers { test("size for specialized, primitive key, value (int, int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala new file mode 100644 index 0000000000000..accfe2e9b7f2a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} +import org.scalatest.FunSuite + +class SamplingUtilsSuite extends FunSuite { + + test("computeFraction") { + // test that the computed fraction guarantees enough data points + // in the sample with a failure rate <= 0.0001 + val n = 100000 + + for (s <- 1 to 15) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(20, 100, 1000)) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(1, 10, 100, 1000)) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, false) + val binomial = new BinomialDistribution(n, frac) + assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 0865c6386f7cd..e15fd59a5a8bb 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.util.random import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.util.Utils.times import scala.language.reflectiveCalls -class XORShiftRandomSuite extends FunSuite with ShouldMatchers { +class XORShiftRandomSuite extends FunSuite with Matchers { def fixture = new { val seed = 1L diff --git a/dev/mima b/dev/mima index ab6bd4469b0e8..b68800d6d0173 100755 --- a/dev/mima +++ b/dev/mima @@ -23,6 +23,9 @@ set -o pipefail FWDIR="$(cd `dirname $0`/..; pwd)" cd $FWDIR +echo -e "q\n" | sbt/sbt oldDeps/update + +export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) -printf "%p:" ` ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? @@ -31,5 +34,5 @@ if [ $ret_val != 0 ]; then echo "NOTE: Exceptions to binary compatibility can be added in project/MimaExcludes.scala" fi -rm -f .generated-mima-excludes +rm -f .generated-mima* exit $ret_val diff --git a/docs/configuration.md b/docs/configuration.md index 71fafa573467f..b84104cc7e653 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -784,6 +784,45 @@ Apart from these, the following properties are also available, and may be useful higher memory usage in Spark. + + spark.executor.logs.rolling.strategy + (none) + + Set the strategy of rolling of executor logs. By default it is disabled. It can + be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", + use spark.executor.logs.rolling.time.interval to set the rolling interval. + For "size", use spark.executor.logs.rolling.size.maxBytes to set + the maximum file size for rolling. + + + + spark.executor.logs.rolling.time.interval + daily + + Set the time interval by which the executor logs will be rolled over. + Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles + for automatic cleaning of old logs. + + + + spark.executor.logs.rolling.size.maxBytes + (none) + + Set the max size of the file by which the executor logs will be rolled over. + Rolling is disabled by default. Value is set in terms of bytes. + See spark.executor.logs.rolling.maxRetainedFiles + for automatic cleaning of old logs. + + + + spark.executor.logs.rolling.maxRetainedFiles + (none) + + Sets the number of latest rolling log files that are going to be retained by the system. + Older log files will be deleted. Disabled by default. + + #### Cluster Managers diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 7989e02dfb732..79784682bfd1b 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -890,6 +890,10 @@ for details. reduceByKey(func, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. + + aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. + sortByKey([ascending], [numTasks]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 6eb41e7ba36fb..28e201d279f41 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -50,6 +50,8 @@ object MovieLensALS { numIterations: Int = 20, lambda: Double = 1.0, rank: Int = 10, + numUserBlocks: Int = -1, + numProductBlocks: Int = -1, implicitPrefs: Boolean = false) def main(args: Array[String]) { @@ -67,8 +69,14 @@ object MovieLensALS { .text(s"lambda (smoothing constant), default: ${defaultParams.lambda}") .action((x, c) => c.copy(lambda = x)) opt[Unit]("kryo") - .text(s"use Kryo serialization") + .text("use Kryo serialization") .action((_, c) => c.copy(kryo = true)) + opt[Int]("numUserBlocks") + .text(s"number of user blocks, default: ${defaultParams.numUserBlocks} (auto)") + .action((x, c) => c.copy(numUserBlocks = x)) + opt[Int]("numProductBlocks") + .text(s"number of product blocks, default: ${defaultParams.numProductBlocks} (auto)") + .action((x, c) => c.copy(numProductBlocks = x)) opt[Unit]("implicitPrefs") .text("use implicit preference") .action((_, c) => c.copy(implicitPrefs = true)) @@ -160,6 +168,8 @@ object MovieLensALS { .setIterations(params.numIterations) .setLambda(params.lambda) .setImplicitPrefs(params.implicitPrefs) + .setUserBlocks(params.numUserBlocks) + .setProductBlocks(params.numProductBlocks) .run(training) val rmse = computeRmse(model, test, params.implicitPrefs) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 5be33f1d5c428..ed35e34ad45ab 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -71,12 +71,12 @@ class SparkFlumeEvent() extends Externalizable { for (i <- 0 until numHeaders) { val keyLength = in.readInt() val keyBuff = new Array[Byte](keyLength) - in.read(keyBuff) + in.readFully(keyBuff) val key : String = Utils.deserialize(keyBuff) val valLength = in.readInt() val valBuff = new Array[Byte](valLength) - in.read(valBuff) + in.readFully(valBuff) val value : String = Utils.deserialize(valBuff) headers.put(key, value) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index d743bd7dd1825..cc56fd6ef28d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -61,7 +61,7 @@ private[recommendation] case class InLinkBlock( * A more compact class to represent a rating than Tuple3[Int, Int, Double]. */ @Experimental -case class Rating(val user: Int, val product: Int, val rating: Double) +case class Rating(user: Int, product: Int, rating: Double) /** * Alternating Least Squares matrix factorization. @@ -93,7 +93,8 @@ case class Rating(val user: Int, val product: Int, val rating: Double) * preferences rather than explicit ratings given to items. */ class ALS private ( - private var numBlocks: Int, + private var numUserBlocks: Int, + private var numProductBlocks: Int, private var rank: Int, private var iterations: Int, private var lambda: Double, @@ -106,14 +107,31 @@ class ALS private ( * Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10, * lambda: 0.01, implicitPrefs: false, alpha: 1.0}. */ - def this() = this(-1, 10, 10, 0.01, false, 1.0) + def this() = this(-1, -1, 10, 10, 0.01, false, 1.0) /** - * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured - * number of blocks. Default: -1. + * Set the number of blocks for both user blocks and product blocks to parallelize the computation + * into; pass -1 for an auto-configured number of blocks. Default: -1. */ def setBlocks(numBlocks: Int): ALS = { - this.numBlocks = numBlocks + this.numUserBlocks = numBlocks + this.numProductBlocks = numBlocks + this + } + + /** + * Set the number of user blocks to parallelize the computation. + */ + def setUserBlocks(numUserBlocks: Int): ALS = { + this.numUserBlocks = numUserBlocks + this + } + + /** + * Set the number of product blocks to parallelize the computation. + */ + def setProductBlocks(numProductBlocks: Int): ALS = { + this.numProductBlocks = numProductBlocks this } @@ -176,31 +194,32 @@ class ALS private ( def run(ratings: RDD[Rating]): MatrixFactorizationModel = { val sc = ratings.context - val numBlocks = if (this.numBlocks == -1) { + val numUserBlocks = if (this.numUserBlocks == -1) { math.max(sc.defaultParallelism, ratings.partitions.size / 2) } else { - this.numBlocks + this.numUserBlocks } - - val partitioner = new Partitioner { - val numPartitions = numBlocks - - def getPartition(x: Any): Int = { - Utils.nonNegativeMod(byteswap32(x.asInstanceOf[Int]), numPartitions) - } + val numProductBlocks = if (this.numProductBlocks == -1) { + math.max(sc.defaultParallelism, ratings.partitions.size / 2) + } else { + this.numProductBlocks } - val ratingsByUserBlock = ratings.map{ rating => - (partitioner.getPartition(rating.user), rating) + val userPartitioner = new ALSPartitioner(numUserBlocks) + val productPartitioner = new ALSPartitioner(numProductBlocks) + + val ratingsByUserBlock = ratings.map { rating => + (userPartitioner.getPartition(rating.user), rating) } - val ratingsByProductBlock = ratings.map{ rating => - (partitioner.getPartition(rating.product), + val ratingsByProductBlock = ratings.map { rating => + (productPartitioner.getPartition(rating.product), Rating(rating.product, rating.user, rating.rating)) } - val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock, partitioner) + val (userInLinks, userOutLinks) = + makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, productPartitioner) val (productInLinks, productOutLinks) = - makeLinkRDDs(numBlocks, ratingsByProductBlock, partitioner) + makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, userPartitioner) userInLinks.setName("userInLinks") userOutLinks.setName("userOutLinks") productInLinks.setName("productInLinks") @@ -232,27 +251,27 @@ class ALS private ( users.setName(s"users-$iter").persist() val YtY = Some(sc.broadcast(computeYtY(users))) val previousProducts = products - products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda, - alpha, YtY) + products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, + userPartitioner, rank, lambda, alpha, YtY) previousProducts.unpersist() logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) products.setName(s"products-$iter").persist() val XtX = Some(sc.broadcast(computeYtY(products))) val previousUsers = users - users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda, - alpha, XtX) + users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, + productPartitioner, rank, lambda, alpha, XtX) previousUsers.unpersist() } } else { for (iter <- 1 to iterations) { // perform ALS update logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) - products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda, - alpha, YtY = None) + products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, + userPartitioner, rank, lambda, alpha, YtY = None) products.setName(s"products-$iter") logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) - users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda, - alpha, YtY = None) + users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, + productPartitioner, rank, lambda, alpha, YtY = None) users.setName(s"users-$iter") } } @@ -340,9 +359,10 @@ class ALS private ( /** * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs */ - private def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])], - outLinks: RDD[(Int, OutLinkBlock)]) = { - blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) => + private def unblockFactors( + blockedFactors: RDD[(Int, Array[Array[Double]])], + outLinks: RDD[(Int, OutLinkBlock)]): RDD[(Int, Array[Double])] = { + blockedFactors.join(outLinks).flatMap { case (b, (factors, outLinkBlock)) => for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) } } @@ -351,14 +371,14 @@ class ALS private ( * Make the out-links table for a block of the users (or products) dataset given the list of * (user, product, rating) values for the users in that block (or the opposite for products). */ - private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating], - partitioner: Partitioner): OutLinkBlock = { + private def makeOutLinkBlock(numProductBlocks: Int, ratings: Array[Rating], + productPartitioner: Partitioner): OutLinkBlock = { val userIds = ratings.map(_.user).distinct.sorted val numUsers = userIds.length val userIdToPos = userIds.zipWithIndex.toMap - val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks)) + val shouldSend = Array.fill(numUsers)(new BitSet(numProductBlocks)) for (r <- ratings) { - shouldSend(userIdToPos(r.user))(partitioner.getPartition(r.product)) = true + shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true } OutLinkBlock(userIds, shouldSend) } @@ -367,18 +387,17 @@ class ALS private ( * Make the in-links table for a block of the users (or products) dataset given a list of * (user, product, rating) values for the users in that block (or the opposite for products). */ - private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating], - partitioner: Partitioner): InLinkBlock = { + private def makeInLinkBlock(numProductBlocks: Int, ratings: Array[Rating], + productPartitioner: Partitioner): InLinkBlock = { val userIds = ratings.map(_.user).distinct.sorted - val numUsers = userIds.length val userIdToPos = userIds.zipWithIndex.toMap // Split out our ratings by product block - val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating]) + val blockRatings = Array.fill(numProductBlocks)(new ArrayBuffer[Rating]) for (r <- ratings) { - blockRatings(partitioner.getPartition(r.product)) += r + blockRatings(productPartitioner.getPartition(r.product)) += r } - val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks) - for (productBlock <- 0 until numBlocks) { + val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numProductBlocks) + for (productBlock <- 0 until numProductBlocks) { // Create an array of (product, Seq(Rating)) ratings val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray // Sort them by product ID @@ -400,14 +419,16 @@ class ALS private ( * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. */ - private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)], partitioner: Partitioner) - : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = - { - val grouped = ratings.partitionBy(new HashPartitioner(numBlocks)) + private def makeLinkRDDs( + numUserBlocks: Int, + numProductBlocks: Int, + ratingsByUserBlock: RDD[(Int, Rating)], + productPartitioner: Partitioner): (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = { + val grouped = ratingsByUserBlock.partitionBy(new HashPartitioner(numUserBlocks)) val links = grouped.mapPartitionsWithIndex((blockId, elements) => { - val ratings = elements.map{_._2}.toArray - val inLinkBlock = makeInLinkBlock(numBlocks, ratings, partitioner) - val outLinkBlock = makeOutLinkBlock(numBlocks, ratings, partitioner) + val ratings = elements.map(_._2).toArray + val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner) + val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner) Iterator.single((blockId, (inLinkBlock, outLinkBlock))) }, true) val inLinks = links.mapValues(_._1) @@ -439,26 +460,24 @@ class ALS private ( * It returns an RDD of new feature vectors for each user block. */ private def updateFeatures( + numUserBlocks: Int, products: RDD[(Int, Array[Array[Double]])], productOutLinks: RDD[(Int, OutLinkBlock)], userInLinks: RDD[(Int, InLinkBlock)], - partitioner: Partitioner, + productPartitioner: Partitioner, rank: Int, lambda: Double, alpha: Double, - YtY: Option[Broadcast[DoubleMatrix]]) - : RDD[(Int, Array[Array[Double]])] = - { - val numBlocks = products.partitions.size + YtY: Option[Broadcast[DoubleMatrix]]): RDD[(Int, Array[Array[Double]])] = { productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) => - val toSend = Array.fill(numBlocks)(new ArrayBuffer[Array[Double]]) - for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numBlocks) { + val toSend = Array.fill(numUserBlocks)(new ArrayBuffer[Array[Double]]) + for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numUserBlocks) { if (outLinkBlock.shouldSend(p)(userBlock)) { toSend(userBlock) += factors(p) } } toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } - }.groupByKey(partitioner) + }.groupByKey(productPartitioner) .join(userInLinks) .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY) @@ -475,7 +494,7 @@ class ALS private ( { // Sort the incoming block factor messages by block ID and make them an array val blockFactors = messages.toSeq.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] - val numBlocks = blockFactors.length + val numProductBlocks = blockFactors.length val numUsers = inLinkBlock.elementIds.length // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since @@ -488,9 +507,12 @@ class ALS private ( val tempXtX = DoubleMatrix.zeros(triangleSize) val fullXtX = DoubleMatrix.zeros(rank, rank) + // Count the number of ratings each user gives to provide user-specific regularization + val numRatings = Array.fill(numUsers)(0) + // Compute the XtX and Xy values for each user by adding products it rated in each product // block - for (productBlock <- 0 until numBlocks) { + for (productBlock <- 0 until numProductBlocks) { var p = 0 while (p < blockFactors(productBlock).length) { val x = wrapDoubleArray(blockFactors(productBlock)(p)) @@ -500,6 +522,7 @@ class ALS private ( if (implicitPrefs) { var i = 0 while (i < us.length) { + numRatings(us(i)) += 1 // Extension to the original paper to handle rs(i) < 0. confidence is a function // of |rs(i)| instead so that it is never negative: val confidence = 1 + alpha * abs(rs(i)) @@ -515,6 +538,7 @@ class ALS private ( } else { var i = 0 while (i < us.length) { + numRatings(us(i)) += 1 userXtX(us(i)).addi(tempXtX) SimpleBlas.axpy(rs(i), x, userXy(us(i))) i += 1 @@ -531,9 +555,10 @@ class ALS private ( // Compute the full XtX matrix from the lower-triangular part we got above fillFullMatrix(userXtX(index), fullXtX) // Add regularization + val regParam = numRatings(index) * lambda var i = 0 while (i < rank) { - fullXtX.data(i * rank + i) += lambda + fullXtX.data(i * rank + i) += regParam i += 1 } // Solve the resulting matrix, which is symmetric and positive-definite @@ -579,6 +604,23 @@ class ALS private ( } } +/** + * Partitioner for ALS. + */ +private[recommendation] class ALSPartitioner(override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = { + Utils.nonNegativeMod(byteswap32(key.asInstanceOf[Int]), numPartitions) + } + + override def equals(obj: Any): Boolean = { + obj match { + case p: ALSPartitioner => + this.numPartitions == p.numPartitions + case _ => + false + } + } +} /** * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. @@ -606,7 +648,7 @@ object ALS { blocks: Int, seed: Long ): MatrixFactorizationModel = { - new ALS(blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings) + new ALS(blocks, blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings) } /** @@ -629,7 +671,7 @@ object ALS { lambda: Double, blocks: Int ): MatrixFactorizationModel = { - new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings) + new ALS(blocks, blocks, rank, iterations, lambda, false, 1.0).run(ratings) } /** @@ -689,7 +731,7 @@ object ALS { alpha: Double, seed: Long ): MatrixFactorizationModel = { - new ALS(blocks, rank, iterations, lambda, true, alpha, seed).run(ratings) + new ALS(blocks, blocks, rank, iterations, lambda, true, alpha, seed).run(ratings) } /** @@ -714,7 +756,7 @@ object ALS { blocks: Int, alpha: Double ): MatrixFactorizationModel = { - new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings) + new ALS(blocks, blocks, rank, iterations, lambda, true, alpha).run(ratings) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 4d7b984e3ec29..44b757b6a1fb7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import scala.collection.JavaConversions._ import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ @@ -56,7 +56,7 @@ object LogisticRegressionSuite { } } -class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers { +class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 8a16284118cf7..951b4f7c6e6f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import scala.collection.JavaConversions._ import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.LocalSparkContext @@ -61,7 +61,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers { +class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 820eca9b1bf65..4b1850659a18e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.mllib.optimization import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext -class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { +class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val nPoints = 10000 val A = 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 37c9b9d085841..81bebec8c7a39 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -121,6 +121,10 @@ class ALSSuite extends FunSuite with LocalSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, true, false, true) } + test("rank-2 matrices with different user and product blocks") { + testALS(100, 200, 2, 15, 0.7, 0.4, numUserBlocks = 4, numProductBlocks = 2) + } + test("pseudorandomness") { val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2) val model11 = ALS.train(ratings, 5, 1, 1.0, 2, 1) @@ -153,35 +157,52 @@ class ALSSuite extends FunSuite with LocalSparkContext { } test("NNALS, rank 2") { - testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, false) + testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false) } /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * - * @param users number of users - * @param products number of products - * @param features number of features (rank of problem) - * @param iterations number of iterations to run - * @param samplingRate what fraction of the user-product pairs are known + * @param users number of users + * @param products number of products + * @param features number of features (rank of problem) + * @param iterations number of iterations to run + * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct - * @param implicitPrefs flag to test implicit feedback - * @param bulkPredict flag to test bulk prediciton + * @param implicitPrefs flag to test implicit feedback + * @param bulkPredict flag to test bulk prediciton * @param negativeWeights whether the generated data can contain negative values - * @param numBlocks number of blocks to partition users and products into + * @param numUserBlocks number of user blocks to partition users into + * @param numProductBlocks number of product blocks to partition products into * @param negativeFactors whether the generated user/product factors can have negative entries */ - def testALS(users: Int, products: Int, features: Int, iterations: Int, - samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false, - bulkPredict: Boolean = false, negativeWeights: Boolean = false, numBlocks: Int = -1, - negativeFactors: Boolean = true) - { + def testALS( + users: Int, + products: Int, + features: Int, + iterations: Int, + samplingRate: Double, + matchThreshold: Double, + implicitPrefs: Boolean = false, + bulkPredict: Boolean = false, + negativeWeights: Boolean = false, + numUserBlocks: Int = -1, + numProductBlocks: Int = -1, + negativeFactors: Boolean = true) { val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, features, samplingRate, implicitPrefs, negativeWeights, negativeFactors) - val model = (new ALS().setBlocks(numBlocks).setRank(features).setIterations(iterations) - .setAlpha(1.0).setImplicitPrefs(implicitPrefs).setLambda(0.01).setSeed(0L) - .setNonnegative(!negativeFactors).run(sc.parallelize(sampledRatings))) + val model = new ALS() + .setUserBlocks(numUserBlocks) + .setProductBlocks(numProductBlocks) + .setRank(features) + .setIterations(iterations) + .setAlpha(1.0) + .setImplicitPrefs(implicitPrefs) + .setLambda(0.01) + .setSeed(0L) + .setNonnegative(!negativeFactors) + .run(sc.parallelize(sampledRatings)) val predictedU = new DoubleMatrix(users, features) for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) { @@ -208,8 +229,9 @@ class ALSSuite extends FunSuite with LocalSparkContext { val prediction = predictedRatings.get(u, p) val correct = trueRatings.get(u, p) if (math.abs(prediction - correct) > matchThreshold) { - fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( - u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) + fail(("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s") + .format(u, p, correct, prediction, trueRatings, predictedRatings, predictedU, + predictedP)) } } } else { diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 1477809943573..bb2d73741c3bf 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -15,16 +15,26 @@ * limitations under the License. */ -import com.typesafe.tools.mima.core.{MissingTypesProblem, MissingClassProblem, ProblemFilters} +import com.typesafe.tools.mima.core._ +import com.typesafe.tools.mima.core.MissingClassProblem +import com.typesafe.tools.mima.core.MissingTypesProblem import com.typesafe.tools.mima.core.ProblemFilters._ import com.typesafe.tools.mima.plugin.MimaKeys.{binaryIssueFilters, previousArtifact} import com.typesafe.tools.mima.plugin.MimaPlugin.mimaDefaultSettings import sbt._ object MimaBuild { + + def excludeMember(fullName: String) = Seq( + ProblemFilters.exclude[MissingMethodProblem](fullName), + ProblemFilters.exclude[MissingFieldProblem](fullName), + ProblemFilters.exclude[IncompatibleResultTypeProblem](fullName), + ProblemFilters.exclude[IncompatibleMethTypeProblem](fullName), + ProblemFilters.exclude[IncompatibleFieldTypeProblem](fullName) + ) + // Exclude a single class and its corresponding object - def excludeClass(className: String) = { - Seq( + def excludeClass(className: String) = Seq( excludePackage(className), ProblemFilters.exclude[MissingClassProblem](className), ProblemFilters.exclude[MissingTypesProblem](className), @@ -32,7 +42,7 @@ object MimaBuild { ProblemFilters.exclude[MissingClassProblem](className + "$"), ProblemFilters.exclude[MissingTypesProblem](className + "$") ) - } + // Exclude a Spark class, that is in the package org.apache.spark def excludeSparkClass(className: String) = { excludeClass("org.apache.spark." + className) @@ -49,20 +59,25 @@ object MimaBuild { val defaultExcludes = Seq() // Read package-private excludes from file - val excludeFilePath = (base.getAbsolutePath + "/.generated-mima-excludes") - val excludeFile = file(excludeFilePath) + val classExcludeFilePath = file(base.getAbsolutePath + "/.generated-mima-class-excludes") + val memberExcludeFilePath = file(base.getAbsolutePath + "/.generated-mima-member-excludes") + val ignoredClasses: Seq[String] = - if (!excludeFile.exists()) { + if (!classExcludeFilePath.exists()) { Seq() } else { - IO.read(excludeFile).split("\n") + IO.read(classExcludeFilePath).split("\n") } + val ignoredMembers: Seq[String] = + if (!memberExcludeFilePath.exists()) { + Seq() + } else { + IO.read(memberExcludeFilePath).split("\n") + } - - val externalExcludeFileClasses = ignoredClasses.flatMap(excludeClass) - - defaultExcludes ++ externalExcludeFileClasses ++ MimaExcludes.excludes + defaultExcludes ++ ignoredClasses.flatMap(excludeClass) ++ + ignoredMembers.flatMap(excludeMember) ++ MimaExcludes.excludes } def mimaSettings(sparkHome: File) = mimaDefaultSettings ++ Seq( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd7efceb23c96..042fdfcc47261 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -52,11 +52,27 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1") + "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.Entry"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$" + + "createZero$1") + ) ++ + Seq( // Ignore some private methods in ALS. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. + "org.apache.spark.mllib.recommendation.ALS.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7") ) ++ MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ - MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") + MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ + MimaBuild.excludeSparkClass("storage.Values") ++ + MimaBuild.excludeSparkClass("storage.Entry") ++ + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") case v if v.startsWith("1.0") => Seq( MimaBuild.excludeSparkPackage("api.java"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 069913dbaac56..2d60a44f04f6f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -59,6 +59,10 @@ object SparkBuild extends Build { lazy val core = Project("core", file("core"), settings = coreSettings) + /** Following project only exists to pull previous artifacts of Spark for generating + Mima ignores. For more information see: SPARK 2071 */ + lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings) + def replDependencies = Seq[ProjectReference](core, graphx, bagel, mllib, sql) ++ maybeHiveRef lazy val repl = Project("repl", file("repl"), settings = replSettings) @@ -86,7 +90,16 @@ object SparkBuild extends Build { lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings) .dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeHive: _*) dependsOn(maybeGanglia: _*) - lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects") + lazy val assembleDepsTask = TaskKey[Unit]("assemble-deps") + lazy val assembleDeps = assembleDepsTask := { + println() + println("**** NOTE ****") + println("'sbt/sbt assemble-deps' is no longer supported.") + println("Instead create a normal assembly and:") + println(" export SPARK_PREPEND_CLASSES=1 (toggle on)") + println(" unset SPARK_PREPEND_CLASSES (toggle off)") + println() + } // A configuration to set an alternative publishLocalConfiguration lazy val MavenCompile = config("m2r") extend(Compile) @@ -336,6 +349,7 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "com.google.guava" % "guava" % "14.0.1", "org.apache.commons" % "commons-lang3" % "3.3.2", + "org.apache.commons" % "commons-math3" % "3.3" % "test", "com.google.code.findbugs" % "jsr305" % "1.3.9", "log4j" % "log4j" % "1.2.17", "org.slf4j" % "slf4j-api" % slf4jVersion, @@ -369,6 +383,7 @@ object SparkBuild extends Build { "net.sf.py4j" % "py4j" % "0.8.1" ), libraryDependencies ++= maybeAvro, + assembleDeps, previousArtifact := sparkPreviousArtifact("spark-core") ) @@ -580,9 +595,7 @@ object SparkBuild extends Build { def assemblyProjSettings = sharedSettings ++ Seq( name := "spark-assembly", - assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn, - jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }, - jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" } + jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" } ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq( @@ -598,6 +611,17 @@ object SparkBuild extends Build { } ) + def oldDepsSettings() = Defaults.defaultSettings ++ Seq( + name := "old-deps", + scalaVersion := "2.10.4", + retrieveManaged := true, + retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", + libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", + "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", + "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", + "spark-core").map(sparkPreviousArtifact(_).get intransitive()) + ) + def twitterSettings() = sharedSettings ++ Seq( name := "spark-streaming-twitter", previousArtifact := sparkPreviousArtifact("spark-streaming-twitter"), diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index a411a5d5914e0..e609b60a0f968 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -454,7 +454,7 @@ def _squared_distance(v1, v2): v2 = _convert_vector(v2) if type(v1) == ndarray and type(v2) == ndarray: diff = v1 - v2 - return diff.dot(diff) + return numpy.dot(diff, diff) elif type(v1) == ndarray: return v2.squared_distance(v1) else: @@ -469,10 +469,12 @@ def _dot(vec, target): calling numpy.dot of the two vectors, but for SciPy ones, we have to transpose them because they're column vectors. """ - if type(vec) == ndarray or type(vec) == SparseVector: + if type(vec) == ndarray: + return numpy.dot(vec, target) + elif type(vec) == SparseVector: return vec.dot(target) elif type(vec) == list: - return _convert_vector(vec).dot(target) + return numpy.dot(_convert_vector(vec), target) else: return vec.transpose().dot(target)[0] diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9c69c79236edc..ddd22850a819c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -31,6 +31,7 @@ import warnings import heapq from random import Random +from math import sqrt, log from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ @@ -202,9 +203,9 @@ def cache(self): def persist(self, storageLevel): """ - Set this RDD's storage level to persist its values across operations after the first time - it is computed. This can only be used to assign a new storage level if the RDD does not - have a storage level set yet. + Set this RDD's storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. """ self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) @@ -213,7 +214,8 @@ def persist(self, storageLevel): def unpersist(self): """ - Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + Mark the RDD as non-persistent, and remove all blocks for it from + memory and disk. """ self.is_cached = False self._jrdd.unpersist() @@ -357,48 +359,87 @@ def sample(self, withReplacement, fraction, seed=None): # this is ported from scala/spark/RDD.scala def takeSample(self, withReplacement, num, seed=None): """ - Return a fixed-size sampled subset of this RDD (currently requires numpy). + Return a fixed-size sampled subset of this RDD (currently requires + numpy). - >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP - [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] + >>> rdd = sc.parallelize(range(0, 10)) + >>> len(rdd.takeSample(True, 20, 1)) + 20 + >>> len(rdd.takeSample(False, 5, 2)) + 5 + >>> len(rdd.takeSample(False, 15, 3)) + 10 """ + numStDev = 10.0 + + if num < 0: + raise ValueError("Sample size cannot be negative.") + elif num == 0: + return [] - fraction = 0.0 - total = 0 - multiplier = 3.0 initialCount = self.count() - maxSelected = 0 + if initialCount == 0: + return [] - if (num < 0): - raise ValueError + rand = Random(seed) - if (initialCount == 0): - return list() + if (not withReplacement) and num >= initialCount: + # shuffle current RDD and return + samples = self.collect() + rand.shuffle(samples) + return samples - if initialCount > sys.maxint - 1: - maxSelected = sys.maxint - 1 - else: - maxSelected = initialCount - - if num > initialCount and not withReplacement: - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - else: - fraction = multiplier * (num + 1) / initialCount - total = num + maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + if num > maxSampleSize: + raise ValueError("Sample size cannot be greater than %d." % maxSampleSize) + fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement) samples = self.sample(withReplacement, fraction, seed).collect() # If the first sample didn't turn out large enough, keep trying to take samples; # this shouldn't happen often because we use a big multiplier for their initial size. # See: scala/spark/RDD.scala - rand = Random(seed) - while len(samples) < total: - samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect() - - sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint)) - sampler.shuffle(samples) - return samples[0:total] + while len(samples) < num: + # TODO: add log warning for when more than one iteration was run + seed = rand.randint(0, sys.maxint) + samples = self.sample(withReplacement, fraction, seed).collect() + + rand.shuffle(samples) + + return samples[0:num] + + @staticmethod + def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement): + """ + Returns a sampling rate that guarantees a sample of + size >= sampleSizeLowerBound 99.99% of the time. + + How the sampling rate is determined: + Let p = num / total, where num is the sample size and total is the + total number of data points in the RDD. We're trying to compute + q > p such that + - when sampling with replacement, we're drawing each data point + with prob_i ~ Pois(q), where we want to guarantee + Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to + total), i.e. the failure rate of not having a sufficiently large + sample < 0.0001. Setting q = p + 5 * sqrt(p/total) is sufficient + to guarantee 0.9999 success rate for num > 12, but we need a + slightly larger q (9 empirically determined). + - when sampling without replacement, we're drawing each data point + with prob_i ~ Binomial(total, fraction) and our choice of q + guarantees 1-delta, or 0.9999 success rate, where success rate is + defined the same as in sampling with replacement. + """ + fraction = float(sampleSizeLowerBound) / total + if withReplacement: + numStDev = 5 + if (sampleSizeLowerBound < 12): + numStDev = 9 + return fraction + numStDev * sqrt(fraction / total) + else: + delta = 0.00005 + gamma = - log(delta) / total + return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * fraction)) def union(self, other): """ @@ -422,8 +463,8 @@ def union(self, other): def intersection(self, other): """ - Return the intersection of this RDD and another one. The output will not - contain any duplicate elements, even if the input RDDs did. + Return the intersection of this RDD and another one. The output will + not contain any duplicate elements, even if the input RDDs did. Note that this method performs a shuffle internally. @@ -665,8 +706,8 @@ def aggregate(self, zeroValue, seqOp, combOp): modify C{t2}. The first function (seqOp) can return a different result type, U, than - the type of this RDD. Thus, we need one operation for merging a T into an U - and one operation for merging two U + the type of this RDD. Thus, we need one operation for merging a T into + an U and one operation for merging two U >>> seqOp = (lambda x, y: (x[0] + y, x[1] + 1)) >>> combOp = (lambda x, y: (x[0] + y[0], x[1] + y[1])) @@ -695,7 +736,7 @@ def max(self): def min(self): """ - Find the maximum item in this RDD. + Find the minimum item in this RDD. >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min() 1.0 @@ -759,8 +800,9 @@ def stdev(self): def sampleStdev(self): """ - Compute the sample standard deviation of this RDD's elements (which corrects for bias in - estimating the standard deviation by dividing by N-1 instead of N). + Compute the sample standard deviation of this RDD's elements (which + corrects for bias in estimating the standard deviation by dividing by + N-1 instead of N). >>> sc.parallelize([1, 2, 3]).sampleStdev() 1.0 @@ -769,8 +811,8 @@ def sampleStdev(self): def sampleVariance(self): """ - Compute the sample variance of this RDD's elements (which corrects for bias in - estimating the variance by dividing by N-1 instead of N). + Compute the sample variance of this RDD's elements (which corrects + for bias in estimating the variance by dividing by N-1 instead of N). >>> sc.parallelize([1, 2, 3]).sampleVariance() 1.0 @@ -822,8 +864,8 @@ def merge(a, b): def takeOrdered(self, num, key=None): """ - Get the N elements from a RDD ordered in ascending order or as specified - by the optional key function. + Get the N elements from a RDD ordered in ascending order or as + specified by the optional key function. >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) [1, 2, 3, 4, 5, 6] @@ -912,8 +954,9 @@ def first(self): def saveAsPickleFile(self, path, batchSize=10): """ - Save this RDD as a SequenceFile of serialized objects. The serializer used is - L{pyspark.serializers.PickleSerializer}, default batch size is 10. + Save this RDD as a SequenceFile of serialized objects. The serializer + used is L{pyspark.serializers.PickleSerializer}, default batch size + is 10. >>> tmpFile = NamedTemporaryFile(delete=True) >>> tmpFile.close() @@ -1178,19 +1221,37 @@ def _mergeCombiners(iterator): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) + + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): + """ + Aggregate the values of each key, using given combine functions and a neutral "zero value". + This function can return a different result type, U, than the type of the values in this RDD, + V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + The former operation is used for merging values within a partition, and the latter is used + for merging values between partitions. To avoid memory allocation, both of these functions are + allowed to modify and return their first argument instead of creating a new U. + """ + def createZero(): + return copy.deepcopy(zeroValue) + + return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) def foldByKey(self, zeroValue, func, numPartitions=None): """ - Merge the values for each key using an associative function "func" and a neutral "zeroValue" - which may be added to the result an arbitrary number of times, and must not change - the result (e.g., 0 for addition, or 1 for multiplication.). + Merge the values for each key using an associative function "func" + and a neutral "zeroValue" which may be added to the result an + arbitrary number of times, and must not change the result + (e.g., 0 for addition, or 1 for multiplication.). >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> from operator import add >>> rdd.foldByKey(0, add).collect() [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions) + def createZero(): + return copy.deepcopy(zeroValue) + + return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) # TODO: support variant with custom partitioner @@ -1200,8 +1261,8 @@ def groupByKey(self, numPartitions=None): Hash-partitions the resulting RDD with into numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey will provide much better - performance. + sum or average) over each key, using reduceByKey will provide much + better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) @@ -1261,8 +1322,8 @@ def groupWith(self, other): def cogroup(self, other, numPartitions=None): """ For each key k in C{self} or C{other}, return a resulting RDD that - contains a tuple with the list of values for that key in C{self} as well - as C{other}. + contains a tuple with the list of values for that key in C{self} as + well as C{other}. >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) @@ -1273,8 +1334,8 @@ def cogroup(self, other, numPartitions=None): def subtractByKey(self, other, numPartitions=None): """ - Return each (key, value) pair in C{self} that has no pair with matching key - in C{other}. + Return each (key, value) pair in C{self} that has no pair with matching + key in C{other}. >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)]) >>> y = sc.parallelize([("a", 3), ("c", None)]) @@ -1312,10 +1373,10 @@ def repartition(self, numPartitions): """ Return a new RDD that has exactly numPartitions partitions. - Can increase or decrease the level of parallelism in this RDD. Internally, this uses - a shuffle to redistribute data. - If you are decreasing the number of partitions in this RDD, consider using `coalesce`, - which can avoid performing a shuffle. + Can increase or decrease the level of parallelism in this RDD. + Internally, this uses a shuffle to redistribute data. + If you are decreasing the number of partitions in this RDD, consider + using `coalesce`, which can avoid performing a shuffle. >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4) >>> sorted(rdd.glom().collect()) [[1], [2, 3], [4, 5], [6, 7]] @@ -1340,9 +1401,10 @@ def coalesce(self, numPartitions, shuffle=False): def zip(self, other): """ - Zips this RDD with another one, returning key-value pairs with the first element in each RDD - second element in each RDD, etc. Assumes that the two RDDs have the same number of - partitions and the same number of elements in each partition (e.g. one was made through + Zips this RDD with another one, returning key-value pairs with the + first element in each RDD second element in each RDD, etc. Assumes + that the two RDDs have the same number of partitions and the same + number of elements in each partition (e.g. one was made through a map on the other). >>> x = sc.parallelize(range(0,5)) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index b4e9618cc25b5..960d0a82448aa 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -117,7 +117,7 @@ def parquetFile(self, path): >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.saveAsParquetFile(parquetFile) >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> srdd.collect() == srdd2.collect() + >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ jschema_rdd = self._ssql_ctx.parquetFile(path) @@ -141,7 +141,7 @@ def table(self, tableName): >>> srdd = sqlCtx.inferSchema(rdd) >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.table("table1") - >>> srdd.collect() == srdd2.collect() + >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ return SchemaRDD(self._ssql_ctx.table(tableName), self) @@ -293,7 +293,7 @@ def saveAsParquetFile(self, path): >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.saveAsParquetFile(parquetFile) >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> srdd2.collect() == srdd.collect() + >>> sorted(srdd2.collect()) == sorted(srdd.collect()) True """ self._jschema_rdd.saveAsParquetFile(path) @@ -307,7 +307,7 @@ def registerAsTable(self, name): >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.registerAsTable("test") >>> srdd2 = sqlCtx.sql("select * from test") - >>> srdd.collect() == srdd2.collect() + >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ self._jschema_rdd.registerAsTable(name) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 184ee810b861b..c15bb457759ed 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -188,6 +188,21 @@ def test_deleting_input_files(self): os.unlink(tempFile.name) self.assertRaises(Exception, lambda: filtered_data.count()) + def testAggregateByKey(self): + data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) + def seqOp(x, y): + x.add(y) + return x + + def combOp(x, y): + x |= y + return x + + sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) + self.assertEqual(3, len(sets)) + self.assertEqual(set([1]), sets[1]) + self.assertEqual(set([2]), sets[3]) + self.assertEqual(set([1, 3]), sets[5]) class TestIO(PySparkTestCase): diff --git a/python/run-tests b/python/run-tests index 3b4501178c89f..9282aa47e8375 100755 --- a/python/run-tests +++ b/python/run-tests @@ -44,7 +44,6 @@ function run_test() { echo -en "\033[0m" # No color exit -1 fi - } echo "Running PySpark tests. Output is in python/unit-tests.log." @@ -55,9 +54,13 @@ run_test "pyspark/conf.py" if [ -n "$_RUN_SQL_TESTS" ]; then run_test "pyspark/sql.py" fi +# These tests are included in the module-level docs, and so must +# be handled on a higher level rather than within the python file. +export PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" +unset PYSPARK_DOC_TEST run_test "pyspark/tests.py" run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 36758f3114e59..46fcfbb9e26ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -111,6 +111,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AVG = Keyword("AVG") protected val BY = Keyword("BY") + protected val CACHE = Keyword("CACHE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") @@ -149,7 +150,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SEMI = Keyword("SEMI") protected val STRING = Keyword("STRING") protected val SUM = Keyword("SUM") + protected val TABLE = Keyword("TABLE") protected val TRUE = Keyword("TRUE") + protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") protected val WHERE = Keyword("WHERE") @@ -189,7 +192,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert + | insert | cache ) protected lazy val select: Parser[LogicalPlan] = @@ -220,6 +223,11 @@ class SqlParser extends StandardTokenParsers with PackratParsers { InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) } + protected lazy val cache: Parser[LogicalPlan] = + (CACHE ^^^ true | UNCACHE ^^^ false) ~ TABLE ~ ident ^^ { + case doCache ~ _ ~ tableName => CacheCommand(tableName, doCache) + } + protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") protected lazy val projection: Parser[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3cf163f9a9a75..d177339d40ae5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -175,6 +175,8 @@ package object dsl { def where(condition: Expression) = Filter(condition, logicalPlan) + def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan) + def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 420303408451f..c074b7bb01e57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -76,7 +76,8 @@ trait CaseConversionExpression { type EvaluatedType = Any def convert(v: String): String - + + override def foldable: Boolean = child.foldable def nullable: Boolean = child.nullable def dataType: DataType = StringType @@ -142,6 +143,8 @@ case class RLike(left: Expression, right: Expression) case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { override def convert(v: String): String = v.toUpperCase() + + override def toString() = s"Upper($child)" } /** @@ -150,4 +153,6 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { override def convert(v: String): String = v.toLowerCase() + + override def toString() = s"Lower($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ccb8245cc2e7d..25a347bec0e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -29,12 +29,15 @@ import org.apache.spark.sql.catalyst.types._ object Optimizer extends RuleExecutor[LogicalPlan] { val batches = + Batch("Combine Limits", FixedPoint(100), + CombineLimits) :: Batch("ConstantFolding", FixedPoint(100), NullPropagation, ConstantFolding, BooleanSimplification, SimplifyFilters, - SimplifyCasts) :: + SimplifyCasts, + SimplifyCaseConversionExpressions) :: Batch("Filter Pushdown", FixedPoint(100), CombineFilters, PushPredicateThroughProject, @@ -104,8 +107,8 @@ object ColumnPruning extends Rule[LogicalPlan] { object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ Count(Literal(null, _)) => Literal(0, e.dataType) - case e @ Sum(Literal(c, _)) if c == 0 => Literal(0, e.dataType) + case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) + case e @ Sum(Literal(c, _)) if c == 0 => Cast(Literal(0L), e.dataType) case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType) case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType) case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType) @@ -130,18 +133,6 @@ object NullPropagation extends Rule[LogicalPlan] { case Literal(candidate, _) if candidate == v => true case _ => false })) => Literal(true, BooleanType) - case e: UnaryMinus => e.child match { - case Literal(null, _) => Literal(null, e.dataType) - case _ => e - } - case e: Cast => e.child match { - case Literal(null, _) => Literal(null, e.dataType) - case _ => e - } - case e: Not => e.child match { - case Literal(null, _) => Literal(null, e.dataType) - case _ => e - } // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) @@ -362,3 +353,29 @@ object SimplifyCasts extends Rule[LogicalPlan] { case Cast(e, dataType) if e.dataType == dataType => e } } + +/** + * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the + * expressions into one single expression. + */ +object CombineLimits extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ll @ Limit(le, nl @ Limit(ne, grandChild)) => + Limit(If(LessThan(ne, le), ne, le), grandChild) + } +} + +/** + * Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because + * the inner conversion is overwritten by the outer one. + */ +object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case Upper(Upper(child)) => Upper(child) + case Upper(Lower(child)) => Upper(child) + case Lower(Upper(child)) => Lower(child) + case Lower(Lower(child)) => Lower(child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7eeb98aea6368..0933a31c362d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.types.{StringType, StructType} +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees abstract class LogicalPlan extends QueryPlan[LogicalPlan] { @@ -96,39 +96,6 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { def references = Set.empty } -/** - * A logical node that represents a non-query command to be executed by the system. For example, - * commands can be used by parsers to represent DDL operations. - */ -abstract class Command extends LeafNode { - self: Product => - def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this -} - -/** - * Returned for commands supported by a given parser, but not catalyst. In general these are DDL - * commands that are passed directly to another system. - */ -case class NativeCommand(cmd: String) extends Command - -/** - * Commands of the form "SET (key) (= value)". - */ -case class SetCommand(key: Option[String], value: Option[String]) extends Command { - override def output = Seq( - AttributeReference("key", StringType, nullable = false)(), - AttributeReference("value", StringType, nullable = false)() - ) -} - -/** - * Returned by a parser when the users only wants to see what query plan would be executed, without - * actually performing the execution. - */ -case class ExplainCommand(plan: LogicalPlan) extends Command { - override def output = Seq(AttributeReference("plan", StringType, nullable = false)()) -} - /** * A logical plan node with single child. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d3347b622f3d8..b777cf4249196 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -135,9 +135,9 @@ case class Aggregate( def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } -case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode { +case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { def output = child.output - def references = limit.references + def references = limitExpr.references } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala new file mode 100644 index 0000000000000..3299e86b85941 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference} +import org.apache.spark.sql.catalyst.types.StringType + +/** + * A logical node that represents a non-query command to be executed by the system. For example, + * commands can be used by parsers to represent DDL operations. + */ +abstract class Command extends LeafNode { + self: Product => + def output: Seq[Attribute] = Seq.empty +} + +/** + * Returned for commands supported by a given parser, but not catalyst. In general these are DDL + * commands that are passed directly to another system. + */ +case class NativeCommand(cmd: String) extends Command { + override def output = + Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)())) +} + +/** + * Commands of the form "SET (key) (= value)". + */ +case class SetCommand(key: Option[String], value: Option[String]) extends Command { + override def output = Seq( + BoundReference(0, AttributeReference("key", StringType, nullable = false)()), + BoundReference(1, AttributeReference("value", StringType, nullable = false)())) +} + +/** + * Returned by a parser when the users only wants to see what query plan would be executed, without + * actually performing the execution. + */ +case class ExplainCommand(plan: LogicalPlan) extends Command { + override def output = + Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)())) +} + +/** + * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. + */ +case class CacheCommand(tableName: String, doCache: Boolean) extends Command diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala new file mode 100644 index 0000000000000..714f01843c0f5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class CombiningLimitsSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Combine Limit", FixedPoint(2), + CombineLimits) :: + Batch("Constant Folding", FixedPoint(3), + NullPropagation, + ConstantFolding, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("limits: combines two limits") { + val originalQuery = + testRelation + .select('a) + .limit(10) + .limit(5) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(5).analyze + + comparePlans(optimized, correctAnswer) + } + + test("limits: combines three limits") { + val originalQuery = + testRelation + .select('a) + .limit(2) + .limit(7) + .limit(5) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 20dfba847790c..6efc0e211eb21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType} +import org.apache.spark.sql.catalyst.types._ // For implicit conversions import org.apache.spark.sql.catalyst.dsl.plans._ @@ -173,4 +173,63 @@ class ConstantFoldingSuite extends OptimizerTest { comparePlans(optimized, correctAnswer) } + + test("Constant folding test: expressions have null literals") { + val originalQuery = + testRelation + .select( + IsNull(Literal(null)) as 'c1, + IsNotNull(Literal(null)) as 'c2, + + GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, + GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, + GetField( + Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), + "a") as 'c5, + + UnaryMinus(Literal(null, IntegerType)) as 'c6, + Cast(Literal(null), IntegerType) as 'c7, + Not(Literal(null, BooleanType)) as 'c8, + + Add(Literal(null, IntegerType), 1) as 'c9, + Add(1, Literal(null, IntegerType)) as 'c10, + + Equals(Literal(null, IntegerType), 1) as 'c11, + Equals(1, Literal(null, IntegerType)) as 'c12, + + Like(Literal(null, StringType), "abc") as 'c13, + Like("abc", Literal(null, StringType)) as 'c14, + + Upper(Literal(null, StringType)) as 'c15) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(true) as 'c1, + Literal(false) as 'c2, + + Literal(null, IntegerType) as 'c3, + Literal(null, IntegerType) as 'c4, + Literal(null, IntegerType) as 'c5, + + Literal(null, IntegerType) as 'c6, + Literal(null, IntegerType) as 'c7, + Literal(null, BooleanType) as 'c8, + + Literal(null, IntegerType) as 'c9, + Literal(null, IntegerType) as 'c10, + + Literal(null, BooleanType) as 'c11, + Literal(null, BooleanType) as 'c12, + + Literal(null, BooleanType) as 'c13, + Literal(null, BooleanType) as 'c14, + + Literal(null, StringType) as 'c15) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 02cc665f8a8c7..1f67c80e54906 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -20,14 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.RightOuter import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.junit.Test class FilterPushdownSuite extends OptimizerTest { @@ -164,7 +161,7 @@ class FilterPushdownSuite extends OptimizerTest { comparePlans(optimized, correctAnswer) } - + test("joins: push down left outer join #1") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala new file mode 100644 index 0000000000000..df1409fe7baee --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +class SimplifyCaseConversionExpressionsSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Simplify CaseConversionExpressions", Once, + SimplifyCaseConversionExpressions) :: Nil + } + + val testRelation = LocalRelation('a.string) + + test("simplify UPPER(UPPER(str))") { + val originalQuery = + testRelation + .select(Upper(Upper('a)) as 'u) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Upper('a) as 'u) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify UPPER(LOWER(str))") { + val originalQuery = + testRelation + .select(Upper(Lower('a)) as 'u) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Upper('a) as 'u) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify LOWER(UPPER(str))") { + val originalQuery = + testRelation + .select(Lower(Upper('a)) as 'l) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .select(Lower('a) as 'l) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify LOWER(LOWER(str))") { + val originalQuery = + testRelation + .select(Lower(Lower('a)) as 'l) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .select(Lower('a) as 'l) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 021e0e8245a0d..378ff54531118 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,10 +31,10 @@ import org.apache.spark.sql.catalyst.{ScalaReflection, dsl} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies @@ -147,14 +147,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def sql(sqlText: String): SchemaRDD = { - val result = new SchemaRDD(this, parseSql(sqlText)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but do not perform any execution. - result.queryExecution.toRdd - result - } + def sql(sqlText: String): SchemaRDD = new SchemaRDD(this, parseSql(sqlText)) /** Returns the specified table as a SchemaRDD */ def table(tableName: String): SchemaRDD = @@ -166,10 +159,9 @@ class SQLContext(@transient val sparkContext: SparkContext) val useCompression = sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false) val asInMemoryRelation = - InMemoryColumnarTableScan( - currentTable.output, executePlan(currentTable).executedPlan, useCompression) + InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) - catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation)) + catalog.registerTable(None, tableName, asInMemoryRelation) } /** Removes the specified table from the in-memory cache. */ @@ -177,17 +169,26 @@ class SQLContext(@transient val sparkContext: SparkContext) EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match { // This is kind of a hack to make sure that if this was just an RDD registered as a table, // we reregister the RDD as a table. - case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd, _)) => + case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) catalog.registerTable(None, tableName, SparkLogicalPlan(e)) - case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) => + case inMem: InMemoryRelation => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan") } } + /** Returns true if the table is currently cached in-memory. */ + def isCached(tableName: String): Boolean = { + val relation = catalog.lookupRelation(None, tableName) + EliminateAnalysisOperators(relation) match { + case _: InMemoryRelation => true + case _ => false + } + } + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext = self.sparkContext @@ -199,6 +200,7 @@ class SQLContext(@transient val sparkContext: SparkContext) PartialAggregation :: LeftSemiJoin :: HashJoin :: + InMemoryScans :: ParquetOperations :: BasicOperators :: CartesianProduct :: @@ -250,8 +252,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] val planner = new SparkPlanner @transient - protected[sql] lazy val emptyResult = - sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) /** * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and @@ -271,22 +272,6 @@ class SQLContext(@transient val sparkContext: SparkContext) protected abstract class QueryExecution { def logical: LogicalPlan - def eagerlyProcess(plan: LogicalPlan): RDD[Row] = plan match { - case SetCommand(key, value) => - // Only this case needs to be executed eagerly. The other cases will - // be taken care of when the actual results are being extracted. - // In the case of HiveContext, sqlConf is overridden to also pass the - // pair into its HiveConf. - if (key.isDefined && value.isDefined) { - set(key.get, value.get) - } - // It doesn't matter what we return here, since this is only used - // to force the evaluation to happen eagerly. To query the results, - // one must use SchemaRDD operations to extract them. - emptyResult - case _ => executedPlan.execute() - } - lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... @@ -294,12 +279,7 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[Row] = { - logical match { - case s: SetCommand => eagerlyProcess(s) - case _ => executedPlan.execute() - } - } + lazy val toRdd: RDD[Row] = executedPlan.execute() protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } @@ -321,7 +301,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * TODO: We only support primitive types, add support for nested types. */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { - val schema = rdd.first.map { case (fieldName, obj) => + val schema = rdd.first().map { case (fieldName, obj) => val dataType = obj.getClass match { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 8855c4e876917..821ac850ac3f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -97,7 +97,7 @@ import java.util.{Map => JMap} @AlphaComponent class SchemaRDD( @transient val sqlContext: SQLContext, - @transient protected[spark] val logicalPlan: LogicalPlan) + @transient val baseLogicalPlan: LogicalPlan) extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike { def baseSchemaRDD = this @@ -178,14 +178,18 @@ class SchemaRDD( def orderBy(sortExprs: SortOrder*): SchemaRDD = new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan)) + @deprecated("use limit with integer argument", "1.1.0") + def limit(limitExpr: Expression): SchemaRDD = + new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) + /** - * Limits the results by the given expressions. + * Limits the results by the given integer. * {{{ * schemaRDD.limit(10) * }}} */ - def limit(limitExpr: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) + def limit(limitNum: Int): SchemaRDD = + new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan)) /** * Performs a grouping followed by an aggregation. @@ -374,6 +378,8 @@ class SchemaRDD( override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() + override def take(num: Int): Array[Row] = limit(num).collect() + // ======================================================================= // Base RDD functions that do NOT change schema // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 3a895e15a4508..656be965a8fd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.SparkLogicalPlan /** * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) */ private[sql] trait SchemaRDDLike { @transient val sqlContext: SQLContext - @transient protected[spark] val logicalPlan: LogicalPlan + @transient val baseLogicalPlan: LogicalPlan private[sql] def baseSchemaRDD: SchemaRDD @@ -48,7 +49,17 @@ private[sql] trait SchemaRDDLike { */ @transient @DeveloperApi - lazy val queryExecution = sqlContext.executePlan(logicalPlan) + lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) + + @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile => + queryExecution.toRdd + SparkLogicalPlan(queryExecution.executedPlan) + case _ => + baseLogicalPlan + } override def toString = s"""${super.toString} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 22f57b758dd02..aff6ffe9f3478 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -37,7 +37,7 @@ import org.apache.spark.storage.StorageLevel */ class JavaSchemaRDD( @transient val sqlContext: SQLContext, - @transient protected[spark] val logicalPlan: LogicalPlan) + @transient val baseLogicalPlan: LogicalPlan) extends JavaRDDLike[Row, JavaRDD[Row]] with SchemaRDDLike { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index fdf28e1bb1261..e1e4f24c6c66c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -17,18 +17,29 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, LeafNode} import org.apache.spark.sql.Row import org.apache.spark.SparkConf -private[sql] case class InMemoryColumnarTableScan( - attributes: Seq[Attribute], - child: SparkPlan, - useCompression: Boolean) - extends LeafNode { +object InMemoryRelation { + def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, child) +} - override def output: Seq[Attribute] = attributes +private[sql] case class InMemoryRelation( + output: Seq[Attribute], + useCompression: Boolean, + child: SparkPlan) + extends LogicalPlan with MultiInstanceRelation { + + override def children = Seq.empty + override def references = Set.empty + + override def newInstance() = + new InMemoryRelation(output.map(_.newInstance), useCompression, child).asInstanceOf[this.type] lazy val cachedColumnBuffers = { val output = child.output @@ -55,14 +66,26 @@ private[sql] case class InMemoryColumnarTableScan( cached.count() cached } +} + +private[sql] case class InMemoryColumnarTableScan( + attributes: Seq[Attribute], + relation: InMemoryRelation) + extends LeafNode { + + override def output: Seq[Attribute] = attributes override def execute() = { - cachedColumnBuffers.mapPartitions { iterator => + relation.cachedColumnBuffers.mapPartitions { iterator => val columnBuffers = iterator.next() assert(!iterator.hasNext) new Iterator[Row] { - val columnAccessors = columnBuffers.map(ColumnAccessor(_)) + // Find the ordinals of the requested columns. If none are requested, use the first. + val requestedColumns = + if (attributes.isEmpty) Seq(0) else attributes.map(relation.output.indexOf(_)) + + val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) val nextRow = new GenericMutableRow(columnAccessors.length) override def next() = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 4613df103943d..07967fe75e882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -77,8 +77,6 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan) SparkLogicalPlan( alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) - case scan @ InMemoryColumnarTableScan(output, _, _) => - scan.copy(attributes = output.map(_.newInstance)) case _ => sys.error("Multiple instance of the same relation detected.") }).asInstanceOf[this.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0455748d40eec..2233216a6ec52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLConf, SQLContext, execution} +import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.parquet._ +import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -156,7 +157,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil - case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => { + case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { (filters: Seq[Expression]) => { @@ -185,12 +186,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { filters, prunePushedDownFilters, ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil - } case _ => Nil } } + object InMemoryScans extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => + pruneFilterProject( + projectList, + filters, + identity[Seq[Expression]], // No filters are pushed down. + InMemoryColumnarTableScan(_, mem)) :: Nil + case _ => Nil + } + } + // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions = self.numPartitions @@ -237,12 +249,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.SetCommand(key, value) => - Seq(execution.SetCommandPhysical(key, value, plan.output)(context)) + Seq(execution.SetCommand(key, value, plan.output)(context)) case logical.ExplainCommand(child) => - val qe = context.executePlan(child) - Seq(execution.ExplainCommandPhysical(qe.executedPlan, plan.output)(context)) + val executedPlan = context.executePlan(child).executedPlan + Seq(execution.ExplainCommand(executedPlan, plan.output)(context)) + case logical.CacheCommand(tableName, cache) => + Seq(execution.CacheCommand(tableName, cache)(context)) case _ => Nil } } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 9364506691f38..0377290af5926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -22,46 +22,94 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} +trait Command { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffectResult` + * so that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[Any] = Seq.empty[Any] +} + /** * :: DeveloperApi :: */ @DeveloperApi -case class SetCommandPhysical(key: Option[String], value: Option[String], output: Seq[Attribute]) - (@transient context: SQLContext) extends LeafNode { - def execute(): RDD[Row] = (key, value) match { - // Set value for key k; the action itself would - // have been performed in QueryExecution eagerly. - case (Some(k), Some(v)) => context.emptyResult +case class SetCommand( + key: Option[String], value: Option[String], output: Seq[Attribute])( + @transient context: SQLContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match { + // Set value for key k. + case (Some(k), Some(v)) => + context.set(k, v) + Array(k -> v) + // Query the value bound to key k. - case (Some(k), None) => - val resultString = context.getOption(k) match { - case Some(v) => s"$k=$v" - case None => s"$k is undefined" - } - context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](resultString))), 1) + case (Some(k), _) => + Array(k -> context.getOption(k).getOrElse("")) + // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - val pairs = context.getAll - val rows = pairs.map { case (k, v) => - new GenericRow(Array[Any](s"$k=$v")) - }.toSeq - // Assume config parameters can fit into one split (machine) ;) - context.sparkContext.parallelize(rows, 1) - // The only other case is invalid semantics and is impossible. - case _ => context.emptyResult + context.getAll + + case _ => + throw new IllegalArgumentException() } + + def execute(): RDD[Row] = { + val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) } + context.sparkContext.parallelize(rows, 1) + } + + override def otherCopyArgs = context :: Nil } /** * :: DeveloperApi :: */ @DeveloperApi -case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) - (@transient context: SQLContext) extends UnaryNode { +case class ExplainCommand( + child: SparkPlan, output: Seq[Attribute])( + @transient context: SQLContext) + extends UnaryNode with Command { + + // Actually "EXPLAIN" command doesn't cause any side effect. + override protected[sql] lazy val sideEffectResult: Seq[String] = this.toString.split("\n") + def execute(): RDD[Row] = { - val planString = new GenericRow(Array[Any](child.toString)) - context.sparkContext.parallelize(Seq(planString)) + val explanation = sideEffectResult.mkString("\n") + context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](explanation))), 1) } override def otherCopyArgs = context :: Nil } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult = { + if (doCache) { + context.cacheTable(tableName) + } else { + context.uncacheTable(tableName) + } + Seq.empty[Any] + } + + override def execute(): RDD[Row] = { + sideEffectResult + context.emptyResult + } + + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 88ff3d49a79b3..8d7a5ba59f96a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -169,7 +169,7 @@ case class LeftSemiJoinHash( def execute() = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashTable = new java.util.HashSet[Row]() + val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null // Create a Hash set of buildKeys @@ -177,43 +177,17 @@ case class LeftSemiJoinHash( currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) if(!rowKey.anyNull) { - val keyExists = hashTable.contains(rowKey) + val keyExists = hashSet.contains(rowKey) if (!keyExists) { - hashTable.add(rowKey) + hashSet.add(rowKey) } } } - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatched: Boolean = false - - private[this] val joinKeys = streamSideKeyGenerator() - - override final def hasNext: Boolean = - streamIter.hasNext && fetchNext() - - override final def next() = { - currentStreamedRow - } - - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatched = false - while (!currentHashMatched && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatched = hashTable.contains(joinKeys.currentValue) - } - } - currentHashMatched - } - } + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) + }) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0331f90272a99..c794da4da4069 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan -import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.test.TestSQLContext class CachedTableSuite extends QueryTest { @@ -34,7 +33,7 @@ class CachedTableSuite extends QueryTest { ) TestSQLContext.table("testData").queryExecution.analyzed match { - case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case _ : InMemoryRelation => // Found evidence of caching case noCache => fail(s"No cache node found in plan $noCache") } @@ -46,7 +45,7 @@ class CachedTableSuite extends QueryTest { ) TestSQLContext.table("testData").queryExecution.analyzed match { - case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + case cachePlan: InMemoryRelation => fail(s"Table still cached after uncache: $cachePlan") case noCache => // Table uncached successfully } @@ -61,13 +60,33 @@ class CachedTableSuite extends QueryTest { test("SELECT Star Cached Table") { TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar") TestSQLContext.cacheTable("selectStar") - TestSQLContext.sql("SELECT * FROM selectStar") + TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect() TestSQLContext.uncacheTable("selectStar") } test("Self-join cached") { + val unCachedAnswer = + TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() TestSQLContext.cacheTable("testData") - TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key") + checkAnswer( + TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), + unCachedAnswer.toSeq) TestSQLContext.uncacheTable("testData") } + + test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { + TestSQLContext.sql("CACHE TABLE testData") + TestSQLContext.table("testData").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => // Found evidence of caching + case _ => fail(s"Table 'testData' should be cached") + } + assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached") + + TestSQLContext.sql("UNCACHE TABLE testData") + TestSQLContext.table("testData").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached") + case _ => // Found evidence of uncaching + } + assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 5eb73a4eff980..08293f7f0ca30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -28,6 +28,7 @@ class SQLConfSuite extends QueryTest { val testVal = "test.val.0" test("programmatic ways of basic setting and getting") { + clear() assert(getOption(testKey).isEmpty) assert(getAll.toSet === Set()) @@ -48,6 +49,7 @@ class SQLConfSuite extends QueryTest { } test("parse SQL set commands") { + clear() sql(s"set $testKey=$testVal") assert(get(testKey, testVal + "_") == testVal) assert(TestSQLContext.get(testKey, testVal + "_") == testVal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index de02bbc7e4700..e9360b0fc7910 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -141,7 +141,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), Seq((2147483645.0,1),(2.0,2))) } - + test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), @@ -332,7 +332,7 @@ class SQLQuerySuite extends QueryTest { (3, "C"), (4, "D"))) } - + test("system function upper()") { checkAnswer( sql("SELECT n,UPPER(l) FROM lowerCaseData"), @@ -349,7 +349,7 @@ class SQLQuerySuite extends QueryTest { (2, "ABC"), (3, null))) } - + test("system function lower()") { checkAnswer( sql("SELECT N,LOWER(L) FROM upperCaseData"), @@ -382,26 +382,27 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(s"$testKey=$testVal")) + Seq(Seq(testKey, testVal)) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(s"$testKey=$testVal"), - Seq(s"${testKey + testKey}=${testVal + testVal}")) + Seq(testKey, testVal), + Seq(testKey + testKey, testVal + testVal)) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(s"$testKey=$testVal")) + Seq(Seq(testKey, testVal)) ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(s"$nonexistentKey is undefined")) + Seq(Seq(nonexistentKey, "")) ) + clear() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 31c5dfba92954..86727b93f3659 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -28,14 +28,14 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true)) + val scan = InMemoryRelation(useCompression = true, plan) checkAnswer(scan, testData.collect().toSeq) } test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true)) + val scan = InMemoryRelation(useCompression = true, plan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -44,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true)) + val scan = InMemoryRelation(useCompression = true, plan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 64978215542ec..96e0ec5136331 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -15,8 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql -package hive +package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.util.{ArrayList => JArrayList} @@ -32,12 +31,13 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.execution.{Command => PhysicalCommand} /** * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is @@ -71,14 +71,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD. */ - def hiveql(hqlQuery: String): SchemaRDD = { - val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but does not perform any execution. - result.queryExecution.toRdd - result - } + def hiveql(hqlQuery: String): SchemaRDD = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) /** An alias for `hiveql`. */ def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) @@ -164,7 +157,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * Runs the specified SQL query using Hive. */ - protected def runSqlHive(sql: String): Seq[String] = { + protected[sql] def runSqlHive(sql: String): Seq[String] = { val maxResults = 100000 val results = runHive(sql, 100000) // It is very confusing when you only get back some of the results... @@ -228,8 +221,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override val strategies: Seq[Strategy] = Seq( CommandStrategy(self), + HiveCommandStrategy(self), TakeOrdered, ParquetOperations, + InMemoryScans, HiveTableScans, DataSinks, Scripts, @@ -251,25 +246,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override lazy val optimizedPlan = optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) - override lazy val toRdd: RDD[Row] = { - def processCmd(cmd: String): RDD[Row] = { - val output = runSqlHive(cmd) - if (output.size == 0) { - emptyResult - } else { - val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) - sparkContext.parallelize(asRows, 1) - } - } - - logical match { - case s: SetCommand => eagerlyProcess(s) - case _ => analyzed match { - case NativeCommand(cmd) => processCmd(cmd) - case _ => executedPlan.execute().map(_.copy()) - } - } - } + override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, @@ -297,7 +274,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ))=> + case (seq: Seq[_], ArrayType(typ)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_,_], MapType(kType, vType)) => map.map { @@ -313,10 +290,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Returns the result as a hive compatible sequence of strings. For native commands, the * execution is simply passed back to Hive. */ - def stringResult(): Seq[String] = analyzed match { - case NativeCommand(cmd) => runSqlHive(cmd) - case ExplainCommand(plan) => executePlan(plan).toString.split("\n") - case query => + def stringResult(): Seq[String] = executedPlan match { + case command: PhysicalCommand => + command.sideEffectResult.map(_.toString) + + case other => val result: Seq[Seq[Any]] = toRdd.collect().toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) @@ -327,8 +305,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def simpleString: String = logical match { - case _: NativeCommand => "" - case _: SetCommand => "" + case _: NativeCommand => "" + case _: SetCommand => "" case _ => executedPlan.toString } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a91b520765349..68284344afd55 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkLogicalPlan import org.apache.spark.sql.hive.execution.{HiveTableScan, InsertIntoHiveTable} -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -130,8 +130,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => castChildOutput(p, table, child) - case p @ logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan( - _, HiveTableScan(_, table, _), _)), _, child, _) => + case p @ logical.InsertIntoTable( + InMemoryRelation(_, _, + HiveTableScan(_, table, _)), _, child, _) => castChildOutput(p, table, child) } @@ -236,6 +237,7 @@ object HiveMetastoreTypes extends RegexParsers { case BinaryType => "binary" case BooleanType => "boolean" case DecimalType => "decimal" + case TimestampType => "timestamp" } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 4e74d9bc909fa..b745d8ffd8f17 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -218,15 +218,19 @@ private[hive] object HiveQl { case Array(key, value) => // "set key=value" SetCommand(Some(key), Some(value)) } - } else if (sql.toLowerCase.startsWith("add jar")) { + } else if (sql.trim.toLowerCase.startsWith("cache table")) { + CacheCommand(sql.drop(12).trim, true) + } else if (sql.trim.toLowerCase.startsWith("uncache table")) { + CacheCommand(sql.drop(14).trim, false) + } else if (sql.trim.toLowerCase.startsWith("add jar")) { AddJar(sql.drop(8)) - } else if (sql.toLowerCase.startsWith("add file")) { + } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.drop(9)) - } else if (sql.startsWith("dfs")) { + } else if (sql.trim.startsWith("dfs")) { DfsCommand(sql) - } else if (sql.startsWith("source")) { + } else if (sql.trim.startsWith("source")) { SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath }) - } else if (sql.startsWith("!")) { + } else if (sql.trim.startsWith("!")) { ShellCommand(sql.drop(1)) } else { val tree = getAst(sql) @@ -839,11 +843,11 @@ private[hive] object HiveQl { case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg)) case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg)) - + /* System functions about string operations */ case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg)) case Token("TOK_FUNCTION", Token(LOWER(), Nil) :: arg :: Nil) => Lower(nodeToExpr(arg)) - + /* Casts */ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), StringType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8b51957162e04..0ac0ee9071f36 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.columnar.InMemoryRelation private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. @@ -44,8 +44,9 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil - case logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan( - _, HiveTableScan(_, table, _), _)), partition, child, overwrite) => + case logical.InsertIntoTable( + InMemoryRelation(_, _, + HiveTableScan(_, table, _)), partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case _ => Nil } @@ -75,4 +76,12 @@ private[hive] trait HiveStrategies { Nil } } + + case class HiveCommandStrategy(context: HiveContext) extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.NativeCommand(sql) => + NativeCommand(sql, plan.output)(context) :: Nil + case _ => Nil + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 041e813598d1b..9386008d02d51 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, NativeCommand} +import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive._ @@ -103,7 +103,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) } else { - new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + + new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + File.separator + "resources") } @@ -130,6 +130,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { override lazy val analyzed = { val describedTables = logical match { case NativeCommand(describedTable(tbl)) => tbl :: Nil + case CacheCommand(tbl, _) => tbl :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala index 29b4b9b006e45..a839231449161 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala @@ -32,14 +32,15 @@ import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ -import org.apache.spark.{TaskContext, SparkException} import org.apache.spark.util.MutablePair +import org.apache.spark.{TaskContext, SparkException} /* Implicits */ import scala.collection.JavaConversions._ @@ -57,7 +58,7 @@ case class HiveTableScan( attributes: Seq[Attribute], relation: MetastoreRelation, partitionPruningPred: Option[Expression])( - @transient val sc: HiveContext) + @transient val context: HiveContext) extends LeafNode with HiveInspectors { @@ -75,7 +76,7 @@ case class HiveTableScan( } @transient - val hadoopReader = new HadoopTableReader(relation.tableDesc, sc) + val hadoopReader = new HadoopTableReader(relation.tableDesc, context) /** * The hive object inspector for this table, which can be used to extract values from the @@ -156,7 +157,7 @@ case class HiveTableScan( hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames) } - addColumnMetadataToConf(sc.hiveconf) + addColumnMetadataToConf(context.hiveconf) @transient def inputRdd = if (!relation.hiveQlTable.isPartitioned) { @@ -428,3 +429,26 @@ case class InsertIntoHiveTable( sc.sparkContext.makeRDD(Nil, 1) } } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class NativeCommand( + sql: String, output: Seq[Attribute])( + @transient context: HiveContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql) + + override def execute(): RDD[spark.sql.Row] = { + if (sideEffectResult.size == 0) { + context.emptyResult + } else { + val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r))) + context.sparkContext.parallelize(rows, 1) + } + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/hive/src/test/resources/golden/semijoin-0-1631b71327abf75b96116036b977b26c b/sql/hive/src/test/resources/golden/semijoin-0-1631b71327abf75b96116036b977b26c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c b/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c new file mode 100644 index 0000000000000..2ed47ab83dd02 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c @@ -0,0 +1,11 @@ +0 val_0 +0 val_0 +0 val_0 +2 val_2 +4 val_4 +5 val_5 +5 val_5 +5 val_5 +8 val_8 +9 val_9 +10 val_10 diff --git a/sql/hive/src/test/resources/golden/semijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6 b/sql/hive/src/test/resources/golden/semijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330 b/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330 new file mode 100644 index 0000000000000..a24bd8c6379e3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330 @@ -0,0 +1,8 @@ +0 val_0 +0 val_0 +0 val_0 +4 val_2 +8 val_4 +10 val_5 +10 val_5 +10 val_5 diff --git a/sql/hive/src/test/resources/golden/semijoin-12-6d93a9d332ba490835b17f261a5467df b/sql/hive/src/test/resources/golden/semijoin-12-6d93a9d332ba490835b17f261a5467df new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-13-18282d38b6efc0017089ab89b661764f b/sql/hive/src/test/resources/golden/semijoin-13-18282d38b6efc0017089ab89b661764f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-14-19cfcefb10e1972bec0ffd421cd79de7 b/sql/hive/src/test/resources/golden/semijoin-14-19cfcefb10e1972bec0ffd421cd79de7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 b/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 new file mode 100644 index 0000000000000..03c61a908b071 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 @@ -0,0 +1,11 @@ +val_0 +val_0 +val_0 +val_10 +val_2 +val_4 +val_5 +val_5 +val_5 +val_8 +val_9 diff --git a/sql/hive/src/test/resources/golden/semijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d b/sql/hive/src/test/resources/golden/semijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019 b/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019 new file mode 100644 index 0000000000000..2dcdfd1217ced --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019 @@ -0,0 +1,3 @@ +0 val_0 +0 val_0 +0 val_0 diff --git a/sql/hive/src/test/resources/golden/semijoin-18-f7b2ce472443982e32d954cbb5c96765 b/sql/hive/src/test/resources/golden/semijoin-18-f7b2ce472443982e32d954cbb5c96765 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f b/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f new file mode 100644 index 0000000000000..a3670515e8cc2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f @@ -0,0 +1,3 @@ +val_10 +val_8 +val_9 diff --git a/sql/hive/src/test/resources/golden/semijoin-2-deb9c3286ae8e851b1fdb270085b16bc b/sql/hive/src/test/resources/golden/semijoin-2-deb9c3286ae8e851b1fdb270085b16bc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1 b/sql/hive/src/test/resources/golden/semijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-21-480418a0646cf7260b494b9eb4821bb6 b/sql/hive/src/test/resources/golden/semijoin-21-480418a0646cf7260b494b9eb4821bb6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41 b/sql/hive/src/test/resources/golden/semijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda b/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda new file mode 100644 index 0000000000000..72bc6a6a88f6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda @@ -0,0 +1,5 @@ +4 val_2 +8 val_4 +10 val_5 +10 val_5 +10 val_5 diff --git a/sql/hive/src/test/resources/golden/semijoin-24-d16b37134de78980b2bf96029e8265c3 b/sql/hive/src/test/resources/golden/semijoin-24-d16b37134de78980b2bf96029e8265c3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b b/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b new file mode 100644 index 0000000000000..d89ea1757c712 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b @@ -0,0 +1,19 @@ +0 +0 +0 +0 +0 +0 +2 +4 +4 +5 +5 +5 +8 +8 +9 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-26-f1d3bab29f1ebafa148dbe3816e1da25 b/sql/hive/src/test/resources/golden/semijoin-26-f1d3bab29f1ebafa148dbe3816e1da25 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d b/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d new file mode 100644 index 0000000000000..dbbdae75a52a4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d @@ -0,0 +1,4 @@ +0 val_0 +0 val_0 +0 val_0 +8 val_8 diff --git a/sql/hive/src/test/resources/golden/semijoin-28-b56400f6d9372f353cf7292a2182e963 b/sql/hive/src/test/resources/golden/semijoin-28-b56400f6d9372f353cf7292a2182e963 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5 b/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5 new file mode 100644 index 0000000000000..07c61afb5124b --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5 @@ -0,0 +1,14 @@ +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +4 val_4 4 val_2 +8 val_8 8 val_4 +10 val_10 10 val_5 +10 val_10 10 val_5 +10 val_10 10 val_5 diff --git a/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389 b/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389 new file mode 100644 index 0000000000000..bf51e8f5d9eb5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389 @@ -0,0 +1,11 @@ +0 val_0 +0 val_0 +0 val_0 +4 val_2 +8 val_4 +10 val_5 +10 val_5 +10 val_5 +16 val_8 +18 val_9 +20 val_10 diff --git a/sql/hive/src/test/resources/golden/semijoin-30-dd901d00fce5898b03a57cbc3028a70a b/sql/hive/src/test/resources/golden/semijoin-30-dd901d00fce5898b03a57cbc3028a70a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 b/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 new file mode 100644 index 0000000000000..d6283e34d8ffc --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 @@ -0,0 +1,14 @@ +0 val_0 +0 val_0 +0 val_0 +0 val_0 +0 val_0 +0 val_0 +2 val_2 +4 val_4 +5 val_5 +5 val_5 +5 val_5 +8 val_8 +9 val_9 +10 val_10 diff --git a/sql/hive/src/test/resources/golden/semijoin-32-23017c7663f2710265a7e2a4a1606d39 b/sql/hive/src/test/resources/golden/semijoin-32-23017c7663f2710265a7e2a4a1606d39 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 b/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 new file mode 100644 index 0000000000000..080180f9d0f0e --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 @@ -0,0 +1,14 @@ +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-34-5e1b832090ab73c141c1167d5b25a490 b/sql/hive/src/test/resources/golden/semijoin-34-5e1b832090ab73c141c1167d5b25a490 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99 b/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99 new file mode 100644 index 0000000000000..4a64d5c625790 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99 @@ -0,0 +1,26 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-36-b1159823dca8025926407f8aa921238d b/sql/hive/src/test/resources/golden/semijoin-36-b1159823dca8025926407f8aa921238d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0 b/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0 new file mode 100644 index 0000000000000..1420c786fb228 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0 @@ -0,0 +1,29 @@ +NULL +NULL +NULL +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-38-f37547c73a48ce3ba089531b176e6ba b/sql/hive/src/test/resources/golden/semijoin-38-f37547c73a48ce3ba089531b176e6ba new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8 b/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8 new file mode 100644 index 0000000000000..1420c786fb228 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8 @@ -0,0 +1,29 @@ +NULL +NULL +NULL +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-4-dfdad5a2742f93e8ea888191460809c0 b/sql/hive/src/test/resources/golden/semijoin-4-dfdad5a2742f93e8ea888191460809c0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-40-32071a51e2ba6e86b1c5e40de55aae63 b/sql/hive/src/test/resources/golden/semijoin-40-32071a51e2ba6e86b1c5e40de55aae63 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092 b/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092 new file mode 100644 index 0000000000000..aef9483bb0bc9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092 @@ -0,0 +1,29 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 +16 +18 +20 diff --git a/sql/hive/src/test/resources/golden/semijoin-42-6b4257a74fca627785c967c99547f4c0 b/sql/hive/src/test/resources/golden/semijoin-42-6b4257a74fca627785c967c99547f4c0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 b/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 new file mode 100644 index 0000000000000..0bc413ef2e09e --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 @@ -0,0 +1,31 @@ +NULL +NULL +NULL +NULL +NULL +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-44-945aaa3a24359ef73acab1e99500d5ea b/sql/hive/src/test/resources/golden/semijoin-44-945aaa3a24359ef73acab1e99500d5ea new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b b/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b new file mode 100644 index 0000000000000..3131e64446f66 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b @@ -0,0 +1,42 @@ +NULL +NULL +NULL +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +2 +4 +4 +5 +5 +5 +8 +8 +9 +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-46-620e01f81f6e5254b4bbe8fab4043ec0 b/sql/hive/src/test/resources/golden/semijoin-46-620e01f81f6e5254b4bbe8fab4043ec0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c b/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c new file mode 100644 index 0000000000000..ff30bedb81861 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c @@ -0,0 +1,35 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 +16 +18 +20 diff --git a/sql/hive/src/test/resources/golden/semijoin-48-8a04442e84f99a584c2882d0af8c25d8 b/sql/hive/src/test/resources/golden/semijoin-48-8a04442e84f99a584c2882d0af8c25d8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-49-df1d6705d3624be72036318a6b42f04c b/sql/hive/src/test/resources/golden/semijoin-49-df1d6705d3624be72036318a6b42f04c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d b/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d new file mode 100644 index 0000000000000..60f6eacee9b14 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d @@ -0,0 +1,22 @@ +0 val_0 +0 val_0 +0 val_0 +0 val_0 +0 val_0 +0 val_0 +2 val_2 +4 val_2 +4 val_4 +5 val_5 +5 val_5 +5 val_5 +8 val_4 +8 val_8 +9 val_9 +10 val_10 +10 val_5 +10 val_5 +10 val_5 +16 val_8 +18 val_9 +20 val_10 diff --git a/sql/hive/src/test/resources/golden/semijoin-6-90bb51b1330230d10a14fb7517457aa0 b/sql/hive/src/test/resources/golden/semijoin-6-90bb51b1330230d10a14fb7517457aa0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-7-333d72e8bce6d11a35fc7a30418f225b b/sql/hive/src/test/resources/golden/semijoin-7-333d72e8bce6d11a35fc7a30418f225b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-8-d46607be851a6f4e27e98cbbefdee994 b/sql/hive/src/test/resources/golden/semijoin-8-d46607be851a6f4e27e98cbbefdee994 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 b/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 new file mode 100644 index 0000000000000..5baaac9bebf6d --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 @@ -0,0 +1,6 @@ +0 val_0 +0 val_0 +0 val_0 +4 val_4 +8 val_8 +10 val_10 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index f9a162ef4e3c0..3132d0112c708 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.execution.SparkLogicalPlan -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.hive.execution.HiveComparisonTest import org.apache.spark.sql.hive.test.TestHive @@ -34,7 +34,7 @@ class CachedTableSuite extends HiveComparisonTest { test("check that table is cached and uncache") { TestHive.table("src").queryExecution.analyzed match { - case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case _ : InMemoryRelation => // Found evidence of caching case noCache => fail(s"No cache node found in plan $noCache") } TestHive.uncacheTable("src") @@ -45,7 +45,7 @@ class CachedTableSuite extends HiveComparisonTest { test("make sure table is uncached") { TestHive.table("src").queryExecution.analyzed match { - case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + case cachePlan: InMemoryRelation => fail(s"Table still cached after uncache: $cachePlan") case noCache => // Table uncached successfully } @@ -56,4 +56,20 @@ class CachedTableSuite extends HiveComparisonTest { TestHive.uncacheTable("src") } } + + test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { + TestHive.hql("CACHE TABLE src") + TestHive.table("src").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => // Found evidence of caching + case _ => fail(s"Table 'src' should be cached") + } + assert(TestHive.isCached("src"), "Table 'src' should be cached") + + TestHive.hql("UNCACHE TABLE src") + TestHive.table("src").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached") + case _ => // Found evidence of uncaching + } + assert(!TestHive.isCached("src"), "Table 'src' should not be cached") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 357c7e654bd20..24c929ff7430d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -24,6 +24,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive.test.TestHive @@ -141,7 +142,7 @@ abstract class HiveComparisonTest // Hack: Hive simply prints the result of a SET command to screen, // and does not return it as a query answer. case _: SetCommand => Seq("0") - case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") + case _: LogicalNativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") case _: ExplainCommand => answer case plan => if (isSorted(plan)) answer else answer.sorted } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index fb8f272d5abfe..ee194dbcb77b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -172,7 +172,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "case_sensitivity", // Flaky test, Hive sometimes returns different set of 10 rows. - "lateral_view_outer" + "lateral_view_outer", + + // After stop taking the `stringOrError` route, exceptions are thrown from these cases. + // See SPARK-2129 for details. + "join_view", + "mergejoins_mixed" ) /** @@ -476,7 +481,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_reorder3", "join_reorder4", "join_star", - "join_view", "lateral_view", "lateral_view_cp", "lateral_view_ppd", @@ -507,7 +511,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "merge1", "merge2", "mergejoins", - "mergejoins_mixed", "multigroupby_singlemr", "multi_insert_gby", "multi_insert_gby3", @@ -597,6 +600,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "select_unquote_and", "select_unquote_not", "select_unquote_or", + "semijoin", "serde_regex", "serde_reported_schema", "set_variable_sub", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6c239b02ed09a..0d656c556965d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.Row -import org.apache.spark.sql.hive.test.TestHive._ +import scala.util.Try + import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{SchemaRDD, execution, Row} /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. @@ -162,16 +164,60 @@ class HiveQuerySuite extends HiveComparisonTest { hql("SELECT * FROM src").toString } + private val explainCommandClassName = + classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$") + + def isExplanation(result: SchemaRDD) = { + val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } + explanation.size == 1 && explanation.head.startsWith(explainCommandClassName) + } + test("SPARK-1704: Explain commands as a SchemaRDD") { hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + val rdd = hql("explain select key, count(value) from src group by key") - assert(rdd.collect().size == 1) - assert(rdd.toString.contains("ExplainCommand")) - assert(rdd.filter(row => row.toString.contains("ExplainCommand")).collect().size == 0, - "actual contents of the result should be the plans of the query to be explained") + assert(isExplanation(rdd)) + TestHive.reset() } + test("Query Hive native command execution result") { + val tableName = "test_native_commands" + + val q0 = hql(s"DROP TABLE IF EXISTS $tableName") + assert(q0.count() == 0) + + val q1 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") + assert(q1.count() == 0) + + val q2 = hql("SHOW TABLES") + val tables = q2.select('result).collect().map { case Row(table: String) => table } + assert(tables.contains(tableName)) + + val q3 = hql(s"DESCRIBE $tableName") + assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { + q3.select('result).collect().map { case Row(fieldDesc: String) => + fieldDesc.split("\t").map(_.trim) + } + } + + val q4 = hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key") + assert(isExplanation(q4)) + + TestHive.reset() + } + + test("Exactly once semantics for DDL and command statements") { + val tableName = "test_exactly_once" + val q0 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") + + // If the table was not created, the following assertion would fail + assert(Try(table(tableName)).isSuccess) + + // If the CREATE TABLE command got executed again, the following assertion would fail + assert(Try(q0.count()).isSuccess) + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" @@ -195,52 +241,69 @@ class HiveQuerySuite extends HiveComparisonTest { test("SET commands semantics for a HiveContext") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" - var testVal = "test.val.0" + val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def fromRows(row: Array[Row]): Array[String] = row.map(_.getString(0)) + def rowsToPairs(rows: Array[Row]) = rows.map { case Row(key: String, value: String) => + key -> value + } clear() // "set" itself returns all config variables currently specified in SQLConf. - assert(hql("set").collect().size == 0) + assert(hql("SET").collect().size == 0) + + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql(s"SET $testKey=$testVal").collect()) + } - // "set key=val" - hql(s"SET $testKey=$testVal") - assert(fromRows(hql("SET").collect()) sameElements Array(s"$testKey=$testVal")) assert(hiveconf.get(testKey, "") == testVal) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql("SET").collect()) + } hql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(fromRows(hql("SET").collect()) sameElements - Array( - s"$testKey=$testVal", - s"${testKey + testKey}=${testVal + testVal}")) assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + rowsToPairs(hql("SET").collect()) + } // "set key" - assert(fromRows(hql(s"SET $testKey").collect()) sameElements - Array(s"$testKey=$testVal")) - assert(fromRows(hql(s"SET $nonexistentKey").collect()) sameElements - Array(s"$nonexistentKey is undefined")) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql(s"SET $testKey").collect()) + } + + assertResult(Array(nonexistentKey -> "")) { + rowsToPairs(hql(s"SET $nonexistentKey").collect()) + } // Assert that sql() should have the same effects as hql() by repeating the above using sql(). clear() - assert(sql("set").collect().size == 0) + assert(sql("SET").collect().size == 0) + + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql(s"SET $testKey=$testVal").collect()) + } - sql(s"SET $testKey=$testVal") - assert(fromRows(sql("SET").collect()) sameElements Array(s"$testKey=$testVal")) assert(hiveconf.get(testKey, "") == testVal) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql("SET").collect()) + } sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(fromRows(sql("SET").collect()) sameElements - Array( - s"$testKey=$testVal", - s"${testKey + testKey}=${testVal + testVal}")) assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + rowsToPairs(sql("SET").collect()) + } - assert(fromRows(sql(s"SET $testKey").collect()) sameElements - Array(s"$testKey=$testVal")) - assert(fromRows(sql(s"SET $nonexistentKey").collect()) sameElements - Array(s"$nonexistentKey is undefined")) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql(s"SET $testKey").collect()) + } + + assertResult(Array(nonexistentKey -> "")) { + rowsToPairs(sql(s"SET $nonexistentKey").collect()) + } + + clear() } // Put tests that depend on specific Hive settings before these last two test, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala index 86753360a07e4..a0aeacbc733bd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala @@ -27,6 +27,7 @@ private[streaming] class ContextWaiter { } def notifyStop() = synchronized { + stopped = true notifyAll() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index 303d149d285e1..d9ac3c91f6e36 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -29,7 +29,6 @@ import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import scala.language.postfixOps /** Testsuite for testing the network receiver behavior */ class NetworkReceiverSuite extends FunSuite with Timeouts { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index cd86019f63e7e..7b33d3b235466 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -223,6 +223,18 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("awaitTermination after stop") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => x).register() + + failAfter(10000 millis) { + ssc.start() + ssc.stop() + ssc.awaitTermination() + } + } + test("awaitTermination with error in task") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index ef0efa552ceaf..2861f5335ae36 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -27,12 +27,12 @@ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler._ -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.Logging -class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers { +class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 6a261e19a35cd..03a73f92b275e 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -40,74 +40,78 @@ object GenerateMIMAIgnore { private val classLoader = Thread.currentThread().getContextClassLoader private val mirror = runtimeMirror(classLoader) - private def classesPrivateWithin(packageName: String): Set[String] = { + + private def isDeveloperApi(sym: unv.Symbol) = + sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.DeveloperApi]) + + private def isExperimental(sym: unv.Symbol) = + sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.Experimental]) + + + private def isPackagePrivate(sym: unv.Symbol) = + !sym.privateWithin.fullName.startsWith("") + + private def isPackagePrivateModule(moduleSymbol: unv.ModuleSymbol) = + !moduleSymbol.privateWithin.fullName.startsWith("") + + /** + * For every class checks via scala reflection if the class itself or contained members + * have DeveloperApi or Experimental annotations or they are package private. + * Returns the tuple of such classes and members. + */ + private def privateWithin(packageName: String): (Set[String], Set[String]) = { val classes = getClasses(packageName) val ignoredClasses = mutable.HashSet[String]() + val ignoredMembers = mutable.HashSet[String]() - def isPackagePrivate(className: String) = { + for (className <- classes) { try { - /* Couldn't figure out if it's possible to determine a-priori whether a given symbol - is a module or class. */ - - val privateAsClass = mirror - .classSymbol(Class.forName(className, false, classLoader)) - .privateWithin - .fullName - .startsWith(packageName) - - val privateAsModule = mirror - .staticModule(className) - .privateWithin - .fullName - .startsWith(packageName) - - privateAsClass || privateAsModule - } catch { - case _: Throwable => { - println("Error determining visibility: " + className) - false + val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader)) + val moduleSymbol = mirror.staticModule(className) // TODO: see if it is necessary. + val directlyPrivateSpark = + isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol) + val developerApi = isDeveloperApi(classSymbol) + val experimental = isExperimental(classSymbol) + + /* Inner classes defined within a private[spark] class or object are effectively + invisible, so we account for them as package private. */ + lazy val indirectlyPrivateSpark = { + val maybeOuter = className.toString.takeWhile(_ != '$') + if (maybeOuter != className) { + isPackagePrivate(mirror.classSymbol(Class.forName(maybeOuter, false, classLoader))) || + isPackagePrivateModule(mirror.staticModule(maybeOuter)) + } else { + false + } + } + if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi || experimental) { + ignoredClasses += className + } else { + // check if this class has package-private/annotated members. + ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } - } - } - def isDeveloperApi(className: String) = { - try { - val clazz = mirror.classSymbol(Class.forName(className, false, classLoader)) - clazz.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.DeveloperApi]) } catch { - case _: Throwable => { - println("Error determining Annotations: " + className) - false - } + case _: Throwable => println("Error instrumenting class:" + className) } } + (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) + } - for (className <- classes) { - val directlyPrivateSpark = isPackagePrivate(className) - val developerApi = isDeveloperApi(className) - - /* Inner classes defined within a private[spark] class or object are effectively - invisible, so we account for them as package private. */ - val indirectlyPrivateSpark = { - val maybeOuter = className.toString.takeWhile(_ != '$') - if (maybeOuter != className) { - isPackagePrivate(maybeOuter) - } else { - false - } - } - if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi) { - ignoredClasses += className - } - } - ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet + private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { + classSymbol.typeSignature.members + .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) } def main(args: Array[String]) { - scala.tools.nsc.io.File(".generated-mima-excludes"). - writeAll(classesPrivateWithin("org.apache.spark").mkString("\n")) - println("Created : .generated-mima-excludes in current directory.") + val (privateClasses, privateMembers) = privateWithin("org.apache.spark") + scala.tools.nsc.io.File(".generated-mima-class-excludes"). + writeAll(privateClasses.mkString("\n")) + println("Created : .generated-mima-class-excludes in current directory.") + scala.tools.nsc.io.File(".generated-mima-member-excludes"). + writeAll(privateMembers.mkString("\n")) + println("Created : .generated-mima-member-excludes in current directory.") } @@ -140,10 +144,17 @@ object GenerateMIMAIgnore { * Get all classes in a package from a jar file. */ private def getClassesFromJar(jarPath: String, packageName: String) = { + import scala.collection.mutable val jar = new JarFile(new File(jarPath)) val enums = jar.entries().map(_.getName).filter(_.startsWith(packageName)) - val classes = for (entry <- enums if entry.endsWith(".class")) - yield Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) + val classes = mutable.HashSet[Class[_]]() + for (entry <- enums if entry.endsWith(".class")) { + try { + classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) + } catch { + case _: Throwable => println("Unable to load:" + entry) + } + } classes } } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 8f0ecb855718e..1cc9c33cd2d02 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -277,7 +277,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) ApplicationMaster.incrementAllocatorLoop(1) - Thread.sleep(100) + Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) } } finally { // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT, @@ -416,6 +416,7 @@ object ApplicationMaster { // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + private val ALLOCATE_HEARTBEAT_INTERVAL = 100 def incrementAllocatorLoop(by: Int) { val count = yarnAllocatorLoop.getAndAdd(by) @@ -467,13 +468,22 @@ object ApplicationMaster { }) } - // Wait for initialization to complete and atleast 'some' nodes can get allocated. + modified + } + + + /** + * Returns when we've either + * 1) received all the requested executors, + * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms, + * 3) hit an error that causes us to terminate trying to get containers. + */ + def waitForInitialAllocations() { yarnAllocatorLoop.synchronized { while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) { yarnAllocatorLoop.wait(1000L) } } - modified } def main(argStrings: Array[String]) { diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 8226207de42b8..4ccddc214c8ad 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -85,7 +85,6 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def run() { val appId = runApp() monitorApplication(appId) - System.exit(0) } def logClusterResourceDetails() { @@ -179,8 +178,17 @@ object Client { System.setProperty("SPARK_YARN_MODE", "true") val sparkConf = new SparkConf - val args = new ClientArguments(argStrings, sparkConf) - new Client(args, sparkConf).run + try { + val args = new ClientArguments(argStrings, sparkConf) + new Client(args, sparkConf).run() + } catch { + case e: Exception => { + Console.err.println(e.getMessage) + System.exit(1) + } + } + + System.exit(0) } } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index a3bd91590fc25..b6ecae1e652fe 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -271,6 +271,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp .asInstanceOf[FinishApplicationMasterRequest] finishReq.setAppAttemptId(appAttemptId) finishReq.setFinishApplicationStatus(status) + finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) resourceManager.finishApplicationMaster(finishReq) } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index b2c413b6d267c..fd3ef9e1fa2de 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -125,11 +125,11 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { case Nil => if (userClass == null) { - printUsageAndExit(1) + throw new IllegalArgumentException(getUsageMessage()) } case _ => - printUsageAndExit(1, args) + throw new IllegalArgumentException(getUsageMessage(args)) } } @@ -138,11 +138,10 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { } - def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { - if (unknownParam != null) { - System.err.println("Unknown/unsupported param " + unknownParam) - } - System.err.println( + def getUsageMessage(unknownParam: Any = null): String = { + val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" + + message + "Usage: org.apache.spark.deploy.yarn.Client [options] \n" + "Options:\n" + " --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" + @@ -158,8 +157,5 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" + " --files files Comma separated list of files to be distributed with the job.\n" + " --archives archives Comma separated list of archives to be distributed with the job." - ) - System.exit(exitCode) } - } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 801e8b381588f..6861b503000ca 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} -import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, ListBuffer, Map} @@ -37,8 +36,8 @@ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.{Apps, Records} -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.hadoop.yarn.util.Records +import org.apache.spark.{SparkException, Logging, SparkConf, SparkContext} /** * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The @@ -80,7 +79,7 @@ trait ClientBase extends Logging { ).foreach { case(cond, errStr) => if (cond) { logError(errStr) - args.printUsageAndExit(1) + throw new IllegalArgumentException(args.getUsageMessage()) } } } @@ -95,15 +94,20 @@ trait ClientBase extends Logging { // If we have requested more then the clusters max for a single resource then exit. if (args.executorMemory > maxMem) { - logError("Required executor memory (%d MB), is above the max threshold (%d MB) of this cluster.". - format(args.executorMemory, maxMem)) - System.exit(1) + val errorMessage = + "Required executor memory (%d MB), is above the max threshold (%d MB) of this cluster." + .format(args.executorMemory, maxMem) + + logError(errorMessage) + throw new IllegalArgumentException(errorMessage) } val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD if (amMem > maxMem) { - logError("Required AM memory (%d) is above the max threshold (%d) of this cluster". - format(args.amMemory, maxMem)) - System.exit(1) + + val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster." + .format(args.amMemory, maxMem) + logError(errorMessage) + throw new IllegalArgumentException(errorMessage) } // We could add checks to make sure the entire cluster has enough resources but that involves @@ -169,14 +173,13 @@ trait ClientBase extends Logging { destPath } - def qualifyForLocal(localURI: URI): Path = { + private def qualifyForLocal(localURI: URI): Path = { var qualifiedURI = localURI - // If not specified assume these are in the local filesystem to keep behavior like Hadoop + // If not specified, assume these are in the local filesystem to keep behavior like Hadoop if (qualifiedURI.getScheme() == null) { qualifiedURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(qualifiedURI)).toString) } - val qualPath = new Path(qualifiedURI) - qualPath + new Path(qualifiedURI) } def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { @@ -188,8 +191,9 @@ trait ClientBase extends Logging { val delegTokenRenewer = Master.getMasterPrincipal(conf) if (UserGroupInformation.isSecurityEnabled()) { if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - logError("Can't get Master Kerberos principal for use as renewer") - System.exit(1) + val errorMessage = "Can't get Master Kerberos principal for use as renewer" + logError(errorMessage) + throw new SparkException(errorMessage) } } val dst = new Path(fs.getHomeDirectory(), appStagingDir) @@ -305,13 +309,13 @@ trait ClientBase extends Logging { val amMemory = calculateAMMemory(newApp) - val JAVA_OPTS = ListBuffer[String]() + val javaOpts = ListBuffer[String]() // Add Xmx for AM memory - JAVA_OPTS += "-Xmx" + amMemory + "m" + javaOpts += "-Xmx" + amMemory + "m" val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) - JAVA_OPTS += "-Djava.io.tmpdir=" + tmpDir + javaOpts += "-Djava.io.tmpdir=" + tmpDir // TODO: Remove once cpuset version is pushed out. // The context is, default gc for server class machines ends up using all cores to do gc - @@ -325,11 +329,11 @@ trait ClientBase extends Logging { if (useConcurrentAndIncrementalGC) { // In our expts, using (default) throughput collector has severe perf ramifications in // multi-tenant machines - JAVA_OPTS += "-XX:+UseConcMarkSweepGC" - JAVA_OPTS += "-XX:+CMSIncrementalMode" - JAVA_OPTS += "-XX:+CMSIncrementalPacing" - JAVA_OPTS += "-XX:CMSIncrementalDutyCycleMin=0" - JAVA_OPTS += "-XX:CMSIncrementalDutyCycle=10" + javaOpts += "-XX:+UseConcMarkSweepGC" + javaOpts += "-XX:+CMSIncrementalMode" + javaOpts += "-XX:+CMSIncrementalPacing" + javaOpts += "-XX:CMSIncrementalDutyCycleMin=0" + javaOpts += "-XX:CMSIncrementalDutyCycle=10" } // SPARK_JAVA_OPTS is deprecated, but for backwards compatibility: @@ -344,22 +348,22 @@ trait ClientBase extends Logging { // If we are being launched in client mode, forward the spark-conf options // onto the executor launcher for ((k, v) <- sparkConf.getAll) { - JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" + javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } } else { // If we are being launched in standalone mode, capture and forward any spark // system properties (e.g. set by spark-class). for ((k, v) <- sys.props.filterKeys(_.startsWith("spark"))) { - JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" + javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } - sys.props.get("spark.driver.extraJavaOptions").foreach(opts => JAVA_OPTS += opts) - sys.props.get("spark.driver.libraryPath").foreach(p => JAVA_OPTS += s"-Djava.library.path=$p") + sys.props.get("spark.driver.extraJavaOptions").foreach(opts => javaOpts += opts) + sys.props.get("spark.driver.libraryPath").foreach(p => javaOpts += s"-Djava.library.path=$p") } - JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources) + javaOpts += ClientBase.getLog4jConfiguration(localResources) // Command for the ApplicationMaster val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ - JAVA_OPTS ++ + javaOpts ++ Seq(args.amClass, "--class", args.userClass, "--jar ", args.userJar, userArgsToString(args), "--executor-memory", args.executorMemory.toString, diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 32f8861dc9503..43dbb2464f929 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{Logging, SparkConf} @@ -46,19 +46,19 @@ trait ExecutorRunnableUtil extends Logging { executorCores: Int, localResources: HashMap[String, LocalResource]): List[String] = { // Extra options for the JVM - val JAVA_OPTS = ListBuffer[String]() + val javaOpts = ListBuffer[String]() // Set the JVM memory val executorMemoryString = executorMemory + "m" - JAVA_OPTS += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " + javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " // Set extra Java options for the executor, if defined sys.props.get("spark.executor.extraJavaOptions").foreach { opts => - JAVA_OPTS += opts + javaOpts += opts } - JAVA_OPTS += "-Djava.io.tmpdir=" + + javaOpts += "-Djava.io.tmpdir=" + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) - JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources) + javaOpts += ClientBase.getLog4jConfiguration(localResources) // Certain configs need to be passed here because they are needed before the Executor // registers with the Scheduler and transfers the spark configs. Since the Executor backend @@ -66,10 +66,10 @@ trait ExecutorRunnableUtil extends Logging { // authentication settings. sparkConf.getAll. filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. - foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } sparkConf.getAkkaConf. - foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. @@ -88,11 +88,11 @@ trait ExecutorRunnableUtil extends Logging { // multi-tennent machines // The options are based on // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline - JAVA_OPTS += " -XX:+UseConcMarkSweepGC " - JAVA_OPTS += " -XX:+CMSIncrementalMode " - JAVA_OPTS += " -XX:+CMSIncrementalPacing " - JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " - JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + javaOpts += " -XX:+UseConcMarkSweepGC " + javaOpts += " -XX:+CMSIncrementalMode " + javaOpts += " -XX:+CMSIncrementalPacing " + javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 " + javaOpts += " -XX:CMSIncrementalDutyCycle=10 " } */ @@ -104,7 +104,7 @@ trait ExecutorRunnableUtil extends Logging { // TODO: If the OOM is not recoverable by rescheduling it on different node, then do // 'something' to fail job ... akin to blacklisting trackers in mapred ? "-XX:OnOutOfMemoryError='kill %p'") ++ - JAVA_OPTS ++ + javaOpts ++ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", masterAddress.toString, slaveId.toString, diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index a4638cc863611..39cdd2e8a522b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -33,10 +33,11 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) def this(sc: SparkContext) = this(sc, new Configuration()) - // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate - // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?) - // Subsequent creations are ignored - since nodes are already allocated by then. - + // Nothing else for now ... initialize application master : which needs a SparkContext to + // determine how to allocate. + // Note that only the first creation of a SparkContext influences (and ideally, there must be + // only one SparkContext, right ?). Subsequent creations are ignored since executors are already + // allocated by then. // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { @@ -48,6 +49,7 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) override def postStartHook() { val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc) if (sparkContextInitialized){ + ApplicationMaster.waitForInitialAllocations() // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt Thread.sleep(3000L) } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 33a60d978c586..6244332f23737 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -19,13 +19,12 @@ package org.apache.spark.deploy.yarn import java.io.IOException import java.util.concurrent.CopyOnWriteArrayList -import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.net.NetUtils import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.protocolrecords._ @@ -33,8 +32,7 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} @@ -77,17 +75,18 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // than user specified and /tmp. System.setProperty("spark.local.dir", getLocalDirs()) - // set the web ui port to be ephemeral for yarn so we don't conflict with + // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") - // when running the AM, the Spark master is always "yarn-cluster" + // When running the AM, the Spark master is always "yarn-cluster" System.setProperty("spark.master", "yarn-cluster") - // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using. + // Use priority 30 as it's higher than HDFS. It's the same priority MapReduce is using. ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) - appAttemptId = getApplicationAttemptId() + appAttemptId = ApplicationMaster.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts amClient = AMRMClient.createAMRMClient() amClient.init(yarnConf) @@ -99,7 +98,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, ApplicationMaster.register(this) // Call this to force generation of secret so it gets populated into the - // hadoop UGI. This has to happen before the startUserClass which does a + // Hadoop UGI. This has to happen before the startUserClass which does a // doAs in order for the credentials to be passed on to the executor containers. val securityMgr = new SecurityManager(sparkConf) @@ -121,7 +120,10 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // Allocate all containers allocateExecutors() - // Wait for the user class to Finish + // Launch thread that will heartbeat to the RM so it won't think the app has died. + launchReporterThread() + + // Wait for the user class to finish userThread.join() System.exit(0) @@ -141,7 +143,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, "spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) } - /** Get the Yarn approved local directories. */ + // Get the Yarn approved local directories. private def getLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the // local dirs, so lets check both. We assume one of the 2 is set. @@ -150,18 +152,9 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, .orElse(Option(System.getenv("LOCAL_DIRS"))) localDirs match { - case None => throw new Exception("Yarn Local dirs can't be empty") + case None => throw new Exception("Yarn local dirs can't be empty") case Some(l) => l } - } - - private def getApplicationAttemptId(): ApplicationAttemptId = { - val envs = System.getenv() - val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) - val containerId = ConverterUtils.toContainerId(containerIdString) - val appAttemptId = containerId.getApplicationAttemptId() - logInfo("ApplicationAttemptId: " + appAttemptId) - appAttemptId } private def registerApplicationMaster(): RegisterApplicationMasterResponse = { @@ -173,25 +166,23 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, logInfo("Starting the user JAR in a separate Thread") val mainMethod = Class.forName( args.userClass, - false /* initialize */ , + false, Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) val t = new Thread { override def run() { - - var successed = false + var succeeded = false try { // Copy - var mainArgs: Array[String] = new Array[String](args.userArgs.size) + val mainArgs = new Array[String](args.userArgs.size) args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) mainMethod.invoke(null, mainArgs) - // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR - // userThread will stop here unless it has uncaught exception thrown out - // It need shutdown hook to set SUCCEEDED - successed = true + // Some apps have "System.exit(0)" at the end. The user thread will stop here unless + // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. + succeeded = true } finally { - logDebug("finishing main") + logDebug("Finishing main") isLastAMRetry = true - if (successed) { + if (succeeded) { ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) } else { ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED) @@ -199,11 +190,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } } } + t.setName("Driver") t.start() t } - // This need to happen before allocateExecutors() + // This needs to happen before allocateExecutors() private def waitForSparkContextInitialized() { logInfo("Waiting for Spark context initialization") try { @@ -231,7 +223,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, sparkContext.preferredNodeLocationData, sparkContext.getConf) } else { - logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d". + logWarning("Unable to retrieve SparkContext in spite of waiting for %d, maxNumTries = %d". format(numTries * waitTime, maxNumTries)) this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, @@ -242,48 +234,37 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } } } finally { - // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT : - // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks. - ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + // In case of exceptions, etc - ensure that the loop in + // ApplicationMaster#sparkContextInitialized() breaks. + ApplicationMaster.doneWithInitialAllocations() } } private def allocateExecutors() { try { - logInfo("Allocating " + args.numExecutors + " executors.") - // Wait until all containers have finished + logInfo("Requesting" + args.numExecutors + " executors.") + // Wait until all containers have launched yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() // Exits the loop if the user thread exits. + + var iters = 0 while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() - ApplicationMaster.incrementAllocatorLoop(1) - Thread.sleep(100) + if (iters == ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) { + ApplicationMaster.doneWithInitialAllocations() + } + Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) + iters += 1 } } finally { - // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT, - // so that the loop in ApplicationMaster#sparkContextInitialized() breaks. - ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + // In case of exceptions, etc - ensure that the loop in + // ApplicationMaster#sparkContextInitialized() breaks. + ApplicationMaster.doneWithInitialAllocations() } logInfo("All executors have launched.") - - // Launch a progress reporter thread, else the app will get killed after expiration - // (def: 10mins) timeout. - if (userThread.isAlive) { - // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. - val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - - // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) - - // must be <= timeoutInterval / 2. - val interval = math.min(timeoutInterval / 2, schedulerInterval) - - launchReporterThread(interval) - } } private def allocateMissingExecutor() { @@ -303,47 +284,35 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } } - private def launchReporterThread(_sleepTime: Long): Thread = { - val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime + private def launchReporterThread(): Thread = { + // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. + val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) + + // must be <= timeoutInterval / 2. + val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) val t = new Thread { override def run() { while (userThread.isAlive) { checkNumExecutorsFailed() allocateMissingExecutor() - sendProgress() - Thread.sleep(sleepTime) + logDebug("Sending progress") + yarnAllocator.allocateResources() + Thread.sleep(interval) } } } // Setting to daemon status, though this is usually not a good idea. t.setDaemon(true) t.start() - logInfo("Started progress reporter thread - sleep time : " + sleepTime) + logInfo("Started progress reporter thread - heartbeat interval : " + interval) t } - private def sendProgress() { - logDebug("Sending progress") - // Simulated with an allocate request with no nodes requested. - yarnAllocator.allocateResources() - } - - /* - def printContainers(containers: List[Container]) = { - for (container <- containers) { - logInfo("Launching shell command on a new container." - + ", containerId=" + container.getId() - + ", containerNode=" + container.getNodeId().getHost() - + ":" + container.getNodeId().getPort() - + ", containerNodeURI=" + container.getNodeHttpAddress() - + ", containerState" + container.getState() - + ", containerResourceMemory" - + container.getResource().getMemory()) - } - } - */ - def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") { synchronized { if (isFinished) { @@ -351,7 +320,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } isFinished = true - logInfo("finishApplicationMaster with " + status) + logInfo("Unregistering ApplicationMaster with " + status) if (registered) { val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") amClient.unregisterApplicationMaster(status, diagnostics, trackingUrl) @@ -386,7 +355,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, def run() { logInfo("AppMaster received a signal.") - // we need to clean up staging dir before HDFS is shut down + // We need to clean up staging dir before HDFS is shut down // make sure we don't delete it until this is the last AM if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir() } @@ -401,21 +370,24 @@ object ApplicationMaster { // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + private val ALLOCATE_HEARTBEAT_INTERVAL = 100 private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() val sparkContextRef: AtomicReference[SparkContext] = - new AtomicReference[SparkContext](null /* initialValue */) + new AtomicReference[SparkContext](null) - val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) + // Variable used to notify the YarnClusterScheduler that it should stop waiting + // for the initial set of executors to be started and get on with its business. + val doneWithInitialAllocationsMonitor = new Object() - def incrementAllocatorLoop(by: Int) { - val count = yarnAllocatorLoop.getAndAdd(by) - if (count >= ALLOCATOR_LOOP_WAIT_COUNT) { - yarnAllocatorLoop.synchronized { - // to wake threads off wait ... - yarnAllocatorLoop.notifyAll() - } + @volatile var isDoneWithInitialAllocations = false + + def doneWithInitialAllocations() { + isDoneWithInitialAllocations = true + doneWithInitialAllocationsMonitor.synchronized { + // to wake threads off wait ... + doneWithInitialAllocationsMonitor.notifyAll() } } @@ -423,7 +395,10 @@ object ApplicationMaster { applicationMasters.add(master) } - // TODO(harvey): See whether this should be discarded - it isn't used anywhere atm... + /** + * Called from YarnClusterScheduler to notify the AM code that a SparkContext has been + * initialized in the user code. + */ def sparkContextInitialized(sc: SparkContext): Boolean = { var modified = false sparkContextRef.synchronized { @@ -431,7 +406,7 @@ object ApplicationMaster { sparkContextRef.notifyAll() } - // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do + // Add a shutdown hook - as a best effort in case users do not call sc.stop or do // System.exit. // Should not really have to do this, but it helps YARN to evict resources earlier. // Not to mention, prevent the Client from declaring failure even though we exited properly. @@ -454,13 +429,29 @@ object ApplicationMaster { }) } - // Wait for initialization to complete and atleast 'some' nodes can get allocated. - yarnAllocatorLoop.synchronized { - while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) { - yarnAllocatorLoop.wait(1000L) + // Wait for initialization to complete and at least 'some' nodes to get allocated. + modified + } + + /** + * Returns when we've either + * 1) received all the requested executors, + * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms, + * 3) hit an error that causes us to terminate trying to get containers. + */ + def waitForInitialAllocations() { + doneWithInitialAllocationsMonitor.synchronized { + while (!isDoneWithInitialAllocations) { + doneWithInitialAllocationsMonitor.wait(1000L) } } - modified + } + + def getApplicationAttemptId(): ApplicationAttemptId = { + val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + appAttemptId } def main(argStrings: Array[String]) { diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 393edd1f2d670..80a8bceb17269 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,14 +21,12 @@ import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer -import org.apache.hadoop.yarn.api._ -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.YarnClient import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{Apps, Records} +import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SparkConf} @@ -97,12 +95,11 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def run() { val appId = runApp() monitorApplication(appId) - System.exit(0) } def logClusterResourceDetails() { val clusterMetrics: YarnClusterMetrics = yarnClient.getYarnClusterMetrics - logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " + + logInfo("Got Cluster metric info from ResourceManager, number of NodeManagers: " + clusterMetrics.getNumNodeManagers) val queueInfo: QueueInfo = yarnClient.getQueueInfo(args.amQueue) @@ -133,7 +130,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def submitApp(appContext: ApplicationSubmissionContext) = { // Submit the application to the applications manager. - logInfo("Submitting application to ASM") + logInfo("Submitting application to ResourceManager") yarnClient.submitApplication(appContext) } @@ -149,7 +146,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa Thread.sleep(interval) val report = yarnClient.getApplicationReport(appId) - logInfo("Application report from ASM: \n" + + logInfo("Application report from ResourceManager: \n" + "\t application identifier: " + appId.toString() + "\n" + "\t appId: " + appId.getId() + "\n" + "\t clientToAMToken: " + report.getClientToAMToken() + "\n" + @@ -188,9 +185,18 @@ object Client { // see Client#setupLaunchEnv(). System.setProperty("SPARK_YARN_MODE", "true") val sparkConf = new SparkConf() - val args = new ClientArguments(argStrings, sparkConf) - new Client(args, sparkConf).run() + try { + val args = new ClientArguments(argStrings, sparkConf) + new Client(args, sparkConf).run() + } catch { + case e: Exception => { + Console.err.println(e.getMessage) + System.exit(1) + } + } + + System.exit(0) } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index d93e5bb0225d5..f71ad036ce0f2 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -72,8 +72,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp override def preStart() { logInfo("Listen to driver: " + driverUrl) driver = context.actorSelection(driverUrl) - // Send a hello message thus the connection is actually established, - // thus we can monitor Lifecycle Events. + // Send a hello message to establish the connection, after which + // we can monitor Lifecycle Events. driver ! "Hello" context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } @@ -95,7 +95,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp amClient.init(yarnConf) amClient.start() - appAttemptId = getApplicationAttemptId() + appAttemptId = ApplicationMaster.getApplicationAttemptId() registerApplicationMaster() waitForSparkMaster() @@ -115,7 +115,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val interval = math.min(timeoutInterval / 2, schedulerInterval) reporterThread = launchReporterThread(interval) - + // Wait for the reporter thread to Finish. reporterThread.join() @@ -134,25 +134,16 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) .orElse(Option(System.getenv("LOCAL_DIRS"))) - + localDirs match { case None => throw new Exception("Yarn Local dirs can't be empty") case Some(l) => l } - } - - private def getApplicationAttemptId(): ApplicationAttemptId = { - val envs = System.getenv() - val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) - val containerId = ConverterUtils.toContainerId(containerIdString) - val appAttemptId = containerId.getApplicationAttemptId() - logInfo("ApplicationAttemptId: " + appAttemptId) - appAttemptId } private def registerApplicationMaster(): RegisterApplicationMasterResponse = { logInfo("Registering the ApplicationMaster") - // TODO:(Raymond) Find out Spark UI address and fill in here? + // TODO: Find out client's Spark UI address and fill in here? amClient.registerApplicationMaster(Utils.localHostName(), 0, "") } @@ -185,8 +176,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp private def allocateExecutors() { - - // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. + // TODO: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map() @@ -198,8 +188,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp preferredNodeLocationData, sparkConf) - logInfo("Allocating " + args.numExecutors + " executors.") - // Wait until all containers have finished + logInfo("Requesting " + args.numExecutors + " executors.") + // Wait until all containers have launched yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { @@ -221,7 +211,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } } - // TODO: We might want to extend this to allocate more containers in case they die ! private def launchReporterThread(_sleepTime: Long): Thread = { val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime @@ -229,7 +218,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp override def run() { while (!driverClosed) { allocateMissingExecutor() - sendProgress() + logDebug("Sending progress") + yarnAllocator.allocateResources() Thread.sleep(sleepTime) } } @@ -241,20 +231,14 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp t } - private def sendProgress() { - logDebug("Sending progress") - // simulated with an allocate request with no nodes requested ... - yarnAllocator.allocateResources() - } - def finishApplicationMaster(status: FinalApplicationStatus) { - logInfo("finish ApplicationMaster with " + status) - amClient.unregisterApplicationMaster(status, "" /* appMessage */ , "" /* appTrackingUrl */) + logInfo("Unregistering ApplicationMaster with " + status) + val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") + amClient.unregisterApplicationMaster(status, "" /* appMessage */ , trackingUrl) } } - object ExecutorLauncher { def main(argStrings: Array[String]) { val args = new ApplicationMasterArguments(argStrings)