diff --git a/core/pom.xml b/core/pom.xml index 04d4b9cc1068e..7c60cf10c3dc2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -192,8 +192,8 @@ org.tachyonproject - tachyon - 0.4.1-thrift + tachyon-client + 0.5.0 org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 3935c8772252e..ab2594cfc02eb 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -34,8 +34,8 @@ abstract class Dependency[T] extends Serializable { /** * :: DeveloperApi :: - * Base class for dependencies where each partition of the parent RDD is used by at most one - * partition of the child RDD. Narrow dependencies allow for pipelined execution. + * Base class for dependencies where each partition of the child RDD depends on a small number + * of partitions of the parent RDD. Narrow dependencies allow for pipelined execution. */ @DeveloperApi abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala new file mode 100644 index 0000000000000..24ccce21b62ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import akka.actor.Actor +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.scheduler.TaskScheduler + +/** + * A heartbeat from executors to the driver. This is a shared message used by several internal + * components to convey liveness or execution information for in-progress tasks. + */ +private[spark] case class Heartbeat( + executorId: String, + taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + blockManagerId: BlockManagerId) + +private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) + +/** + * Lives in the driver to receive heartbeats from executors.. + */ +private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor { + override def receive = { + case Heartbeat(executorId, taskMetrics, blockManagerId) => + val response = HeartbeatResponse( + !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) + sender ! response + } +} diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 50d8e93e1f0d7..807ef3e9c9d60 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -45,10 +45,7 @@ trait Logging { initializeIfNecessary() var className = this.getClass.getName // Ignore trailing $'s in the class names for Scala objects - if (className.endsWith("$")) { - className = className.substring(0, className.length - 1) - } - log_ = LoggerFactory.getLogger(className) + log_ = LoggerFactory.getLogger(className.stripSuffix("$")) } log_ } @@ -110,23 +107,27 @@ trait Logging { } private def initializeLogging() { - // If Log4j is being used, but is not initialized, load a default properties file - val binder = StaticLoggerBinder.getSingleton - val usingLog4j = binder.getLoggerFactoryClassStr.endsWith("Log4jLoggerFactory") - val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4jInitialized && usingLog4j) { + // Don't use a logger in here, as this is itself occurring during initialization of a logger + // If Log4j 1.2 is being used, but is not initialized, load a default properties file + val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr + // This distinguishes the log4j 1.2 binding, currently + // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently + // org.apache.logging.slf4j.Log4jLoggerFactory + val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized && usingLog4j12) { val defaultLogProps = "org/apache/spark/log4j-defaults.properties" Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { case Some(url) => PropertyConfigurator.configure(url) - log.info(s"Using Spark's default log4j profile: $defaultLogProps") + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") case None => System.err.println(s"Spark was unable to load $defaultLogProps") } } Logging.initialized = true - // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads + // Force a call into slf4j to initialize it. Avoids this happening from multiple threads // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html log } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 17e9ef902621a..9ba21cfcde01a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary +import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast @@ -307,6 +308,8 @@ class SparkContext(config: SparkConf) extends Logging { // Create and start the scheduler private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master) + private val heartbeatReceiver = env.actorSystem.actorOf( + Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") @volatile private[spark] var dagScheduler: DAGScheduler = _ try { dagScheduler = new DAGScheduler(this) @@ -455,7 +458,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ - def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -992,7 +995,9 @@ class SparkContext(config: SparkConf) extends Logging { val dagSchedulerCopy = dagScheduler dagScheduler = null if (dagSchedulerCopy != null) { + env.metricsSystem.report() metadataCleaner.cancel() + env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) dagSchedulerCopy.stop() taskScheduler = null @@ -1453,9 +1458,9 @@ object SparkContext extends Logging { /** Creates a task scheduler based on a given master URL. Extracted for testing. */ private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = { // Regular expression used for local[N] and local[*] master formats - val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters @@ -1485,8 +1490,12 @@ object SparkContext extends Logging { scheduler case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + def localCpuCount = Runtime.getRuntime.availableProcessors() + // local[*, M] means the number of cores on the computer with M failures + // local[N, M] means exactly N threads with M failures + val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threads.toInt) + val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) scheduler diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 6ee731b22c03c..92c809d854167 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -193,13 +193,7 @@ object SparkEnv extends Logging { logInfo("Registering " + name) actorSystem.actorOf(Props(newActor), name = name) } else { - val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") - val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" - val timeout = AkkaUtils.lookupTimeout(conf) - logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + AkkaUtils.makeDriverRef(name, conf, actorSystem) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 47708cb2e78bd..76d4193e96aea 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -783,6 +783,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) sortByKey(comp, ascending) } + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] + sortByKey(comp, ascending, numPartitions) + } + /** * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling * `collect` or `save` on the resulting RDD will return or output an ordered list of records diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a9d758bf998c3..94d666aa92025 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -731,19 +731,30 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) + /** + * We try to reuse a single Socket to transfer accumulator updates, as they are all added + * by the DAGScheduler's single-threaded actor anyway. + */ + @transient var socket: Socket = _ + + def openSocket(): Socket = synchronized { + if (socket == null || socket.isClosed) { + socket = new Socket(serverHost, serverPort) + } + socket + } + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = { + : JList[Array[Byte]] = synchronized { if (serverHost == null) { // This happens on the worker node, where we just want to remember all the updates val1.addAll(val2) val1 } else { // This happens on the master, where we pass the updates to Python through a socket - val socket = new Socket(serverHost, serverPort) - // SPARK-2282: Immediately reuse closed sockets because we create one per task. - socket.setReuseAddress(true) + val socket = openSocket() val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) @@ -757,7 +768,6 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: if (byteRead == -1) { throw new SparkException("EOF reached before Python server acknowledged") } - socket.close() null } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 21f8667819c44..a70ecdb375373 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -154,6 +154,8 @@ private[spark] class Master( } override def postStop() { + masterMetricsSystem.report() + applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { recoveryCompletionTask.cancel() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ce425443051b0..fb5252da96519 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -357,6 +357,7 @@ private[spark] class Worker( } override def postStop() { + metricsSystem.report() registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 860b47e056451..af736de405397 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -88,6 +88,7 @@ private[spark] class CoarseGrainedExecutorBackend( case StopExecutor => logInfo("Driver commanded a shutdown") + executor.stop() context.stop(self) context.system.shutdown() } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 3b69bc4ca4142..1bb1b4aae91bb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import java.util.concurrent._ import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark._ import org.apache.spark.scheduler._ @@ -48,6 +48,8 @@ private[spark] class Executor( private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) + @volatile private var isStopped = false + // No ip or host:port - just hostname Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -107,6 +109,8 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + startDriverHeartbeater() + def launchTask( context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { val tr = new TaskRunner(context, taskId, taskName, serializedTask) @@ -121,6 +125,12 @@ private[spark] class Executor( } } + def stop() { + env.metricsSystem.report() + isStopped = true + threadPool.shutdown() + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -137,11 +147,12 @@ private[spark] class Executor( } class TaskRunner( - execBackend: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) + execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) extends Runnable { @volatile private var killed = false - @volatile private var task: Task[Any] = _ + @volatile var task: Task[Any] = _ + @volatile var attemptedTask: Option[Task[Any]] = None def kill(interruptThread: Boolean) { logInfo(s"Executor is trying to kill $taskName (TID $taskId)") @@ -158,7 +169,6 @@ private[spark] class Executor( val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var attemptedTask: Option[Task[Any]] = None var taskStart: Long = 0 def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum val startGCTime = gcTime @@ -200,7 +210,6 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.hostname = Utils.localHostName() m.executorDeserializeTime = taskStart - startTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime @@ -350,4 +359,42 @@ private[spark] class Executor( } } } + + def startDriverHeartbeater() { + val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) + val timeout = AkkaUtils.lookupTimeout(conf) + val retryAttempts = AkkaUtils.numRetries(conf) + val retryIntervalMs = AkkaUtils.retryWaitMs(conf) + val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + + val t = new Thread() { + override def run() { + // Sleep a random interval so the heartbeats don't end up in sync + Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) + + while (!isStopped) { + val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + for (taskRunner <- runningTasks.values()) { + if (!taskRunner.attemptedTask.isEmpty) { + Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => + tasksMetrics += ((taskRunner.taskId, metrics)) + } + } + } + + val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, + retryAttempts, retryIntervalMs, timeout) + if (response.reregisterBlockManager) { + logWarning("Told to re-register on heartbeat") + env.blockManager.reregister() + } + Thread.sleep(interval) + } + } + } + t.setDaemon(true) + t.setName("Driver Heartbeater") + t.start() + } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 21fe643b8d71f..56cd8723a3a22 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -23,6 +23,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus} /** * :: DeveloperApi :: * Metrics tracked during the execution of a task. + * + * This class is used to house metrics both for in-progress and completed tasks. In executors, + * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread + * reads it to send in-progress metrics, and the task thread reads it to send metrics along with + * the completed task. + * + * So, when adding new fields, take into consideration that the whole object can be serialized for + * shipping off at any time to consumers of the SparkListener interface. */ @DeveloperApi class TaskMetrics extends Serializable { @@ -143,7 +151,7 @@ class ShuffleReadMetrics extends Serializable { /** * Absolute time when this task finished reading shuffle data */ - var shuffleFinishTime: Long = _ + var shuffleFinishTime: Long = -1 /** * Number of blocks fetched in this shuffle by this task (remote or local) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 651511da1b7fe..6ef817d0e587e 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -91,6 +91,10 @@ private[spark] class MetricsSystem private (val instance: String, sinks.foreach(_.stop) } + def report(): Unit = { + sinks.foreach(_.report()) + } + def registerSource(source: Source) { sources += source try { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 05852f1f98993..81b9056b40fb8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -57,5 +57,9 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 542dce65366b2..9d5f2ae9328ad 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -66,5 +66,9 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index aeb4ad44a0647..d7b5f5c40efae 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -81,4 +81,8 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index ed27234b4e760..2588fe2c9edb8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -35,4 +35,6 @@ private[spark] class JmxSink(val property: Properties, val registry: MetricRegis reporter.stop() } + override def report() { } + } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 571539ba5e467..2f65bc8b46609 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -57,4 +57,6 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr override def start() { } override def stop() { } + + override def report() { } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 6f2b5a06027ea..0d83d8c425ca4 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -20,4 +20,5 @@ package org.apache.spark.metrics.sink private[spark] trait Sink { def start: Unit def stop: Unit + def report(): Unit } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index e7221e3032c11..11ebafbf6d457 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -49,8 +49,8 @@ private[spark] case class CoalescedRDDPartition( } /** - * Computes how many of the parents partitions have getPreferredLocation - * as one of their preferredLocations + * Computes the fraction of the parents' partitions containing preferredLocation within + * their getPreferredLocs. * @return locality of this coalesced partition between 0 and 1 */ def localFraction: Double = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 50186d097a632..d87c3048985fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -21,7 +21,7 @@ import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -29,7 +29,6 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import akka.actor._ -import akka.actor.OneForOneStrategy import akka.actor.SupervisorStrategy.Stop import akka.pattern.ask import akka.util.Timeout @@ -39,8 +38,9 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} +import org.apache.spark.storage._ import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} +import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -154,6 +154,23 @@ class DAGScheduler( eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) } + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics) + blockManagerId: BlockManagerId): Boolean = { + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) + implicit val timeout = Timeout(600 seconds) + + Await.result( + blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), + timeout.duration).asInstanceOf[Boolean] + } + // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { eventProcessActor ! ExecutorLost(execId) @@ -194,11 +211,15 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => + // We are going to register ancestor shuffle dependencies + registerShuffleDependencies(shuffleDep, jobId) + // Then register current shuffleDep val stage = newOrUsedStage( shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, shuffleDep.rdd.creationSite) shuffleToMapStage(shuffleDep.shuffleId) = stage + stage } } @@ -263,6 +284,9 @@ class DAGScheduler( private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(r: RDD[_]) { if (!visited(r)) { visited += r @@ -273,18 +297,69 @@ class DAGScheduler( case shufDep: ShuffleDependency[_, _, _] => parents += getShuffleMapStage(shufDep, jobId) case _ => - visit(dep.rdd) + waitingForVisit.push(dep.rdd) } } } } - visit(rdd) + waitingForVisit.push(rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } parents.toList } + // Find ancestor missing shuffle dependencies and register into shuffleToMapStage + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = { + val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) + while (!parentsWithNoMapStage.isEmpty) { + val currentShufDep = parentsWithNoMapStage.pop() + val stage = + newOrUsedStage( + currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId, + currentShufDep.rdd.creationSite) + shuffleToMapStage(currentShufDep.shuffleId) = stage + } + } + + // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet + private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { + val parents = new Stack[ShuffleDependency[_, _, _]] + val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] + def visit(r: RDD[_]) { + if (!visited(r)) { + visited += r + for (dep <- r.dependencies) { + dep match { + case shufDep: ShuffleDependency[_, _, _] => + if (!shuffleToMapStage.contains(shufDep.shuffleId)) { + parents.push(shufDep) + } + + waitingForVisit.push(shufDep.rdd) + case _ => + waitingForVisit.push(dep.rdd) + } + } + } + } + + waitingForVisit.push(rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } + parents + } + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd @@ -297,13 +372,16 @@ class DAGScheduler( missing += mapStage } case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) + waitingForVisit.push(narrowDep.rdd) } } } } } - visit(stage.rdd) + waitingForVisit.push(stage.rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } missing.toList } @@ -1102,6 +1180,9 @@ class DAGScheduler( } val visitedRdds = new HashSet[RDD[_]] val visitedStages = new HashSet[Stage] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visitedRdds(rdd)) { visitedRdds += rdd @@ -1111,15 +1192,18 @@ class DAGScheduler( val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { visitedStages += mapStage - visit(mapStage.rdd) + waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) + waitingForVisit.push(narrowDep.rdd) } } } } - visit(stage.rdd) + waitingForVisit.push(stage.rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } visitedRdds.contains(target.rdd) } @@ -1131,6 +1215,22 @@ class DAGScheduler( */ private[spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { + getPreferredLocsInternal(rdd, partition, new HashSet) + } + + /** Recursive implementation for getPreferredLocs. */ + private def getPreferredLocsInternal( + rdd: RDD[_], + partition: Int, + visited: HashSet[(RDD[_],Int)]) + : Seq[TaskLocation] = + { + // If the partition has already been visited, no need to re-visit. + // This avoids exponential path exploration. SPARK-695 + if (!visited.add((rdd,partition))) { + // Nil has already been returned for previously visited partitions. + return Nil + } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) if (!cached.isEmpty) { @@ -1147,7 +1247,7 @@ class DAGScheduler( rdd.dependencies.foreach { case n: NarrowDependency[_] => for (inPart <- n.getParents(partition)) { - val locs = getPreferredLocs(n.rdd, inPart) + val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 82163eadd56e9..d01d318633877 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -75,6 +75,12 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId) @DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorMetricsUpdate( + execId: String, + taskMetrics: Seq[(Long, Int, TaskMetrics)]) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String) extends SparkListenerEvent @@ -158,6 +164,11 @@ trait SparkListener { * Called when the application ends */ def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { } + + /** + * Called when the driver receives task metrics from an executor in a heartbeat. + */ + def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index ed9fb24bc8ce8..e79ffd7a3587d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -68,6 +68,8 @@ private[spark] trait SparkListenerBus extends Logging { foreachListener(_.onApplicationStart(applicationStart)) case applicationEnd: SparkListenerApplicationEnd => foreachListener(_.onApplicationEnd(applicationEnd)) + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + foreachListener(_.onExecutorMetricsUpdate(metricsUpdate)) case SparkListenerShutdown => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5871edeb856ad..5c5e421404a21 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -26,6 +26,8 @@ import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.Utils + /** * A unit of execution. We have two kinds of Task's in Spark: @@ -44,6 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + context.taskMetrics.hostname = Utils.localHostName(); taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 819c35257b5a7..1a0b877c8a5e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -18,6 +18,8 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId /** * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. @@ -54,4 +56,12 @@ private[spark] trait TaskScheduler { // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. def defaultParallelism(): Int + + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index be3673c48eda8..d2f764fc22f54 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -32,6 +32,9 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.util.Utils +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import akka.actor.Props /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. @@ -320,6 +323,26 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + override def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + blockManagerId: BlockManagerId): Boolean = { + val metricsWithStageIds = taskMetrics.flatMap { + case (id, metrics) => { + taskIdToTaskSetId.get(id) + .flatMap(activeTaskSets.get) + .map(_.stageId) + .map(x => (id, x, metrics)) + } + } + dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) + } + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { taskSetManager.handleTaskGettingResult(tid) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 5b897597fa285..3d1cf312ccc97 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -23,8 +23,9 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.executor.{TaskMetrics, Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.storage.BlockManagerId private case class ReviveOffers() @@ -32,6 +33,8 @@ private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: private case class KillTask(taskId: Long, interruptThread: Boolean) +private case class StopExecutor() + /** * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend @@ -63,6 +66,9 @@ private[spark] class LocalActor( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + + case StopExecutor => + executor.stop() } def reviveOffers() { @@ -91,6 +97,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: } override def stop() { + localActor ! StopExecutor } override def reviveOffers() { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index e32ad9c036ad4..7c9dc8e5f88ef 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -35,8 +36,8 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, - Serializer.getSerializer(dep.serializer)) + val ser = Serializer.getSerializer(dep.serializer) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { @@ -54,16 +55,13 @@ private[spark] class HashShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - // Define a Comparator for the whole record based on the key Ordering. - val cmp = new Ordering[Product2[K, C]] { - override def compare(o1: Product2[K, C], o2: Product2[K, C]): Int = { - keyOrd.compare(o1._1, o2._1) - } - } - val sortBuffer: Array[Product2[K, C]] = aggregatedIter.toArray - // TODO: do external sort. - scala.util.Sorting.quickSort(sortBuffer)(cmp) - sortBuffer.iterator + // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, + // the ExternalSorter won't spill to disk. + val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) + sorter.write(aggregatedIter) + context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + sorter.iterator case None => aggregatedIter } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 69905a960a2ca..ccf830e118ee7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -200,14 +200,17 @@ object BlockFetcherIterator { // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight for (id <- localBlocksToFetch) { - getLocalFromDisk(id, serializer) match { - case Some(iter) => { - // Pass 0 as size since it's not in flight - results.put(new FetchResult(id, 0, () => iter)) - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") + try { + // getLocalFromDisk never return None but throws BlockException + val iter = getLocalFromDisk(id, serializer).get + // Pass 0 as size since it's not in flight + results.put(new FetchResult(id, 0, () => iter)) + logDebug("Got local block " + id) + } catch { + case e: Exception => { + logError(s"Error occurred while fetching local blocks", e) + results.put(new FetchResult(id, -1, null)) + return } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d746526639e58..c0a06017945f0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -116,15 +116,6 @@ private[spark] class BlockManager( private var asyncReregisterTask: Future[Unit] = null private val asyncReregisterLock = new Object - private def heartBeat(): Unit = { - if (!master.sendHeartBeat(blockManagerId)) { - reregister() - } - } - - private val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - private var heartBeatTask: Cancellable = null - private val metadataCleaner = new MetadataCleaner( MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) private val broadcastCleaner = new MetadataCleaner( @@ -161,11 +152,6 @@ private[spark] class BlockManager( private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) - if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { - heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { - Utils.tryOrExit { heartBeat() } - } - } } /** @@ -195,7 +181,7 @@ private[spark] class BlockManager( * * Note that this method must be called without any BlockInfo locks held. */ - private def reregister(): Unit = { + def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") master.registerBlockManager(blockManagerId, maxMemory, slaveActor) @@ -1065,9 +1051,6 @@ private[spark] class BlockManager( } def stop(): Unit = { - if (heartBeatTask != null) { - heartBeatTask.cancel() - } connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() @@ -1095,12 +1078,6 @@ private[spark] object BlockManager extends Logging { (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } - def getHeartBeatFrequency(conf: SparkConf): Long = - conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) / 4 - - def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean = - conf.getBoolean("spark.test.disableBlockManagerHeartBeat", false) - /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that * might cause errors if one attempts to read from the unmapped buffer, but it's better than diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7897fade2df2b..669307765d1fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -21,7 +21,6 @@ import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global import akka.actor._ -import akka.pattern.ask import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ @@ -29,8 +28,8 @@ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Logging { - val AKKA_RETRY_ATTEMPTS: Int = conf.getInt("spark.akka.num.retries", 3) - val AKKA_RETRY_INTERVAL_MS: Int = conf.getInt("spark.akka.retry.wait", 3000) + private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) + private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" @@ -42,15 +41,6 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log logInfo("Removed " + execId + " successfully in removeExecutor") } - /** - * Send the driver actor a heart beat from the slave. Returns true if everything works out, - * false if the driver does not know about the given block manager, which means the block - * manager should re-register. - */ - def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { - askDriverWithReply[Boolean](HeartBeat(blockManagerId)) - } - /** Register the BlockManager's id with the driver. */ def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { logInfo("Trying to register BlockManager") @@ -223,33 +213,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log * throw a SparkException if this fails. */ private def askDriverWithReply[T](message: Any): T = { - // TODO: Consider removing multiple attempts - if (driverActor == null) { - throw new SparkException("Error sending message to BlockManager as driverActor is null " + - "[message = " + message + "]") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < AKKA_RETRY_ATTEMPTS) { - attempts += 1 - try { - val future = driverActor.ask(message)(timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("BlockManagerMaster returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) - } - Thread.sleep(AKKA_RETRY_INTERVAL_MS) - } - - throw new SparkException( - "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) + AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, + timeout) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index c17cf55b0d907..bd31e3c5a187f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -52,25 +52,24 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) - val slaveTimeout = conf.get("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequency(conf) * 3)).toLong + val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", + math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) - val checkTimeoutInterval = conf.get("spark.storage.blockManagerTimeoutIntervalMs", - "60000").toLong + val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", + 60000) var timeoutCheckingTask: Cancellable = null override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - } + import context.dispatcher + timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, + checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) super.preStart() } def receive = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + logInfo("received a register") register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -129,8 +128,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case ExpireDeadHosts => expireDeadHosts() - case HeartBeat(blockManagerId) => - sender ! heartBeat(blockManagerId) + case BlockManagerHeartbeat(blockManagerId) => + sender ! heartbeatReceived(blockManagerId) case other => logWarning("Got unknown message: " + other) @@ -216,7 +215,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val minSeenTime = now - slaveTimeout val toRemove = new mutable.HashSet[BlockManagerId] for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime) { + if (info.lastSeenMs < minSeenTime && info.blockManagerId.executorId != "") { logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") toRemove += info.blockManagerId @@ -230,7 +229,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) } - private def heartBeat(blockManagerId: BlockManagerId): Boolean = { + /** + * Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { blockManagerId.executorId == "" && !isLocal } else { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 2b53bf33b5fba..10b65286fb7db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef -private[storage] object BlockManagerMessages { +private[spark] object BlockManagerMessages { ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. ////////////////////////////////////////////////////////////////////////////////// @@ -53,8 +53,6 @@ private[storage] object BlockManagerMessages { sender: ActorRef) extends ToBlockManagerMaster - case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: BlockId, @@ -124,5 +122,7 @@ private[storage] object BlockManagerMessages { case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true) extends ToBlockManagerMaster + case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index efb527b4f03e6..da2f5d3172fe2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -130,32 +130,16 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) - // create executor summary map if necessary - val executorSummaryMap = stageData.executorSummary - executorSummaryMap.getOrElseUpdate(key = info.executorId, op = new ExecutorSummary) - - executorSummaryMap.get(info.executorId).foreach { y => - // first update failed-task, succeed-task - taskEnd.reason match { - case Success => - y.succeededTasks += 1 - case _ => - y.failedTasks += 1 - } - - // update duration - y.taskTime += info.duration - - val metrics = taskEnd.taskMetrics - if (metrics != null) { - metrics.inputMetrics.foreach { y.inputBytes += _.bytesRead } - metrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } - metrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } - y.memoryBytesSpilled += metrics.memoryBytesSpilled - y.diskBytesSpilled += metrics.diskBytesSpilled - } + val execSummaryMap = stageData.executorSummary + val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary) + + taskEnd.reason match { + case Success => + execSummary.succeededTasks += 1 + case _ => + execSummary.failedTasks += 1 } - + execSummary.taskTime += info.duration stageData.numActiveTasks -= 1 val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = @@ -171,28 +155,75 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { (Some(e.toErrorString), None) } + if (!metrics.isEmpty) { + val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics) + updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics) + } - val taskRunTime = metrics.map(_.executorRunTime).getOrElse(0L) - stageData.executorRunTime += taskRunTime - val inputBytes = metrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L) - stageData.inputBytes += inputBytes - - val shuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L) - stageData.shuffleReadBytes += shuffleRead - - val shuffleWrite = - metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L) - stageData.shuffleWriteBytes += shuffleWrite - - val memoryBytesSpilled = metrics.map(_.memoryBytesSpilled).getOrElse(0L) - stageData.memoryBytesSpilled += memoryBytesSpilled + val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) + taskData.taskInfo = info + taskData.taskMetrics = metrics + taskData.errorMessage = errorMessage + } + } - val diskBytesSpilled = metrics.map(_.diskBytesSpilled).getOrElse(0L) - stageData.diskBytesSpilled += diskBytesSpilled + /** + * Upon receiving new metrics for a task, updates the per-stage and per-executor-per-stage + * aggregate metrics by calculating deltas between the currently recorded metrics and the new + * metrics. + */ + def updateAggregateMetrics( + stageData: StageUIData, + execId: String, + taskMetrics: TaskMetrics, + oldMetrics: Option[TaskMetrics]) { + val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) + + val shuffleWriteDelta = + (taskMetrics.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L)) + stageData.shuffleWriteBytes += shuffleWriteDelta + execSummary.shuffleWrite += shuffleWriteDelta + + val shuffleReadDelta = + (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L)) + stageData.shuffleReadBytes += shuffleReadDelta + execSummary.shuffleRead += shuffleReadDelta + + val diskSpillDelta = + taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) + stageData.diskBytesSpilled += diskSpillDelta + execSummary.diskBytesSpilled += diskSpillDelta + + val memorySpillDelta = + taskMetrics.memoryBytesSpilled - oldMetrics.map(_.memoryBytesSpilled).getOrElse(0L) + stageData.memoryBytesSpilled += memorySpillDelta + execSummary.memoryBytesSpilled += memorySpillDelta + + val timeDelta = + taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L) + stageData.executorRunTime += timeDelta + } - stageData.taskData(info.taskId) = new TaskUIData(info, metrics, errorMessage) + override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { + for ((taskId, sid, taskMetrics) <- executorMetricsUpdate.taskMetrics) { + val stageData = stageIdToData.getOrElseUpdate(sid, { + logWarning("Metrics update for task in unknown stage " + sid) + new StageUIData + }) + val taskData = stageData.taskData.get(taskId) + taskData.map { t => + if (!t.taskInfo.finished) { + updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics, + t.taskMetrics) + + // Overwrite task metrics + t.taskMetrics = Some(taskMetrics) + } + } } - } // end of onTaskEnd + } override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index be11a11695b01..2f96f7909c199 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -55,8 +55,11 @@ private[jobs] object UIData { var executorSummary = new HashMap[String, ExecutorSummary] } + /** + * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. + */ case class TaskUIData( - taskInfo: TaskInfo, - taskMetrics: Option[TaskMetrics] = None, - errorMessage: Option[String] = None) + var taskInfo: TaskInfo, + var taskMetrics: Option[TaskMetrics] = None, + var errorMessage: Option[String] = None) } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 9930c717492f2..feafd654e9e71 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -18,13 +18,16 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap +import scala.concurrent.Await import scala.concurrent.duration.{Duration, FiniteDuration} -import akka.actor.{ActorSystem, ExtendedActorSystem} +import akka.actor.{Actor, ActorRef, ActorSystem, ExtendedActorSystem} +import akka.pattern.ask + import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. @@ -124,4 +127,63 @@ private[spark] object AkkaUtils extends Logging { /** Space reserved for extra data in an Akka message besides serialized task or task result. */ val reservedSizeBytes = 200 * 1024 + + /** Returns the configured number of times to retry connecting */ + def numRetries(conf: SparkConf): Int = { + conf.getInt("spark.akka.num.retries", 3) + } + + /** Returns the configured number of milliseconds to wait on each retry */ + def retryWaitMs(conf: SparkConf): Int = { + conf.getInt("spark.akka.retry.wait", 3000) + } + + /** + * Send a message to the given actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + def askWithReply[T]( + message: Any, + actor: ActorRef, + retryAttempts: Int, + retryInterval: Int, + timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + if (actor == null) { + throw new SparkException("Error sending message as driverActor is null " + + "[message = " + message + "]") + } + var attempts = 0 + var lastException: Exception = null + while (attempts < retryAttempts) { + attempts += 1 + try { + val future = actor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning("Error sending message in " + attempts + " attempts", e) + } + Thread.sleep(retryInterval) + } + + throw new SparkException( + "Error sending message [message = " + message + "]", lastException) + } + + def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = { + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.getInt("spark.driver.port", 7077) + Utils.checkHost(driverHost, "Expected hostname") + val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" + val timeout = AkkaUtils.lookupTimeout(conf) + logInfo(s"Connecting to $name: $url") + Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 67e3be21c3c93..495a0d48633a4 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterEach, FunSuite, PrivateMethodTester} import org.apache.spark.scheduler.{TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} @@ -25,12 +25,12 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite - extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { + extends FunSuite with PrivateMethodTester with Logging with BeforeAndAfterEach { def createTaskScheduler(master: String): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. - sc = new SparkContext("local", "test") + val sc = new SparkContext("local", "test") val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master) sched.asInstanceOf[TaskSchedulerImpl] @@ -68,6 +68,15 @@ class SparkContextSchedulerCreationSuite } } + test("local-*-n-failures") { + val sched = createTaskScheduler("local[* ,2]") + assert(sched.maxTaskFailures === 2) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors()) + case _ => fail() + } + } + test("local-n-failures") { val sched = createTaskScheduler("local[4, 2]") assert(sched.maxTaskFailures === 2) @@ -77,6 +86,20 @@ class SparkContextSchedulerCreationSuite } } + test("bad-local-n") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + + test("bad-local-n-failures") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*,4]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + test("local-default-parallelism") { val defaultParallelism = System.getProperty("spark.default.parallelism") System.setProperty("spark.default.parallelism", "16") diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9021662bcf712..36e238b4c9434 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -23,12 +23,15 @@ import scala.language.reflectiveCalls import akka.actor._ import akka.testkit.{ImplicitSender, TestKit, TestActorRef} import org.scalatest.{BeforeAndAfter, FunSuiteLike} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite +import org.apache.spark.executor.TaskMetrics class BuggyDAGEventProcessActor extends Actor { val state = 0 @@ -63,7 +66,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike - with ImplicitSender with BeforeAndAfter with LocalSparkContext { + with ImplicitSender with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -77,6 +80,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F override def schedulingMode: SchedulingMode = SchedulingMode.NONE override def start() = {} override def stop() = {} + override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) @@ -291,6 +296,18 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } + test("avoid exponential blowup when getting preferred locs list") { + // Build up a complex dependency graph with repeated zip operations, without preferred locations. + var rdd: RDD[_] = new MyRDD(sc, 1, Nil) + (1 to 30).foreach(_ => rdd = rdd.zip(rdd)) + // getPreferredLocs runs quickly, indicating that exponential graph traversal is avoided. + failAfter(10 seconds) { + val preferredLocs = scheduler.getPreferredLocs(rdd,0) + // No preferred locations are returned. + assert(preferredLocs.length === 0) + } + } + test("unserializable task") { val unserializableRdd = new MyRDD(sc, 1, Nil) { class UnserializableClass @@ -342,6 +359,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 + override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true } val noKillScheduler = new DAGScheduler( sc, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala new file mode 100644 index 0000000000000..8dca2ebb312f5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.{FunSuite, Matchers} +import org.scalatest.PrivateMethodTester._ + +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.stubbing.Answer +import org.mockito.invocation.InvocationOnMock + +import org.apache.spark._ +import org.apache.spark.storage.BlockFetcherIterator._ +import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, + Message} + +class BlockFetcherIteratorSuite extends FunSuite with Matchers { + + test("block fetch from local fails using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + doReturn(connManager).when(blockManager).connectionManager + doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId + + doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + val answer = new Answer[Option[Iterator[Any]]] { + override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { + throw new Exception + } + } + + // 3rd block is going to fail + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) + doAnswer(answer).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + + // 3rd getLocalFromDisk invocation should be failed + verify(blockManager, times(3)).getLocalFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + // the 2nd element of the tuple returned by iterator.next should be defined when fetching successfully + assert(iterator.next._2.isDefined, "1st element should be defined but is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next._2.isDefined, "2nd element should be defined but is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + // 3rd fetch should be failed + assert(!iterator.next._2.isDefined, "3rd element should not be defined but is actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") + // Don't call next() after fetching non-defined element even if thare are rest of elements in the iterator. + // Otherwise, BasicBlockFetcherIterator hangs up. + } + + + test("block fetch from local succeed using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + doReturn(connManager).when(blockManager).connectionManager + doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId + + doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + + // All blocks should be fetched successfully + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + + // getLocalFromDis should be invoked for all of 5 blocks + verify(blockManager, times(5)).getLocalFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 1st element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next._2.isDefined, "All elements should be defined but 2nd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 3rd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 4th element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined") + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index dd4fd535d3577..0ac0269d7cfc1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,25 +19,28 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays +import java.util.concurrent.TimeUnit import akka.actor._ -import org.apache.spark.SparkConf -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} +import akka.pattern.ask +import akka.util.Timeout + import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers -import org.scalatest.time.SpanSugar._ import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Await +import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps @@ -76,7 +79,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter oldArch = System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.storage.disableBlockManagerHeartBeat", "true") conf.set("spark.driver.port", boundPort.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -344,7 +346,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("reregistration on heart beat") { - val heartBeat = PrivateMethod[Unit]('heartBeat) store = makeBlockManager(2000) val a1 = new Array[Byte](400) @@ -356,13 +357,15 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - store invokePrivate heartBeat() - assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + implicit val timeout = Timeout(30, TimeUnit.SECONDS) + val reregister = !Await.result( + master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), + timeout.duration).asInstanceOf[Boolean] + assert(reregister == true) } test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -380,7 +383,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("reregistration doesn't dead lock") { - val heartBeat = PrivateMethod[Unit]('heartBeat) store = makeBlockManager(2000) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -400,7 +402,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } val t3 = new Thread { override def run() { - store invokePrivate heartBeat() + store.reregister() } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 86a271eb67000..cb8252515238e 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -21,7 +21,8 @@ import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.spark._ -import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} +import org.apache.spark.{LocalSparkContext, SparkConf, Success} +import org.apache.spark.executor.{ShuffleWriteMetrics, ShuffleReadMetrics, TaskMetrics} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -129,4 +130,87 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(listener.stageIdToData(task.stageId).numCompleteTasks === 1) assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) } + + test("test update metrics") { + val conf = new SparkConf() + val listener = new JobProgressListener(conf) + + val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0)) + val execId = "exe-1" + + def makeTaskMetrics(base: Int) = { + val taskMetrics = new TaskMetrics() + val shuffleReadMetrics = new ShuffleReadMetrics() + val shuffleWriteMetrics = new ShuffleWriteMetrics() + taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) + shuffleReadMetrics.remoteBytesRead = base + 1 + shuffleReadMetrics.remoteBlocksFetched = base + 2 + shuffleWriteMetrics.shuffleBytesWritten = base + 3 + taskMetrics.executorRunTime = base + 4 + taskMetrics.diskBytesSpilled = base + 5 + taskMetrics.memoryBytesSpilled = base + 6 + taskMetrics + } + + def makeTaskInfo(taskId: Long, finishTime: Int = 0) = { + val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL, + false) + taskInfo.finishTime = finishTime + taskInfo + } + + listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1234L))) + listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1235L))) + listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1236L))) + listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1237L))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( + (1234L, 0, makeTaskMetrics(0)), + (1235L, 0, makeTaskMetrics(100)), + (1236L, 1, makeTaskMetrics(200))))) + + var stage0Data = listener.stageIdToData.get(0).get + var stage1Data = listener.stageIdToData.get(1).get + assert(stage0Data.shuffleReadBytes == 102) + assert(stage1Data.shuffleReadBytes == 201) + assert(stage0Data.shuffleWriteBytes == 106) + assert(stage1Data.shuffleWriteBytes == 203) + assert(stage0Data.executorRunTime == 108) + assert(stage1Data.executorRunTime == 204) + assert(stage0Data.diskBytesSpilled == 110) + assert(stage1Data.diskBytesSpilled == 205) + assert(stage0Data.memoryBytesSpilled == 112) + assert(stage1Data.memoryBytesSpilled == 206) + assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 2) + assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 102) + assert(stage1Data.taskData.get(1236L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 202) + + // task that was included in a heartbeat + listener.onTaskEnd(SparkListenerTaskEnd(0, taskType, Success, makeTaskInfo(1234L, 1), + makeTaskMetrics(300))) + // task that wasn't included in a heartbeat + listener.onTaskEnd(SparkListenerTaskEnd(1, taskType, Success, makeTaskInfo(1237L, 1), + makeTaskMetrics(400))) + + stage0Data = listener.stageIdToData.get(0).get + stage1Data = listener.stageIdToData.get(1).get + assert(stage0Data.shuffleReadBytes == 402) + assert(stage1Data.shuffleReadBytes == 602) + assert(stage0Data.shuffleWriteBytes == 406) + assert(stage1Data.shuffleWriteBytes == 606) + assert(stage0Data.executorRunTime == 408) + assert(stage1Data.executorRunTime == 608) + assert(stage0Data.diskBytesSpilled == 410) + assert(stage1Data.diskBytesSpilled == 610) + assert(stage0Data.memoryBytesSpilled == 412) + assert(stage1Data.memoryBytesSpilled == 612) + assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 302) + assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 402) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index ddb5df40360e9..65a71e5a83698 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -190,6 +190,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") } } + + // sortByKey - should spill ~17 times + val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultE = rddE.sortByKey().collect().toSeq + assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) } test("spilling in local cluster with many reduce tasks") { @@ -256,6 +261,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") } } + + // sortByKey - should spill ~8 times per executor + val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultE = rddE.sortByKey().collect().toSeq + assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) } test("cleanup of intermediate files in sorter") { diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index c44320239bbbf..53df9b5a3f1d5 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -29,7 +29,6 @@ import re import subprocess import sys -import tempfile import urllib2 try: @@ -39,15 +38,15 @@ JIRA_IMPORTED = False # Location of your Spark git development area -SPARK_HOME = os.environ.get("SPARK_HOME", "/home/patrick/Documents/spark") +SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) # Remote name which points to the Gihub site PR_REMOTE_NAME = os.environ.get("PR_REMOTE_NAME", "apache-github") # Remote name which points to Apache git PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache") # ASF JIRA username -JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "pwendell") +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") # ASF JIRA password -JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "1234") +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" @@ -129,7 +128,7 @@ def merge_pr(pr_num, target_ref): merge_message_flags = [] merge_message_flags += ["-m", title] - if body != None: + if body is not None: # We remove @ symbols from the body to avoid triggering e-mails # to people every time someone creates a public fork of Spark. merge_message_flags += ["-m", body.replace("@", "")] @@ -179,7 +178,14 @@ def cherry_pick(pr_num, merge_hash, default_branch): run_cmd("git fetch %s %s:%s" % (PUSH_REMOTE_NAME, pick_ref, pick_branch_name)) run_cmd("git checkout %s" % pick_branch_name) - run_cmd("git cherry-pick -sx %s" % merge_hash) + + try: + run_cmd("git cherry-pick -sx %s" % merge_hash) + except Exception as e: + msg = "Error cherry-picking: %s\nWould you like to manually fix-up this merge?" % e + continue_maybe(msg) + msg = "Okay, please fix any conflicts and finish the cherry-pick. Finished?" + continue_maybe(msg) continue_maybe("Pick complete (local ref %s). Push to %s?" % ( pick_branch_name, PUSH_REMOTE_NAME)) @@ -280,6 +286,7 @@ def get_version_json(version_str): pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) +pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) url = pr["url"] title = pr["title"] @@ -289,19 +296,23 @@ def get_version_json(version_str): base_ref = pr["head"]["ref"] pr_repo_desc = "%s/%s" % (user_login, base_ref) -if pr["merged"] is True: +# Merged pull requests don't appear as merged in the GitHub API; +# Instead, they're closed by asfgit. +merge_commits = \ + [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] + +if merge_commits: + merge_hash = merge_commits[0]["commit_id"] + message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] + print "Pull request %s has already been merged, assuming you want to backport" % pr_num - merge_commit_desc = run_cmd([ - 'git', 'log', '--merges', '--first-parent', - '--grep=pull request #%s' % pr_num, '--oneline']).split("\n")[0] - if merge_commit_desc == "": + commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', + "%s^{commit}" % merge_hash]).strip() != "" + if not commit_is_downloaded: fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) - merge_hash = merge_commit_desc[:7] - message = merge_commit_desc[8:] - - print "Found: %s" % message - maybe_cherry_pick(pr_num, merge_hash, latest_branch) + print "Found commit %s:\n%s" % (merge_hash, message) + cherry_pick(pr_num, merge_hash, latest_branch) sys.exit(0) if not bool(pr["mergeable"]): @@ -323,9 +334,13 @@ def get_version_json(version_str): merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: - continue_maybe("Would you like to update an associated JIRA?") - jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) - resolve_jira(title, merged_refs, jira_comment) + if JIRA_USERNAME and JIRA_PASSWORD: + continue_maybe("Would you like to update an associated JIRA?") + jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) + resolve_jira(title, merged_refs, jira_comment) + else: + print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Exiting without trying to close the associated JIRA." else: print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." print "Exiting without trying to close the associated JIRA." diff --git a/docs/configuration.md b/docs/configuration.md index ea69057b5be10..2a71d7b820e5f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -541,6 +541,13 @@ Apart from these, the following properties are also available, and may be useful output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + + spark.executor.heartbeatInterval + 10000 + Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let + the driver know that the executor is still alive and update it with metrics for in-progress + tasks. + #### Networking diff --git a/docs/monitoring.md b/docs/monitoring.md index 84073fe4d949a..d07ec4a57a2cc 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -33,7 +33,7 @@ application's UI after the application has finished. If Spark is run on Mesos or YARN, it is still possible to reconstruct the UI of a finished application through Spark's history server, provided that the application's event logs exist. -You can start a the history server by executing: +You can start the history server by executing: ./sbin/start-history-server.sh @@ -106,7 +106,7 @@ follows: Indicates whether the history server should use kerberos to login. This is useful if the history server is accessing HDFS files on a secure Hadoop cluster. If this is - true it looks uses the configs spark.history.kerberos.principal and + true, it uses the configs spark.history.kerberos.principal and spark.history.kerberos.keytab. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a047d32b6ee6c..7261badd411a9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -769,3 +769,13 @@ To start the Spark SQL CLI, run the following in the Spark directory: Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. + +# Cached tables + +Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`. +Then Spark SQL will scan only required columns and will automatically tune compression to minimize +memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory. + +Note that if you just call `cache` rather than `cacheTable`, tables will _not_ be cached in +in-memory columnar format. So we strongly recommend using `cacheTable` whenever you want to +cache tables. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 90a0eef60c200..7b8b7933434c4 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -939,7 +939,7 @@ Receiving multiple data streams can therefore be achieved by creating multiple i and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input stream receiving two topics of data can be split into two Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to received in parallel, and increasing overall throughput. +thus allowing data to be received in parallel, and increasing overall throughput. Another parameter that should be considered is the receiver's blocking interval. For most receivers, the received data is coalesced together into large blocks of data before storing inside Spark's memory. @@ -980,7 +980,7 @@ If the number of tasks launched per second is high (say, 50 or more per second), of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* **Task Serialization**: Using Kryo serialization for serializing tasks can reduced the task +* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task sizes, and therefore reduce the time taken to send them to the slaves. * **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 6db9bf3cf5be6..cf3d2cca81ff6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,7 +21,6 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} @@ -36,6 +35,9 @@ import org.apache.spark.rdd.RDD * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. */ object DecisionTreeRunner { @@ -48,11 +50,12 @@ object DecisionTreeRunner { case class Params( input: String = null, + dataFormat: String = "libsvm", algo: Algo = Classification, - numClassesForClassification: Int = 2, - maxDepth: Int = 5, + maxDepth: Int = 4, impurity: ImpurityType = Gini, - maxBins: Int = 100) + maxBins: Int = 100, + fracTest: Double = 0.2) def main(args: Array[String]) { val defaultParams = Params() @@ -69,25 +72,31 @@ object DecisionTreeRunner { opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) - opt[Int]("numClassesForClassification") - .text(s"number of classes for classification, " - + s"default: ${defaultParams.numClassesForClassification}") - .action((x, c) => c.copy(numClassesForClassification = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) arg[String]("") .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") .required() .action((x, c) => c.copy(input = x)) checkConfig { params => - if (params.algo == Classification && - (params.impurity == Gini || params.impurity == Entropy)) { - success - } else if (params.algo == Regression && params.impurity == Variance) { - success + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") } else { - failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + if (params.algo == Classification && + (params.impurity == Gini || params.impurity == Entropy)) { + success + } else if (params.algo == Regression && params.impurity == Variance) { + success + } else { + failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + } } } } @@ -100,16 +109,57 @@ object DecisionTreeRunner { } def run(params: Params) { + val conf = new SparkConf().setAppName("DecisionTreeRunner") val sc = new SparkContext(conf) // Load training data and cache it. - val examples = MLUtils.loadLabeledPoints(sc, params.input).cache() + val origExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + } + // For classification, re-index classes if needed. + val (examples, numClasses) = params.algo match { + case Classification => { + // classCounts: class --> # examples in class + val classCounts = origExamples.map(_.label).countByValue() + val sortedClasses = classCounts.keys.toList.sorted + val numClasses = classCounts.size + // classIndexMap: class --> index in 0,...,numClasses-1 + val classIndexMap = { + if (classCounts.keySet != Set(0.0, 1.0)) { + sortedClasses.zipWithIndex.toMap + } else { + Map[Double, Int]() + } + } + val examples = { + if (classIndexMap.isEmpty) { + origExamples + } else { + origExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features)) + } + } + val numExamples = examples.count() + println(s"numClasses = $numClasses.") + println(s"Per-class example fractions, counts:") + println(s"Class\tFrac\tCount") + sortedClasses.foreach { c => + val frac = classCounts(c) / numExamples.toDouble + println(s"$c\t$frac\t${classCounts(c)}") + } + (examples, numClasses) + } + case Regression => + (origExamples, 0) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } - val splits = examples.randomSplit(Array(0.8, 0.2)) + // Split into training, test. + val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) val training = splits(0).cache() val test = splits(1).cache() - val numTraining = training.count() val numTest = test.count() @@ -129,17 +179,19 @@ object DecisionTreeRunner { impurity = impurityCalculator, maxDepth = params.maxDepth, maxBins = params.maxBins, - numClassesForClassification = params.numClassesForClassification) + numClassesForClassification = numClasses) val model = DecisionTree.train(training, strategy) + println(model) + if (params.algo == Classification) { val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy.") + println(s"Test accuracy = $accuracy") } if (params.algo == Regression) { val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse.") + println(s"Test mean squared error = $mse") } sc.stop() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 38095e88dcea9..e20e2c8f26991 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import scala.collection.Map -import scala.reflect.ClassTag +import scala.reflect.{classTag, ClassTag} import java.util.Properties import java.util.concurrent.Executors @@ -48,8 +48,8 @@ private[streaming] class KafkaInputDStream[ K: ClassTag, V: ClassTag, - U <: Decoder[_]: Manifest, - T <: Decoder[_]: Manifest]( + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -66,8 +66,8 @@ private[streaming] class KafkaReceiver[ K: ClassTag, V: ClassTag, - U <: Decoder[_]: Manifest, - T <: Decoder[_]: Manifest]( + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel @@ -103,10 +103,10 @@ class KafkaReceiver[ tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) } - val keyDecoder = manifest[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[K]] - val valueDecoder = manifest[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[V]] diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 86bb91f362d29..48668f763e41e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -65,7 +65,7 @@ object KafkaUtils { * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ - def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: Manifest, T <: Decoder[_]: Manifest]( + def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( ssc: StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -89,8 +89,6 @@ object KafkaUtils { groupId: String, topics: JMap[String, JInt] ): JavaPairReceiverInputDStream[String, String] = { - implicit val cmt: ClassTag[String] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) } @@ -111,8 +109,6 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[String, String] = { - implicit val cmt: ClassTag[String] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } @@ -140,13 +136,11 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[K, V] = { - implicit val keyCmt: ClassTag[K] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] - implicit val valueCmt: ClassTag[V] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] + implicit val keyCmt: ClassTag[K] = ClassTag(keyTypeClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueTypeClass) - implicit val keyCmd: Manifest[U] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[U]] - implicit val valueCmd: Manifest[T] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[T]] + implicit val keyCmd: ClassTag[U] = ClassTag(keyDecoderClass) + implicit val valueCmd: ClassTag[T] = ClassTag(valueDecoderClass) createStream[K, V, U, T]( jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) diff --git a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index d03d7774e8c80..3b1880e143513 100644 --- a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -82,5 +82,9 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry, override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/make-distribution.sh b/make-distribution.sh index 0a3283ecec6f8..1441497b3995a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -128,7 +128,7 @@ if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then if [[ ! $REPLY =~ ^[Yy]$ ]]; then echo "Okay, exiting." exit 1 - fi + fi fi if [ "$NAME" == "none" ]; then @@ -173,7 +173,7 @@ cp $FWDIR/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" -cp -r $FWDIR/examples/src/main "$DISTDIR/examples/src/" +cp -r $FWDIR/examples/src/main "$DISTDIR/examples/src/" if [ "$SPARK_HIVE" == "true" ]; then cp $FWDIR/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" @@ -199,7 +199,7 @@ cp -r "$FWDIR/ec2" "$DISTDIR" # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then - TACHYON_VERSION="0.4.1" + TACHYON_VERSION="0.5.0" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` diff --git a/mllib/pom.xml b/mllib/pom.xml index cb0fa7b97cb15..45046eca5b18c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -60,6 +60,10 @@ junit junit + + org.apache.commons + commons-math3 + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 954621ee8b933..d2e8ccf208970 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -24,10 +24,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -453,4 +455,99 @@ class PythonMLLibAPI extends Serializable { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } + + // Used by the *RDD methods to get default seed if not passed in from pyspark + private def getSeedOrDefault(seed: java.lang.Long): Long = { + if (seed == null) Utils.random.nextLong else seed + } + + // Used by *RDD methods to get default numPartitions if not passed in from pyspark + private def getNumPartitionsOrDefault(numPartitions: java.lang.Integer, + jsc: JavaSparkContext): Int = { + if (numPartitions == null) { + jsc.sc.defaultParallelism + } else { + numPartitions + } + } + + // Note: for the following methods, numPartitions and seed are boxed to allow nulls to be passed + // in for either argument from pyspark + + /** + * Java stub for Python mllib RandomRDDGenerators.uniformRDD() + */ + def uniformRDD(jsc: JavaSparkContext, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.normalRDD() + */ + def normalRDD(jsc: JavaSparkContext, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.poissonRDD() + */ + def poissonRDD(jsc: JavaSparkContext, + mean: Double, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.uniformVectorRDD() + */ + def uniformVectorRDD(jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.normalVectorRDD() + */ + def normalVectorRDD(jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.poissonVectorRDD() + */ + def poissonVectorRDD(jsc: JavaSparkContext, + mean: Double, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala new file mode 100644 index 0000000000000..0f6d5809e098f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * Maps a sequence of terms to their term frequencies using the hashing trick. + * + * @param numFeatures number of features (default: 1000000) + */ +@Experimental +class HashingTF(val numFeatures: Int) extends Serializable { + + def this() = this(1000000) + + /** + * Returns the index of the input term. + */ + def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) + + /** + * Transforms the input document into a sparse term frequency vector. + */ + def transform(document: Iterable[_]): Vector = { + val termFrequencies = mutable.HashMap.empty[Int, Double] + document.foreach { term => + val i = indexOf(term) + termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0) + } + Vectors.sparse(numFeatures, termFrequencies.toSeq) + } + + /** + * Transforms the input document into a sparse term frequency vector (Java version). + */ + def transform(document: JavaIterable[_]): Vector = { + transform(document.asScala) + } + + /** + * Transforms the input document to term frequency vectors. + */ + def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { + dataset.map(this.transform) + } + + /** + * Transforms the input document to term frequency vectors (Java version). + */ + def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = { + dataset.rdd.map(this.transform).toJavaRDD() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala new file mode 100644 index 0000000000000..7ed611a857acc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Inverse document frequency (IDF). + * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total + * number of documents and `d(t)` is the number of documents that contain term `t`. + */ +@Experimental +class IDF { + + // TODO: Allow different IDF formulations. + + private var brzIdf: BDV[Double] = _ + + /** + * Computes the inverse document frequency. + * @param dataset an RDD of term frequency vectors + */ + def fit(dataset: RDD[Vector]): this.type = { + brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + seqOp = (df, v) => df.add(v), + combOp = (df1, df2) => df1.merge(df2) + ).idf() + this + } + + /** + * Computes the inverse document frequency. + * @param dataset a JavaRDD of term frequency vectors + */ + def fit(dataset: JavaRDD[Vector]): this.type = { + fit(dataset.rdd) + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors. + * @param dataset an RDD of term frequency vectors + * @return an RDD of TF-IDF vectors + */ + def transform(dataset: RDD[Vector]): RDD[Vector] = { + if (!initialized) { + throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") + } + val theIdf = brzIdf + val bcIdf = dataset.context.broadcast(theIdf) + dataset.mapPartitions { iter => + val thisIdf = bcIdf.value + iter.map { v => + val n = v.size + v match { + case sv: SparseVector => + val nnz = sv.indices.size + val newValues = new Array[Double](nnz) + var k = 0 + while (k < nnz) { + newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) + k += 1 + } + Vectors.sparse(n, sv.indices, newValues) + case dv: DenseVector => + val newValues = new Array[Double](n) + var j = 0 + while (j < n) { + newValues(j) = dv.values(j) * thisIdf(j) + j += 1 + } + Vectors.dense(newValues) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } + } + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). + * @param dataset a JavaRDD of term frequency vectors + * @return a JavaRDD of TF-IDF vectors + */ + def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { + transform(dataset.rdd).toJavaRDD() + } + + /** Returns the IDF vector. */ + def idf(): Vector = { + if (!initialized) { + throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") + } + Vectors.fromBreeze(brzIdf) + } + + private def initialized: Boolean = brzIdf != null +} + +private object IDF { + + /** Document frequency aggregator. */ + class DocumentFrequencyAggregator extends Serializable { + + /** number of documents */ + private var m = 0L + /** document frequency vector */ + private var df: BDV[Long] = _ + + /** Adds a new document. */ + def add(doc: Vector): this.type = { + if (isEmpty) { + df = BDV.zeros(doc.size) + } + doc match { + case sv: SparseVector => + val nnz = sv.indices.size + var k = 0 + while (k < nnz) { + if (sv.values(k) > 0) { + df(sv.indices(k)) += 1L + } + k += 1 + } + case dv: DenseVector => + val n = dv.size + var j = 0 + while (j < n) { + if (dv.values(j) > 0.0) { + df(j) += 1L + } + j += 1 + } + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + m += 1L + this + } + + /** Merges another. */ + def merge(other: DocumentFrequencyAggregator): this.type = { + if (!other.isEmpty) { + m += other.m + if (df == null) { + df = other.df.copy + } else { + df += other.df + } + } + this + } + + private def isEmpty: Boolean = m == 0L + + /** Returns the current IDF vector. */ + def idf(): BDV[Double] = { + if (isEmpty) { + throw new IllegalStateException("Haven't seen any document yet.") + } + val n = df.length + val inv = BDV.zeros[Double](n) + var j = 0 + while (j < n) { + inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) + j += 1 + } + inv + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala index d7ee2d3f46846..021d651d4dbaa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala @@ -26,14 +26,17 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * Generator methods for creating RDDs comprised of i.i.d samples from some distribution. + * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. */ @Experimental object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -49,7 +52,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n, p).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -63,9 +69,12 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. * sc.defaultParallelism used for the number of partitions in the RDD. * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n).map(v => a + (b - a) * v)`. + * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. @@ -77,7 +86,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -93,7 +105,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -107,9 +122,12 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. * sc.defaultParallelism used for the number of partitions in the RDD. * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n).map(v => mean + sigma * v)`. + * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). @@ -121,7 +139,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. @@ -142,7 +160,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. @@ -157,7 +175,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. @@ -172,7 +190,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * * @param sc SparkContext used to create the RDD. * @param generator DistributionGenerator used to populate the RDD. @@ -192,7 +210,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * * @param sc SparkContext used to create the RDD. * @param generator DistributionGenerator used to populate the RDD. @@ -210,7 +228,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. @@ -229,7 +247,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * * @param sc SparkContext used to create the RDD. @@ -251,14 +269,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. */ @Experimental def uniformVectorRDD(sc: SparkContext, @@ -270,14 +288,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. */ @Experimental def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { @@ -286,7 +304,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. @@ -294,7 +312,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, @@ -308,14 +326,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, @@ -327,14 +345,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { @@ -343,7 +361,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -352,7 +370,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -367,7 +385,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -375,7 +393,7 @@ object RandomRDDGenerators { * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -388,7 +406,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * sc.defaultParallelism used for the number of partitions in the RDD. * @@ -396,7 +414,7 @@ object RandomRDDGenerators { * @param mean Mean, or lambda, for the Poisson distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -408,7 +426,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * * @param sc SparkContext used to create the RDD. @@ -417,7 +435,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, @@ -431,7 +449,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * * @param sc SparkContext used to create the RDD. @@ -439,7 +457,7 @@ object RandomRDDGenerators { * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, @@ -452,7 +470,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * sc.defaultParallelism used for the number of partitions in the RDD. * @@ -460,7 +478,7 @@ object RandomRDDGenerators { * @param generator DistributionGenerator used to populate the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index d208cfb917f3d..36d262fed425a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -290,8 +290,8 @@ class ALS private ( val usersOut = unblockFactors(users, userOutLinks) val productsOut = unblockFactors(products, productOutLinks) - usersOut.setName("usersOut").persist() - productsOut.setName("productsOut").persist() + usersOut.setName("usersOut").persist(StorageLevel.MEMORY_AND_DISK) + productsOut.setName("productsOut").persist(StorageLevel.MEMORY_AND_DISK) // Materialize usersOut and productsOut. usersOut.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 899286d235a9d..a1a76fcbe9f9c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -65,6 +65,48 @@ class MatrixFactorizationModel private[mllib] ( } } + /** + * Recommends products to a user. + * + * @param user the user to recommend products to + * @param num how many products to return. The number returned may be less than this. + * @return [[Rating]] objects, each of which contains the given user ID, a product ID, and a + * "score" in the rating field. Each represents one recommended product, and they are sorted + * by score, decreasing. The first returned is the one predicted to be most strongly + * recommended to the user. The score is an opaque value that indicates how strongly + * recommended the product is. + */ + def recommendProducts(user: Int, num: Int): Array[Rating] = + recommend(userFeatures.lookup(user).head, productFeatures, num) + .map(t => Rating(user, t._1, t._2)) + + /** + * Recommends users to a product. That is, this returns users who are most likely to be + * interested in a product. + * + * @param product the product to recommend users to + * @param num how many users to return. The number returned may be less than this. + * @return [[Rating]] objects, each of which contains a user ID, the given product ID, and a + * "score" in the rating field. Each represents one recommended user, and they are sorted + * by score, decreasing. The first returned is the one predicted to be most strongly + * recommended to the product. The score is an opaque value that indicates how strongly + * recommended the user is. + */ + def recommendUsers(product: Int, num: Int): Array[Rating] = + recommend(productFeatures.lookup(product).head, userFeatures, num) + .map(t => Rating(t._1, product, t._2)) + + private def recommend( + recommendToFeatures: Array[Double], + recommendableFeatures: RDD[(Int, Array[Double])], + num: Int): Array[(Int, Double)] = { + val recommendToVector = new DoubleMatrix(recommendToFeatures) + val scored = recommendableFeatures.map { case (id,features) => + (id, recommendToVector.dot(new DoubleMatrix(features))) + } + scored.top(num)(Ordering.by(_._2)) + } + /** * :: DeveloperApi :: * Predict the rating of many users for many products. @@ -80,6 +122,4 @@ class MatrixFactorizationModel private[mllib] ( predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) } - // TODO: Figure out what other good bulk prediction methods would look like. - // Probably want a way to get the top users for a product or vice-versa. } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 68f3867ba6c11..9d6de9b6e1f60 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -30,7 +30,7 @@ object Statistics { /** * Compute the Pearson correlation matrix for the input RDD of Vectors. - * Returns NaN if either vector has 0 variance. + * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. @@ -39,7 +39,7 @@ object Statistics { /** * Compute the correlation matrix for the input RDD of Vectors using the specified method. - * Methods currently supported: `pearson` (default), `spearman` + * Methods currently supported: `pearson` (default), `spearman`. * * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], @@ -55,20 +55,26 @@ object Statistics { /** * Compute the Pearson correlation for the input RDDs. - * Columns with 0 covariance produce NaN entries in the correlation matrix. + * Returns NaN if either vector has 0 variance. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** * Compute the correlation for the input RDDs using the specified method. - * Methods currently supported: pearson (default), spearman + * Methods currently supported: `pearson` (default), `spearman`. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` *@return A Double containing the correlation between the two input RDD[Double]s using the diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 1f7de630e778c..9bd0c2cd05de4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -89,20 +89,18 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => // add an extra element to signify the end of the list so that flatMap can flush the last // batch of duplicates - val padded = iter ++ - Iterator[((Double, Long), Long)](((Double.NaN, -1L), -1L)) - var lastVal = 0.0 - var firstRank = 0.0 - val idBuffer = new ArrayBuffer[Long]() + val end = -1L + val padded = iter ++ Iterator[((Double, Long), Long)](((Double.NaN, end), end)) + val firstEntry = padded.next() + var lastVal = firstEntry._1._1 + var firstRank = firstEntry._2.toDouble + val idBuffer = ArrayBuffer(firstEntry._1._2) padded.flatMap { case ((v, id), rank) => - if (v == lastVal && id != Long.MinValue) { + if (v == lastVal && id != end) { idBuffer += id Iterator.empty } else { - val entries = if (idBuffer.size == 0) { - // edge case for the first value matching the initial value of lastVal - Iterator.empty - } else if (idBuffer.size == 1) { + val entries = if (idBuffer.size == 1) { Iterator((idBuffer(0), firstRank)) } else { val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ad32e3f4560fe..7d123dd6ae996 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.random.XORShiftRandom /** * :: Experimental :: - * A class that implements a decision tree algorithm for classification and regression. It - * supports both continuous and categorical features. + * A class which implements a decision tree learning algorithm for classification and regression. + * It supports both continuous and categorical features. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. @@ -42,8 +42,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /** * Method to train a decision tree model over an RDD - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @return a DecisionTreeModel that can be used for prediction + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { @@ -60,7 +60,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -100,7 +100,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var level = 0 var break = false - while (level < maxDepth && !break) { + while (level <= maxDepth && !break) { logDebug("#####################################") logDebug("level = " + level) @@ -152,7 +152,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node @@ -173,7 +173,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth - 1) { + if (level < maxDepth) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { @@ -197,17 +197,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo object DecisionTree extends Serializable with Logging { /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. The parameters for the algorithm are specified using the strategy parameter. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).train(input) @@ -219,12 +218,14 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @return a DecisionTreeModel that can be used for prediction + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -241,13 +242,15 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -266,11 +269,13 @@ object DecisionTree extends Serializable with Logging { * 1 to denote the two classes. The method also supports categorical features inputs where the * number of categories can specified using the categoricalFeaturesInfo option. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles @@ -279,7 +284,7 @@ object DecisionTree extends Serializable with Logging { * an entry (n -> k) implies the feature n is categorical with k * categories 0, 1, 2, ... , k-1. It's important to note that * features are zero-indexed. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -301,11 +306,10 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -348,11 +352,10 @@ object DecisionTree extends Serializable with Logging { /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -373,7 +376,7 @@ object DecisionTree extends Serializable with Logging { groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { /* - * The high-level description for the best split optimizations are noted here. + * The high-level descriptions of the best split optimizations are noted here. * * *Level-wise training* * We perform bin calculations for all nodes at the given level to avoid making multiple @@ -396,18 +399,27 @@ object DecisionTree extends Serializable with Logging { * drastically reduce the communication overhead. */ - // common calculations for multiple nested methods + // Common calculations for multiple nested methods: + + // numNodes: Number of nodes in this (level of tree, group), + // where nodes at deeper (larger) levels may be divided into groups. val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) + // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size logDebug("numFeatures = " + numFeatures) + + // numBins: Number of bins = 1 + number of possible splits val numBins = bins(0).length logDebug("numBins = " + numBins) + val numClasses = strategy.numClassesForClassification logDebug("numClasses = " + numClasses) + val isMulticlassClassification = strategy.isMulticlassClassification logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlassClassificationWithCategoricalFeatures = strategy.isMulticlassWithCategoricalFeatures logDebug("isMultiClassWithCategoricalFeatures = " + @@ -465,10 +477,13 @@ object DecisionTree extends Serializable with Logging { } /** - * Find bin for one feature. + * Find bin for one (labeledPoint, feature). */ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -535,7 +550,9 @@ object DecisionTree extends Serializable with Logging { } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { sequentialBinSearchForOrderedCategoricalFeatureInClassification() @@ -555,6 +572,14 @@ object DecisionTree extends Serializable with Logging { * where b_ij is an integer between 0 and numBins - 1 for regressions and binary * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. + * + * For unordered features, the "bin index" returned is actually the feature value (category). + * + * @return Array of size 1 + numFeatures * numNodes, where + * arr(0) = label for labeledPoint, and + * arr(1 + numFeatures * nodeIndex + featureIndex) = + * bin index for this labeledPoint + * (or InvalidBinIndex if labeledPoint is not handled by this node) */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. @@ -598,9 +623,21 @@ object DecisionTree extends Serializable with Logging { // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) - def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int) = { - + /** + * Increment aggregate in location for (node, feature, bin, label). + * + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. + */ + def updateBinForOrderedFeature( + arr: Array[Double], + agg: Array[Double], + nodeIndex: Int, + label: Double, + featureIndex: Int): Unit = { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex @@ -612,44 +649,58 @@ object DecisionTree extends Serializable with Logging { agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 } - def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], - label: Double, agg: Array[Double], rightChildShift: Int) = { + /** + * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label), + * where [bins] ranges over all bins. + * Updates left or right side of aggregate depending on split. + * + * @param arr arr(0) = label. + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param agg Indexed by (left/right, node, feature, bin, label) + * where label is the least significant bit. + * The left/right specifier is a 0/1 index indicating left/right child info. + * @param rightChildShift Offset for right side of agg. + */ + def updateBinForUnorderedFeature( + nodeIndex: Int, + featureIndex: Int, + arr: Array[Double], + label: Double, + agg: Array[Double], + rightChildShift: Int): Unit = { // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex + val arrIndex = 1 + numFeatures * nodeIndex + featureIndex + val featureValue = arr(arrIndex).toInt // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + val aggShift = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + label.toInt // Find all matching bins and increment their values val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { - val labelInt = label.toInt - if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + 1 + val aggIndex = aggShift + binIndex * numClasses + if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { + agg(aggIndex) += 1 } else { - agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + 1 + agg(rightChildShift + aggIndex) += 1 } binIndex += 1 } } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. */ - def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -671,17 +722,21 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numClasses * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * For ordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. + * For unordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). + * @param agg Array storing aggregate calculation. + * For ordered features, this is of size: + * numClasses * numBins * numFeatures * numNodes. + * For unordered features, this is of size: + * 2 * numClasses * numBins * numFeatures * numNodes. */ - def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -717,16 +772,17 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * Performs a sequential aggregation over a partition for regression. + * For l nodes, k features, * the count, sum, sum of squares of one of the p bins is incremented. * - * @param agg Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for regression + * @param agg Array storing aggregate calculation, updated by this function. + * Size: 3 * numBins * numFeatures * numNodes + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -757,14 +813,30 @@ object DecisionTree extends Serializable with Logging { /** * Performs a sequential aggregation over a partition. + * For l nodes, k features, + * For classification: + * Either the left count or the right count of one of the bins is + * incremented based upon whether the feature is classified as 0 or 1. + * For regression: + * The count, sum, sum of squares of one of the bins is incremented. + * + * @param agg Array storing aggregate calculation, updated by this function. + * Size for classification: + * numClasses * numBins * numFeatures * numNodes for ordered features, or + * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. + * Size for regression: + * 3 * numBins * numFeatures * numNodes. + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { case Classification => if(isMulticlassClassificationWithCategoricalFeatures) { - unorderedClassificationBinSeqOp(arr, agg) + multiclassWithCategoricalBinSeqOp(arr, agg) } else { - orderedClassificationBinSeqOp(arr, agg) + binaryOrNotCategoricalBinSeqOp(arr, agg) } case Regression => regressionBinSeqOp(arr, agg) } @@ -815,20 +887,10 @@ object DecisionTree extends Serializable with Logging { topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => - var classIndex = 0 - val leftCounts: Array[Double] = new Array[Double](numClasses) - val rightCounts: Array[Double] = new Array[Double](numClasses) - var leftTotalCount = 0.0 - var rightTotalCount = 0.0 - while (classIndex < numClasses) { - val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) - val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) - leftCounts(classIndex) = leftClassCount - leftTotalCount += leftClassCount - rightCounts(classIndex) = rightClassCount - rightTotalCount += rightClassCount - classIndex += 1 - } + val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) + val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) + val leftTotalCount = leftCounts.sum + val rightTotalCount = rightCounts.sum val impurity = { if (level > 0) { @@ -845,33 +907,17 @@ object DecisionTree extends Serializable with Logging { } } - if (leftTotalCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) - } - if (rightTotalCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) - } - - val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) - val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - - val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) - val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } // Sum of count for each label - val leftRightCounts: Array[Double] - = leftCounts.zip(rightCounts) - .map{case (leftCount, rightCount) => leftCount + rightCount} + val leftRightCounts: Array[Double] = + leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => + leftCount + rightCount + } def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { @@ -885,6 +931,22 @@ object DecisionTree extends Serializable with Logging { val predict = indexOfLargestArrayElement(leftRightCounts) val prob = leftRightCounts(predict) / totalCount + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(leftCounts, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(rightCounts, rightTotalCount) + } + + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) @@ -937,10 +999,18 @@ object DecisionTree extends Serializable with Logging { /** * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], - * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, - * (numBins - 1), numClasses) + * @param binData Aggregate array slice from getBinDataForNode. + * For classification: + * For unordered features, this is leftChildData ++ rightChildData, + * each of which is indexed by (feature, split/bin, class), + * with class being the least significant bit. + * For ordered features, this is of size numClasses * numBins * numFeatures. + * For regression: + * This is of size 2 * numFeatures * numBins. + * @return (leftNodeAgg, rightNodeAgg) pair of arrays. + * For classification, each array is of size (numFeatures, (numBins - 1), numClasses). + * For regression, each array is of size (numFeatures, (numBins - 1), 3). + * */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { @@ -983,6 +1053,11 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Reshape binData for this feature. + * Indexes binData as (feature, split, class) with class as the least significant bit. + * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value + */ def findAggForUnorderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1107,7 +1182,7 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. - * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param binData Bin data slice for this node, given by getBinDataForNode. * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ @@ -1133,7 +1208,7 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex : Double = { + val maxSplitIndex: Double = { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { numBins - 1 @@ -1162,8 +1237,8 @@ object DecisionTree extends Serializable with Logging { (bestFeatureIndex, bestSplitIndex, bestGainStats) } + logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) - logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } @@ -1214,8 +1289,17 @@ object DecisionTree extends Serializable with Logging { bestSplits } - private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = { + /** + * Get the number of values to be stored per node in the bin aggregates. + * + * @param numBins Number of bins = 1 + number of possible splits. + */ + private def getElementsPerNode( + numFeatures: Int, + numBins: Int, + numClasses: Int, + isMulticlassClassificationWithCategoricalFeatures: Boolean, + algo: Algo): Int = { algo match { case Classification => if (isMulticlassClassificationWithCategoricalFeatures) { @@ -1228,18 +1312,40 @@ object DecisionTree extends Serializable with Logging { } /** - * Returns split and bins for decision tree calculation. - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * There are 2^(maxFeatureValue - 1) - 1 splits. + * (b) For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one split per category. + + * Categorical case (a) features are called unordered features. + * Other cases are called ordered features. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree - * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache - * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + * parameters for construction the DecisionTree + * @return A tuple of (splits,bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numBins - 1). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() // Find the number of features by looking at the first sample @@ -1271,7 +1377,8 @@ object DecisionTree extends Serializable with Logging { logDebug("fraction of data used for calculating quantiles = " + fraction) // sampled input for RDD calculation - val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val sampledInput = + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length val stride: Double = numSamples.toDouble / numBins @@ -1294,8 +1401,10 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) for (index <- 0 until numBins - 1) { - val sampleIndex = (index + 1) * stride.toInt - val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) + val sampleIndex = index * stride.toInt + // Set threshold halfway in between 2 samples. + val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 + val split = new Split(featureIndex, threshold, Continuous, List()) splits(featureIndex)(index) = split } } else { // Categorical feature @@ -1304,8 +1413,10 @@ object DecisionTree extends Serializable with Logging { = numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + // classification that satisfy the space constraint. + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { @@ -1330,8 +1441,13 @@ object DecisionTree extends Serializable with Logging { } index += 1 } - } else { - + } else { // ordered feature + /* For a given categorical feature, use a subsample of the data + * to choose how to arrange possible splits. + * This examines each category and computes a centroid. + * These centroids are later used to sort the possible splits. + * centroidForCategories is a mapping: category (for the given feature) --> centroid + */ val centroidForCategories = { if (isMulticlassClassification) { // For categorical variables in multiclass classification, @@ -1341,7 +1457,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) + .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1352,7 +1468,7 @@ object DecisionTree extends Serializable with Logging { } } - logDebug("centriod for categories = " + centroidForCategories.mkString(",")) + logDebug("centroid for categories = " + centroidForCategories.mkString(",")) // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() @@ -1367,7 +1483,7 @@ object DecisionTree extends Serializable with Logging { // bins sorted by centroids val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) var categoriesForSplit = List[Double]() categoriesSortedByCentroid.iterator.zipWithIndex.foreach { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7c027ac2fda6b..5c65b537b6867 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -27,7 +27,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * Stores all the configuration options for tree construction * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value is 2 * leads to binary classification * @param maxBins maximum number of bins used for splitting features @@ -52,7 +53,9 @@ class Strategy ( val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), val maxMemoryInMB: Int = 128) extends Serializable { - require(numClassesForClassification >= 2) + if (algo == Classification) { + require(numClassesForClassification >= 2) + } val isMulticlassClassification = numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index a0e2d91762782..9297c20596527 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -34,10 +34,13 @@ object Entropy extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 0.0 var classIndex = 0 @@ -58,6 +61,7 @@ object Entropy extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 48144b5e6d1e4..2874bcf496484 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -33,10 +33,13 @@ object Gini extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 1.0 var classIndex = 0 @@ -54,6 +57,7 @@ object Gini extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 7b2a9320cc21d..92b0c7b4a6fbc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -31,7 +31,7 @@ trait Impurity extends Serializable { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double @@ -42,7 +42,7 @@ trait Impurity extends Serializable { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels - * @return information value + * @return information value, or 0 if count = 0 */ @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 97149a99ead59..698a1a2a8e899 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -31,7 +31,7 @@ object Variance extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = @@ -43,9 +43,13 @@ object Variance extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + if (count == 0) { + return 0 + } val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index bf692ca8c4bd7..3d3406b5d5f22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,8 @@ import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: - * Model to store the decision tree parameters + * Decision tree model for classification or regression. + * This model stores the decision tree structure and parameters. * @param topNode root node * @param algo algorithm type -- classification or regression */ @@ -50,4 +51,32 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } + + /** + * Get number of nodes in tree, including leaf nodes. + */ + def numNodes: Int = { + 1 + topNode.numDescendants + } + + /** + * Get depth of tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + def depth: Int = { + topNode.subtreeDepth + } + + /** + * Print full model. + */ + override def toString: String = algo match { + case Classification => + s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2) + case Regression => + s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2) + case _ => throw new IllegalArgumentException( + s"DecisionTreeModel given unknown algo parameter: $algo.") + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 682f213f411a7..944f11c2c2e4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -91,4 +91,60 @@ class Node ( } } } + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int = { + if (isLeaf) { + 0 + } else { + 2 + leftNode.get.numDescendants + rightNode.get.numDescendants + } + } + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. + */ + private[tree] def subtreeDepth: Int = { + if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) + } + } + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String = { + + def splitToString(split: Split, left: Boolean): String = { + split.featureType match { + case Continuous => if (left) { + s"(feature ${split.feature} <= ${split.threshold})" + } else { + s"(feature ${split.feature} > ${split.threshold})" + } + case Categorical => if (left) { + s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" + } else { + s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" + } + } + } + val prefix: String = " " * indentFactor + if (isLeaf) { + prefix + s"Predict: $predict\n" + } else { + prefix + s"If ${splitToString(split.get, left=true)}\n" + + leftNode.get.subtreeToString(indentFactor + 1) + + prefix + s"Else ${splitToString(split.get, left=false)}\n" + + rightNode.get.subtreeToString(indentFactor + 1) + } + } + } diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java new file mode 100644 index 0000000000000..e8d99f4ae43ae --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; + +public class JavaTfIdfSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaTfIdfSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void tfIdf() { + // The tests are to check Java compatibility. + HashingTF tf = new HashingTF(); + JavaRDD> documents = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("this is a sentence".split(" ")), + Lists.newArrayList("this is another sentence".split(" ")), + Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD termFreqs = tf.transform(documents); + termFreqs.collect(); + IDF idf = new IDF(); + JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); + List localTfIdfs = tfIdfs.collect(); + int indexOfThis = tf.indexOf("this"); + for (Vector v: localTfIdfs) { + Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index bf2365f82044c..f6ca9643227f8 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -20,6 +20,11 @@ import java.io.Serializable; import java.util.List; +import scala.Tuple2; +import scala.Tuple3; + +import org.jblas.DoubleMatrix; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -28,8 +33,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.jblas.DoubleMatrix; - public class JavaALSSuite implements Serializable { private transient JavaSparkContext sc; @@ -44,21 +47,28 @@ public void tearDown() { sc = null; } - static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, - DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { + static void validatePrediction( + MatrixFactorizationModel model, + int users, + int products, + int features, + DoubleMatrix trueRatings, + double matchThreshold, + boolean implicitPrefs, + DoubleMatrix truePrefs) { DoubleMatrix predictedU = new DoubleMatrix(users, features); - List> userFeatures = model.userFeatures().toJavaRDD().collect(); + List> userFeatures = model.userFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { - for (scala.Tuple2 userFeature : userFeatures) { + for (Tuple2 userFeature : userFeatures) { predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); } } DoubleMatrix predictedP = new DoubleMatrix(products, features); - List> productFeatures = + List> productFeatures = model.productFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { - for (scala.Tuple2 productFeature : productFeatures) { + for (Tuple2 productFeature : productFeatures) { predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); } } @@ -75,7 +85,8 @@ static void validatePrediction(MatrixFactorizationModel model, int users, int pr } } } else { - // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests) + // For implicit prefs we use the confidence-weighted RMSE to test + // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; for (int u = 0; u < users; ++u) { @@ -100,7 +111,7 @@ public void runALSUsingStaticMethods() { int iterations = 15; int users = 50; int products = 100; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); @@ -114,14 +125,14 @@ public void runALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) - .setIterations(iterations) - .run(data.rdd()); + .setIterations(iterations) + .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); } @@ -131,7 +142,7 @@ public void runImplicitALSUsingStaticMethods() { int iterations = 15; int users = 80; int products = 160; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -145,7 +156,7 @@ public void runImplicitALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -163,12 +174,42 @@ public void runImplicitALSWithNegativeWeight() { int iterations = 15; int users = 80; int products = 160; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, true); JavaRDD data = sc.parallelize(testData._1()); - MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .setSeed(8675309L) + .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); } + @Test + public void runRecommend() { + int features = 5; + int iterations = 10; + int users = 200; + int products = 50; + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true, false); + JavaRDD data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .setSeed(8675309L) + .run(data.rdd()); + validateRecommendations(model.recommendProducts(1, 10), 10); + validateRecommendations(model.recommendUsers(1, 20), 20); + } + + private static void validateRecommendations(Rating[] recommendations, int howMany) { + Assert.assertEquals(howMany, recommendations.length); + for (int i = 1; i < recommendations.length; i++) { + Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating()); + } + Assert.assertTrue(recommendations[0].rating() > 0.7); + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala new file mode 100644 index 0000000000000..a599e0d938569 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.LocalSparkContext + +class HashingTFSuite extends FunSuite with LocalSparkContext { + + test("hashing tf on a single doc") { + val hashingTF = new HashingTF(1000) + val doc = "a a b b c d".split(" ") + val n = hashingTF.numFeatures + val termFreqs = Seq( + (hashingTF.indexOf("a"), 2.0), + (hashingTF.indexOf("b"), 2.0), + (hashingTF.indexOf("c"), 1.0), + (hashingTF.indexOf("d"), 1.0)) + assert(termFreqs.map(_._1).forall(i => i >= 0 && i < n), + "index must be in range [0, #features)") + assert(termFreqs.map(_._1).toSet.size === 4, "expecting perfect hashing") + val expected = Vectors.sparse(n, termFreqs) + assert(hashingTF.transform(doc) === expected) + } + + test("hashing tf on an RDD") { + val hashingTF = new HashingTF + val localDocs: Seq[Seq[String]] = Seq( + "a a b b b c d".split(" "), + "a b c d a b c".split(" "), + "c b a c b a a".split(" ")) + val docs = sc.parallelize(localDocs, 2) + assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala new file mode 100644 index 0000000000000..78a2804ff204b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class IDFSuite extends FunSuite with LocalSparkContext { + + test("idf") { + val n = 4 + val localTermFrequencies = Seq( + Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(n, Array(1), Array(1.0)) + ) + val m = localTermFrequencies.size + val termFrequencies = sc.parallelize(localTermFrequencies, 2) + val idf = new IDF + intercept[IllegalStateException] { + idf.idf() + } + intercept[IllegalStateException] { + idf.transform(termFrequencies) + } + idf.fit(termFrequencies) + val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => + math.log((m.toDouble + 1.0) / (x + 1.0)) + }) + assert(idf.idf() ~== expected absTol 1e-12) + val tfidf = idf.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(tfidf.size === 3) + val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] + assert(tfidf0.indices === Array(1, 3)) + assert(Vectors.dense(tfidf0.values) ~== + Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12) + val tfidf1 = tfidf(1L).asInstanceOf[DenseVector] + assert(Vectors.dense(tfidf1.values) ~== + Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12) + val tfidf2 = tfidf(2L).asInstanceOf[SparseVector] + assert(tfidf2.indices === Array(1)) + assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index bce4251426df7..a3f76f77a5dcc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -31,6 +31,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) val yData = Array(4.0, 5.0, 3.0) + val zeros = new Array[Double](3) val data = Seq( Vectors.dense(1.0, 0.0, 0.0, -2.0), Vectors.dense(4.0, 5.0, 0.0, 3.0), @@ -46,6 +47,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val p1 = Statistics.corr(x, y, "pearson") assert(approxEqual(expected, default)) assert(approxEqual(expected, p1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val p2 = Statistics.corr(x1, y1) + assert(approxEqual(expected, p2)) + } + + // RDD of zero variance + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z).isNaN()) } test("corr(x, y) spearman") { @@ -54,6 +67,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val expected = 0.5 val s1 = Statistics.corr(x, y, "spearman") assert(approxEqual(expected, s1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val s2 = Statistics.corr(x1, y1, "spearman") + assert(approxEqual(expected, s2)) + } + + // RDD of zero variance => zero variance in ranks + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z, "spearman").isNaN()) } test("corr(X) default, pearson") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5961a618c59d9..10462db700628 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -31,6 +30,18 @@ import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { + def validateClassifier( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy) + } + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -50,7 +61,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) @@ -130,7 +141,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) @@ -236,7 +247,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("extract categories from a number for multiclass classification") { val l = DecisionTree.extractMultiClassCategories(13, 10) assert(l.length === 3) - assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) + assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } test("split and bin calculations for unordered categorical variables with multiclass " + @@ -247,7 +258,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) @@ -341,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) @@ -397,7 +408,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, numClassesForClassification = 2, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -413,7 +424,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) assert(stats.predict === 1) - assert(stats.prob == 0.6) + assert(stats.prob === 0.6) assert(stats.impurity > 0.2) } @@ -424,7 +435,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Regression, Variance, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) @@ -439,7 +450,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict == 0.6) + assert(stats.predict === 0.6) assert(stats.impurity > 0.2) } @@ -460,7 +471,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -483,7 +493,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -507,7 +516,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -531,7 +539,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -587,7 +594,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) @@ -602,12 +609,78 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } + test("stump with 1 continuous variable for binary classification, to check off-by-1 error") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) + arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + + test("stump with 2 continuous variables for binary classification") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + assert(model.topNode.split.get.feature === 1) + } + + test("stump with categorical variables for multiclass classification, with just enough bins") { + val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1)) + assert(bestSplit.featureType === Categorical) + val gain = bestSplits(0)._2 + assert(gain.leftImpurity === 0) + assert(gain.rightImpurity === 0) + } + test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -625,9 +698,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -644,7 +721,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5a835f58207cf..537ca0dcf267d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -71,7 +71,12 @@ object MimaExcludes { "org.apache.spark.storage.TachyonStore.putValues") ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.flume.FlumeReceiver.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver.this") ) ++ Seq( // Ignore some private methods in ALS. ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 312c75d112cbf..c58555fc9d2c5 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -49,6 +49,16 @@ Main entry point for accessing data stored in Apache Hive.. """ +# The following block allows us to import python's random instead of mllib.random for scripts in +# mllib that depend on top level pyspark packages, which transitively depend on python's random. +# Since Python's import logic looks for modules in the current package first, we eliminate +# mllib.random as a candidate for C{import random} by removing the first search path, the script's +# location, in order to force the loader to look in Python's top-level modules for C{random}. +import sys +s = sys.path.pop(0) +import random +sys.path.insert(0, s) + from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.sql import SQLContext diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 2204e9c9ca701..45d36e5d0e764 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -86,6 +86,7 @@ Exception:... """ +import select import struct import SocketServer import threading @@ -209,19 +210,38 @@ def addInPlace(self, value1, value2): class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + """ + This handler will keep polling updates from the same socket until the + server is shutdown. + """ + def handle(self): from pyspark.accumulators import _accumulatorRegistry - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + +class AccumulatorServer(SocketServer.TCPServer): + """ + A simple TCP server that intercepts shutdown() in order to interrupt + our continuous polling on the handler. + """ + server_shutdown = False + def shutdown(self): + self.server_shutdown = True + SocketServer.TCPServer.shutdown(self) def _start_update_server(): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 71f4ad1a8d44e..54720c2324ca6 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -255,4 +255,8 @@ def _test(): exit(-1) if __name__ == "__main__": + # remove current path from list of search paths to avoid importing mllib.random + # for C{import random}, which is done in an external dependency of pyspark during doctests. + import sys + sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py new file mode 100644 index 0000000000000..36e710dbae7a8 --- /dev/null +++ b/python/pyspark/mllib/random.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for random data generation. +""" + + +from pyspark.rdd import RDD +from pyspark.mllib._common import _deserialize_double, _deserialize_double_vector +from pyspark.serializers import NoOpSerializer + +class RandomRDDGenerators: + """ + Generator methods for creating RDDs comprised of i.i.d samples from + some distribution. + """ + + @staticmethod + def uniformRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the + uniform distribution on [0.0, 1.0]. + + To transform the distribution in the generated RDD from U[0.0, 1.0] + to U[a, b], use + C{RandomRDDGenerators.uniformRDD(sc, n, p, seed)\ + .map(lambda v: a + (b - a) * v)} + + >>> x = RandomRDDGenerators.uniformRDD(sc, 100).collect() + >>> len(x) + 100 + >>> max(x) <= 1.0 and min(x) >= 0.0 + True + >>> RandomRDDGenerators.uniformRDD(sc, 100, 4).getNumPartitions() + 4 + >>> parts = RandomRDDGenerators.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts == sc.defaultParallelism + True + """ + jrdd = sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) + uniform = RDD(jrdd, sc, NoOpSerializer()) + return uniform.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def normalRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d samples from the standard normal + distribution. + + To transform the distribution in the generated RDD from standard normal + to some other normal N(mean, sigma), use + C{RandomRDDGenerators.normal(sc, n, p, seed)\ + .map(lambda v: mean + sigma * v)} + + >>> x = RandomRDDGenerators.normalRDD(sc, 1000, seed=1L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - 0.0) < 0.1 + True + >>> abs(stats.stdev() - 1.0) < 0.1 + True + """ + jrdd = sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) + normal = RDD(jrdd, sc, NoOpSerializer()) + return normal.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def poissonRDD(sc, mean, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d samples from the Poisson + distribution with the input mean. + + >>> mean = 100.0 + >>> x = RandomRDDGenerators.poissonRDD(sc, mean, 1000, seed=1L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + True + """ + jrdd = sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) + poisson = RDD(jrdd, sc, NoOpSerializer()) + return poisson.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the uniform distribution on [0.0 1.0]. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDGenerators.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat.shape + (10, 10) + >>> mat.max() <= 1.0 and mat.min() >= 0.0 + True + >>> RandomRDDGenerators.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + 4 + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + uniform = RDD(jrdd, sc, NoOpSerializer()) + return uniform.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + @staticmethod + def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the standard normal distribution. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDGenerators.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - 0.0) < 0.1 + True + >>> abs(mat.std() - 1.0) < 0.1 + True + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + normal = RDD(jrdd, sc, NoOpSerializer()) + return normal.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + @staticmethod + def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the Poisson distribution with the input mean. + + >>> import numpy as np + >>> mean = 100.0 + >>> rdd = RandomRDDGenerators.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> mat = np.mat(rdd.collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(mat.std() - sqrt(mean)) < 0.5 + True + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) + poisson = RDD(jrdd, sc, NoOpSerializer()) + return poisson.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/run-tests b/python/run-tests index 29f755fc0dcd3..5049e15ce5f8a 100755 --- a/python/run-tests +++ b/python/run-tests @@ -67,6 +67,7 @@ run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/linalg.py" +run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" diff --git a/repl/pom.xml b/repl/pom.xml index 4ebb1b82f0e8c..68f4504450778 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -55,6 +55,12 @@ ${project.version} runtime + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test + org.eclipse.jetty jetty-server diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 6f9fa0d9f2b25..42c7e511dc3f5 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -230,6 +230,20 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, case xs => xs find (_.name == cmd) } } + private var fallbackMode = false + + private def toggleFallbackMode() { + val old = fallbackMode + fallbackMode = !old + System.setProperty("spark.repl.fallback", fallbackMode.toString) + echo(s""" + |Switched ${if (old) "off" else "on"} fallback mode without restarting. + | If you have defined classes in the repl, it would + |be good to redefine them incase you plan to use them. If you still run + |into issues it would be good to restart the repl and turn on `:fallback` + |mode as first command. + """.stripMargin) + } /** Show the history */ lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { @@ -299,6 +313,9 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), + nullary("fallback", """ + |disable/enable advanced repl changes, these fix some issues but may introduce others. + |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) ) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 3842c291d0b7b..f60bbb4662af1 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -892,11 +892,16 @@ import org.apache.spark.util.Utils def definedTypeSymbol(name: String) = definedSymbols(newTypeName(name)) def definedTermSymbol(name: String) = definedSymbols(newTermName(name)) + val definedClasses = handlers.exists { + case _: ClassHandler => true + case _ => false + } + /** Code to import bound names from previous lines - accessPath is code to * append to objectName to access anything bound by request. */ val SparkComputedImports(importsPreamble, importsTrailer, accessPath) = - importsCode(referencedNames.toSet) + importsCode(referencedNames.toSet, definedClasses) /** Code to access a variable with the specified name */ def fullPath(vname: String) = { diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala index 9099e052f5796..193a42dcded12 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -108,8 +108,9 @@ trait SparkImports { * last one imported is actually usable. */ case class SparkComputedImports(prepend: String, append: String, access: String) + def fallback = System.getProperty("spark.repl.fallback", "false").toBoolean - protected def importsCode(wanted: Set[Name]): SparkComputedImports = { + protected def importsCode(wanted: Set[Name], definedClass: Boolean): SparkComputedImports = { /** Narrow down the list of requests from which imports * should be taken. Removes requests which cannot contribute * useful imports for the specified set of wanted names. @@ -124,8 +125,14 @@ trait SparkImports { // Single symbol imports might be implicits! See bug #1752. Rather than // try to finesse this, we will mimic all imports for now. def keepHandler(handler: MemberHandler) = handler match { - case _: ImportHandler => true - case x => x.definesImplicit || (x.definedNames exists wanted) + /* This case clause tries to "precisely" import only what is required. And in this + * it may miss out on some implicits, because implicits are not known in `wanted`. Thus + * it is suitable for defining classes. AFAIK while defining classes implicits are not + * needed.*/ + case h: ImportHandler if definedClass && !fallback => + h.importedNames.exists(x => wanted.contains(x)) + case _: ImportHandler => true + case x => x.definesImplicit || (x.definedNames exists wanted) } reqs match { @@ -182,7 +189,7 @@ trait SparkImports { // ambiguity errors will not be generated. Also, quote // the name of the variable, so that we don't need to // handle quoting keywords separately. - case x: ClassHandler => + case x: ClassHandler if !fallback => // I am trying to guess if the import is a defined class // This is an ugly hack, I am not 100% sure of the consequences. // Here we, let everything but "defined classes" use the import with val. diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index e2d8d5ff38dbe..c8763eb277052 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -256,6 +256,33 @@ class ReplSuite extends FunSuite { assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + + test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + // We need to use local-cluster to test this case. + val output = runInterpreter("local-cluster[1,1,512]", + """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |import sqlContext.createSchemaRDD + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2632 importing a method from non serializable class and not using it.") { + val output = runInterpreter("local", + """ + |class TestClass() { def testMethod = 3 } + |val t = new TestClass + |import t.testMethod + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 72add5e20e8b4..c1154eb81c319 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.Map + import org.apache.spark.sql.catalyst.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 422839dab770d..3d41acb79e5fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.Map + import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index e030d6e13d472..e75373d5a74a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -182,7 +182,7 @@ class ScalaReflectionSuite extends FunSuite { assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) // TimestampType - assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-7-25 10:26:00"))) + assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-07-25 10:26:00"))) // NullType assert(NullType === typeOfObject(null)) diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java index 17334ca31b2b7..b73a371e93001 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing Lists. @@ -25,8 +25,8 @@ * {@code null} values. * * To create an {@link ArrayType}, - * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType)} or - * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType, boolean)} + * {@link DataType#createArrayType(DataType)} or + * {@link DataType#createArrayType(DataType, boolean)} * should be used. */ public class ArrayType extends DataType { diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java index 61703179850e9..7daad60f62a0b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing byte[] values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java index 8fa24d85d1238..5a1f52725631b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing boolean and Boolean values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java index 2de32978e2705..e5cdf06b21bbe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing byte and Byte values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java similarity index 99% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index f84e5a490a905..3eccddef88134 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; import java.util.HashSet; import java.util.List; diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java index 9250491a2d2ca..bc54c078d7a4e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing java.math.BigDecimal values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java index 3e86917fddc4b..f0060d0bcf9f5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing double and Double values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java index fa860d40176ef..4a6a37f69176a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing float and Float values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java index bd973eca2c3ce..bfd70490bbbbb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing int and Integer values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java index e00233304cefa..af13a46eb165c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing long and Long values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java similarity index 91% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java index 94936e2e4ee7a..063e6b34abc48 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing Maps. A MapType object comprises two fields, @@ -27,8 +27,8 @@ * For values of a MapType column, keys are not allowed to have {@code null} values. * * To create a {@link MapType}, - * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType)} or - * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType, boolean)} + * {@link DataType#createMapType(DataType, DataType)} or + * {@link DataType#createMapType(DataType, DataType, boolean)} * should be used. */ public class MapType extends DataType { diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java index 98f9507acf121..7d7604b4e3d2d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing short and Short values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java index b8e7dbe646071..f4ba0c07c9c6e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing String values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java index 54e9c11ea415e..b48e2a2c5f953 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * A StructField object represents a field in a StructType object. @@ -26,7 +26,7 @@ * values. * * To create a {@link StructField}, - * {@link org.apache.spark.sql.api.java.types.DataType#createStructField(String, DataType, boolean)} + * {@link DataType#createStructField(String, DataType, boolean)} * should be used. */ public class StructField { diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java similarity index 86% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java index 33a42f4b16265..a4b501efd9a10 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java @@ -15,18 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; import java.util.Arrays; -import java.util.List; /** * The data type representing Rows. * A StructType object comprises an array of StructFields. * * To create an {@link StructType}, - * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(java.util.List)} or - * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(StructField[])} + * {@link DataType#createStructType(java.util.List)} or + * {@link DataType#createStructType(StructField[])} * should be used. */ public class StructType extends DataType { diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java index 65295779f71ec..06d44c731cdfe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing java.sql.Timestamp values. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/package-info.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java index 53603614518f5..67007a9f0d1a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package-info.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java @@ -18,4 +18,4 @@ /** * Allows the execution of relational queries, including those expressed in SQL using Spark. */ -package org.apache.spark.sql; \ No newline at end of file +package org.apache.spark.sql.api.java; diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java deleted file mode 100644 index f169ac65e226f..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/** - * Allows users to get and create Spark SQL data types. - */ -package org.apache.spark.sql.api.java.types; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index c1c18a0cd0ed6..809dd038f94aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,9 +23,8 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.sql.api.java.types.{StructType => JStructType} import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql._ +import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} @@ -104,9 +103,9 @@ class JavaSQLContext(val sqlContext: SQLContext) { * provided schema. Otherwise, there will be runtime exception. */ @DeveloperApi - def applySchema(rowRDD: JavaRDD[Row], schema: JStructType): JavaSchemaRDD = { + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): JavaSchemaRDD = { val scalaRowRDD = rowRDD.rdd.map(r => r.row) - val scalaSchema = asScalaDataType(schema).asInstanceOf[StructType] + val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType] val logicalPlan = SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) @@ -133,7 +132,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { * returning the result as a JavaSchemaRDD. */ @Experimental - def jsonFile(path: String, schema: JStructType): JavaSchemaRDD = + def jsonFile(path: String, schema: StructType): JavaSchemaRDD = jsonRDD(sqlContext.sparkContext.textFile(path), schema) /** @@ -155,10 +154,10 @@ class JavaSQLContext(val sqlContext: SQLContext) { * returning the result as a JavaSchemaRDD. */ @Experimental - def jsonRDD(json: JavaRDD[String], schema: JStructType): JavaSchemaRDD = { + def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { val appliedScalaSchema = Option(asScalaDataType(schema)).getOrElse( - JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[StructType] + JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) @@ -181,22 +180,37 @@ class JavaSQLContext(val sqlContext: SQLContext) { val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => val (dataType, nullable) = property.getPropertyType match { - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + case c: Class[_] if c == classOf[java.lang.String] => + (org.apache.spark.sql.StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => + (org.apache.spark.sql.ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => + (org.apache.spark.sql.IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => + (org.apache.spark.sql.LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => + (org.apache.spark.sql.DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => + (org.apache.spark.sql.ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => + (org.apache.spark.sql.FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => + (org.apache.spark.sql.BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => + (org.apache.spark.sql.ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => + (org.apache.spark.sql.IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => + (org.apache.spark.sql.LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => + (org.apache.spark.sql.DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => + (org.apache.spark.sql.ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => + (org.apache.spark.sql.FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => + (org.apache.spark.sql.BooleanType, true) } AttributeReference(property.getName, dataType, nullable)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 824574149858c..4d799b4038fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -22,7 +22,6 @@ import java.util.{List => JList} import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.sql.api.java.types.StructType import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 74f5630fbddf1..c416a745739b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -154,6 +154,7 @@ private[sql] object ColumnBuilder { case STRING.typeId => new StringColumnBuilder case BINARY.typeId => new BinaryColumnBuilder case GENERIC.typeId => new GenericColumnBuilder + case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5f1fe99f75c9d..8bec015c7b465 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -94,6 +94,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + execution.HashOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + case _ => Nil } } @@ -155,8 +159,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 2750ddbce896f..82f0a74b630bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -72,7 +72,7 @@ trait HashJoin { while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { val newMatchList = new ArrayBuffer[Row]() @@ -136,6 +136,185 @@ trait HashJoin { } } +/** + * Constant Value for Binary Join Node + */ +object HashOuterJoin { + val DUMMY_LIST = Seq[Row](null) + val EMPTY_LIST = Seq[Row]() +} + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def output = left.output ++ right.output + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = { + // TODO: Use Spark's HashMap implementation. + val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()}) + existingMatchList += currentRow.copy() + } + + hashTable.toMap[Row, ArrayBuffer[Row]] + } + + def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case x => throw new Exception(s"Need to add implementation for $x") + } + } + } +} + /** * :: DeveloperApi :: * Performs an inner hash join of two child relations by first shuffling the data using the join @@ -189,7 +368,7 @@ case class LeftSemiJoinHash( while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { hashSet.add(rowKey) @@ -314,10 +493,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod */ @DeveloperApi case class BroadcastNestedLoopJoin( - streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - extends BinaryNode { + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. + /** BuildRight means the right relation <=> the broadcast relation. */ + val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output = { @@ -333,11 +521,6 @@ case class BroadcastNestedLoopJoin( } } - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - @transient lazy val boundCondition = InterpretedPredicate( condition @@ -348,57 +531,78 @@ case class BroadcastNestedLoopJoin( val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] // TODO: Use Spark's BitSet. - val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) streamedIter.foreach { streamedRow => var i = 0 - var matched = false + var streamRowMatched = false while (i < broadcastedRelation.value.size) { // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() - matched = true - includedBroadcastTuples += i + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => } i += 1 } - if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { - matchedRows += joinedRow(streamedRow, rightNulls).copy() + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => } } Iterator((matchedRows, includedBroadcastTuples)) } - val includedBroadcastTuples = streamedPlusMatches.map(_._2) + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) val allIncludedBroadcastTuples = if (includedBroadcastTuples.count == 0) { new scala.collection.mutable.BitSet(broadcastedRelation.value.size) } else { - streamedPlusMatches.map(_._2).reduce(_ ++ _) + includedBroadcastTuples.reduce(_ ++ _) } val leftNulls = new GenericMutableRow(left.output.size) - val rightOuterMatches: Seq[Row] = - if (joinType == RightOuter || joinType == FullOuter) { - broadcastedRelation.value.zipWithIndex.filter { - case (row, i) => !allIncludedBroadcastTuples.contains(i) - }.map { - case (row, _) => new JoinedRow(leftNulls, row) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls) + case _ => + } } - } else { - Vector() + i += 1 } + arrBuf.toSeq + } // TODO: Breaks lineage. sparkContext.union( - streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches)) + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index bd29ee421bbc4..70db1ebd3a3e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.json +import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index d1aa3c8d53757..77353f4eb0227 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.types.util import org.apache.spark.sql._ -import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField} +import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField} import scala.collection.JavaConverters._ @@ -74,37 +74,37 @@ protected[sql] object DataTypeConversions { * Returns the equivalent DataType in Scala for the given DataType in Java. */ def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { - case stringType: org.apache.spark.sql.api.java.types.StringType => + case stringType: org.apache.spark.sql.api.java.StringType => StringType - case binaryType: org.apache.spark.sql.api.java.types.BinaryType => + case binaryType: org.apache.spark.sql.api.java.BinaryType => BinaryType - case booleanType: org.apache.spark.sql.api.java.types.BooleanType => + case booleanType: org.apache.spark.sql.api.java.BooleanType => BooleanType - case timestampType: org.apache.spark.sql.api.java.types.TimestampType => + case timestampType: org.apache.spark.sql.api.java.TimestampType => TimestampType - case decimalType: org.apache.spark.sql.api.java.types.DecimalType => + case decimalType: org.apache.spark.sql.api.java.DecimalType => DecimalType - case doubleType: org.apache.spark.sql.api.java.types.DoubleType => + case doubleType: org.apache.spark.sql.api.java.DoubleType => DoubleType - case floatType: org.apache.spark.sql.api.java.types.FloatType => + case floatType: org.apache.spark.sql.api.java.FloatType => FloatType - case byteType: org.apache.spark.sql.api.java.types.ByteType => + case byteType: org.apache.spark.sql.api.java.ByteType => ByteType - case integerType: org.apache.spark.sql.api.java.types.IntegerType => + case integerType: org.apache.spark.sql.api.java.IntegerType => IntegerType - case longType: org.apache.spark.sql.api.java.types.LongType => + case longType: org.apache.spark.sql.api.java.LongType => LongType - case shortType: org.apache.spark.sql.api.java.types.ShortType => + case shortType: org.apache.spark.sql.api.java.ShortType => ShortType - case arrayType: org.apache.spark.sql.api.java.types.ArrayType => + case arrayType: org.apache.spark.sql.api.java.ArrayType => ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull) - case mapType: org.apache.spark.sql.api.java.types.MapType => + case mapType: org.apache.spark.sql.api.java.MapType => MapType( asScalaDataType(mapType.getKeyType), asScalaDataType(mapType.getValueType), mapType.isValueContainsNull) - case structType: org.apache.spark.sql.api.java.types.StructType => + case structType: org.apache.spark.sql.api.java.StructType => StructType(structType.getFields.map(asScalaStructField)) } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 8ee4591105010..3c92906d82864 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -28,9 +28,6 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.sql.api.java.types.DataType; -import org.apache.spark.sql.api.java.types.StructField; -import org.apache.spark.sql.api.java.types.StructType; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java index 96a503962f7d1..d099a48a1f4b6 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -24,8 +24,6 @@ import org.junit.Test; import org.apache.spark.sql.types.util.DataTypeConversions; -import org.apache.spark.sql.api.java.types.DataType; -import org.apache.spark.sql.api.java.types.StructField; public class JavaSideDataTypeConversionSuite { public void checkDataType(DataType javaDataType) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 025c396ef0629..037890682f7b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner} +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -class JoinSuite extends QueryTest { +class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData @@ -34,6 +40,56 @@ class JoinSuite extends QueryTest { assert(planned.size === 1) } + test("join operator selection") { + def assertJoin(sqlString: String, c: Class[_]): Any = { + val rdd = sql(sqlString) + val physical = rdd.queryExecution.sparkPlan + val operators = physical.collect { + case j: ShuffledHashJoin => j + case j: HashOuterJoin => j + case j: LeftSemiJoinHash => j + case j: BroadcastHashJoin => j + case j: LeftSemiJoinBNL => j + case j: CartesianProduct => j + case j: BroadcastNestedLoopJoin => j + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + } + } + + val cases1 = Seq( + ("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a where key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]) + // TODO add BroadcastNestedLoopJoin + ) + cases1.foreach { c => assertJoin(c._1, c._2) } + } + test("multiple-key equi-join is hash-join") { val x = testData2.as('x) val y = testData2.as('y) @@ -114,6 +170,33 @@ class JoinSuite extends QueryTest { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) } test("right outer join") { @@ -125,11 +208,38 @@ class JoinSuite extends QueryTest { (4, "d", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) } test("full outer join") { - val left = upperCaseData.where('N <= 4).as('left) - val right = upperCaseData.where('N >= 3).as('right) + upperCaseData.where('N <= 4).registerAsTable("left") + upperCaseData.where('N >= 3).registerAsTable("right") + + val left = UnresolvedRelation(None, "left", None) + val right = UnresolvedRelation(None, "right", None) checkAnswer( left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), @@ -139,5 +249,25 @@ class JoinSuite extends QueryTest { (4, "D", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bebb490645420..5c571d35d1bb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -505,5 +505,24 @@ class SQLQuerySuite extends QueryTest { (2, null) :: (3, null) :: (4, 2147483644) :: Nil) + + // The value of a MapType column can be a mutable map. + val rowRDD3 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) + } + + val schemaRDD3 = applySchema(rowRDD3, schema2) + schemaRDD3.registerAsTable("applySchema3") + + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index 46de6fe239228..ff1debff0f8c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.api.java import org.apache.spark.sql.types.util.DataTypeConversions import org.scalatest.FunSuite -import org.apache.spark.sql._ +import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField} +import org.apache.spark.sql.{StructType => SStructType} import DataTypeConversions._ class ScalaSideDataTypeConversionSuite extends FunSuite { - def checkDataType(scalaDataType: DataType) { + def checkDataType(scalaDataType: SDataType) { val javaDataType = asJavaDataType(scalaDataType) val actual = asScalaDataType(javaDataType) assert(scalaDataType === actual, s"Converted data type ${actual} " + @@ -34,48 +35,52 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { test("convert data types") { // Simple DataTypes. - checkDataType(StringType) - checkDataType(BinaryType) - checkDataType(BooleanType) - checkDataType(TimestampType) - checkDataType(DecimalType) - checkDataType(DoubleType) - checkDataType(FloatType) - checkDataType(ByteType) - checkDataType(IntegerType) - checkDataType(LongType) - checkDataType(ShortType) + checkDataType(org.apache.spark.sql.StringType) + checkDataType(org.apache.spark.sql.BinaryType) + checkDataType(org.apache.spark.sql.BooleanType) + checkDataType(org.apache.spark.sql.TimestampType) + checkDataType(org.apache.spark.sql.DecimalType) + checkDataType(org.apache.spark.sql.DoubleType) + checkDataType(org.apache.spark.sql.FloatType) + checkDataType(org.apache.spark.sql.ByteType) + checkDataType(org.apache.spark.sql.IntegerType) + checkDataType(org.apache.spark.sql.LongType) + checkDataType(org.apache.spark.sql.ShortType) // Simple ArrayType. - val simpleScalaArrayType = ArrayType(StringType, true) + val simpleScalaArrayType = + org.apache.spark.sql.ArrayType(org.apache.spark.sql.StringType, true) checkDataType(simpleScalaArrayType) // Simple MapType. - val simpleScalaMapType = MapType(StringType, LongType) + val simpleScalaMapType = + org.apache.spark.sql.MapType(org.apache.spark.sql.StringType, org.apache.spark.sql.LongType) checkDataType(simpleScalaMapType) // Simple StructType. - val simpleScalaStructType = StructType( - StructField("a", DecimalType, false) :: - StructField("b", BooleanType, true) :: - StructField("c", LongType, true) :: - StructField("d", BinaryType, false) :: Nil) + val simpleScalaStructType = SStructType( + SStructField("a", org.apache.spark.sql.DecimalType, false) :: + SStructField("b", org.apache.spark.sql.BooleanType, true) :: + SStructField("c", org.apache.spark.sql.LongType, true) :: + SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil) checkDataType(simpleScalaStructType) // Complex StructType. - val complexScalaStructType = StructType( - StructField("simpleArray", simpleScalaArrayType, true) :: - StructField("simpleMap", simpleScalaMapType, true) :: - StructField("simpleStruct", simpleScalaStructType, true) :: - StructField("boolean", BooleanType, false) :: Nil) + val complexScalaStructType = SStructType( + SStructField("simpleArray", simpleScalaArrayType, true) :: + SStructField("simpleMap", simpleScalaMapType, true) :: + SStructField("simpleStruct", simpleScalaStructType, true) :: + SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: Nil) checkDataType(complexScalaStructType) // Complex ArrayType. - val complexScalaArrayType = ArrayType(complexScalaStructType, true) + val complexScalaArrayType = + org.apache.spark.sql.ArrayType(complexScalaStructType, true) checkDataType(complexScalaArrayType) // Complex MapType. - val complexScalaMapType = MapType(complexScalaStructType, complexScalaArrayType, false) + val complexScalaMapType = + org.apache.spark.sql.MapType(complexScalaStructType, complexScalaArrayType, false) checkDataType(complexScalaMapType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 35ab14cbc353d..3baa6f8ec0c83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -41,7 +41,7 @@ object TestNullableColumnAccessor { class NullableColumnAccessorSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { + Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index d8898527baa39..dc813fe146c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -37,7 +37,7 @@ object TestNullableColumnBuilder { class NullableColumnBuilderSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { + Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { testNullableColumnBuilder(_) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 27268ecb923e9..cb17d7ce58ea0 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -288,8 +288,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { out.println(cmd) } - ret = driver.run(cmd).getResponseCode + val rc = driver.run(cmd) + ret = rc.getResponseCode if (ret != 0) { + console.printError(rc.getErrorMessage()) driver.close() return ret } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 5202aa9903e03..a56b19a4bcda0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -53,10 +53,9 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo } override def run(command: String): CommandProcessorResponse = { - val execution = context.executePlan(context.hql(command).logicalPlan) - // TODO unify the error code try { + val execution = context.executePlan(context.hql(command).logicalPlan) hiveResponse = execution.stringResult() tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 27b444daba2d4..7e3b8727bebed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -131,12 +131,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + + ss.err = new PrintStream(outputBuffer, true, "UTF-8") + ss.out = new PrintStream(outputBuffer, true, "UTF-8") + ss } - sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") - sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") - override def set(key: String, value: String): Unit = { super.set(key, value) runSqlHive(s"SET $key=$value") diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-1436cccda63b78dd6e43a399da6cc474 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-1436cccda63b78dd6e43a399da6cc474 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-1-8d9bf54373f45bc35f8cb6e82771b154 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-1-8d9bf54373f45bc35f8cb6e82771b154 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-2-7816c17905012cf381abf93d230faa8d b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-2-7816c17905012cf381abf93d230faa8d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 similarity index 100% rename from sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 rename to sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index bcb00f871d185..c5736723b47c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,32 +17,25 @@ package org.apache.spark.sql.hive.execution -import org.scalatest.{BeforeAndAfterAll, FunSuite} - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.hive.test.TestHive - class HiveTableScanSuite extends HiveComparisonTest { - // MINOR HACK: You must run a query before calling reset the first time. - TestHive.hql("SHOW TABLES") - TestHive.reset() - - TestHive.hql("""CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING) - | ROW FORMAT SERDE - | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' - | STORED AS RCFILE - """.stripMargin) - TestHive.hql("""FROM src - | INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01') - | SELECT 100,100 LIMIT 1 - """.stripMargin) - TestHive.hql("""ALTER TABLE part_scan_test SET SERDE - | 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' - """.stripMargin) - TestHive.hql("""FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02') - | SELECT 200,200 LIMIT 1 - """.stripMargin) - createQueryTest("partition_based_table_scan_with_different_serde", - "SELECT * from part_scan_test", false) + createQueryTest("partition_based_table_scan_with_different_serde", + """ + |CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING) + |ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' + |STORED AS RCFILE; + | + |FROM src + |INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01') + |SELECT 100,100 LIMIT 1; + | + |ALTER TABLE part_scan_test SET SERDE + |'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe'; + | + |FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02') + |SELECT 200,200 LIMIT 1; + | + |SELECT * from part_scan_test; + """.stripMargin) } diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 16ff89a8a9809..bcf6d43ab34eb 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -114,9 +114,10 @@ object GenerateMIMAIgnore { private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { classSymbol.typeSignature.members.filterNot(x => - x.fullName.startsWith("java") || x.fullName.startsWith("scala")) - .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++ - getInnerFunctions(classSymbol) + x.fullName.startsWith("java") || x.fullName.startsWith("scala") + ).filter(x => + isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x) + ).map(_.fullName) ++ getInnerFunctions(classSymbol) } def main(args: Array[String]) {