From 29660443077619ee854025b8d0d3d64181724054 Mon Sep 17 00:00:00 2001 From: joyyoj Date: Tue, 10 Jun 2014 17:26:17 -0700 Subject: [PATCH 01/38] [SPARK-1998] SparkFlumeEvent with body bigger than 1020 bytes are not re... flume event sent to Spark will fail if the body is too large and numHeaders is greater than zero Author: joyyoj Closes #951 from joyyoj/master and squashes the following commits: f4660c5 [joyyoj] [SPARK-1998] SparkFlumeEvent with body bigger than 1020 bytes are not read properly --- .../org/apache/spark/streaming/flume/FlumeInputDStream.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 5be33f1d5c428..ed35e34ad45ab 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -71,12 +71,12 @@ class SparkFlumeEvent() extends Externalizable { for (i <- 0 until numHeaders) { val keyLength = in.readInt() val keyBuff = new Array[Byte](keyLength) - in.read(keyBuff) + in.readFully(keyBuff) val key : String = Utils.deserialize(keyBuff) val valLength = in.readInt() val valBuff = new Array[Byte](valLength) - in.read(valBuff) + in.readFully(valBuff) val value : String = Utils.deserialize(valBuff) headers.put(key, value) From 4823bf470ec1b47a6f404834d4453e61d3dcbec9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 10 Jun 2014 20:22:02 -0700 Subject: [PATCH 02/38] [SPARK-1940] Enabling rolling of executor logs, and automatic cleanup of old executor logs Currently, in the default log4j configuration, all the executor logs get sent to the file [executor-working-dir]/stderr. This does not all log files to be rolled, so old logs cannot be removed. Using log4j RollingFileAppender allows log4j logs to be rolled, but all the logs get sent to a different set of files, other than the files stdout and stderr . So the logs are not visible in the Spark web UI any more as Spark web UI only reads the files stdout and stderr. Furthermore, it still does not allow the stdout and stderr to be cleared periodically in case a large amount of stuff gets written to them (e.g. by explicit `println` inside map function). This PR solves this by implementing a simple `RollingFileAppender` within Spark (disabled by default). When enabled (using configuration parameter `spark.executor.rollingLogs.enabled`), the logs can get rolled over either by time interval (set with `spark.executor.rollingLogs.interval`, set to daily by default), or by size of logs (set with `spark.executor.rollingLogs.size`). Finally, old logs can be automatically deleted by specifying how many of the latest log files to keep (set with `spark.executor.rollingLogs.keepLastN`). The web UI has also been modified to show the logs across the rolled-over files. You can test this locally (without waiting a whole day) by setting configuration `spark.executor.rollingLogs.enabled=true` and `spark.executor.rollingLogs.interval=minutely`. Continuously generate logs by running spark jobs and the generated logs files would look like this (`stderr` and `stdout` are the most current log file that are being written to). ``` stderr stderr--2014-05-27--14-37 stderr--2014-05-27--14-47 stderr--2014-05-27--15-05 stdout stdout--2014-05-27--14-47 ``` The web ui should show logs across these files. Author: Tathagata Das Closes #895 from tdas/rolling-logs and squashes the following commits: fd8f87f [Tathagata Das] Minor change. d326aee [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into rolling-logs ad956c1 [Tathagata Das] Scala style fix. 1f0a6ec [Tathagata Das] Some more changes based on Patrick's PR comments. c8bfe4e [Tathagata Das] Refactore FileAppender to a package spark.util.logging and broke up the file into multiple files. Changed configuration parameter names. 4224409 [Tathagata Das] Style fix. 108a9f8 [Tathagata Das] Added better constraint handling for rolling policies. f7da977 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into rolling-logs 9134495 [Tathagata Das] Simplified rolling logs by removing Daily/Hourly/MinutelyRollingFileAppender, and removing the setting rollingLogs.enabled 312d874 [Tathagata Das] Minor fixes based on PR comments. 8a67d83 [Tathagata Das] Fixed comments. b36cfd6 [Tathagata Das] Implemented RollingPolicy, TimeBasedRollingPolicy and SizeBasedRollingPolicy, and changed RollingFileAppender accordingly. b7e8272 [Tathagata Das] Style fix, 374c9a9 [Tathagata Das] Added missing license. 24354ea [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into rolling-logs 6cc09c7 [Tathagata Das] Fixed bugs in rolling logs, and added more debug statements. adf4910 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into rolling-logs 931f8fb [Tathagata Das] Changed log viewer in Spark web UI to handle rolling log files. cb4fb6d [Tathagata Das] Added FileAppender and RollingFileAppender to generate rolling executor logs. --- .../spark/deploy/worker/ExecutorRunner.scala | 17 +- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../spark/deploy/worker/ui/LogPage.scala | 78 +++--- .../scala/org/apache/spark/util/Utils.scala | 53 +++++ .../spark/util/logging/FileAppender.scala | 180 ++++++++++++++ .../util/logging/RollingFileAppender.scala | 163 +++++++++++++ .../spark/util/logging/RollingPolicy.scala | 139 +++++++++++ .../spark/deploy/JsonProtocolSuite.scala | 4 +- .../deploy/worker/ExecutorRunnerTest.scala | 3 +- .../apache/spark/util/FileAppenderSuite.scala | 225 ++++++++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 32 +++ docs/configuration.md | 39 +++ 12 files changed, 895 insertions(+), 40 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala create mode 100644 core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala create mode 100644 core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala create mode 100644 core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d27e0e1f15c65..d09136de49807 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -23,9 +23,10 @@ import akka.actor.ActorRef import com.google.common.base.Charsets import com.google.common.io.Files -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.util.logging.FileAppender /** * Manages the execution of one executor process. @@ -42,12 +43,15 @@ private[spark] class ExecutorRunner( val sparkHome: File, val workDir: File, val workerUrl: String, + val conf: SparkConf, var state: ExecutorState.Value) extends Logging { val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null + var stdoutAppender: FileAppender = null + var stderrAppender: FileAppender = null // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might // make sense to remove this in the future. @@ -76,6 +80,13 @@ private[spark] class ExecutorRunner( if (process != null) { logInfo("Killing process!") process.destroy() + process.waitFor() + if (stdoutAppender != null) { + stdoutAppender.stop() + } + if (stderrAppender != null) { + stderrAppender.stop() + } val exitCode = process.waitFor() worker ! ExecutorStateChanged(appId, execId, state, message, Some(exitCode)) } @@ -137,11 +148,11 @@ private[spark] class ExecutorRunner( // Redirect its stdout and stderr to files val stdout = new File(executorDir, "stdout") - CommandUtils.redirectStream(process.getInputStream, stdout) + stdoutAppender = FileAppender(process.getInputStream, stdout, conf) val stderr = new File(executorDir, "stderr") Files.write(header, stderr, Charsets.UTF_8) - CommandUtils.redirectStream(process.getErrorStream, stderr) + stderrAppender = FileAppender(process.getErrorStream, stderr, conf) // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run // long-lived processes only. However, in the future, we might restart the executor a few diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 100de26170a50..a0ecaf709f8e2 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -235,7 +235,7 @@ private[spark] class Worker( val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, self, workerId, host, appDesc.sparkHome.map(userSparkHome => new File(userSparkHome)).getOrElse(sparkHome), - workDir, akkaUrl, ExecutorState.RUNNING) + workDir, akkaUrl, conf, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 8381f59672ea3..6a5ffb1b71bfb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -24,8 +24,10 @@ import scala.xml.Node import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils +import org.apache.spark.Logging +import org.apache.spark.util.logging.{FileAppender, RollingFileAppender} -private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") { +private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir @@ -39,21 +41,18 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") { val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = (appId, executorId, driverId) match { + val logDir = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => - s"${workDir.getPath}/$appId/$executorId/$logType" + s"${workDir.getPath}/$appId/$executorId/" case (None, None, Some(d)) => - s"${workDir.getPath}/$driverId/$logType" + s"${workDir.getPath}/$driverId/" case _ => throw new Exception("Request must specify either application or driver identifiers") } - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - - val pre = s"==== Bytes $startByte-$endByte of $logLength of $path ====\n" - pre + Utils.offsetBytes(path, startByte, endByte) + val (logText, startByte, endByte, logLength) = getLog(logDir, logType, offset, byteLength) + val pre = s"==== Bytes $startByte-$endByte of $logLength of $logDir$logType ====\n" + pre + logText } def render(request: HttpServletRequest): Seq[Node] = { @@ -65,19 +64,16 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") { val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val (path, params) = (appId, executorId, driverId) match { + val (logDir, params) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => - (s"${workDir.getPath}/$a/$e/$logType", s"appId=$a&executorId=$e") + (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e") case (None, None, Some(d)) => - (s"${workDir.getPath}/$d/$logType", s"driverId=$d") + (s"${workDir.getPath}/$d/", s"driverId=$d") case _ => throw new Exception("Request must specify either application or driver identifiers") } - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - val logText = {Utils.offsetBytes(path, startByte, endByte)} + val (logText, startByte, endByte, logLength) = getLog(logDir, logType, offset, byteLength) val linkToMaster =

Back to Master

val range = Bytes {startByte.toString} - {endByte.toString} of {logLength} @@ -127,23 +123,37 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") { UIUtils.basicSparkPage(content, logType + " log page for " + appId) } - /** Determine the byte range for a log or log page. */ - private def getByteRange(path: String, offset: Option[Long], byteLength: Int): (Long, Long) = { - val defaultBytes = 100 * 1024 - val maxBytes = 1024 * 1024 - val file = new File(path) - val logLength = file.length() - val getOffset = offset.getOrElse(logLength - defaultBytes) - val startByte = - if (getOffset < 0) { - 0L - } else if (getOffset > logLength) { - logLength - } else { - getOffset + /** Get the part of the log files given the offset and desired length of bytes */ + private def getLog( + logDirectory: String, + logType: String, + offsetOption: Option[Long], + byteLength: Int + ): (String, Long, Long, Long) = { + try { + val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) + logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") + + val totalLength = files.map { _.length }.sum + val offset = offsetOption.getOrElse(totalLength - byteLength) + val startIndex = { + if (offset < 0) { + 0L + } else if (offset > totalLength) { + totalLength + } else { + offset + } } - val logPageLength = math.min(byteLength, maxBytes) - val endByte = math.min(startByte + logPageLength, logLength) - (startByte, endByte) + val endIndex = math.min(startIndex + totalLength, totalLength) + logDebug(s"Getting log from $startIndex to $endIndex") + val logText = Utils.offsetBytes(files, startIndex, endIndex) + logDebug(s"Got log of length ${logText.length} bytes") + (logText, startIndex, endIndex, totalLength) + } catch { + case e: Exception => + logError(s"Error getting $logType logs from directory $logDirectory", e) + ("Error getting logs due to exception: " + e.getMessage, 0, 0, 0) + } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3b1b6df089b8e..4ce28bb0cf059 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -862,6 +862,59 @@ private[spark] object Utils extends Logging { Source.fromBytes(buff).mkString } + /** + * Return a string containing data across a set of files. The `startIndex` + * and `endIndex` is based on the cumulative size of all the files take in + * the given order. See figure below for more details. + */ + def offsetBytes(files: Seq[File], start: Long, end: Long): String = { + val fileLengths = files.map { _.length } + val startIndex = math.max(start, 0) + val endIndex = math.min(end, fileLengths.sum) + val fileToLength = files.zip(fileLengths).toMap + logDebug("Log files: \n" + fileToLength.mkString("\n")) + + val stringBuffer = new StringBuffer((endIndex - startIndex).toInt) + var sum = 0L + for (file <- files) { + val startIndexOfFile = sum + val endIndexOfFile = sum + fileToLength(file) + logDebug(s"Processing file $file, " + + s"with start index = $startIndexOfFile, end index = $endIndex") + + /* + ____________ + range 1: | | + | case A | + + files: |==== file 1 ====|====== file 2 ======|===== file 3 =====| + + | case B . case C . case D | + range 2: |___________.____________________.______________| + */ + + if (startIndex <= startIndexOfFile && endIndex >= endIndexOfFile) { + // Case C: read the whole file + stringBuffer.append(offsetBytes(file.getAbsolutePath, 0, fileToLength(file))) + } else if (startIndex > startIndexOfFile && startIndex < endIndexOfFile) { + // Case A and B: read from [start of required range] to [end of file / end of range] + val effectiveStartIndex = startIndex - startIndexOfFile + val effectiveEndIndex = math.min(endIndex - startIndexOfFile, fileToLength(file)) + stringBuffer.append(Utils.offsetBytes( + file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) + } else if (endIndex > startIndexOfFile && endIndex < endIndexOfFile) { + // Case D: read from [start of file] to [end of require range] + val effectiveStartIndex = math.max(startIndex - startIndexOfFile, 0) + val effectiveEndIndex = endIndex - startIndexOfFile + stringBuffer.append(Utils.offsetBytes( + file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) + } + sum += fileToLength(file) + logDebug(s"After processing file $file, string built is ${stringBuffer.toString}}") + } + stringBuffer.toString + } + /** * Clone an object using a Spark serializer. */ diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala new file mode 100644 index 0000000000000..8e9c3036d09c2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.logging + +import java.io.{File, FileOutputStream, InputStream} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.{IntParam, Utils} + +/** + * Continuously appends the data from an input stream into the given file. + */ +private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSize: Int = 8192) + extends Logging { + @volatile private var outputStream: FileOutputStream = null + @volatile private var markedForStop = false // has the appender been asked to stopped + @volatile private var stopped = false // has the appender stopped + + // Thread that reads the input stream and writes to file + private val writingThread = new Thread("File appending thread for " + file) { + setDaemon(true) + override def run() { + Utils.logUncaughtExceptions { + appendStreamToFile() + } + } + } + writingThread.start() + + /** + * Wait for the appender to stop appending, either because input stream is closed + * or because of any error in appending + */ + def awaitTermination() { + synchronized { + if (!stopped) { + wait() + } + } + } + + /** Stop the appender */ + def stop() { + markedForStop = true + } + + /** Continuously read chunks from the input stream and append to the file */ + protected def appendStreamToFile() { + try { + logDebug("Started appending thread") + openFile() + val buf = new Array[Byte](bufferSize) + var n = 0 + while (!markedForStop && n != -1) { + n = inputStream.read(buf) + if (n != -1) { + appendToFile(buf, n) + } + } + } catch { + case e: Exception => + logError(s"Error writing stream to file $file", e) + } finally { + closeFile() + synchronized { + stopped = true + notifyAll() + } + } + } + + /** Append bytes to the file output stream */ + protected def appendToFile(bytes: Array[Byte], len: Int) { + if (outputStream == null) { + openFile() + } + outputStream.write(bytes, 0, len) + } + + /** Open the file output stream */ + protected def openFile() { + outputStream = new FileOutputStream(file, false) + logDebug(s"Opened file $file") + } + + /** Close the file output stream */ + protected def closeFile() { + outputStream.flush() + outputStream.close() + logDebug(s"Closed file $file") + } +} + +/** + * Companion object to [[org.apache.spark.util.logging.FileAppender]] which has helper + * functions to choose the correct type of FileAppender based on SparkConf configuration. + */ +private[spark] object FileAppender extends Logging { + + /** Create the right appender based on Spark configuration */ + def apply(inputStream: InputStream, file: File, conf: SparkConf): FileAppender = { + + import RollingFileAppender._ + + val rollingStrategy = conf.get(STRATEGY_PROPERTY, STRATEGY_DEFAULT) + val rollingSizeBytes = conf.get(SIZE_PROPERTY, STRATEGY_DEFAULT) + val rollingInterval = conf.get(INTERVAL_PROPERTY, INTERVAL_DEFAULT) + + def createTimeBasedAppender() = { + val validatedParams: Option[(Long, String)] = rollingInterval match { + case "daily" => + logInfo(s"Rolling executor logs enabled for $file with daily rolling") + Some(24 * 60 * 60 * 1000L, "--YYYY-MM-dd") + case "hourly" => + logInfo(s"Rolling executor logs enabled for $file with hourly rolling") + Some(60 * 60 * 1000L, "--YYYY-MM-dd--HH") + case "minutely" => + logInfo(s"Rolling executor logs enabled for $file with rolling every minute") + Some(60 * 1000L, "--YYYY-MM-dd--HH-mm") + case IntParam(seconds) => + logInfo(s"Rolling executor logs enabled for $file with rolling $seconds seconds") + Some(seconds * 1000L, "--YYYY-MM-dd--HH-mm-ss") + case _ => + logWarning(s"Illegal interval for rolling executor logs [$rollingInterval], " + + s"rolling logs not enabled") + None + } + validatedParams.map { + case (interval, pattern) => + new RollingFileAppender( + inputStream, file, new TimeBasedRollingPolicy(interval, pattern), conf) + }.getOrElse { + new FileAppender(inputStream, file) + } + } + + def createSizeBasedAppender() = { + rollingSizeBytes match { + case IntParam(bytes) => + logInfo(s"Rolling executor logs enabled for $file with rolling every $bytes bytes") + new RollingFileAppender(inputStream, file, new SizeBasedRollingPolicy(bytes), conf) + case _ => + logWarning( + s"Illegal size [$rollingSizeBytes] for rolling executor logs, rolling logs not enabled") + new FileAppender(inputStream, file) + } + } + + rollingStrategy match { + case "" => + new FileAppender(inputStream, file) + case "time" => + createTimeBasedAppender() + case "size" => + createSizeBasedAppender() + case _ => + logWarning( + s"Illegal strategy [$rollingStrategy] for rolling executor logs, " + + s"rolling logs not enabled") + new FileAppender(inputStream, file) + } + } +} + + diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala new file mode 100644 index 0000000000000..1bbbd20cf076f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.logging + +import java.io.{File, FileFilter, InputStream} + +import org.apache.commons.io.FileUtils +import org.apache.spark.SparkConf +import RollingFileAppender._ + +/** + * Continuously appends data from input stream into the given file, and rolls + * over the file after the given interval. The rolled over files are named + * based on the given pattern. + * + * @param inputStream Input stream to read data from + * @param activeFile File to write data to + * @param rollingPolicy Policy based on which files will be rolled over. + * @param conf SparkConf that is used to pass on extra configurations + * @param bufferSize Optional buffer size. Used mainly for testing. + */ +private[spark] class RollingFileAppender( + inputStream: InputStream, + activeFile: File, + val rollingPolicy: RollingPolicy, + conf: SparkConf, + bufferSize: Int = DEFAULT_BUFFER_SIZE + ) extends FileAppender(inputStream, activeFile, bufferSize) { + + private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1) + + /** Stop the appender */ + override def stop() { + super.stop() + } + + /** Append bytes to file after rolling over is necessary */ + override protected def appendToFile(bytes: Array[Byte], len: Int) { + if (rollingPolicy.shouldRollover(len)) { + rollover() + rollingPolicy.rolledOver() + } + super.appendToFile(bytes, len) + rollingPolicy.bytesWritten(len) + } + + /** Rollover the file, by closing the output stream and moving it over */ + private def rollover() { + try { + closeFile() + moveFile() + openFile() + if (maxRetainedFiles > 0) { + deleteOldFiles() + } + } catch { + case e: Exception => + logError(s"Error rolling over $activeFile", e) + } + } + + /** Move the active log file to a new rollover file */ + private def moveFile() { + val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix() + val rolloverFile = new File( + activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile + try { + logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") + if (activeFile.exists) { + if (!rolloverFile.exists) { + FileUtils.moveFile(activeFile, rolloverFile) + logInfo(s"Rolled over $activeFile to $rolloverFile") + } else { + // In case the rollover file name clashes, make a unique file name. + // The resultant file names are long and ugly, so this is used only + // if there is a name collision. This can be avoided by the using + // the right pattern such that name collisions do not occur. + var i = 0 + var altRolloverFile: File = null + do { + altRolloverFile = new File(activeFile.getParent, + s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile + i += 1 + } while (i < 10000 && altRolloverFile.exists) + + logWarning(s"Rollover file $rolloverFile already exists, " + + s"rolled over $activeFile to file $altRolloverFile") + FileUtils.moveFile(activeFile, altRolloverFile) + } + } else { + logWarning(s"File $activeFile does not exist") + } + } + } + + /** Retain only last few files */ + private[util] def deleteOldFiles() { + try { + val rolledoverFiles = activeFile.getParentFile.listFiles(new FileFilter { + def accept(f: File): Boolean = { + f.getName.startsWith(activeFile.getName) && f != activeFile + } + }).sorted + val filesToBeDeleted = rolledoverFiles.take( + math.max(0, rolledoverFiles.size - maxRetainedFiles)) + filesToBeDeleted.foreach { file => + logInfo(s"Deleting file executor log file ${file.getAbsolutePath}") + file.delete() + } + } catch { + case e: Exception => + logError("Error cleaning logs in directory " + activeFile.getParentFile.getAbsolutePath, e) + } + } +} + +/** + * Companion object to [[org.apache.spark.util.logging.RollingFileAppender]]. Defines + * names of configurations that configure rolling file appenders. + */ +private[spark] object RollingFileAppender { + val STRATEGY_PROPERTY = "spark.executor.logs.rolling.strategy" + val STRATEGY_DEFAULT = "" + val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval" + val INTERVAL_DEFAULT = "daily" + val SIZE_PROPERTY = "spark.executor.logs.rolling.size.maxBytes" + val SIZE_DEFAULT = (1024 * 1024).toString + val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" + val DEFAULT_BUFFER_SIZE = 8192 + + /** + * Get the sorted list of rolled over files. This assumes that the all the rolled + * over file names are prefixed with the `activeFileName`, and the active file + * name has the latest logs. So it sorts all the rolled over logs (that are + * prefixed with `activeFileName`) and appends the active file + */ + def getSortedRolledOverFiles(directory: String, activeFileName: String): Seq[File] = { + val rolledOverFiles = new File(directory).getAbsoluteFile.listFiles.filter { file => + val fileName = file.getName + fileName.startsWith(activeFileName) && fileName != activeFileName + }.sorted + val activeFile = { + val file = new File(directory, activeFileName).getAbsoluteFile + if (file.exists) Some(file) else None + } + rolledOverFiles ++ activeFile + } +} diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala new file mode 100644 index 0000000000000..84e5c3c917dcb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.logging + +import java.text.SimpleDateFormat +import java.util.Calendar + +import org.apache.spark.Logging + +/** + * Defines the policy based on which [[org.apache.spark.util.logging.RollingFileAppender]] will + * generate rolling files. + */ +private[spark] trait RollingPolicy { + + /** Whether rollover should be initiated at this moment */ + def shouldRollover(bytesToBeWritten: Long): Boolean + + /** Notify that rollover has occurred */ + def rolledOver() + + /** Notify that bytes have been written */ + def bytesWritten(bytes: Long) + + /** Get the desired name of the rollover file */ + def generateRolledOverFileSuffix(): String +} + +/** + * Defines a [[org.apache.spark.util.logging.RollingPolicy]] by which files will be rolled + * over at a fixed interval. + */ +private[spark] class TimeBasedRollingPolicy( + var rolloverIntervalMillis: Long, + rollingFileSuffixPattern: String, + checkIntervalConstraint: Boolean = true // set to false while testing + ) extends RollingPolicy with Logging { + + import TimeBasedRollingPolicy._ + if (checkIntervalConstraint && rolloverIntervalMillis < MINIMUM_INTERVAL_SECONDS * 1000L) { + logWarning(s"Rolling interval [${rolloverIntervalMillis/1000L} seconds] is too small. " + + s"Setting the interval to the acceptable minimum of $MINIMUM_INTERVAL_SECONDS seconds.") + rolloverIntervalMillis = MINIMUM_INTERVAL_SECONDS * 1000L + } + + @volatile private var nextRolloverTime = calculateNextRolloverTime() + private val formatter = new SimpleDateFormat(rollingFileSuffixPattern) + + /** Should rollover if current time has exceeded next rollover time */ + def shouldRollover(bytesToBeWritten: Long): Boolean = { + System.currentTimeMillis > nextRolloverTime + } + + /** Rollover has occurred, so find the next time to rollover */ + def rolledOver() { + nextRolloverTime = calculateNextRolloverTime() + logDebug(s"Current time: ${System.currentTimeMillis}, next rollover time: " + nextRolloverTime) + } + + def bytesWritten(bytes: Long) { } // nothing to do + + private def calculateNextRolloverTime(): Long = { + val now = System.currentTimeMillis() + val targetTime = ( + math.ceil(now.toDouble / rolloverIntervalMillis) * rolloverIntervalMillis + ).toLong + logDebug(s"Next rollover time is $targetTime") + targetTime + } + + def generateRolledOverFileSuffix(): String = { + formatter.format(Calendar.getInstance.getTime) + } +} + +private[spark] object TimeBasedRollingPolicy { + val MINIMUM_INTERVAL_SECONDS = 60L // 1 minute +} + +/** + * Defines a [[org.apache.spark.util.logging.RollingPolicy]] by which files will be rolled + * over after reaching a particular size. + */ +private[spark] class SizeBasedRollingPolicy( + var rolloverSizeBytes: Long, + checkSizeConstraint: Boolean = true // set to false while testing + ) extends RollingPolicy with Logging { + + import SizeBasedRollingPolicy._ + if (checkSizeConstraint && rolloverSizeBytes < MINIMUM_SIZE_BYTES) { + logWarning(s"Rolling size [$rolloverSizeBytes bytes] is too small. " + + s"Setting the size to the acceptable minimum of $MINIMUM_SIZE_BYTES bytes.") + rolloverSizeBytes = MINIMUM_SIZE_BYTES + } + + @volatile private var bytesWrittenSinceRollover = 0L + val formatter = new SimpleDateFormat("--YYYY-MM-dd--HH-mm-ss--SSSS") + + /** Should rollover if the next set of bytes is going to exceed the size limit */ + def shouldRollover(bytesToBeWritten: Long): Boolean = { + logInfo(s"$bytesToBeWritten + $bytesWrittenSinceRollover > $rolloverSizeBytes") + bytesToBeWritten + bytesWrittenSinceRollover > rolloverSizeBytes + } + + /** Rollover has occurred, so reset the counter */ + def rolledOver() { + bytesWrittenSinceRollover = 0 + } + + /** Increment the bytes that have been written in the current file */ + def bytesWritten(bytes: Long) { + bytesWrittenSinceRollover += bytes + } + + /** Get the desired name of the rollover file */ + def generateRolledOverFileSuffix(): String = { + formatter.format(Calendar.getInstance.getTime) + } +} + +private[spark] object SizeBasedRollingPolicy { + val MINIMUM_SIZE_BYTES = RollingFileAppender.DEFAULT_BUFFER_SIZE * 10 +} + diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index bfae32dae0dc5..01ab2d549325c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.SparkConf class JsonProtocolSuite extends FunSuite { @@ -116,7 +117,8 @@ class JsonProtocolSuite extends FunSuite { } def createExecutorRunner(): ExecutorRunner = { new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", - new File("sparkHome"), new File("workDir"), "akka://worker", ExecutorState.RUNNING) + new File("sparkHome"), new File("workDir"), "akka://worker", + new SparkConf, ExecutorState.RUNNING) } def createDriverRunner(): DriverRunner = { new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(), diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 8ae387fa0be6f..e5f748d55500d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -22,6 +22,7 @@ import java.io.File import org.scalatest.FunSuite import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} +import org.apache.spark.SparkConf class ExecutorRunnerTest extends FunSuite { test("command includes appId") { @@ -32,7 +33,7 @@ class ExecutorRunnerTest extends FunSuite { sparkHome, "appUiUrl") val appId = "12345-worker321-9876" val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome.getOrElse(".")), - f("ooga"), "blah", ExecutorState.RUNNING) + f("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) assert(er.getCommandSeq.last === appId) } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala new file mode 100644 index 0000000000000..53d7f5c6072e6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io._ + +import scala.collection.mutable.HashSet +import scala.reflect._ + +import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.spark.{Logging, SparkConf} +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} + +class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { + + val testFile = new File("FileAppenderSuite-test-" + System.currentTimeMillis).getAbsoluteFile + + before { + cleanup() + } + + after { + cleanup() + } + + test("basic file appender") { + val testString = (1 to 1000).mkString(", ") + val inputStream = IOUtils.toInputStream(testString) + val appender = new FileAppender(inputStream, testFile) + inputStream.close() + appender.awaitTermination() + assert(FileUtils.readFileToString(testFile) === testString) + } + + test("rolling file appender - time-based rolling") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val rolloverIntervalMillis = 100 + val durationMillis = 1000 + val numRollovers = durationMillis / rolloverIntervalMillis + val textToAppend = (1 to numRollovers).map( _.toString * 10 ) + + val appender = new RollingFileAppender(testInputStream, testFile, + new TimeBasedRollingPolicy(rolloverIntervalMillis, s"--HH-mm-ss-SSSS", false), + new SparkConf(), 10) + + testRolling(appender, testOutputStream, textToAppend, rolloverIntervalMillis) + } + + test("rolling file appender - size-based rolling") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val rolloverSize = 1000 + val textToAppend = (1 to 3).map( _.toString * 1000 ) + + val appender = new RollingFileAppender(testInputStream, testFile, + new SizeBasedRollingPolicy(rolloverSize, false), new SparkConf(), 99) + + val files = testRolling(appender, testOutputStream, textToAppend, 0) + files.foreach { file => + logInfo(file.toString + ": " + file.length + " bytes") + assert(file.length <= rolloverSize) + } + } + + test("rolling file appender - cleaning") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val conf = new SparkConf().set(RollingFileAppender.RETAINED_FILES_PROPERTY, "10") + val appender = new RollingFileAppender(testInputStream, testFile, + new SizeBasedRollingPolicy(1000, false), conf, 10) + + // send data to appender through the input stream, and wait for the data to be written + val allGeneratedFiles = new HashSet[String]() + val items = (1 to 10).map { _.toString * 10000 } + for (i <- 0 until items.size) { + testOutputStream.write(items(i).getBytes("UTF8")) + testOutputStream.flush() + allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles( + testFile.getParentFile.toString, testFile.getName).map(_.toString) + + Thread.sleep(10) + } + testOutputStream.close() + appender.awaitTermination() + logInfo("Appender closed") + + // verify whether the earliest file has been deleted + val rolledOverFiles = allGeneratedFiles.filter { _ != testFile.toString }.toArray.sorted + logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + rolledOverFiles.mkString("\n")) + assert(rolledOverFiles.size > 2) + val earliestRolledOverFile = rolledOverFiles.head + val existingRolledOverFiles = RollingFileAppender.getSortedRolledOverFiles( + testFile.getParentFile.toString, testFile.getName).map(_.toString) + logInfo("Existing rolled over files:\n" + existingRolledOverFiles.mkString("\n")) + assert(!existingRolledOverFiles.toSet.contains(earliestRolledOverFile)) + } + + test("file appender selection") { + // Test whether FileAppender.apply() returns the right type of the FileAppender based + // on SparkConf settings. + + def testAppenderSelection[ExpectedAppender: ClassTag, ExpectedRollingPolicy]( + properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): FileAppender = { + + // Set spark conf properties + val conf = new SparkConf + properties.foreach { p => + conf.set(p._1, p._2) + } + + // Create and test file appender + val inputStream = new PipedInputStream(new PipedOutputStream()) + val appender = FileAppender(inputStream, new File("stdout"), conf) + assert(appender.isInstanceOf[ExpectedAppender]) + assert(appender.getClass.getSimpleName === + classTag[ExpectedAppender].runtimeClass.getSimpleName) + if (appender.isInstanceOf[RollingFileAppender]) { + val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy + rollingPolicy.isInstanceOf[ExpectedRollingPolicy] + val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) { + rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis + } else { + rollingPolicy.asInstanceOf[SizeBasedRollingPolicy].rolloverSizeBytes + } + assert(policyParam === expectedRollingPolicyParam) + } + appender + } + + import RollingFileAppender._ + + def rollingStrategy(strategy: String) = Seq(STRATEGY_PROPERTY -> strategy) + def rollingSize(size: String) = Seq(SIZE_PROPERTY -> size) + def rollingInterval(interval: String) = Seq(INTERVAL_PROPERTY -> interval) + + val msInDay = 24 * 60 * 60 * 1000L + val msInHour = 60 * 60 * 1000L + val msInMinute = 60 * 1000L + + // test no strategy -> no rolling + testAppenderSelection[FileAppender, Any](Seq.empty) + + // test time based rolling strategy + testAppenderSelection[RollingFileAppender, Any](rollingStrategy("time"), msInDay) + testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy]( + rollingStrategy("time") ++ rollingInterval("daily"), msInDay) + testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy]( + rollingStrategy("time") ++ rollingInterval("hourly"), msInHour) + testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy]( + rollingStrategy("time") ++ rollingInterval("minutely"), msInMinute) + testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy]( + rollingStrategy("time") ++ rollingInterval("123456789"), 123456789 * 1000L) + testAppenderSelection[FileAppender, Any]( + rollingStrategy("time") ++ rollingInterval("xyz")) + + // test size based rolling strategy + testAppenderSelection[RollingFileAppender, SizeBasedRollingPolicy]( + rollingStrategy("size") ++ rollingSize("123456789"), 123456789) + testAppenderSelection[FileAppender, Any](rollingSize("xyz")) + + // test illegal strategy + testAppenderSelection[FileAppender, Any](rollingStrategy("xyz")) + } + + /** + * Run the rolling file appender with data and see whether all the data was written correctly + * across rolled over files. + */ + def testRolling( + appender: FileAppender, + outputStream: OutputStream, + textToAppend: Seq[String], + sleepTimeBetweenTexts: Long + ): Seq[File] = { + // send data to appender through the input stream, and wait for the data to be written + val expectedText = textToAppend.mkString("") + for (i <- 0 until textToAppend.size) { + outputStream.write(textToAppend(i).getBytes("UTF8")) + outputStream.flush() + Thread.sleep(sleepTimeBetweenTexts) + } + logInfo("Data sent to appender") + outputStream.close() + appender.awaitTermination() + logInfo("Appender closed") + + // verify whether all the data written to rolled over files is same as expected + val generatedFiles = RollingFileAppender.getSortedRolledOverFiles( + testFile.getParentFile.toString, testFile.getName) + logInfo("Filtered files: \n" + generatedFiles.mkString("\n")) + assert(generatedFiles.size > 1) + val allText = generatedFiles.map { file => + FileUtils.readFileToString(file) + }.mkString("") + assert(allText === expectedText) + generatedFiles + } + + /** Delete all the generated rolledover files */ + def cleanup() { + testFile.getParentFile.listFiles.filter { file => + file.getName.startsWith(testFile.getName) + }.foreach { _.delete() } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 0aad882ed76a8..1ee936bc78f49 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -140,6 +140,38 @@ class UtilsSuite extends FunSuite { Utils.deleteRecursively(tmpDir2) } + test("reading offset bytes across multiple files") { + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val files = (1 to 3).map(i => new File(tmpDir, i.toString)) + Files.write("0123456789", files(0), Charsets.UTF_8) + Files.write("abcdefghij", files(1), Charsets.UTF_8) + Files.write("ABCDEFGHIJ", files(2), Charsets.UTF_8) + + // Read first few bytes in the 1st file + assert(Utils.offsetBytes(files, 0, 5) === "01234") + + // Read bytes within the 1st file + assert(Utils.offsetBytes(files, 5, 8) === "567") + + // Read bytes across 1st and 2nd file + assert(Utils.offsetBytes(files, 8, 18) === "89abcdefgh") + + // Read bytes across 1st, 2nd and 3rd file + assert(Utils.offsetBytes(files, 5, 24) === "56789abcdefghijABCD") + + // Read some nonexistent bytes in the beginning + assert(Utils.offsetBytes(files, -5, 18) === "0123456789abcdefgh") + + // Read some nonexistent bytes at the end + assert(Utils.offsetBytes(files, 18, 35) === "ijABCDEFGHIJ") + + // Read some nonexistent bytes on both ends + assert(Utils.offsetBytes(files, -5, 35) === "0123456789abcdefghijABCDEFGHIJ") + + Utils.deleteRecursively(tmpDir) + } + test("deserialize long value") { val testval : Long = 9730889947L val bbuf = ByteBuffer.allocate(8) diff --git a/docs/configuration.md b/docs/configuration.md index 71fafa573467f..b84104cc7e653 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -784,6 +784,45 @@ Apart from these, the following properties are also available, and may be useful higher memory usage in Spark. + + spark.executor.logs.rolling.strategy + (none) + + Set the strategy of rolling of executor logs. By default it is disabled. It can + be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", + use spark.executor.logs.rolling.time.interval to set the rolling interval. + For "size", use spark.executor.logs.rolling.size.maxBytes to set + the maximum file size for rolling. + + + + spark.executor.logs.rolling.time.interval + daily + + Set the time interval by which the executor logs will be rolled over. + Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles + for automatic cleaning of old logs. + + + + spark.executor.logs.rolling.size.maxBytes + (none) + + Set the max size of the file by which the executor logs will be rolled over. + Rolling is disabled by default. Value is set in terms of bytes. + See spark.executor.logs.rolling.maxRetainedFiles + for automatic cleaning of old logs. + + + + spark.executor.logs.rolling.maxRetainedFiles + (none) + + Sets the number of latest rolling log files that are going to be retained by the system. + Older log files will be deleted. Disabled by default. + + #### Cluster Managers From c48b6222ea60f2a31741297f03737c89fd572e52 Mon Sep 17 00:00:00 2001 From: witgo Date: Tue, 10 Jun 2014 20:24:05 -0700 Subject: [PATCH 03/38] Resolve scalatest warnings during build Author: witgo Closes #1032 from witgo/ShouldMatchers and squashes the following commits: 7ebf34c [witgo] Resolve scalatest warnings during build --- core/src/test/scala/org/apache/spark/AccumulatorSuite.scala | 4 ++-- core/src/test/scala/org/apache/spark/DistributedSuite.scala | 4 ++-- .../test/scala/org/apache/spark/JobCancellationSuite.scala | 4 ++-- core/src/test/scala/org/apache/spark/ShuffleSuite.scala | 4 ++-- .../src/test/scala/org/apache/spark/deploy/ClientSuite.scala | 4 ++-- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 4 ++-- core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala | 4 ++-- .../org/apache/spark/scheduler/SparkListenerSuite.scala | 4 ++-- .../scala/org/apache/spark/storage/BlockManagerSuite.scala | 5 +++-- .../org/apache/spark/ui/jobs/JobProgressListenerSuite.scala | 4 ++-- .../test/scala/org/apache/spark/util/DistributionSuite.scala | 4 ++-- .../test/scala/org/apache/spark/util/NextIteratorSuite.scala | 4 ++-- .../org/apache/spark/util/collection/OpenHashMapSuite.scala | 4 ++-- .../org/apache/spark/util/collection/OpenHashSetSuite.scala | 4 ++-- .../spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala | 4 ++-- .../org/apache/spark/util/random/XORShiftRandomSuite.scala | 4 ++-- .../spark/mllib/classification/LogisticRegressionSuite.scala | 4 ++-- .../spark/mllib/optimization/GradientDescentSuite.scala | 4 ++-- .../org/apache/spark/mllib/optimization/LBFGSSuite.scala | 4 ++-- .../org/apache/spark/streaming/NetworkReceiverSuite.scala | 1 - .../org/apache/spark/streaming/StreamingListenerSuite.scala | 4 ++-- 21 files changed, 41 insertions(+), 41 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 4e7c34e6d1ada..3aab88e9e9196 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark import scala.collection.mutable import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.SparkContext._ -class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { +class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 14ddd6f1ec08f..41c294f727b3c 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} import org.apache.spark.SparkContext._ @@ -31,7 +31,7 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter +class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 2c8ef405c944c..a57430e829ced 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -25,7 +25,7 @@ import scala.concurrent.duration._ import scala.concurrent.future import org.scalatest.{BeforeAndAfter, FunSuite} -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.SparkContext._ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} @@ -35,7 +35,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers * in both FIFO and fair scheduling modes. */ -class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter +class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter with LocalSparkContext { override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index be6508a40ea61..7b0607dd3ed2d 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.SparkContext._ import org.apache.spark.ShuffleSuite.NonJavaSerializableClass @@ -26,7 +26,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.MutablePair -class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { +class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val conf = new SparkConf(loadDefaults = false) diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index d6b93f5fedd3b..4161aede1d1d0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.deploy import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers -class ClientSuite extends FunSuite with ShouldMatchers { +class ClientSuite extends FunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) ClientArguments.isValidJarUrl("file://some/path/to/a/jarFile.jar") should be (true) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 02427a4a83506..565c53e9529ff 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkException, Test import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers -class SparkSubmitSuite extends FunSuite with ShouldMatchers { +class SparkSubmitSuite extends FunSuite with Matchers { def beforeAll() { System.setProperty("spark.testing", "true") } diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index d0619559bb457..656917628f7a8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.rdd import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.{Logging, SharedSparkContext} import org.apache.spark.SparkContext._ -class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { +class SortingSuite extends FunSuite with SharedSparkContext with Matchers with Logging { test("sortByKey") { val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 5426e578a9ddd..be506e0287a16 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -22,13 +22,13 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.executor.TaskMetrics -class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers +class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 81bd8257bc155..d7dbe5164b7f6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,7 +28,7 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ -import org.scalatest.matchers.ShouldMatchers._ +import org.scalatest.Matchers import org.scalatest.time.SpanSugar._ import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} @@ -39,7 +39,8 @@ import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, U import scala.language.implicitConversions import scala.language.postfixOps -class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { +class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter + with PrivateMethodTester { private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 8c06a2d9aa4ab..91b4c7b0dd962 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -18,14 +18,14 @@ package org.apache.spark.ui.jobs import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.{LocalSparkContext, SparkConf, Success} import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils -class JobProgressListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { +class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { test("test LRU eviction of stages") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala index 63642461e4465..090d48ec921a1 100644 --- a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.util import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers /** * */ -class DistributionSuite extends FunSuite with ShouldMatchers { +class DistributionSuite extends FunSuite with Matchers { test("summary") { val d = new Distribution((1 to 100).toArray.map{_.toDouble}) val stats = d.statCounter diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index 32d74d0500b72..cf438a3d72a06 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -22,9 +22,9 @@ import java.util.NoSuchElementException import scala.collection.mutable.Buffer import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers -class NextIteratorSuite extends FunSuite with ShouldMatchers { +class NextIteratorSuite extends FunSuite with Matchers { test("one iteration") { val i = new StubIterator(Buffer(1)) i.hasNext should be === true diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index b024c89d94d33..6a70877356409 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.util.SizeEstimator -class OpenHashMapSuite extends FunSuite with ShouldMatchers { +class OpenHashMapSuite extends FunSuite with Matchers { test("size for specialized, primitive value (int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index ff4a98f5dcd4a..68a03e3a0970f 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.util.collection import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.util.SizeEstimator -class OpenHashSetSuite extends FunSuite with ShouldMatchers { +class OpenHashSetSuite extends FunSuite with Matchers { test("size for specialized, primitive int") { val loadFactor = 0.7 diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index e3fca173908e9..8c7df7d73dcd3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.util.SizeEstimator -class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers { +class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers { test("size for specialized, primitive key, value (int, int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 0865c6386f7cd..e15fd59a5a8bb 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.util.random import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.util.Utils.times import scala.language.reflectiveCalls -class XORShiftRandomSuite extends FunSuite with ShouldMatchers { +class XORShiftRandomSuite extends FunSuite with Matchers { def fixture = new { val seed = 1L diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 4d7b984e3ec29..44b757b6a1fb7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import scala.collection.JavaConversions._ import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ @@ -56,7 +56,7 @@ object LogisticRegressionSuite { } } -class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers { +class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 8a16284118cf7..951b4f7c6e6f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import scala.collection.JavaConversions._ import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.LocalSparkContext @@ -61,7 +61,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers { +class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 820eca9b1bf65..4b1850659a18e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.mllib.optimization import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext -class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { +class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val nPoints = 10000 val A = 2.0 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index 303d149d285e1..d9ac3c91f6e36 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -29,7 +29,6 @@ import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import scala.language.postfixOps /** Testsuite for testing the network receiver behavior */ class NetworkReceiverSuite extends FunSuite with Timeouts { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index ef0efa552ceaf..2861f5335ae36 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -27,12 +27,12 @@ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler._ -import org.scalatest.matchers.ShouldMatchers +import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.Logging -class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers { +class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) From a2052a44f3c08076d8d1ac505c8eb5395141bf79 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Tue, 10 Jun 2014 21:49:08 -0700 Subject: [PATCH 04/38] [SPARK-2065] give launched instances names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This update resolves [SPARK-2065](https://issues.apache.org/jira/browse/SPARK-2065). It gives launched EC2 instances descriptive names by using instance tags. Launched instances now show up in the EC2 console with these names. I used `format()` with named parameters, which I believe is the recommended practice for string formatting in Python, but which doesn’t seem to be used elsewhere in the script. Author: Nicholas Chammas Author: nchammas Closes #1043 from nchammas/master and squashes the following commits: 69f6e22 [Nicholas Chammas] PEP8 fixes 2627247 [Nicholas Chammas] broke up lines before they hit 100 chars 6544b7e [Nicholas Chammas] [SPARK-2065] give launched instances names 69da6cf [nchammas] Merge pull request #1 from apache/master --- ec2/spark_ec2.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9d5748ba4bc23..52a89cb2481ca 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -200,6 +200,7 @@ def get_spark_shark_version(opts): sys.exit(1) return (version, spark_shark_map[version]) + # Attempt to resolve an appropriate AMI given the architecture and # region of the request. def get_spark_ami(opts): @@ -418,6 +419,16 @@ def launch_cluster(conn, opts, cluster_name): master_nodes = master_res.instances print "Launched master in %s, regid = %s" % (zone, master_res.id) + # Give the instances descriptive names + for master in master_nodes: + master.add_tag( + key='Name', + value='spark-{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + for slave in slave_nodes: + slave.add_tag( + key='Name', + value='spark-{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + # Return all the instances return (master_nodes, slave_nodes) From 601032f5bfe2dcdc240bfcc553f401e6facbf5ec Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Tue, 10 Jun 2014 21:59:01 -0700 Subject: [PATCH 05/38] HOTFIX: clear() configs in SQLConf-related unit tests. Thanks goes to @liancheng, who pointed out that `sql/test-only *.SQLConfSuite *.SQLQuerySuite` passed but `sql/test-only *.SQLQuerySuite *.SQLConfSuite` failed. The reason is that some tests use the same test keys and without clear()'ing, they get carried over to other tests. This hotfix simply adds some `clear()` calls. This problem was not evident on Jenkins before probably because `parallelExecution` is not set to `false` for `sqlCoreSettings`. Author: Zongheng Yang Closes #1040 from concretevitamin/sqlconf-tests and squashes the following commits: 6d14ceb [Zongheng Yang] HOTFIX: clear() confs in SQLConf related unit tests. --- sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala | 2 ++ .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 1 + 2 files changed, 3 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 5eb73a4eff980..08293f7f0ca30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -28,6 +28,7 @@ class SQLConfSuite extends QueryTest { val testVal = "test.val.0" test("programmatic ways of basic setting and getting") { + clear() assert(getOption(testKey).isEmpty) assert(getAll.toSet === Set()) @@ -48,6 +49,7 @@ class SQLConfSuite extends QueryTest { } test("parse SQL set commands") { + clear() sql(s"set $testKey=$testVal") assert(get(testKey, testVal + "_") == testVal) assert(TestSQLContext.get(testKey, testVal + "_") == testVal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index de02bbc7e4700..c1fc99f077431 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -402,6 +402,7 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $nonexistentKey"), Seq(Seq(s"$nonexistentKey is undefined")) ) + clear() } } From 0402bd77ec786d1fa6cfd7f9cc3aa97c7ab16fd8 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 10 Jun 2014 23:13:48 -0700 Subject: [PATCH 06/38] [SPARK-2093] [SQL] NullPropagation should use exact type value. `NullPropagation` should use exact type value when transform `Count` or `Sum`. Author: Takuya UESHIN Closes #1034 from ueshin/issues/SPARK-2093 and squashes the following commits: 65b6ff1 [Takuya UESHIN] Modify the literal value of the result of transformation from Sum to long value. 830c20b [Takuya UESHIN] Add Cast to the result of transformation from Count. 9314806 [Takuya UESHIN] Fix NullPropagation to use exact type value. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ccb8245cc2e7d..e41fd2db74858 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -104,8 +104,8 @@ object ColumnPruning extends Rule[LogicalPlan] { object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ Count(Literal(null, _)) => Literal(0, e.dataType) - case e @ Sum(Literal(c, _)) if c == 0 => Literal(0, e.dataType) + case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) + case e @ Sum(Literal(c, _)) if c == 0 => Cast(Literal(0L), e.dataType) case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType) case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType) case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType) From 0266a0c8a70e0fbaeb0df63031f7a750ffc31a80 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 11 Jun 2014 00:06:50 -0700 Subject: [PATCH 07/38] [SPARK-1968][SQL] SQL/HiveQL command for caching/uncaching tables JIRA issue: [SPARK-1968](https://issues.apache.org/jira/browse/SPARK-1968) This PR added support for SQL/HiveQL command for caching/uncaching tables: ``` scala> sql("CACHE TABLE src") ... res0: org.apache.spark.sql.SchemaRDD = SchemaRDD[0] at RDD at SchemaRDD.scala:98 == Query Plan == CacheCommandPhysical src, true scala> table("src") ... res1: org.apache.spark.sql.SchemaRDD = SchemaRDD[3] at RDD at SchemaRDD.scala:98 == Query Plan == InMemoryColumnarTableScan [key#0,value#1], (HiveTableScan [key#0,value#1], (MetastoreRelation default, src, None), None), false scala> isCached("src") res2: Boolean = true scala> sql("CACHE TABLE src") ... res3: org.apache.spark.sql.SchemaRDD = SchemaRDD[4] at RDD at SchemaRDD.scala:98 == Query Plan == CacheCommandPhysical src, false scala> table("src") ... res4: org.apache.spark.sql.SchemaRDD = SchemaRDD[11] at RDD at SchemaRDD.scala:98 == Query Plan == HiveTableScan [key#2,value#3], (MetastoreRelation default, src, None), None scala> isCached("src") res5: Boolean = false ``` Things also work for `hql`. Author: Cheng Lian Closes #1038 from liancheng/sqlCacheTable and squashes the following commits: ecb7194 [Cheng Lian] Trimmed the SQL string before parsing special commands 6f4ce42 [Cheng Lian] Moved logical command classes to a separate file 3458a24 [Cheng Lian] Added comment for public API f0ffacc [Cheng Lian] Added isCached() predicate 15ec6d2 [Cheng Lian] Added "(UN)CACHE TABLE" SQL/HiveQL statements --- .../apache/spark/sql/catalyst/SqlParser.scala | 10 +++- .../catalyst/plans/logical/LogicalPlan.scala | 35 +---------- .../sql/catalyst/plans/logical/commands.scala | 60 +++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 9 +++ .../spark/sql/execution/SparkStrategies.scala | 7 ++- .../apache/spark/sql/execution/commands.scala | 23 +++++++ .../apache/spark/sql/CachedTableSuite.scala | 16 +++++ .../org/apache/spark/sql/hive/HiveQl.scala | 18 +++--- .../org/apache/spark/sql/hive/TestHive.scala | 5 +- .../spark/sql/hive/CachedTableSuite.scala | 16 +++++ 10 files changed, 152 insertions(+), 47 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 36758f3114e59..46fcfbb9e26ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -111,6 +111,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AVG = Keyword("AVG") protected val BY = Keyword("BY") + protected val CACHE = Keyword("CACHE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") @@ -149,7 +150,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SEMI = Keyword("SEMI") protected val STRING = Keyword("STRING") protected val SUM = Keyword("SUM") + protected val TABLE = Keyword("TABLE") protected val TRUE = Keyword("TRUE") + protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") protected val WHERE = Keyword("WHERE") @@ -189,7 +192,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert + | insert | cache ) protected lazy val select: Parser[LogicalPlan] = @@ -220,6 +223,11 @@ class SqlParser extends StandardTokenParsers with PackratParsers { InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) } + protected lazy val cache: Parser[LogicalPlan] = + (CACHE ^^^ true | UNCACHE ^^^ false) ~ TABLE ~ ident ^^ { + case doCache ~ _ ~ tableName => CacheCommand(tableName, doCache) + } + protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") protected lazy val projection: Parser[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7eeb98aea6368..0933a31c362d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.types.{StringType, StructType} +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees abstract class LogicalPlan extends QueryPlan[LogicalPlan] { @@ -96,39 +96,6 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { def references = Set.empty } -/** - * A logical node that represents a non-query command to be executed by the system. For example, - * commands can be used by parsers to represent DDL operations. - */ -abstract class Command extends LeafNode { - self: Product => - def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this -} - -/** - * Returned for commands supported by a given parser, but not catalyst. In general these are DDL - * commands that are passed directly to another system. - */ -case class NativeCommand(cmd: String) extends Command - -/** - * Commands of the form "SET (key) (= value)". - */ -case class SetCommand(key: Option[String], value: Option[String]) extends Command { - override def output = Seq( - AttributeReference("key", StringType, nullable = false)(), - AttributeReference("value", StringType, nullable = false)() - ) -} - -/** - * Returned by a parser when the users only wants to see what query plan would be executed, without - * actually performing the execution. - */ -case class ExplainCommand(plan: LogicalPlan) extends Command { - override def output = Seq(AttributeReference("plan", StringType, nullable = false)()) -} - /** * A logical plan node with single child. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala new file mode 100644 index 0000000000000..d05c9652753e0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.types.StringType + +/** + * A logical node that represents a non-query command to be executed by the system. For example, + * commands can be used by parsers to represent DDL operations. + */ +abstract class Command extends LeafNode { + self: Product => + def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this +} + +/** + * Returned for commands supported by a given parser, but not catalyst. In general these are DDL + * commands that are passed directly to another system. + */ +case class NativeCommand(cmd: String) extends Command + +/** + * Commands of the form "SET (key) (= value)". + */ +case class SetCommand(key: Option[String], value: Option[String]) extends Command { + override def output = Seq( + AttributeReference("key", StringType, nullable = false)(), + AttributeReference("value", StringType, nullable = false)() + ) +} + +/** + * Returned by a parser when the users only wants to see what query plan would be executed, without + * actually performing the execution. + */ +case class ExplainCommand(plan: LogicalPlan) extends Command { + override def output = Seq(AttributeReference("plan", StringType, nullable = false)()) +} + +/** + * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. + */ +case class CacheCommand(tableName: String, doCache: Boolean) extends Command + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 021e0e8245a0d..264192ed1aa26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -188,6 +188,15 @@ class SQLContext(@transient val sparkContext: SparkContext) } } + /** Returns true if the table is currently cached in-memory. */ + def isCached(tableName: String): Boolean = { + val relation = catalog.lookupRelation(None, tableName) + EliminateAnalysisOperators(relation) match { + case SparkLogicalPlan(_: InMemoryColumnarTableScan) => true + case _ => false + } + } + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext = self.sparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0455748d40eec..f2f95dfe27e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -239,10 +239,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.SetCommand(key, value) => Seq(execution.SetCommandPhysical(key, value, plan.output)(context)) case logical.ExplainCommand(child) => - val qe = context.executePlan(child) - Seq(execution.ExplainCommandPhysical(qe.executedPlan, plan.output)(context)) + val executedPlan = context.executePlan(child).executedPlan + Seq(execution.ExplainCommandPhysical(executedPlan, plan.output)(context)) + case logical.CacheCommand(tableName, cache) => + Seq(execution.CacheCommandPhysical(tableName, cache)(context)) case _ => Nil } } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 9364506691f38..be26d19e66862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -65,3 +65,26 @@ case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) override def otherCopyArgs = context :: Nil } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CacheCommandPhysical(tableName: String, doCache: Boolean)(@transient context: SQLContext) + extends LeafNode { + + lazy val commandSideEffect = { + if (doCache) { + context.cacheTable(tableName) + } else { + context.uncacheTable(tableName) + } + } + + override def execute(): RDD[Row] = { + commandSideEffect + context.emptyResult + } + + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0331f90272a99..ebca3adc2ff01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -70,4 +70,20 @@ class CachedTableSuite extends QueryTest { TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key") TestSQLContext.uncacheTable("testData") } + + test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { + TestSQLContext.sql("CACHE TABLE testData") + TestSQLContext.table("testData").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => // Found evidence of caching + case _ => fail(s"Table 'testData' should be cached") + } + assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached") + + TestSQLContext.sql("UNCACHE TABLE testData") + TestSQLContext.table("testData").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached") + case _ => // Found evidence of uncaching + } + assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 4e74d9bc909fa..b745d8ffd8f17 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -218,15 +218,19 @@ private[hive] object HiveQl { case Array(key, value) => // "set key=value" SetCommand(Some(key), Some(value)) } - } else if (sql.toLowerCase.startsWith("add jar")) { + } else if (sql.trim.toLowerCase.startsWith("cache table")) { + CacheCommand(sql.drop(12).trim, true) + } else if (sql.trim.toLowerCase.startsWith("uncache table")) { + CacheCommand(sql.drop(14).trim, false) + } else if (sql.trim.toLowerCase.startsWith("add jar")) { AddJar(sql.drop(8)) - } else if (sql.toLowerCase.startsWith("add file")) { + } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.drop(9)) - } else if (sql.startsWith("dfs")) { + } else if (sql.trim.startsWith("dfs")) { DfsCommand(sql) - } else if (sql.startsWith("source")) { + } else if (sql.trim.startsWith("source")) { SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath }) - } else if (sql.startsWith("!")) { + } else if (sql.trim.startsWith("!")) { ShellCommand(sql.drop(1)) } else { val tree = getAst(sql) @@ -839,11 +843,11 @@ private[hive] object HiveQl { case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg)) case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg)) - + /* System functions about string operations */ case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg)) case Token("TOK_FUNCTION", Token(LOWER(), Nil) :: arg :: Nil) => Lower(nodeToExpr(arg)) - + /* Casts */ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), StringType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 041e813598d1b..9386008d02d51 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, NativeCommand} +import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive._ @@ -103,7 +103,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) } else { - new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + + new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + File.separator + "resources") } @@ -130,6 +130,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { override lazy val analyzed = { val describedTables = logical match { case NativeCommand(describedTable(tbl)) => tbl :: Nil + case CacheCommand(tbl, _) => tbl :: Nil case _ => Nil } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index f9a162ef4e3c0..91ac03ca30cd7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -56,4 +56,20 @@ class CachedTableSuite extends HiveComparisonTest { TestHive.uncacheTable("src") } } + + test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { + TestHive.hql("CACHE TABLE src") + TestHive.table("src").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => // Found evidence of caching + case _ => fail(s"Table 'src' should be cached") + } + assert(TestHive.isCached("src"), "Table 'src' should be cached") + + TestHive.hql("UNCACHE TABLE src") + TestHive.table("src").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached") + case _ => // Found evidence of uncaching + } + assert(!TestHive.isCached("src"), "Table 'src' should not be cached") + } } From 0f1dc3a73d6687f0027a68eb9a4a13a7c7848205 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 11 Jun 2014 00:22:40 -0700 Subject: [PATCH 08/38] [SPARK-2091][MLLIB] use numpy.dot instead of ndarray.dot `ndarray.dot` is not available in numpy 1.4. This PR makes pyspark/mllib compatible with numpy 1.4. Author: Xiangrui Meng Closes #1035 from mengxr/numpy-1.4 and squashes the following commits: 7ad2f0c [Xiangrui Meng] use numpy.dot instead of ndarray.dot --- python/pyspark/mllib/_common.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index a411a5d5914e0..e609b60a0f968 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -454,7 +454,7 @@ def _squared_distance(v1, v2): v2 = _convert_vector(v2) if type(v1) == ndarray and type(v2) == ndarray: diff = v1 - v2 - return diff.dot(diff) + return numpy.dot(diff, diff) elif type(v1) == ndarray: return v2.squared_distance(v1) else: @@ -469,10 +469,12 @@ def _dot(vec, target): calling numpy.dot of the two vectors, but for SciPy ones, we have to transpose them because they're column vectors. """ - if type(vec) == ndarray or type(vec) == SparseVector: + if type(vec) == ndarray: + return numpy.dot(vec, target) + elif type(vec) == SparseVector: return vec.dot(target) elif type(vec) == list: - return _convert_vector(vec).dot(target) + return numpy.dot(_convert_vector(vec), target) else: return vec.transpose().dot(target)[0] From 6e11930310e3864790d0e30f0df7bf691cbeb85d Mon Sep 17 00:00:00 2001 From: "Qiuzhuang.Lian" Date: Wed, 11 Jun 2014 00:36:06 -0700 Subject: [PATCH 09/38] SPARK-2107: FilterPushdownSuite doesn't need Junit jar. Author: Qiuzhuang.Lian Closes #1046 from Qiuzhuang/master and squashes the following commits: 0a9921a [Qiuzhuang.Lian] SPARK-2107: FilterPushdownSuite doesn't need Junit jar. --- .../spark/sql/catalyst/optimizer/FilterPushdownSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 02cc665f8a8c7..0cada785b6630 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -20,14 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.RightOuter import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.junit.Test class FilterPushdownSuite extends OptimizerTest { From 2a4225dd944441d3f735625bb6bae6fca8fd06ca Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 11 Jun 2014 07:57:28 -0500 Subject: [PATCH 10/38] SPARK-1639. Tidy up some Spark on YARN code This contains a bunch of small tidyings of the Spark on YARN code. I focused on the yarn stable code. @tgravescs, let me know if you'd like me to make these for the alpha code as well. Author: Sandy Ryza Closes #561 from sryza/sandy-spark-1639 and squashes the following commits: 72b6a02 [Sandy Ryza] Fix comment and set name on driver thread c2190b2 [Sandy Ryza] SPARK-1639. Tidy up some Spark on YARN code --- .../spark/deploy/yarn/ApplicationMaster.scala | 16 +- .../apache/spark/deploy/yarn/ClientBase.scala | 38 ++-- .../deploy/yarn/ExecutorRunnableUtil.scala | 28 +-- .../cluster/YarnClusterScheduler.scala | 10 +- .../spark/deploy/yarn/ApplicationMaster.scala | 197 +++++++++--------- .../org/apache/spark/deploy/yarn/Client.scala | 10 +- .../spark/deploy/yarn/ExecutorLauncher.scala | 40 ++-- 7 files changed, 161 insertions(+), 178 deletions(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 8f0ecb855718e..1cc9c33cd2d02 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -277,7 +277,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) ApplicationMaster.incrementAllocatorLoop(1) - Thread.sleep(100) + Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) } } finally { // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT, @@ -416,6 +416,7 @@ object ApplicationMaster { // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + private val ALLOCATE_HEARTBEAT_INTERVAL = 100 def incrementAllocatorLoop(by: Int) { val count = yarnAllocatorLoop.getAndAdd(by) @@ -467,13 +468,22 @@ object ApplicationMaster { }) } - // Wait for initialization to complete and atleast 'some' nodes can get allocated. + modified + } + + + /** + * Returns when we've either + * 1) received all the requested executors, + * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms, + * 3) hit an error that causes us to terminate trying to get containers. + */ + def waitForInitialAllocations() { yarnAllocatorLoop.synchronized { while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) { yarnAllocatorLoop.wait(1000L) } } - modified } def main(argStrings: Array[String]) { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 801e8b381588f..29a35680c0e72 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} -import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, ListBuffer, Map} @@ -37,7 +36,7 @@ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.{Apps, Records} +import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SparkConf, SparkContext} /** @@ -169,14 +168,13 @@ trait ClientBase extends Logging { destPath } - def qualifyForLocal(localURI: URI): Path = { + private def qualifyForLocal(localURI: URI): Path = { var qualifiedURI = localURI - // If not specified assume these are in the local filesystem to keep behavior like Hadoop + // If not specified, assume these are in the local filesystem to keep behavior like Hadoop if (qualifiedURI.getScheme() == null) { qualifiedURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(qualifiedURI)).toString) } - val qualPath = new Path(qualifiedURI) - qualPath + new Path(qualifiedURI) } def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { @@ -305,13 +303,13 @@ trait ClientBase extends Logging { val amMemory = calculateAMMemory(newApp) - val JAVA_OPTS = ListBuffer[String]() + val javaOpts = ListBuffer[String]() // Add Xmx for AM memory - JAVA_OPTS += "-Xmx" + amMemory + "m" + javaOpts += "-Xmx" + amMemory + "m" val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) - JAVA_OPTS += "-Djava.io.tmpdir=" + tmpDir + javaOpts += "-Djava.io.tmpdir=" + tmpDir // TODO: Remove once cpuset version is pushed out. // The context is, default gc for server class machines ends up using all cores to do gc - @@ -325,11 +323,11 @@ trait ClientBase extends Logging { if (useConcurrentAndIncrementalGC) { // In our expts, using (default) throughput collector has severe perf ramifications in // multi-tenant machines - JAVA_OPTS += "-XX:+UseConcMarkSweepGC" - JAVA_OPTS += "-XX:+CMSIncrementalMode" - JAVA_OPTS += "-XX:+CMSIncrementalPacing" - JAVA_OPTS += "-XX:CMSIncrementalDutyCycleMin=0" - JAVA_OPTS += "-XX:CMSIncrementalDutyCycle=10" + javaOpts += "-XX:+UseConcMarkSweepGC" + javaOpts += "-XX:+CMSIncrementalMode" + javaOpts += "-XX:+CMSIncrementalPacing" + javaOpts += "-XX:CMSIncrementalDutyCycleMin=0" + javaOpts += "-XX:CMSIncrementalDutyCycle=10" } // SPARK_JAVA_OPTS is deprecated, but for backwards compatibility: @@ -344,22 +342,22 @@ trait ClientBase extends Logging { // If we are being launched in client mode, forward the spark-conf options // onto the executor launcher for ((k, v) <- sparkConf.getAll) { - JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" + javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } } else { // If we are being launched in standalone mode, capture and forward any spark // system properties (e.g. set by spark-class). for ((k, v) <- sys.props.filterKeys(_.startsWith("spark"))) { - JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" + javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } - sys.props.get("spark.driver.extraJavaOptions").foreach(opts => JAVA_OPTS += opts) - sys.props.get("spark.driver.libraryPath").foreach(p => JAVA_OPTS += s"-Djava.library.path=$p") + sys.props.get("spark.driver.extraJavaOptions").foreach(opts => javaOpts += opts) + sys.props.get("spark.driver.libraryPath").foreach(p => javaOpts += s"-Djava.library.path=$p") } - JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources) + javaOpts += ClientBase.getLog4jConfiguration(localResources) // Command for the ApplicationMaster val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ - JAVA_OPTS ++ + javaOpts ++ Seq(args.amClass, "--class", args.userClass, "--jar ", args.userJar, userArgsToString(args), "--executor-memory", args.executorMemory.toString, diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 32f8861dc9503..43dbb2464f929 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{Logging, SparkConf} @@ -46,19 +46,19 @@ trait ExecutorRunnableUtil extends Logging { executorCores: Int, localResources: HashMap[String, LocalResource]): List[String] = { // Extra options for the JVM - val JAVA_OPTS = ListBuffer[String]() + val javaOpts = ListBuffer[String]() // Set the JVM memory val executorMemoryString = executorMemory + "m" - JAVA_OPTS += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " + javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " // Set extra Java options for the executor, if defined sys.props.get("spark.executor.extraJavaOptions").foreach { opts => - JAVA_OPTS += opts + javaOpts += opts } - JAVA_OPTS += "-Djava.io.tmpdir=" + + javaOpts += "-Djava.io.tmpdir=" + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) - JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources) + javaOpts += ClientBase.getLog4jConfiguration(localResources) // Certain configs need to be passed here because they are needed before the Executor // registers with the Scheduler and transfers the spark configs. Since the Executor backend @@ -66,10 +66,10 @@ trait ExecutorRunnableUtil extends Logging { // authentication settings. sparkConf.getAll. filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. - foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } sparkConf.getAkkaConf. - foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. @@ -88,11 +88,11 @@ trait ExecutorRunnableUtil extends Logging { // multi-tennent machines // The options are based on // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline - JAVA_OPTS += " -XX:+UseConcMarkSweepGC " - JAVA_OPTS += " -XX:+CMSIncrementalMode " - JAVA_OPTS += " -XX:+CMSIncrementalPacing " - JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " - JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + javaOpts += " -XX:+UseConcMarkSweepGC " + javaOpts += " -XX:+CMSIncrementalMode " + javaOpts += " -XX:+CMSIncrementalPacing " + javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 " + javaOpts += " -XX:CMSIncrementalDutyCycle=10 " } */ @@ -104,7 +104,7 @@ trait ExecutorRunnableUtil extends Logging { // TODO: If the OOM is not recoverable by rescheduling it on different node, then do // 'something' to fail job ... akin to blacklisting trackers in mapred ? "-XX:OnOutOfMemoryError='kill %p'") ++ - JAVA_OPTS ++ + javaOpts ++ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", masterAddress.toString, slaveId.toString, diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index a4638cc863611..39cdd2e8a522b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -33,10 +33,11 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) def this(sc: SparkContext) = this(sc, new Configuration()) - // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate - // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?) - // Subsequent creations are ignored - since nodes are already allocated by then. - + // Nothing else for now ... initialize application master : which needs a SparkContext to + // determine how to allocate. + // Note that only the first creation of a SparkContext influences (and ideally, there must be + // only one SparkContext, right ?). Subsequent creations are ignored since executors are already + // allocated by then. // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { @@ -48,6 +49,7 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) override def postStartHook() { val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc) if (sparkContextInitialized){ + ApplicationMaster.waitForInitialAllocations() // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt Thread.sleep(3000L) } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 33a60d978c586..6244332f23737 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -19,13 +19,12 @@ package org.apache.spark.deploy.yarn import java.io.IOException import java.util.concurrent.CopyOnWriteArrayList -import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.net.NetUtils import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.protocolrecords._ @@ -33,8 +32,7 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} @@ -77,17 +75,18 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // than user specified and /tmp. System.setProperty("spark.local.dir", getLocalDirs()) - // set the web ui port to be ephemeral for yarn so we don't conflict with + // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") - // when running the AM, the Spark master is always "yarn-cluster" + // When running the AM, the Spark master is always "yarn-cluster" System.setProperty("spark.master", "yarn-cluster") - // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using. + // Use priority 30 as it's higher than HDFS. It's the same priority MapReduce is using. ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) - appAttemptId = getApplicationAttemptId() + appAttemptId = ApplicationMaster.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts amClient = AMRMClient.createAMRMClient() amClient.init(yarnConf) @@ -99,7 +98,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, ApplicationMaster.register(this) // Call this to force generation of secret so it gets populated into the - // hadoop UGI. This has to happen before the startUserClass which does a + // Hadoop UGI. This has to happen before the startUserClass which does a // doAs in order for the credentials to be passed on to the executor containers. val securityMgr = new SecurityManager(sparkConf) @@ -121,7 +120,10 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // Allocate all containers allocateExecutors() - // Wait for the user class to Finish + // Launch thread that will heartbeat to the RM so it won't think the app has died. + launchReporterThread() + + // Wait for the user class to finish userThread.join() System.exit(0) @@ -141,7 +143,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, "spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) } - /** Get the Yarn approved local directories. */ + // Get the Yarn approved local directories. private def getLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the // local dirs, so lets check both. We assume one of the 2 is set. @@ -150,18 +152,9 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, .orElse(Option(System.getenv("LOCAL_DIRS"))) localDirs match { - case None => throw new Exception("Yarn Local dirs can't be empty") + case None => throw new Exception("Yarn local dirs can't be empty") case Some(l) => l } - } - - private def getApplicationAttemptId(): ApplicationAttemptId = { - val envs = System.getenv() - val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) - val containerId = ConverterUtils.toContainerId(containerIdString) - val appAttemptId = containerId.getApplicationAttemptId() - logInfo("ApplicationAttemptId: " + appAttemptId) - appAttemptId } private def registerApplicationMaster(): RegisterApplicationMasterResponse = { @@ -173,25 +166,23 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, logInfo("Starting the user JAR in a separate Thread") val mainMethod = Class.forName( args.userClass, - false /* initialize */ , + false, Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) val t = new Thread { override def run() { - - var successed = false + var succeeded = false try { // Copy - var mainArgs: Array[String] = new Array[String](args.userArgs.size) + val mainArgs = new Array[String](args.userArgs.size) args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) mainMethod.invoke(null, mainArgs) - // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR - // userThread will stop here unless it has uncaught exception thrown out - // It need shutdown hook to set SUCCEEDED - successed = true + // Some apps have "System.exit(0)" at the end. The user thread will stop here unless + // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. + succeeded = true } finally { - logDebug("finishing main") + logDebug("Finishing main") isLastAMRetry = true - if (successed) { + if (succeeded) { ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) } else { ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED) @@ -199,11 +190,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } } } + t.setName("Driver") t.start() t } - // This need to happen before allocateExecutors() + // This needs to happen before allocateExecutors() private def waitForSparkContextInitialized() { logInfo("Waiting for Spark context initialization") try { @@ -231,7 +223,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, sparkContext.preferredNodeLocationData, sparkContext.getConf) } else { - logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d". + logWarning("Unable to retrieve SparkContext in spite of waiting for %d, maxNumTries = %d". format(numTries * waitTime, maxNumTries)) this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, @@ -242,48 +234,37 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } } } finally { - // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT : - // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks. - ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + // In case of exceptions, etc - ensure that the loop in + // ApplicationMaster#sparkContextInitialized() breaks. + ApplicationMaster.doneWithInitialAllocations() } } private def allocateExecutors() { try { - logInfo("Allocating " + args.numExecutors + " executors.") - // Wait until all containers have finished + logInfo("Requesting" + args.numExecutors + " executors.") + // Wait until all containers have launched yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() // Exits the loop if the user thread exits. + + var iters = 0 while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() - ApplicationMaster.incrementAllocatorLoop(1) - Thread.sleep(100) + if (iters == ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) { + ApplicationMaster.doneWithInitialAllocations() + } + Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) + iters += 1 } } finally { - // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT, - // so that the loop in ApplicationMaster#sparkContextInitialized() breaks. - ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + // In case of exceptions, etc - ensure that the loop in + // ApplicationMaster#sparkContextInitialized() breaks. + ApplicationMaster.doneWithInitialAllocations() } logInfo("All executors have launched.") - - // Launch a progress reporter thread, else the app will get killed after expiration - // (def: 10mins) timeout. - if (userThread.isAlive) { - // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. - val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - - // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) - - // must be <= timeoutInterval / 2. - val interval = math.min(timeoutInterval / 2, schedulerInterval) - - launchReporterThread(interval) - } } private def allocateMissingExecutor() { @@ -303,47 +284,35 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } } - private def launchReporterThread(_sleepTime: Long): Thread = { - val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime + private def launchReporterThread(): Thread = { + // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. + val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) + + // must be <= timeoutInterval / 2. + val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) val t = new Thread { override def run() { while (userThread.isAlive) { checkNumExecutorsFailed() allocateMissingExecutor() - sendProgress() - Thread.sleep(sleepTime) + logDebug("Sending progress") + yarnAllocator.allocateResources() + Thread.sleep(interval) } } } // Setting to daemon status, though this is usually not a good idea. t.setDaemon(true) t.start() - logInfo("Started progress reporter thread - sleep time : " + sleepTime) + logInfo("Started progress reporter thread - heartbeat interval : " + interval) t } - private def sendProgress() { - logDebug("Sending progress") - // Simulated with an allocate request with no nodes requested. - yarnAllocator.allocateResources() - } - - /* - def printContainers(containers: List[Container]) = { - for (container <- containers) { - logInfo("Launching shell command on a new container." - + ", containerId=" + container.getId() - + ", containerNode=" + container.getNodeId().getHost() - + ":" + container.getNodeId().getPort() - + ", containerNodeURI=" + container.getNodeHttpAddress() - + ", containerState" + container.getState() - + ", containerResourceMemory" - + container.getResource().getMemory()) - } - } - */ - def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") { synchronized { if (isFinished) { @@ -351,7 +320,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } isFinished = true - logInfo("finishApplicationMaster with " + status) + logInfo("Unregistering ApplicationMaster with " + status) if (registered) { val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") amClient.unregisterApplicationMaster(status, diagnostics, trackingUrl) @@ -386,7 +355,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, def run() { logInfo("AppMaster received a signal.") - // we need to clean up staging dir before HDFS is shut down + // We need to clean up staging dir before HDFS is shut down // make sure we don't delete it until this is the last AM if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir() } @@ -401,21 +370,24 @@ object ApplicationMaster { // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + private val ALLOCATE_HEARTBEAT_INTERVAL = 100 private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() val sparkContextRef: AtomicReference[SparkContext] = - new AtomicReference[SparkContext](null /* initialValue */) + new AtomicReference[SparkContext](null) - val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) + // Variable used to notify the YarnClusterScheduler that it should stop waiting + // for the initial set of executors to be started and get on with its business. + val doneWithInitialAllocationsMonitor = new Object() - def incrementAllocatorLoop(by: Int) { - val count = yarnAllocatorLoop.getAndAdd(by) - if (count >= ALLOCATOR_LOOP_WAIT_COUNT) { - yarnAllocatorLoop.synchronized { - // to wake threads off wait ... - yarnAllocatorLoop.notifyAll() - } + @volatile var isDoneWithInitialAllocations = false + + def doneWithInitialAllocations() { + isDoneWithInitialAllocations = true + doneWithInitialAllocationsMonitor.synchronized { + // to wake threads off wait ... + doneWithInitialAllocationsMonitor.notifyAll() } } @@ -423,7 +395,10 @@ object ApplicationMaster { applicationMasters.add(master) } - // TODO(harvey): See whether this should be discarded - it isn't used anywhere atm... + /** + * Called from YarnClusterScheduler to notify the AM code that a SparkContext has been + * initialized in the user code. + */ def sparkContextInitialized(sc: SparkContext): Boolean = { var modified = false sparkContextRef.synchronized { @@ -431,7 +406,7 @@ object ApplicationMaster { sparkContextRef.notifyAll() } - // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do + // Add a shutdown hook - as a best effort in case users do not call sc.stop or do // System.exit. // Should not really have to do this, but it helps YARN to evict resources earlier. // Not to mention, prevent the Client from declaring failure even though we exited properly. @@ -454,13 +429,29 @@ object ApplicationMaster { }) } - // Wait for initialization to complete and atleast 'some' nodes can get allocated. - yarnAllocatorLoop.synchronized { - while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) { - yarnAllocatorLoop.wait(1000L) + // Wait for initialization to complete and at least 'some' nodes to get allocated. + modified + } + + /** + * Returns when we've either + * 1) received all the requested executors, + * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms, + * 3) hit an error that causes us to terminate trying to get containers. + */ + def waitForInitialAllocations() { + doneWithInitialAllocationsMonitor.synchronized { + while (!isDoneWithInitialAllocations) { + doneWithInitialAllocationsMonitor.wait(1000L) } } - modified + } + + def getApplicationAttemptId(): ApplicationAttemptId = { + val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + appAttemptId } def main(argStrings: Array[String]) { diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 393edd1f2d670..24027618c1f35 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,14 +21,12 @@ import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer -import org.apache.hadoop.yarn.api._ -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.YarnClient import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{Apps, Records} +import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SparkConf} @@ -102,7 +100,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def logClusterResourceDetails() { val clusterMetrics: YarnClusterMetrics = yarnClient.getYarnClusterMetrics - logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " + + logInfo("Got Cluster metric info from ResourceManager, number of NodeManagers: " + clusterMetrics.getNumNodeManagers) val queueInfo: QueueInfo = yarnClient.getQueueInfo(args.amQueue) @@ -133,7 +131,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def submitApp(appContext: ApplicationSubmissionContext) = { // Submit the application to the applications manager. - logInfo("Submitting application to ASM") + logInfo("Submitting application to ResourceManager") yarnClient.submitApplication(appContext) } @@ -149,7 +147,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa Thread.sleep(interval) val report = yarnClient.getApplicationReport(appId) - logInfo("Application report from ASM: \n" + + logInfo("Application report from ResourceManager: \n" + "\t application identifier: " + appId.toString() + "\n" + "\t appId: " + appId.getId() + "\n" + "\t clientToAMToken: " + report.getClientToAMToken() + "\n" + diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index d93e5bb0225d5..4f8854a09a1e5 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -72,8 +72,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp override def preStart() { logInfo("Listen to driver: " + driverUrl) driver = context.actorSelection(driverUrl) - // Send a hello message thus the connection is actually established, - // thus we can monitor Lifecycle Events. + // Send a hello message to establish the connection, after which + // we can monitor Lifecycle Events. driver ! "Hello" context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } @@ -95,7 +95,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp amClient.init(yarnConf) amClient.start() - appAttemptId = getApplicationAttemptId() + appAttemptId = ApplicationMaster.getApplicationAttemptId() registerApplicationMaster() waitForSparkMaster() @@ -141,18 +141,9 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } } - private def getApplicationAttemptId(): ApplicationAttemptId = { - val envs = System.getenv() - val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) - val containerId = ConverterUtils.toContainerId(containerIdString) - val appAttemptId = containerId.getApplicationAttemptId() - logInfo("ApplicationAttemptId: " + appAttemptId) - appAttemptId - } - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { logInfo("Registering the ApplicationMaster") - // TODO:(Raymond) Find out Spark UI address and fill in here? + // TODO: Find out client's Spark UI address and fill in here? amClient.registerApplicationMaster(Utils.localHostName(), 0, "") } @@ -185,8 +176,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp private def allocateExecutors() { - - // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. + // TODO: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map() @@ -198,8 +188,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp preferredNodeLocationData, sparkConf) - logInfo("Allocating " + args.numExecutors + " executors.") - // Wait until all containers have finished + logInfo("Requesting " + args.numExecutors + " executors.") + // Wait until all containers have launched yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { @@ -221,7 +211,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } } - // TODO: We might want to extend this to allocate more containers in case they die ! private def launchReporterThread(_sleepTime: Long): Thread = { val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime @@ -229,7 +218,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp override def run() { while (!driverClosed) { allocateMissingExecutor() - sendProgress() + logDebug("Sending progress") + yarnAllocator.allocateResources() Thread.sleep(sleepTime) } } @@ -241,20 +231,14 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp t } - private def sendProgress() { - logDebug("Sending progress") - // simulated with an allocate request with no nodes requested ... - yarnAllocator.allocateResources() - } - def finishApplicationMaster(status: FinalApplicationStatus) { - logInfo("finish ApplicationMaster with " + status) - amClient.unregisterApplicationMaster(status, "" /* appMessage */ , "" /* appTrackingUrl */) + logInfo("Unregistering ApplicationMaster with " + status) + val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") + amClient.unregisterApplicationMaster(status, "" /* appMessage */ , trackingUrl) } } - object ExecutorLauncher { def main(argStrings: Array[String]) { val args = new ApplicationMasterArguments(argStrings) From 5b754b45f301a310e9232f3a5270201ebfc16076 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Wed, 11 Jun 2014 10:47:06 -0700 Subject: [PATCH 11/38] [SPARK-2069] MIMA false positives Fixes SPARK 2070 and 2071 Author: Prashant Sharma Closes #1021 from ScrapCodes/SPARK-2070/package-private-methods and squashes the following commits: 7979a57 [Prashant Sharma] addressed code review comments 558546d [Prashant Sharma] A little fancy error message. 59275ab [Prashant Sharma] SPARK-2071 Mima ignores classes and its members from previous versions too. 0c4ff2b [Prashant Sharma] SPARK-2070 Ignore methods along with annotated classes. --- .gitignore | 2 +- dev/mima | 5 +- project/MimaBuild.scala | 39 ++++-- project/SparkBuild.scala | 15 +++ .../spark/tools/GenerateMIMAIgnore.scala | 123 ++++++++++-------- 5 files changed, 114 insertions(+), 70 deletions(-) diff --git a/.gitignore b/.gitignore index b54a3058de659..4f177c82ae5e0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ sbt/*.jar .settings .cache -.generated-mima-excludes +.generated-mima* /build/ work/ out/ diff --git a/dev/mima b/dev/mima index ab6bd4469b0e8..b68800d6d0173 100755 --- a/dev/mima +++ b/dev/mima @@ -23,6 +23,9 @@ set -o pipefail FWDIR="$(cd `dirname $0`/..; pwd)" cd $FWDIR +echo -e "q\n" | sbt/sbt oldDeps/update + +export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) -printf "%p:" ` ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? @@ -31,5 +34,5 @@ if [ $ret_val != 0 ]; then echo "NOTE: Exceptions to binary compatibility can be added in project/MimaExcludes.scala" fi -rm -f .generated-mima-excludes +rm -f .generated-mima* exit $ret_val diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 1477809943573..bb2d73741c3bf 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -15,16 +15,26 @@ * limitations under the License. */ -import com.typesafe.tools.mima.core.{MissingTypesProblem, MissingClassProblem, ProblemFilters} +import com.typesafe.tools.mima.core._ +import com.typesafe.tools.mima.core.MissingClassProblem +import com.typesafe.tools.mima.core.MissingTypesProblem import com.typesafe.tools.mima.core.ProblemFilters._ import com.typesafe.tools.mima.plugin.MimaKeys.{binaryIssueFilters, previousArtifact} import com.typesafe.tools.mima.plugin.MimaPlugin.mimaDefaultSettings import sbt._ object MimaBuild { + + def excludeMember(fullName: String) = Seq( + ProblemFilters.exclude[MissingMethodProblem](fullName), + ProblemFilters.exclude[MissingFieldProblem](fullName), + ProblemFilters.exclude[IncompatibleResultTypeProblem](fullName), + ProblemFilters.exclude[IncompatibleMethTypeProblem](fullName), + ProblemFilters.exclude[IncompatibleFieldTypeProblem](fullName) + ) + // Exclude a single class and its corresponding object - def excludeClass(className: String) = { - Seq( + def excludeClass(className: String) = Seq( excludePackage(className), ProblemFilters.exclude[MissingClassProblem](className), ProblemFilters.exclude[MissingTypesProblem](className), @@ -32,7 +42,7 @@ object MimaBuild { ProblemFilters.exclude[MissingClassProblem](className + "$"), ProblemFilters.exclude[MissingTypesProblem](className + "$") ) - } + // Exclude a Spark class, that is in the package org.apache.spark def excludeSparkClass(className: String) = { excludeClass("org.apache.spark." + className) @@ -49,20 +59,25 @@ object MimaBuild { val defaultExcludes = Seq() // Read package-private excludes from file - val excludeFilePath = (base.getAbsolutePath + "/.generated-mima-excludes") - val excludeFile = file(excludeFilePath) + val classExcludeFilePath = file(base.getAbsolutePath + "/.generated-mima-class-excludes") + val memberExcludeFilePath = file(base.getAbsolutePath + "/.generated-mima-member-excludes") + val ignoredClasses: Seq[String] = - if (!excludeFile.exists()) { + if (!classExcludeFilePath.exists()) { Seq() } else { - IO.read(excludeFile).split("\n") + IO.read(classExcludeFilePath).split("\n") } + val ignoredMembers: Seq[String] = + if (!memberExcludeFilePath.exists()) { + Seq() + } else { + IO.read(memberExcludeFilePath).split("\n") + } - - val externalExcludeFileClasses = ignoredClasses.flatMap(excludeClass) - - defaultExcludes ++ externalExcludeFileClasses ++ MimaExcludes.excludes + defaultExcludes ++ ignoredClasses.flatMap(excludeClass) ++ + ignoredMembers.flatMap(excludeMember) ++ MimaExcludes.excludes } def mimaSettings(sparkHome: File) = mimaDefaultSettings ++ Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 069913dbaac56..ecd9d7068068d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -59,6 +59,10 @@ object SparkBuild extends Build { lazy val core = Project("core", file("core"), settings = coreSettings) + /** Following project only exists to pull previous artifacts of Spark for generating + Mima ignores. For more information see: SPARK 2071 */ + lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings) + def replDependencies = Seq[ProjectReference](core, graphx, bagel, mllib, sql) ++ maybeHiveRef lazy val repl = Project("repl", file("repl"), settings = replSettings) @@ -598,6 +602,17 @@ object SparkBuild extends Build { } ) + def oldDepsSettings() = Defaults.defaultSettings ++ Seq( + name := "old-deps", + scalaVersion := "2.10.4", + retrieveManaged := true, + retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", + libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", + "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", + "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", + "spark-core").map(sparkPreviousArtifact(_).get intransitive()) + ) + def twitterSettings() = sharedSettings ++ Seq( name := "spark-streaming-twitter", previousArtifact := sparkPreviousArtifact("spark-streaming-twitter"), diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 6a261e19a35cd..03a73f92b275e 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -40,74 +40,78 @@ object GenerateMIMAIgnore { private val classLoader = Thread.currentThread().getContextClassLoader private val mirror = runtimeMirror(classLoader) - private def classesPrivateWithin(packageName: String): Set[String] = { + + private def isDeveloperApi(sym: unv.Symbol) = + sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.DeveloperApi]) + + private def isExperimental(sym: unv.Symbol) = + sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.Experimental]) + + + private def isPackagePrivate(sym: unv.Symbol) = + !sym.privateWithin.fullName.startsWith("") + + private def isPackagePrivateModule(moduleSymbol: unv.ModuleSymbol) = + !moduleSymbol.privateWithin.fullName.startsWith("") + + /** + * For every class checks via scala reflection if the class itself or contained members + * have DeveloperApi or Experimental annotations or they are package private. + * Returns the tuple of such classes and members. + */ + private def privateWithin(packageName: String): (Set[String], Set[String]) = { val classes = getClasses(packageName) val ignoredClasses = mutable.HashSet[String]() + val ignoredMembers = mutable.HashSet[String]() - def isPackagePrivate(className: String) = { + for (className <- classes) { try { - /* Couldn't figure out if it's possible to determine a-priori whether a given symbol - is a module or class. */ - - val privateAsClass = mirror - .classSymbol(Class.forName(className, false, classLoader)) - .privateWithin - .fullName - .startsWith(packageName) - - val privateAsModule = mirror - .staticModule(className) - .privateWithin - .fullName - .startsWith(packageName) - - privateAsClass || privateAsModule - } catch { - case _: Throwable => { - println("Error determining visibility: " + className) - false + val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader)) + val moduleSymbol = mirror.staticModule(className) // TODO: see if it is necessary. + val directlyPrivateSpark = + isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol) + val developerApi = isDeveloperApi(classSymbol) + val experimental = isExperimental(classSymbol) + + /* Inner classes defined within a private[spark] class or object are effectively + invisible, so we account for them as package private. */ + lazy val indirectlyPrivateSpark = { + val maybeOuter = className.toString.takeWhile(_ != '$') + if (maybeOuter != className) { + isPackagePrivate(mirror.classSymbol(Class.forName(maybeOuter, false, classLoader))) || + isPackagePrivateModule(mirror.staticModule(maybeOuter)) + } else { + false + } + } + if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi || experimental) { + ignoredClasses += className + } else { + // check if this class has package-private/annotated members. + ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } - } - } - def isDeveloperApi(className: String) = { - try { - val clazz = mirror.classSymbol(Class.forName(className, false, classLoader)) - clazz.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.DeveloperApi]) } catch { - case _: Throwable => { - println("Error determining Annotations: " + className) - false - } + case _: Throwable => println("Error instrumenting class:" + className) } } + (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) + } - for (className <- classes) { - val directlyPrivateSpark = isPackagePrivate(className) - val developerApi = isDeveloperApi(className) - - /* Inner classes defined within a private[spark] class or object are effectively - invisible, so we account for them as package private. */ - val indirectlyPrivateSpark = { - val maybeOuter = className.toString.takeWhile(_ != '$') - if (maybeOuter != className) { - isPackagePrivate(maybeOuter) - } else { - false - } - } - if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi) { - ignoredClasses += className - } - } - ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet + private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { + classSymbol.typeSignature.members + .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) } def main(args: Array[String]) { - scala.tools.nsc.io.File(".generated-mima-excludes"). - writeAll(classesPrivateWithin("org.apache.spark").mkString("\n")) - println("Created : .generated-mima-excludes in current directory.") + val (privateClasses, privateMembers) = privateWithin("org.apache.spark") + scala.tools.nsc.io.File(".generated-mima-class-excludes"). + writeAll(privateClasses.mkString("\n")) + println("Created : .generated-mima-class-excludes in current directory.") + scala.tools.nsc.io.File(".generated-mima-member-excludes"). + writeAll(privateMembers.mkString("\n")) + println("Created : .generated-mima-member-excludes in current directory.") } @@ -140,10 +144,17 @@ object GenerateMIMAIgnore { * Get all classes in a package from a jar file. */ private def getClassesFromJar(jarPath: String, packageName: String) = { + import scala.collection.mutable val jar = new JarFile(new File(jarPath)) val enums = jar.entries().map(_.getName).filter(_.startsWith(packageName)) - val classes = for (entry <- enums if entry.endsWith(".class")) - yield Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) + val classes = mutable.HashSet[Class[_]]() + for (entry <- enums if entry.endsWith(".class")) { + try { + classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) + } catch { + case _: Throwable => println("Unable to load:" + entry) + } + } classes } } From e508f599f88baaa31a3498fb0bdbafdbc303119e Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Wed, 11 Jun 2014 10:49:34 -0700 Subject: [PATCH 12/38] [SPARK-2108] Mark SparkContext methods that return block information as developer API's Author: Prashant Sharma Closes #1047 from ScrapCodes/SPARK-2108/mark-as-dev-api and squashes the following commits: 073ee34 [Prashant Sharma] [SPARK-2108] Mark SparkContext methods that return block information as developer API's --- core/src/main/scala/org/apache/spark/SparkContext.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d721aba709600..8bdaf0bf76e85 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -823,9 +823,11 @@ class SparkContext(config: SparkConf) extends Logging { } /** + * :: DeveloperApi :: * Return information about what RDDs are cached, if they are in mem or on disk, how much space * they take, etc. */ + @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) } @@ -837,8 +839,10 @@ class SparkContext(config: SparkConf) extends Logging { def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap /** + * :: DeveloperApi :: * Return information about blocks stored in all of the slaves */ + @DeveloperApi def getExecutorStorageStatus: Array[StorageStatus] = { env.blockManager.master.getStorageStatus } From 4d5c12aa1c54c49377a4bafe3bcc4993d5e1a552 Mon Sep 17 00:00:00 2001 From: Lars Albertsson Date: Wed, 11 Jun 2014 10:54:42 -0700 Subject: [PATCH 13/38] SPARK-2113: awaitTermination() after stop() will hang in Spark Stremaing Author: Lars Albertsson Closes #1001 from lallea/contextwaiter_stopped and squashes the following commits: 93cd314 [Lars Albertsson] Mend StreamingContext stop() followed by awaitTermination(). --- .../org/apache/spark/streaming/ContextWaiter.scala | 1 + .../spark/streaming/StreamingContextSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala index 86753360a07e4..a0aeacbc733bd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala @@ -27,6 +27,7 @@ private[streaming] class ContextWaiter { } def notifyStop() = synchronized { + stopped = true notifyAll() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index cd86019f63e7e..7b33d3b235466 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -223,6 +223,18 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("awaitTermination after stop") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => x).register() + + failAfter(10000 millis) { + ssc.start() + ssc.stop() + ssc.awaitTermination() + } + } + test("awaitTermination with error in task") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) From 4107cce58c41160a0dc20339621eacdf8a8b1191 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 11 Jun 2014 12:01:04 -0700 Subject: [PATCH 14/38] [SPARK-2042] Prevent unnecessary shuffle triggered by take() This PR implements `take()` on a `SchemaRDD` by inserting a logical limit that is followed by a `collect()`. This is also accompanied by adding a catalyst optimizer rule for collapsing adjacent limits. Doing so prevents an unnecessary shuffle that is sometimes triggered by `take()`. Author: Sameer Agarwal Closes #1048 from sameeragarwal/master and squashes the following commits: 3eeb848 [Sameer Agarwal] Fixing Tests 1b76ff1 [Sameer Agarwal] Deprecating limit(limitExpr: Expression) in v1.1.0 b723ac4 [Sameer Agarwal] Added limit folding tests a0ff7c4 [Sameer Agarwal] Adding catalyst rule to fold two consecutive limits 8d42d03 [Sameer Agarwal] Implement trigger() as limit() followed by collect() --- .../spark/sql/catalyst/dsl/package.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 13 ++++ .../plans/logical/basicOperators.scala | 4 +- .../optimizer/CombiningLimitsSuite.scala | 71 +++++++++++++++++++ .../org/apache/spark/sql/SchemaRDD.scala | 12 +++- 5 files changed, 97 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3cf163f9a9a75..d177339d40ae5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -175,6 +175,8 @@ package object dsl { def where(condition: Expression) = Filter(condition, logicalPlan) + def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan) + def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e41fd2db74858..28d1aa2e3aafc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.types._ object Optimizer extends RuleExecutor[LogicalPlan] { val batches = + Batch("Combine Limits", FixedPoint(100), + CombineLimits) :: Batch("ConstantFolding", FixedPoint(100), NullPropagation, ConstantFolding, @@ -362,3 +364,14 @@ object SimplifyCasts extends Rule[LogicalPlan] { case Cast(e, dataType) if e.dataType == dataType => e } } + +/** + * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the + * expressions into one single expression. + */ +object CombineLimits extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ll @ Limit(le, nl @ Limit(ne, grandChild)) => + Limit(If(LessThan(ne, le), ne, le), grandChild) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d3347b622f3d8..b777cf4249196 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -135,9 +135,9 @@ case class Aggregate( def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } -case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode { +case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { def output = child.output - def references = limit.references + def references = limitExpr.references } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala new file mode 100644 index 0000000000000..714f01843c0f5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class CombiningLimitsSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Combine Limit", FixedPoint(2), + CombineLimits) :: + Batch("Constant Folding", FixedPoint(3), + NullPropagation, + ConstantFolding, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("limits: combines two limits") { + val originalQuery = + testRelation + .select('a) + .limit(10) + .limit(5) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(5).analyze + + comparePlans(optimized, correctAnswer) + } + + test("limits: combines three limits") { + val originalQuery = + testRelation + .select('a) + .limit(2) + .limit(7) + .limit(5) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 8855c4e876917..7ad8edf5a5a6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -178,14 +178,18 @@ class SchemaRDD( def orderBy(sortExprs: SortOrder*): SchemaRDD = new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan)) + @deprecated("use limit with integer argument", "1.1.0") + def limit(limitExpr: Expression): SchemaRDD = + new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) + /** - * Limits the results by the given expressions. + * Limits the results by the given integer. * {{{ * schemaRDD.limit(10) * }}} */ - def limit(limitExpr: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) + def limit(limitNum: Int): SchemaRDD = + new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan)) /** * Performs a grouping followed by an aggregation. @@ -374,6 +378,8 @@ class SchemaRDD( override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() + override def take(num: Int): Array[Row] = limit(num).collect() + // ======================================================================= // Base RDD functions that do NOT change schema // ======================================================================= From ce6deb1e5b4cd40c97730fcf5dc89cb2f624bce2 Mon Sep 17 00:00:00 2001 From: Daoyuan Date: Wed, 11 Jun 2014 12:08:28 -0700 Subject: [PATCH 15/38] [SQL] Code Cleanup: Left Semi Hash Join Some improvement for PR #837, add another case to white list and use `filter` to build result iterator. Author: Daoyuan Closes #1049 from adrian-wang/clean-LeftSemiJoinHash and squashes the following commits: b314d5a [Daoyuan] change hashSet name 27579a9 [Daoyuan] add semijoin to white list and use filter to create new iterator in LeftSemiJoinBNL Signed-off-by: Michael Armbrust --- .../apache/spark/sql/execution/joins.scala | 40 ++++-------------- ...emijoin-0-1631b71327abf75b96116036b977b26c | 0 ...emijoin-1-e7c99e1f46d502edbb0925d75aab5f0c | 11 +++++ ...mijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6 | 0 ...emijoin-11-246a40dcafe077f02397e30d227c330 | 8 ++++ ...mijoin-12-6d93a9d332ba490835b17f261a5467df | 0 ...mijoin-13-18282d38b6efc0017089ab89b661764f | 0 ...mijoin-14-19cfcefb10e1972bec0ffd421cd79de7 | 0 ...emijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 | 11 +++++ ...emijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d | 0 ...mijoin-17-f0f8720cfd11fd71af17b310dc2d1019 | 3 ++ ...mijoin-18-f7b2ce472443982e32d954cbb5c96765 | 0 ...mijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f | 3 ++ ...emijoin-2-deb9c3286ae8e851b1fdb270085b16bc | 0 ...mijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1 | 0 ...mijoin-21-480418a0646cf7260b494b9eb4821bb6 | 0 ...mijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41 | 0 ...mijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda | 5 +++ ...mijoin-24-d16b37134de78980b2bf96029e8265c3 | 0 ...mijoin-25-be2bd011cc80b480b271a08dbf381e9b | 19 +++++++++ ...mijoin-26-f1d3bab29f1ebafa148dbe3816e1da25 | 0 ...mijoin-27-391c256254d171973f02e7f33672ce1d | 4 ++ ...mijoin-28-b56400f6d9372f353cf7292a2182e963 | 0 ...mijoin-29-9efeef3d3c38e22a74d379978178c4f5 | 14 +++++++ ...emijoin-3-b4d4317dd3a10e18502f20f5c5250389 | 11 +++++ ...mijoin-30-dd901d00fce5898b03a57cbc3028a70a | 0 ...emijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 | 14 +++++++ ...mijoin-32-23017c7663f2710265a7e2a4a1606d39 | 0 ...mijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 | 14 +++++++ ...mijoin-34-5e1b832090ab73c141c1167d5b25a490 | 0 ...mijoin-35-8d5731f26232f6e26dd8012461b08d99 | 26 ++++++++++++ ...mijoin-36-b1159823dca8025926407f8aa921238d | 0 ...mijoin-37-a15b9074f999ce836be5329591b968d0 | 29 +++++++++++++ ...emijoin-38-f37547c73a48ce3ba089531b176e6ba | 0 ...mijoin-39-c22a6f11368affcb80a9c80e653a47a8 | 29 +++++++++++++ ...emijoin-4-dfdad5a2742f93e8ea888191460809c0 | 0 ...mijoin-40-32071a51e2ba6e86b1c5e40de55aae63 | 0 ...mijoin-41-cf74f73a33b1af8902b7970a9350b092 | 29 +++++++++++++ ...mijoin-42-6b4257a74fca627785c967c99547f4c0 | 0 ...mijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 | 31 ++++++++++++++ ...mijoin-44-945aaa3a24359ef73acab1e99500d5ea | 0 ...mijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b | 42 +++++++++++++++++++ ...mijoin-46-620e01f81f6e5254b4bbe8fab4043ec0 | 0 ...mijoin-47-f0140e4ee92508ba241f91c157b7af9c | 35 ++++++++++++++++ ...mijoin-48-8a04442e84f99a584c2882d0af8c25d8 | 0 ...mijoin-49-df1d6705d3624be72036318a6b42f04c | 0 ...emijoin-5-d3c2f84a12374b307c58a69aba4ec70d | 22 ++++++++++ ...emijoin-6-90bb51b1330230d10a14fb7517457aa0 | 0 ...emijoin-7-333d72e8bce6d11a35fc7a30418f225b | 0 ...emijoin-8-d46607be851a6f4e27e98cbbefdee994 | 0 ...emijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 | 6 +++ .../execution/HiveCompatibilitySuite.scala | 1 + 52 files changed, 374 insertions(+), 33 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/semijoin-0-1631b71327abf75b96116036b977b26c create mode 100644 sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c create mode 100644 sql/hive/src/test/resources/golden/semijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6 create mode 100644 sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330 create mode 100644 sql/hive/src/test/resources/golden/semijoin-12-6d93a9d332ba490835b17f261a5467df create mode 100644 sql/hive/src/test/resources/golden/semijoin-13-18282d38b6efc0017089ab89b661764f create mode 100644 sql/hive/src/test/resources/golden/semijoin-14-19cfcefb10e1972bec0ffd421cd79de7 create mode 100644 sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 create mode 100644 sql/hive/src/test/resources/golden/semijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d create mode 100644 sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019 create mode 100644 sql/hive/src/test/resources/golden/semijoin-18-f7b2ce472443982e32d954cbb5c96765 create mode 100644 sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f create mode 100644 sql/hive/src/test/resources/golden/semijoin-2-deb9c3286ae8e851b1fdb270085b16bc create mode 100644 sql/hive/src/test/resources/golden/semijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1 create mode 100644 sql/hive/src/test/resources/golden/semijoin-21-480418a0646cf7260b494b9eb4821bb6 create mode 100644 sql/hive/src/test/resources/golden/semijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41 create mode 100644 sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda create mode 100644 sql/hive/src/test/resources/golden/semijoin-24-d16b37134de78980b2bf96029e8265c3 create mode 100644 sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b create mode 100644 sql/hive/src/test/resources/golden/semijoin-26-f1d3bab29f1ebafa148dbe3816e1da25 create mode 100644 sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d create mode 100644 sql/hive/src/test/resources/golden/semijoin-28-b56400f6d9372f353cf7292a2182e963 create mode 100644 sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5 create mode 100644 sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389 create mode 100644 sql/hive/src/test/resources/golden/semijoin-30-dd901d00fce5898b03a57cbc3028a70a create mode 100644 sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 create mode 100644 sql/hive/src/test/resources/golden/semijoin-32-23017c7663f2710265a7e2a4a1606d39 create mode 100644 sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 create mode 100644 sql/hive/src/test/resources/golden/semijoin-34-5e1b832090ab73c141c1167d5b25a490 create mode 100644 sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99 create mode 100644 sql/hive/src/test/resources/golden/semijoin-36-b1159823dca8025926407f8aa921238d create mode 100644 sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0 create mode 100644 sql/hive/src/test/resources/golden/semijoin-38-f37547c73a48ce3ba089531b176e6ba create mode 100644 sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8 create mode 100644 sql/hive/src/test/resources/golden/semijoin-4-dfdad5a2742f93e8ea888191460809c0 create mode 100644 sql/hive/src/test/resources/golden/semijoin-40-32071a51e2ba6e86b1c5e40de55aae63 create mode 100644 sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092 create mode 100644 sql/hive/src/test/resources/golden/semijoin-42-6b4257a74fca627785c967c99547f4c0 create mode 100644 sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 create mode 100644 sql/hive/src/test/resources/golden/semijoin-44-945aaa3a24359ef73acab1e99500d5ea create mode 100644 sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b create mode 100644 sql/hive/src/test/resources/golden/semijoin-46-620e01f81f6e5254b4bbe8fab4043ec0 create mode 100644 sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c create mode 100644 sql/hive/src/test/resources/golden/semijoin-48-8a04442e84f99a584c2882d0af8c25d8 create mode 100644 sql/hive/src/test/resources/golden/semijoin-49-df1d6705d3624be72036318a6b42f04c create mode 100644 sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d create mode 100644 sql/hive/src/test/resources/golden/semijoin-6-90bb51b1330230d10a14fb7517457aa0 create mode 100644 sql/hive/src/test/resources/golden/semijoin-7-333d72e8bce6d11a35fc7a30418f225b create mode 100644 sql/hive/src/test/resources/golden/semijoin-8-d46607be851a6f4e27e98cbbefdee994 create mode 100644 sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 88ff3d49a79b3..8d7a5ba59f96a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -169,7 +169,7 @@ case class LeftSemiJoinHash( def execute() = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashTable = new java.util.HashSet[Row]() + val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null // Create a Hash set of buildKeys @@ -177,43 +177,17 @@ case class LeftSemiJoinHash( currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) if(!rowKey.anyNull) { - val keyExists = hashTable.contains(rowKey) + val keyExists = hashSet.contains(rowKey) if (!keyExists) { - hashTable.add(rowKey) + hashSet.add(rowKey) } } } - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatched: Boolean = false - - private[this] val joinKeys = streamSideKeyGenerator() - - override final def hasNext: Boolean = - streamIter.hasNext && fetchNext() - - override final def next() = { - currentStreamedRow - } - - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatched = false - while (!currentHashMatched && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatched = hashTable.contains(joinKeys.currentValue) - } - } - currentHashMatched - } - } + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) + }) } } } diff --git a/sql/hive/src/test/resources/golden/semijoin-0-1631b71327abf75b96116036b977b26c b/sql/hive/src/test/resources/golden/semijoin-0-1631b71327abf75b96116036b977b26c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c b/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c new file mode 100644 index 0000000000000..2ed47ab83dd02 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c @@ -0,0 +1,11 @@ +0 val_0 +0 val_0 +0 val_0 +2 val_2 +4 val_4 +5 val_5 +5 val_5 +5 val_5 +8 val_8 +9 val_9 +10 val_10 diff --git a/sql/hive/src/test/resources/golden/semijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6 b/sql/hive/src/test/resources/golden/semijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330 b/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330 new file mode 100644 index 0000000000000..a24bd8c6379e3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330 @@ -0,0 +1,8 @@ +0 val_0 +0 val_0 +0 val_0 +4 val_2 +8 val_4 +10 val_5 +10 val_5 +10 val_5 diff --git a/sql/hive/src/test/resources/golden/semijoin-12-6d93a9d332ba490835b17f261a5467df b/sql/hive/src/test/resources/golden/semijoin-12-6d93a9d332ba490835b17f261a5467df new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-13-18282d38b6efc0017089ab89b661764f b/sql/hive/src/test/resources/golden/semijoin-13-18282d38b6efc0017089ab89b661764f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-14-19cfcefb10e1972bec0ffd421cd79de7 b/sql/hive/src/test/resources/golden/semijoin-14-19cfcefb10e1972bec0ffd421cd79de7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 b/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 new file mode 100644 index 0000000000000..03c61a908b071 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 @@ -0,0 +1,11 @@ +val_0 +val_0 +val_0 +val_10 +val_2 +val_4 +val_5 +val_5 +val_5 +val_8 +val_9 diff --git a/sql/hive/src/test/resources/golden/semijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d b/sql/hive/src/test/resources/golden/semijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019 b/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019 new file mode 100644 index 0000000000000..2dcdfd1217ced --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019 @@ -0,0 +1,3 @@ +0 val_0 +0 val_0 +0 val_0 diff --git a/sql/hive/src/test/resources/golden/semijoin-18-f7b2ce472443982e32d954cbb5c96765 b/sql/hive/src/test/resources/golden/semijoin-18-f7b2ce472443982e32d954cbb5c96765 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f b/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f new file mode 100644 index 0000000000000..a3670515e8cc2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f @@ -0,0 +1,3 @@ +val_10 +val_8 +val_9 diff --git a/sql/hive/src/test/resources/golden/semijoin-2-deb9c3286ae8e851b1fdb270085b16bc b/sql/hive/src/test/resources/golden/semijoin-2-deb9c3286ae8e851b1fdb270085b16bc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1 b/sql/hive/src/test/resources/golden/semijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-21-480418a0646cf7260b494b9eb4821bb6 b/sql/hive/src/test/resources/golden/semijoin-21-480418a0646cf7260b494b9eb4821bb6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41 b/sql/hive/src/test/resources/golden/semijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda b/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda new file mode 100644 index 0000000000000..72bc6a6a88f6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda @@ -0,0 +1,5 @@ +4 val_2 +8 val_4 +10 val_5 +10 val_5 +10 val_5 diff --git a/sql/hive/src/test/resources/golden/semijoin-24-d16b37134de78980b2bf96029e8265c3 b/sql/hive/src/test/resources/golden/semijoin-24-d16b37134de78980b2bf96029e8265c3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b b/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b new file mode 100644 index 0000000000000..d89ea1757c712 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b @@ -0,0 +1,19 @@ +0 +0 +0 +0 +0 +0 +2 +4 +4 +5 +5 +5 +8 +8 +9 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-26-f1d3bab29f1ebafa148dbe3816e1da25 b/sql/hive/src/test/resources/golden/semijoin-26-f1d3bab29f1ebafa148dbe3816e1da25 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d b/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d new file mode 100644 index 0000000000000..dbbdae75a52a4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d @@ -0,0 +1,4 @@ +0 val_0 +0 val_0 +0 val_0 +8 val_8 diff --git a/sql/hive/src/test/resources/golden/semijoin-28-b56400f6d9372f353cf7292a2182e963 b/sql/hive/src/test/resources/golden/semijoin-28-b56400f6d9372f353cf7292a2182e963 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5 b/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5 new file mode 100644 index 0000000000000..07c61afb5124b --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5 @@ -0,0 +1,14 @@ +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +0 val_0 0 val_0 +4 val_4 4 val_2 +8 val_8 8 val_4 +10 val_10 10 val_5 +10 val_10 10 val_5 +10 val_10 10 val_5 diff --git a/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389 b/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389 new file mode 100644 index 0000000000000..bf51e8f5d9eb5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389 @@ -0,0 +1,11 @@ +0 val_0 +0 val_0 +0 val_0 +4 val_2 +8 val_4 +10 val_5 +10 val_5 +10 val_5 +16 val_8 +18 val_9 +20 val_10 diff --git a/sql/hive/src/test/resources/golden/semijoin-30-dd901d00fce5898b03a57cbc3028a70a b/sql/hive/src/test/resources/golden/semijoin-30-dd901d00fce5898b03a57cbc3028a70a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 b/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 new file mode 100644 index 0000000000000..d6283e34d8ffc --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 @@ -0,0 +1,14 @@ +0 val_0 +0 val_0 +0 val_0 +0 val_0 +0 val_0 +0 val_0 +2 val_2 +4 val_4 +5 val_5 +5 val_5 +5 val_5 +8 val_8 +9 val_9 +10 val_10 diff --git a/sql/hive/src/test/resources/golden/semijoin-32-23017c7663f2710265a7e2a4a1606d39 b/sql/hive/src/test/resources/golden/semijoin-32-23017c7663f2710265a7e2a4a1606d39 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 b/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 new file mode 100644 index 0000000000000..080180f9d0f0e --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 @@ -0,0 +1,14 @@ +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-34-5e1b832090ab73c141c1167d5b25a490 b/sql/hive/src/test/resources/golden/semijoin-34-5e1b832090ab73c141c1167d5b25a490 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99 b/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99 new file mode 100644 index 0000000000000..4a64d5c625790 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99 @@ -0,0 +1,26 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-36-b1159823dca8025926407f8aa921238d b/sql/hive/src/test/resources/golden/semijoin-36-b1159823dca8025926407f8aa921238d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0 b/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0 new file mode 100644 index 0000000000000..1420c786fb228 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0 @@ -0,0 +1,29 @@ +NULL +NULL +NULL +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-38-f37547c73a48ce3ba089531b176e6ba b/sql/hive/src/test/resources/golden/semijoin-38-f37547c73a48ce3ba089531b176e6ba new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8 b/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8 new file mode 100644 index 0000000000000..1420c786fb228 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8 @@ -0,0 +1,29 @@ +NULL +NULL +NULL +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-4-dfdad5a2742f93e8ea888191460809c0 b/sql/hive/src/test/resources/golden/semijoin-4-dfdad5a2742f93e8ea888191460809c0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-40-32071a51e2ba6e86b1c5e40de55aae63 b/sql/hive/src/test/resources/golden/semijoin-40-32071a51e2ba6e86b1c5e40de55aae63 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092 b/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092 new file mode 100644 index 0000000000000..aef9483bb0bc9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092 @@ -0,0 +1,29 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 +16 +18 +20 diff --git a/sql/hive/src/test/resources/golden/semijoin-42-6b4257a74fca627785c967c99547f4c0 b/sql/hive/src/test/resources/golden/semijoin-42-6b4257a74fca627785c967c99547f4c0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 b/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 new file mode 100644 index 0000000000000..0bc413ef2e09e --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 @@ -0,0 +1,31 @@ +NULL +NULL +NULL +NULL +NULL +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-44-945aaa3a24359ef73acab1e99500d5ea b/sql/hive/src/test/resources/golden/semijoin-44-945aaa3a24359ef73acab1e99500d5ea new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b b/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b new file mode 100644 index 0000000000000..3131e64446f66 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b @@ -0,0 +1,42 @@ +NULL +NULL +NULL +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +2 +4 +4 +5 +5 +5 +8 +8 +9 +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 diff --git a/sql/hive/src/test/resources/golden/semijoin-46-620e01f81f6e5254b4bbe8fab4043ec0 b/sql/hive/src/test/resources/golden/semijoin-46-620e01f81f6e5254b4bbe8fab4043ec0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c b/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c new file mode 100644 index 0000000000000..ff30bedb81861 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c @@ -0,0 +1,35 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +4 +4 +8 +8 +10 +10 +10 +10 +10 +10 +10 +10 +10 +10 +16 +18 +20 diff --git a/sql/hive/src/test/resources/golden/semijoin-48-8a04442e84f99a584c2882d0af8c25d8 b/sql/hive/src/test/resources/golden/semijoin-48-8a04442e84f99a584c2882d0af8c25d8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-49-df1d6705d3624be72036318a6b42f04c b/sql/hive/src/test/resources/golden/semijoin-49-df1d6705d3624be72036318a6b42f04c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d b/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d new file mode 100644 index 0000000000000..60f6eacee9b14 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d @@ -0,0 +1,22 @@ +0 val_0 +0 val_0 +0 val_0 +0 val_0 +0 val_0 +0 val_0 +2 val_2 +4 val_2 +4 val_4 +5 val_5 +5 val_5 +5 val_5 +8 val_4 +8 val_8 +9 val_9 +10 val_10 +10 val_5 +10 val_5 +10 val_5 +16 val_8 +18 val_9 +20 val_10 diff --git a/sql/hive/src/test/resources/golden/semijoin-6-90bb51b1330230d10a14fb7517457aa0 b/sql/hive/src/test/resources/golden/semijoin-6-90bb51b1330230d10a14fb7517457aa0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-7-333d72e8bce6d11a35fc7a30418f225b b/sql/hive/src/test/resources/golden/semijoin-7-333d72e8bce6d11a35fc7a30418f225b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-8-d46607be851a6f4e27e98cbbefdee994 b/sql/hive/src/test/resources/golden/semijoin-8-d46607be851a6f4e27e98cbbefdee994 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 b/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 new file mode 100644 index 0000000000000..5baaac9bebf6d --- /dev/null +++ b/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 @@ -0,0 +1,6 @@ +0 val_0 +0 val_0 +0 val_0 +4 val_4 +8 val_8 +10 val_10 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index fb8f272d5abfe..3581617c269a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -597,6 +597,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "select_unquote_and", "select_unquote_not", "select_unquote_or", + "semijoin", "serde_regex", "serde_reported_schema", "set_variable_sub", From fe78b8b6f7e3fe519659134c6fcaf7344077ead8 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 11 Jun 2014 12:11:46 -0700 Subject: [PATCH 16/38] HOTFIX: A few PySpark tests were not actually run This is a hot fix for the hot fix in fb499be1ac935b6f91046ec8ff23ac1267c82342. The changes in that commit did not actually cause the `doctest` module in python to be loaded for the following tests: - pyspark/broadcast.py - pyspark/accumulators.py - pyspark/serializers.py (@pwendell I might have told you the wrong thing) Author: Andrew Or Closes #1053 from andrewor14/python-test-fix and squashes the following commits: d2e5401 [Andrew Or] Explain why these tests are handled differently 0bd6fdd [Andrew Or] Fix 3 pyspark tests not being invoked --- bin/pyspark | 20 +++++++++++++------- python/run-tests | 5 ++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 114cbbc3a8a8e..0b5ed40e2157d 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -45,7 +45,7 @@ fi . $FWDIR/bin/load-spark-env.sh # Figure out which Python executable to use -if [ -z "$PYSPARK_PYTHON" ] ; then +if [[ -z "$PYSPARK_PYTHON" ]]; then PYSPARK_PYTHON="python" fi export PYSPARK_PYTHON @@ -59,7 +59,7 @@ export OLD_PYTHONSTARTUP=$PYTHONSTARTUP export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py # If IPython options are specified, assume user wants to run IPython -if [ -n "$IPYTHON_OPTS" ]; then +if [[ -n "$IPYTHON_OPTS" ]]; then IPYTHON=1 fi @@ -76,6 +76,16 @@ for i in "$@"; do done export PYSPARK_SUBMIT_ARGS +# For pyspark tests +if [[ -n "$SPARK_TESTING" ]]; then + if [[ -n "$PYSPARK_DOC_TEST" ]]; then + exec "$PYSPARK_PYTHON" -m doctest $1 + else + exec "$PYSPARK_PYTHON" $1 + fi + exit +fi + # If a python file is provided, directly run spark-submit. if [[ "$1" =~ \.py$ ]]; then echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 @@ -86,10 +96,6 @@ else if [[ "$IPYTHON" = "1" ]]; then exec ipython $IPYTHON_OPTS else - if [[ -n $SPARK_TESTING ]]; then - exec "$PYSPARK_PYTHON" -m doctest - else - exec "$PYSPARK_PYTHON" - fi + exec "$PYSPARK_PYTHON" fi fi diff --git a/python/run-tests b/python/run-tests index 3b4501178c89f..9282aa47e8375 100755 --- a/python/run-tests +++ b/python/run-tests @@ -44,7 +44,6 @@ function run_test() { echo -en "\033[0m" # No color exit -1 fi - } echo "Running PySpark tests. Output is in python/unit-tests.log." @@ -55,9 +54,13 @@ run_test "pyspark/conf.py" if [ -n "$_RUN_SQL_TESTS" ]; then run_test "pyspark/sql.py" fi +# These tests are included in the module-level docs, and so must +# be handled on a higher level rather than within the python file. +export PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" +unset PYSPARK_DOC_TEST run_test "pyspark/tests.py" run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" From 14e6dc94f68e57de82841c4ebbb573797a53869c Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 11 Jun 2014 15:54:41 -0700 Subject: [PATCH 17/38] HOTFIX: PySpark tests should be order insensitive. This has been messing up the SQL PySpark tests on Jenkins. Author: Patrick Wendell Closes #1054 from pwendell/pyspark and squashes the following commits: 1eb5487 [Patrick Wendell] False change 06f062d [Patrick Wendell] HOTFIX: PySpark tests should be order insensitive --- python/pyspark/sql.py | 8 ++++---- .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index b4e9618cc25b5..960d0a82448aa 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -117,7 +117,7 @@ def parquetFile(self, path): >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.saveAsParquetFile(parquetFile) >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> srdd.collect() == srdd2.collect() + >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ jschema_rdd = self._ssql_ctx.parquetFile(path) @@ -141,7 +141,7 @@ def table(self, tableName): >>> srdd = sqlCtx.inferSchema(rdd) >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.table("table1") - >>> srdd.collect() == srdd2.collect() + >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ return SchemaRDD(self._ssql_ctx.table(tableName), self) @@ -293,7 +293,7 @@ def saveAsParquetFile(self, path): >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.saveAsParquetFile(parquetFile) >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> srdd2.collect() == srdd.collect() + >>> sorted(srdd2.collect()) == sorted(srdd.collect()) True """ self._jschema_rdd.saveAsParquetFile(path) @@ -307,7 +307,7 @@ def registerAsTable(self, name): >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.registerAsTable("test") >>> srdd2 = sqlCtx.sql("select * from test") - >>> srdd.collect() == srdd2.collect() + >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ self._jschema_rdd.registerAsTable(name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 7ad8edf5a5a6e..44b19bca460b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -33,6 +33,7 @@ import org.apache.spark.api.java.JavaRDD import java.util.{Map => JMap} /** + * ***FALSE CHANGE*** * :: AlphaComponent :: * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, * SchemaRDDs can be used in relational queries, as shown in the examples below. From d45e0c6b986b54129313fc4a44c741ab4b04462d Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 11 Jun 2014 15:55:41 -0700 Subject: [PATCH 18/38] HOTFIX: Forgot to remove false change in previous commit --- sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 44b19bca460b0..7ad8edf5a5a6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -33,7 +33,6 @@ import org.apache.spark.api.java.JavaRDD import java.util.{Map => JMap} /** - * ***FALSE CHANGE*** * :: AlphaComponent :: * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, * SchemaRDDs can be used in relational queries, as shown in the examples below. From 9a2448daf984d5bb550dfe0d9e28cbb80ef5cb51 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 11 Jun 2014 17:58:35 -0700 Subject: [PATCH 19/38] [SPARK-2052] [SQL] Add optimization for CaseConversionExpression's. Add optimization for `CaseConversionExpression`'s. Author: Takuya UESHIN Closes #990 from ueshin/issues/SPARK-2052 and squashes the following commits: 2568666 [Takuya UESHIN] Move some rules back. dde7ede [Takuya UESHIN] Add tests to check if ConstantFolding can handle null literals and remove the unneeded rules from NullPropagation. c4eea67 [Takuya UESHIN] Fix toString methods. 23e2363 [Takuya UESHIN] Make CaseConversionExpressions foldable if the child is foldable. 0ff7568 [Takuya UESHIN] Add tests for collapsing case statements. 3977d80 [Takuya UESHIN] Add optimization for CaseConversionExpression's. --- .../expressions/stringOperations.scala | 7 +- .../sql/catalyst/optimizer/Optimizer.scala | 30 +++--- .../optimizer/ConstantFoldingSuite.scala | 61 ++++++++++++- ...mplifyCaseConversionExpressionsSuite.scala | 91 +++++++++++++++++++ 4 files changed, 174 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 420303408451f..c074b7bb01e57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -76,7 +76,8 @@ trait CaseConversionExpression { type EvaluatedType = Any def convert(v: String): String - + + override def foldable: Boolean = child.foldable def nullable: Boolean = child.nullable def dataType: DataType = StringType @@ -142,6 +143,8 @@ case class RLike(left: Expression, right: Expression) case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { override def convert(v: String): String = v.toUpperCase() + + override def toString() = s"Upper($child)" } /** @@ -150,4 +153,6 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { override def convert(v: String): String = v.toLowerCase() + + override def toString() = s"Lower($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 28d1aa2e3aafc..25a347bec0e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,7 +36,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] { ConstantFolding, BooleanSimplification, SimplifyFilters, - SimplifyCasts) :: + SimplifyCasts, + SimplifyCaseConversionExpressions) :: Batch("Filter Pushdown", FixedPoint(100), CombineFilters, PushPredicateThroughProject, @@ -132,18 +133,6 @@ object NullPropagation extends Rule[LogicalPlan] { case Literal(candidate, _) if candidate == v => true case _ => false })) => Literal(true, BooleanType) - case e: UnaryMinus => e.child match { - case Literal(null, _) => Literal(null, e.dataType) - case _ => e - } - case e: Cast => e.child match { - case Literal(null, _) => Literal(null, e.dataType) - case _ => e - } - case e: Not => e.child match { - case Literal(null, _) => Literal(null, e.dataType) - case _ => e - } // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) @@ -375,3 +364,18 @@ object CombineLimits extends Rule[LogicalPlan] { Limit(If(LessThan(ne, le), ne, le), grandChild) } } + +/** + * Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because + * the inner conversion is overwritten by the outer one. + */ +object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case Upper(Upper(child)) => Upper(child) + case Upper(Lower(child)) => Upper(child) + case Lower(Upper(child)) => Lower(child) + case Lower(Lower(child)) => Lower(child) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 20dfba847790c..6efc0e211eb21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType} +import org.apache.spark.sql.catalyst.types._ // For implicit conversions import org.apache.spark.sql.catalyst.dsl.plans._ @@ -173,4 +173,63 @@ class ConstantFoldingSuite extends OptimizerTest { comparePlans(optimized, correctAnswer) } + + test("Constant folding test: expressions have null literals") { + val originalQuery = + testRelation + .select( + IsNull(Literal(null)) as 'c1, + IsNotNull(Literal(null)) as 'c2, + + GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, + GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, + GetField( + Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), + "a") as 'c5, + + UnaryMinus(Literal(null, IntegerType)) as 'c6, + Cast(Literal(null), IntegerType) as 'c7, + Not(Literal(null, BooleanType)) as 'c8, + + Add(Literal(null, IntegerType), 1) as 'c9, + Add(1, Literal(null, IntegerType)) as 'c10, + + Equals(Literal(null, IntegerType), 1) as 'c11, + Equals(1, Literal(null, IntegerType)) as 'c12, + + Like(Literal(null, StringType), "abc") as 'c13, + Like("abc", Literal(null, StringType)) as 'c14, + + Upper(Literal(null, StringType)) as 'c15) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(true) as 'c1, + Literal(false) as 'c2, + + Literal(null, IntegerType) as 'c3, + Literal(null, IntegerType) as 'c4, + Literal(null, IntegerType) as 'c5, + + Literal(null, IntegerType) as 'c6, + Literal(null, IntegerType) as 'c7, + Literal(null, BooleanType) as 'c8, + + Literal(null, IntegerType) as 'c9, + Literal(null, IntegerType) as 'c10, + + Literal(null, BooleanType) as 'c11, + Literal(null, BooleanType) as 'c12, + + Literal(null, BooleanType) as 'c13, + Literal(null, BooleanType) as 'c14, + + Literal(null, StringType) as 'c15) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala new file mode 100644 index 0000000000000..df1409fe7baee --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +class SimplifyCaseConversionExpressionsSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Simplify CaseConversionExpressions", Once, + SimplifyCaseConversionExpressions) :: Nil + } + + val testRelation = LocalRelation('a.string) + + test("simplify UPPER(UPPER(str))") { + val originalQuery = + testRelation + .select(Upper(Upper('a)) as 'u) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Upper('a) as 'u) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify UPPER(LOWER(str))") { + val originalQuery = + testRelation + .select(Upper(Lower('a)) as 'u) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Upper('a) as 'u) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify LOWER(UPPER(str))") { + val originalQuery = + testRelation + .select(Lower(Upper('a)) as 'l) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .select(Lower('a) as 'l) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify LOWER(LOWER(str))") { + val originalQuery = + testRelation + .select(Lower(Lower('a)) as 'l) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .select(Lower('a) as 'l) + .analyze + + comparePlans(optimized, correctAnswer) + } +} From d9203350b06a9c737421ba244162b1365402c01b Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Wed, 11 Jun 2014 18:16:33 -0700 Subject: [PATCH 20/38] [SPARK-1672][MLLIB] Separate user and product partitioning in ALS Some clean up work following #593. 1. Allow to set different number user blocks and number product blocks in `ALS`. 2. Update `MovieLensALS` to reflect the change. Author: Tor Myklebust Author: Xiangrui Meng Closes #1014 from mengxr/SPARK-1672 and squashes the following commits: 0e910dd [Xiangrui Meng] change private[this] to private[recommendation] 36420c7 [Xiangrui Meng] set exclusion rules for ALS 9128b77 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-1672 294efe9 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-1672 9bab77b [Xiangrui Meng] clean up add numUserBlocks and numProductBlocks to MovieLensALS 84c8e8c [Xiangrui Meng] Merge branch 'master' into SPARK-1672 d17a8bf [Xiangrui Meng] merge master a4925fd [Tor Myklebust] Style. bd8a75c [Tor Myklebust] Merge branch 'master' of github.com:apache/spark into alsseppar 021f54b [Tor Myklebust] Separate user and product blocks. dcf583a [Tor Myklebust] Remove the partitioner member variable; instead, thread that needle everywhere it needs to go. 23d6f91 [Tor Myklebust] Stop making the partitioner configurable. 495784f [Tor Myklebust] Merge branch 'master' of https://github.com/apache/spark 674933a [Tor Myklebust] Fix style. 40edc23 [Tor Myklebust] Fix missing space. f841345 [Tor Myklebust] Fix daft bug creating 'pairs', also for -> foreach. 5ec9e6c [Tor Myklebust] Clean a couple of things up using 'map'. 36a0f43 [Tor Myklebust] Make the partitioner private. d872b09 [Tor Myklebust] Add negative id ALS test. df27697 [Tor Myklebust] Support custom partitioners. Currently we use the same partitioner for users and products. c90b6d8 [Tor Myklebust] Scramble user and product ids before bucketing. c774d7d [Tor Myklebust] Make the partitioner a member variable and use it instead of modding directly. --- .../spark/examples/mllib/MovieLensALS.scala | 12 +- .../spark/mllib/recommendation/ALS.scala | 164 +++++++++++------- .../spark/mllib/recommendation/ALSSuite.scala | 60 +++++-- project/MimaExcludes.scala | 8 + 4 files changed, 160 insertions(+), 84 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 6eb41e7ba36fb..28e201d279f41 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -50,6 +50,8 @@ object MovieLensALS { numIterations: Int = 20, lambda: Double = 1.0, rank: Int = 10, + numUserBlocks: Int = -1, + numProductBlocks: Int = -1, implicitPrefs: Boolean = false) def main(args: Array[String]) { @@ -67,8 +69,14 @@ object MovieLensALS { .text(s"lambda (smoothing constant), default: ${defaultParams.lambda}") .action((x, c) => c.copy(lambda = x)) opt[Unit]("kryo") - .text(s"use Kryo serialization") + .text("use Kryo serialization") .action((_, c) => c.copy(kryo = true)) + opt[Int]("numUserBlocks") + .text(s"number of user blocks, default: ${defaultParams.numUserBlocks} (auto)") + .action((x, c) => c.copy(numUserBlocks = x)) + opt[Int]("numProductBlocks") + .text(s"number of product blocks, default: ${defaultParams.numProductBlocks} (auto)") + .action((x, c) => c.copy(numProductBlocks = x)) opt[Unit]("implicitPrefs") .text("use implicit preference") .action((_, c) => c.copy(implicitPrefs = true)) @@ -160,6 +168,8 @@ object MovieLensALS { .setIterations(params.numIterations) .setLambda(params.lambda) .setImplicitPrefs(params.implicitPrefs) + .setUserBlocks(params.numUserBlocks) + .setProductBlocks(params.numProductBlocks) .run(training) val rmse = computeRmse(model, test, params.implicitPrefs) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index d743bd7dd1825..f1ae7b85b4a69 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -61,7 +61,7 @@ private[recommendation] case class InLinkBlock( * A more compact class to represent a rating than Tuple3[Int, Int, Double]. */ @Experimental -case class Rating(val user: Int, val product: Int, val rating: Double) +case class Rating(user: Int, product: Int, rating: Double) /** * Alternating Least Squares matrix factorization. @@ -93,7 +93,8 @@ case class Rating(val user: Int, val product: Int, val rating: Double) * preferences rather than explicit ratings given to items. */ class ALS private ( - private var numBlocks: Int, + private var numUserBlocks: Int, + private var numProductBlocks: Int, private var rank: Int, private var iterations: Int, private var lambda: Double, @@ -106,14 +107,31 @@ class ALS private ( * Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10, * lambda: 0.01, implicitPrefs: false, alpha: 1.0}. */ - def this() = this(-1, 10, 10, 0.01, false, 1.0) + def this() = this(-1, -1, 10, 10, 0.01, false, 1.0) /** - * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured - * number of blocks. Default: -1. + * Set the number of blocks for both user blocks and product blocks to parallelize the computation + * into; pass -1 for an auto-configured number of blocks. Default: -1. */ def setBlocks(numBlocks: Int): ALS = { - this.numBlocks = numBlocks + this.numUserBlocks = numBlocks + this.numProductBlocks = numBlocks + this + } + + /** + * Set the number of user blocks to parallelize the computation. + */ + def setUserBlocks(numUserBlocks: Int): ALS = { + this.numUserBlocks = numUserBlocks + this + } + + /** + * Set the number of product blocks to parallelize the computation. + */ + def setProductBlocks(numProductBlocks: Int): ALS = { + this.numProductBlocks = numProductBlocks this } @@ -176,31 +194,32 @@ class ALS private ( def run(ratings: RDD[Rating]): MatrixFactorizationModel = { val sc = ratings.context - val numBlocks = if (this.numBlocks == -1) { + val numUserBlocks = if (this.numUserBlocks == -1) { math.max(sc.defaultParallelism, ratings.partitions.size / 2) } else { - this.numBlocks + this.numUserBlocks } - - val partitioner = new Partitioner { - val numPartitions = numBlocks - - def getPartition(x: Any): Int = { - Utils.nonNegativeMod(byteswap32(x.asInstanceOf[Int]), numPartitions) - } + val numProductBlocks = if (this.numProductBlocks == -1) { + math.max(sc.defaultParallelism, ratings.partitions.size / 2) + } else { + this.numProductBlocks } - val ratingsByUserBlock = ratings.map{ rating => - (partitioner.getPartition(rating.user), rating) + val userPartitioner = new ALSPartitioner(numUserBlocks) + val productPartitioner = new ALSPartitioner(numProductBlocks) + + val ratingsByUserBlock = ratings.map { rating => + (userPartitioner.getPartition(rating.user), rating) } - val ratingsByProductBlock = ratings.map{ rating => - (partitioner.getPartition(rating.product), + val ratingsByProductBlock = ratings.map { rating => + (productPartitioner.getPartition(rating.product), Rating(rating.product, rating.user, rating.rating)) } - val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock, partitioner) + val (userInLinks, userOutLinks) = + makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, productPartitioner) val (productInLinks, productOutLinks) = - makeLinkRDDs(numBlocks, ratingsByProductBlock, partitioner) + makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, userPartitioner) userInLinks.setName("userInLinks") userOutLinks.setName("userOutLinks") productInLinks.setName("productInLinks") @@ -232,27 +251,27 @@ class ALS private ( users.setName(s"users-$iter").persist() val YtY = Some(sc.broadcast(computeYtY(users))) val previousProducts = products - products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda, - alpha, YtY) + products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, + userPartitioner, rank, lambda, alpha, YtY) previousProducts.unpersist() logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) products.setName(s"products-$iter").persist() val XtX = Some(sc.broadcast(computeYtY(products))) val previousUsers = users - users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda, - alpha, XtX) + users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, + productPartitioner, rank, lambda, alpha, XtX) previousUsers.unpersist() } } else { for (iter <- 1 to iterations) { // perform ALS update logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) - products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda, - alpha, YtY = None) + products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, + userPartitioner, rank, lambda, alpha, YtY = None) products.setName(s"products-$iter") logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) - users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda, - alpha, YtY = None) + users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, + productPartitioner, rank, lambda, alpha, YtY = None) users.setName(s"users-$iter") } } @@ -340,9 +359,10 @@ class ALS private ( /** * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs */ - private def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])], - outLinks: RDD[(Int, OutLinkBlock)]) = { - blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) => + private def unblockFactors( + blockedFactors: RDD[(Int, Array[Array[Double]])], + outLinks: RDD[(Int, OutLinkBlock)]): RDD[(Int, Array[Double])] = { + blockedFactors.join(outLinks).flatMap { case (b, (factors, outLinkBlock)) => for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) } } @@ -351,14 +371,14 @@ class ALS private ( * Make the out-links table for a block of the users (or products) dataset given the list of * (user, product, rating) values for the users in that block (or the opposite for products). */ - private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating], - partitioner: Partitioner): OutLinkBlock = { + private def makeOutLinkBlock(numProductBlocks: Int, ratings: Array[Rating], + productPartitioner: Partitioner): OutLinkBlock = { val userIds = ratings.map(_.user).distinct.sorted val numUsers = userIds.length val userIdToPos = userIds.zipWithIndex.toMap - val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks)) + val shouldSend = Array.fill(numUsers)(new BitSet(numProductBlocks)) for (r <- ratings) { - shouldSend(userIdToPos(r.user))(partitioner.getPartition(r.product)) = true + shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true } OutLinkBlock(userIds, shouldSend) } @@ -367,18 +387,17 @@ class ALS private ( * Make the in-links table for a block of the users (or products) dataset given a list of * (user, product, rating) values for the users in that block (or the opposite for products). */ - private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating], - partitioner: Partitioner): InLinkBlock = { + private def makeInLinkBlock(numProductBlocks: Int, ratings: Array[Rating], + productPartitioner: Partitioner): InLinkBlock = { val userIds = ratings.map(_.user).distinct.sorted - val numUsers = userIds.length val userIdToPos = userIds.zipWithIndex.toMap // Split out our ratings by product block - val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating]) + val blockRatings = Array.fill(numProductBlocks)(new ArrayBuffer[Rating]) for (r <- ratings) { - blockRatings(partitioner.getPartition(r.product)) += r + blockRatings(productPartitioner.getPartition(r.product)) += r } - val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks) - for (productBlock <- 0 until numBlocks) { + val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numProductBlocks) + for (productBlock <- 0 until numProductBlocks) { // Create an array of (product, Seq(Rating)) ratings val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray // Sort them by product ID @@ -400,14 +419,16 @@ class ALS private ( * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. */ - private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)], partitioner: Partitioner) - : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = - { - val grouped = ratings.partitionBy(new HashPartitioner(numBlocks)) + private def makeLinkRDDs( + numUserBlocks: Int, + numProductBlocks: Int, + ratingsByUserBlock: RDD[(Int, Rating)], + productPartitioner: Partitioner): (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = { + val grouped = ratingsByUserBlock.partitionBy(new HashPartitioner(numUserBlocks)) val links = grouped.mapPartitionsWithIndex((blockId, elements) => { - val ratings = elements.map{_._2}.toArray - val inLinkBlock = makeInLinkBlock(numBlocks, ratings, partitioner) - val outLinkBlock = makeOutLinkBlock(numBlocks, ratings, partitioner) + val ratings = elements.map(_._2).toArray + val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner) + val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner) Iterator.single((blockId, (inLinkBlock, outLinkBlock))) }, true) val inLinks = links.mapValues(_._1) @@ -439,26 +460,24 @@ class ALS private ( * It returns an RDD of new feature vectors for each user block. */ private def updateFeatures( + numUserBlocks: Int, products: RDD[(Int, Array[Array[Double]])], productOutLinks: RDD[(Int, OutLinkBlock)], userInLinks: RDD[(Int, InLinkBlock)], - partitioner: Partitioner, + productPartitioner: Partitioner, rank: Int, lambda: Double, alpha: Double, - YtY: Option[Broadcast[DoubleMatrix]]) - : RDD[(Int, Array[Array[Double]])] = - { - val numBlocks = products.partitions.size + YtY: Option[Broadcast[DoubleMatrix]]): RDD[(Int, Array[Array[Double]])] = { productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) => - val toSend = Array.fill(numBlocks)(new ArrayBuffer[Array[Double]]) - for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numBlocks) { + val toSend = Array.fill(numUserBlocks)(new ArrayBuffer[Array[Double]]) + for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numUserBlocks) { if (outLinkBlock.shouldSend(p)(userBlock)) { toSend(userBlock) += factors(p) } } toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } - }.groupByKey(partitioner) + }.groupByKey(productPartitioner) .join(userInLinks) .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY) @@ -475,7 +494,7 @@ class ALS private ( { // Sort the incoming block factor messages by block ID and make them an array val blockFactors = messages.toSeq.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] - val numBlocks = blockFactors.length + val numProductBlocks = blockFactors.length val numUsers = inLinkBlock.elementIds.length // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since @@ -490,7 +509,7 @@ class ALS private ( // Compute the XtX and Xy values for each user by adding products it rated in each product // block - for (productBlock <- 0 until numBlocks) { + for (productBlock <- 0 until numProductBlocks) { var p = 0 while (p < blockFactors(productBlock).length) { val x = wrapDoubleArray(blockFactors(productBlock)(p)) @@ -579,6 +598,23 @@ class ALS private ( } } +/** + * Partitioner for ALS. + */ +private[recommendation] class ALSPartitioner(override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = { + Utils.nonNegativeMod(byteswap32(key.asInstanceOf[Int]), numPartitions) + } + + override def equals(obj: Any): Boolean = { + obj match { + case p: ALSPartitioner => + this.numPartitions == p.numPartitions + case _ => + false + } + } +} /** * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. @@ -606,7 +642,7 @@ object ALS { blocks: Int, seed: Long ): MatrixFactorizationModel = { - new ALS(blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings) + new ALS(blocks, blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings) } /** @@ -629,7 +665,7 @@ object ALS { lambda: Double, blocks: Int ): MatrixFactorizationModel = { - new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings) + new ALS(blocks, blocks, rank, iterations, lambda, false, 1.0).run(ratings) } /** @@ -689,7 +725,7 @@ object ALS { alpha: Double, seed: Long ): MatrixFactorizationModel = { - new ALS(blocks, rank, iterations, lambda, true, alpha, seed).run(ratings) + new ALS(blocks, blocks, rank, iterations, lambda, true, alpha, seed).run(ratings) } /** @@ -714,7 +750,7 @@ object ALS { blocks: Int, alpha: Double ): MatrixFactorizationModel = { - new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings) + new ALS(blocks, blocks, rank, iterations, lambda, true, alpha).run(ratings) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 37c9b9d085841..81bebec8c7a39 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -121,6 +121,10 @@ class ALSSuite extends FunSuite with LocalSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, true, false, true) } + test("rank-2 matrices with different user and product blocks") { + testALS(100, 200, 2, 15, 0.7, 0.4, numUserBlocks = 4, numProductBlocks = 2) + } + test("pseudorandomness") { val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2) val model11 = ALS.train(ratings, 5, 1, 1.0, 2, 1) @@ -153,35 +157,52 @@ class ALSSuite extends FunSuite with LocalSparkContext { } test("NNALS, rank 2") { - testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, false) + testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false) } /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * - * @param users number of users - * @param products number of products - * @param features number of features (rank of problem) - * @param iterations number of iterations to run - * @param samplingRate what fraction of the user-product pairs are known + * @param users number of users + * @param products number of products + * @param features number of features (rank of problem) + * @param iterations number of iterations to run + * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct - * @param implicitPrefs flag to test implicit feedback - * @param bulkPredict flag to test bulk prediciton + * @param implicitPrefs flag to test implicit feedback + * @param bulkPredict flag to test bulk prediciton * @param negativeWeights whether the generated data can contain negative values - * @param numBlocks number of blocks to partition users and products into + * @param numUserBlocks number of user blocks to partition users into + * @param numProductBlocks number of product blocks to partition products into * @param negativeFactors whether the generated user/product factors can have negative entries */ - def testALS(users: Int, products: Int, features: Int, iterations: Int, - samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false, - bulkPredict: Boolean = false, negativeWeights: Boolean = false, numBlocks: Int = -1, - negativeFactors: Boolean = true) - { + def testALS( + users: Int, + products: Int, + features: Int, + iterations: Int, + samplingRate: Double, + matchThreshold: Double, + implicitPrefs: Boolean = false, + bulkPredict: Boolean = false, + negativeWeights: Boolean = false, + numUserBlocks: Int = -1, + numProductBlocks: Int = -1, + negativeFactors: Boolean = true) { val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, features, samplingRate, implicitPrefs, negativeWeights, negativeFactors) - val model = (new ALS().setBlocks(numBlocks).setRank(features).setIterations(iterations) - .setAlpha(1.0).setImplicitPrefs(implicitPrefs).setLambda(0.01).setSeed(0L) - .setNonnegative(!negativeFactors).run(sc.parallelize(sampledRatings))) + val model = new ALS() + .setUserBlocks(numUserBlocks) + .setProductBlocks(numProductBlocks) + .setRank(features) + .setIterations(iterations) + .setAlpha(1.0) + .setImplicitPrefs(implicitPrefs) + .setLambda(0.01) + .setSeed(0L) + .setNonnegative(!negativeFactors) + .run(sc.parallelize(sampledRatings)) val predictedU = new DoubleMatrix(users, features) for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) { @@ -208,8 +229,9 @@ class ALSSuite extends FunSuite with LocalSparkContext { val prediction = predictedRatings.get(u, p) val correct = trueRatings.get(u, p) if (math.abs(prediction - correct) > matchThreshold) { - fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( - u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) + fail(("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s") + .format(u, p, correct, prediction, trueRatings, predictedRatings, predictedU, + predictedP)) } } } else { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd7efceb23c96..c80ab9a9f8e60 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -54,6 +54,14 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1") ) ++ + Seq( // Ignore some private methods in ALS. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. + "org.apache.spark.mllib.recommendation.ALS.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7") + ) ++ MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") From 508fd371d6dbb826fd8a00787d347235b549e189 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 11 Jun 2014 20:45:29 -0700 Subject: [PATCH 21/38] [SPARK-2044] Pluggable interface for shuffles This is a first cut at moving shuffle logic behind a pluggable interface, as described at https://issues.apache.org/jira/browse/SPARK-2044, to let us more easily experiment with new shuffle implementations. It moves the existing shuffle code to a class HashShuffleManager behind a general ShuffleManager interface. Two things are still missing to make this complete: * MapOutputTracker needs to be hidden behind the ShuffleManager interface; this will also require adding methods to ShuffleManager that will let the DAGScheduler interact with it as it does with the MapOutputTracker today * The code to do map-sides and reduce-side combine in ShuffledRDD, PairRDDFunctions, etc needs to be moved into the ShuffleManager's readers and writers However, some of these may also be done later after we merge the current interface. Author: Matei Zaharia Closes #1009 from mateiz/pluggable-shuffle and squashes the following commits: 7a09862 [Matei Zaharia] review comments be33d3f [Matei Zaharia] review comments 1513d4e [Matei Zaharia] Add ASF header ac56831 [Matei Zaharia] Bug fix and better error message 4f681ba [Matei Zaharia] Move write part of ShuffleMapTask to ShuffleManager f6f011d [Matei Zaharia] Move hash shuffle reader behind ShuffleManager interface 55c7717 [Matei Zaharia] Changed RDD code to use ShuffleReader 75cc044 [Matei Zaharia] Partial work to move hash shuffle in --- .../org/apache/spark/ContextCleaner.scala | 2 +- .../scala/org/apache/spark/Dependency.scala | 12 +- .../scala/org/apache/spark/SparkEnv.scala | 28 +++-- .../org/apache/spark/rdd/CoGroupedRDD.scala | 22 ++-- .../org/apache/spark/rdd/ShuffledRDD.scala | 12 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 17 +-- .../apache/spark/scheduler/DAGScheduler.scala | 12 +- .../spark/scheduler/ShuffleMapTask.scala | 73 +++--------- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../spark/scheduler/TaskResultGetter.scala | 3 +- .../apache/spark/serializer/Serializer.scala | 4 + .../BaseShuffleHandle.scala} | 26 ++-- .../apache/spark/shuffle/ShuffleHandle.scala | 25 ++++ .../apache/spark/shuffle/ShuffleManager.scala | 57 +++++++++ .../apache/spark/shuffle/ShuffleReader.scala | 29 +++++ .../apache/spark/shuffle/ShuffleWriter.scala | 31 +++++ .../hash}/BlockStoreShuffleFetcher.scala | 9 +- .../shuffle/hash/HashShuffleManager.scala | 60 ++++++++++ .../shuffle/hash/HashShuffleReader.scala | 42 +++++++ .../shuffle/hash/HashShuffleWriter.scala | 111 ++++++++++++++++++ .../apache/spark/ContextCleanerSuite.scala | 6 +- .../scala/org/apache/spark/ShuffleSuite.scala | 6 +- 22 files changed, 459 insertions(+), 130 deletions(-) rename core/src/main/scala/org/apache/spark/{ShuffleFetcher.scala => shuffle/BaseShuffleHandle.scala} (66%) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala rename core/src/main/scala/org/apache/spark/{ => shuffle/hash}/BlockStoreShuffleFetcher.scala (96%) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index e2d2250982daa..bf3c3a6ceb5ef 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Register a ShuffleDependency for cleanup when it is garbage collected. */ - def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 2c31cc20211ff..c8c194a111aac 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -20,6 +20,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleHandle /** * :: DeveloperApi :: @@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * Represents a dependency on the output of a shuffle stage. * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output - * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null, + * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will * be used. */ @DeveloperApi -class ShuffleDependency[K, V]( +class ShuffleDependency[K, V, C]( @transient rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, - val serializer: Serializer = null) + val serializer: Option[Serializer] = None, + val keyOrdering: Option[Ordering[K]] = None, + val aggregator: Option[Aggregator[K, V, C]] = None) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() + val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( + shuffleId, rdd.partitions.size, this) + rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 720151a6b0f84..8dfa8cc4b5b3f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.ConnectionManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -56,7 +57,7 @@ class SparkEnv ( val closureSerializer: Serializer, val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, - val shuffleFetcher: ShuffleFetcher, + val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, val blockManager: BlockManager, val connectionManager: ConnectionManager, @@ -80,7 +81,7 @@ class SparkEnv ( pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() - shuffleFetcher.stop() + shuffleManager.stop() broadcastManager.stop() blockManager.stop() blockManager.master.stop() @@ -163,13 +164,20 @@ object SparkEnv extends Logging { def instantiateClass[T](propertyName: String, defaultClassName: String): T = { val name = conf.get(propertyName, defaultClassName) val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader) - // First try with the constructor that takes SparkConf. If we can't find one, - // use a no-arg constructor instead. + // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just + // SparkConf, then one taking no arguments try { - cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) + .newInstance(conf, new java.lang.Boolean(isDriver)) + .asInstanceOf[T] } catch { case _: NoSuchMethodException => - cls.getConstructor().newInstance().asInstanceOf[T] + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } } } @@ -219,9 +227,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val shuffleFetcher = instantiateClass[ShuffleFetcher]( - "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") - val httpFileServer = new HttpFileServer(securityManager) httpFileServer.initialize() conf.set("spark.fileserver.uri", httpFileServer.serverUri) @@ -242,6 +247,9 @@ object SparkEnv extends Logging { "." } + val shuffleManager = instantiateClass[ShuffleManager]( + "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -255,7 +263,7 @@ object SparkEnv extends Logging { closureSerializer, cacheManager, mapOutputTracker, - shuffleFetcher, + shuffleManager, broadcastManager, blockManager, connectionManager, diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 9ff76892aed32..5951865e56c9d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleHandle private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep( } } -private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep +private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) extends Partition with Serializable { @@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: private type CoGroupValue = (Any, Int) // Int is dependency number private type CoGroupCombiner = Seq[CoGroup] - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): CoGroupedRDD[K] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[Any, Any](rdd, part, serializer) + new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer) } } } @@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) + case s: ShuffleDependency[_, _, _] => + new ShuffleCoGroupSplitDep(s.shuffleHandle) case _ => new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } @@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] rddIterators += ((it, depNum)) - case ShuffleCoGroupSplitDep(shuffleId) => + case ShuffleCoGroupSplitDep(handle) => // Read map outputs of shuffle - val fetcher = SparkEnv.get.shuffleFetcher - val ser = Serializer.getSerializer(serializer) - val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) + val it = SparkEnv.get.shuffleManager + .getReader(handle, split.index, split.index + 1, context) + .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 802b0bdfb2d59..bb108ef163c56 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( part: Partitioner) extends RDD[P](prev.context, Nil) { - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[P] = { - val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - val ser = Serializer.getSerializer(serializer) - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser) + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[P]] } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 9a09c05bbc959..ed24ea22a661c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) + case s: ShuffleDependency[_, _, _] => + new ShuffleCoGroupSplitDep(s.shuffleHandle) case _ => new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } @@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] - val ser = Serializer.getSerializer(serializer) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) @@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( case NarrowCoGroupSplitDep(rdd, _, itsSplit) => rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - case ShuffleCoGroupSplitDep(shuffleId) => - val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context, ser) + case ShuffleCoGroupSplitDep(handle) => + val iter = SparkEnv.get.shuffleManager + .getReader(handle, partition.index, partition.index + 1, context) + .read() iter.foreach(op) } // the first dep is rdd1; add all values to the map diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e09a4221e8315..3c85b5a2ae776 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -190,7 +190,7 @@ class DAGScheduler( * The jobId value passed in will be used if the stage doesn't already exist with * a lower jobId (jobId always increases across jobs.) */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => @@ -210,7 +210,7 @@ class DAGScheduler( private def newStage( rdd: RDD[_], numTasks: Int, - shuffleDep: Option[ShuffleDependency[_,_]], + shuffleDep: Option[ShuffleDependency[_, _, _]], jobId: Int, callSite: Option[String] = None) : Stage = @@ -233,7 +233,7 @@ class DAGScheduler( private def newOrUsedStage( rdd: RDD[_], numTasks: Int, - shuffleDep: ShuffleDependency[_,_], + shuffleDep: ShuffleDependency[_, _, _], jobId: Int, callSite: Option[String] = None) : Stage = @@ -269,7 +269,7 @@ class DAGScheduler( // we can't do it in its constructor because # of partitions is unknown for (dep <- r.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => parents += getShuffleMapStage(shufDep, jobId) case _ => visit(dep.rdd) @@ -290,7 +290,7 @@ class DAGScheduler( if (getCacheLocs(rdd).contains(Nil)) { for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { missing += mapStage @@ -1088,7 +1088,7 @@ class DAGScheduler( visitedRdds += rdd for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { visitedStages += mapStage diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index ed0f56f1abdf5..0098b5a59d1a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ +import org.apache.spark.shuffle.ShuffleWriter private[spark] object ShuffleMapTask { @@ -37,7 +38,7 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. private val serializedInfoCache = new HashMap[Int, Array[Byte]] - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { + def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { synchronized { val old = serializedInfoCache.get(stageId).orNull if (old != null) { @@ -56,12 +57,12 @@ private[spark] object ShuffleMapTask { } } - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] (rdd, dep) } @@ -99,7 +100,7 @@ private[spark] object ShuffleMapTask { private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], - var dep: ShuffleDependency[_,_], + var dep: ShuffleDependency[_, _, _], _partitionId: Int, @transient private var locs: Seq[TaskLocation]) extends Task[MapStatus](stageId, _partitionId) @@ -141,66 +142,22 @@ private[spark] class ShuffleMapTask( } override def runTask(context: TaskContext): MapStatus = { - val numOutputSplits = dep.partitioner.numPartitions metrics = Some(context.taskMetrics) - - val blockManager = SparkEnv.get.blockManager - val shuffleBlockManager = blockManager.shuffleBlockManager - var shuffle: ShuffleWriterGroup = null - var success = false - + var writer: ShuffleWriter[Any, Any] = null try { - // Obtain all the block writers for shuffle blocks. - val ser = Serializer.getSerializer(dep.serializer) - shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) - - // Write the map output to its associated buckets. + val manager = SparkEnv.get.shuffleManager + writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) for (elem <- rdd.iterator(split, context)) { - val pair = elem.asInstanceOf[Product2[Any, Any]] - val bucketId = dep.partitioner.getPartition(pair._1) - shuffle.writers(bucketId).write(pair) - } - - // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - var totalTime = 0L - val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() - val size = writer.fileSegment().length - totalBytes += size - totalTime += writer.timeWriting() - MapOutputTracker.compressSize(size) + writer.write(elem.asInstanceOf[Product2[Any, Any]]) } - - // Update shuffle metrics. - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - - success = true - new MapStatus(blockManager.blockManagerId, compressedSizes) - } catch { case e: Exception => - // If there is an exception from running the task, revert the partial writes - // and throw the exception upstream to Spark. - if (shuffle != null && shuffle.writers != null) { - for (writer <- shuffle.writers) { - writer.revertPartialWrites() - writer.close() + return writer.stop(success = true).get + } catch { + case e: Exception => + if (writer != null) { + writer.stop(success = false) } - } - throw e + throw e } finally { - // Release the writers back to the shuffle block manager. - if (shuffle != null && shuffle.writers != null) { - try { - shuffle.releaseWriters(success) - } catch { - case e: Exception => logError("Failed to release shuffle writers", e) - } - } - // Execute the callbacks on task completion. context.executeOnCompleteCallbacks() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5c1fc30e4a557..3bf9713f728c6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -40,7 +40,7 @@ private[spark] class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, - val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage + val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, callSite: Option[String]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 99d305b36a959..df59f444b7a0e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -71,7 +71,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val loader = Thread.currentThread.getContextClassLoader taskSetManager.abort("ClassNotFound with classloader: " + loader) case ex: Exception => - taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex)) + logError("Exception while getting task result", ex) + taskSetManager.abort("Exception while getting task result: %s".format(ex)) } } }) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index ee26970a3d874..f2f5cea469c61 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -52,6 +52,10 @@ object Serializer { def getSerializer(serializer: Serializer): Serializer = { if (serializer == null) SparkEnv.get.serializer else serializer } + + def getSerializer(serializer: Option[Serializer]): Serializer = { + serializer.getOrElse(SparkEnv.get.serializer) + } } diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala similarity index 66% rename from core/src/main/scala/org/apache/spark/ShuffleFetcher.scala rename to core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala index a4f69b6b22b2c..b36c457d6d514 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala @@ -15,22 +15,16 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle +import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner} import org.apache.spark.serializer.Serializer -private[spark] abstract class ShuffleFetcher { - - /** - * Fetch the shuffle outputs for a given ShuffleDependency. - * @return An iterator over the elements of the fetched shuffle outputs. - */ - def fetch[T]( - shuffleId: Int, - reduceId: Int, - context: TaskContext, - serializer: Serializer = SparkEnv.get.serializer): Iterator[T] - - /** Stop the fetcher */ - def stop() {} -} +/** + * A basic ShuffleHandle implementation that just captures registerShuffle's parameters. + */ +private[spark] class BaseShuffleHandle[K, V, C]( + shuffleId: Int, + val numMaps: Int, + val dependency: ShuffleDependency[K, V, C]) + extends ShuffleHandle(shuffleId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala new file mode 100644 index 0000000000000..13c7115f88afa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks. + * + * @param shuffleId ID of the shuffle + */ +private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala new file mode 100644 index 0000000000000..9c859b8b4a118 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.{TaskContext, ShuffleDependency} + +/** + * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the + * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles + * with it, and executors (or tasks running locally in the driver) can ask to read and write data. + * + * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and + * boolean isDriver as parameters. + */ +private[spark] trait ShuffleManager { + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle + + /** Get a writer for a given partition. Called on executors by map tasks. */ + def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] + + /** Remove a shuffle's metadata from the ShuffleManager. */ + def unregisterShuffle(shuffleId: Int) + + /** Shut down this ShuffleManager. */ + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala new file mode 100644 index 0000000000000..b30e366d06006 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * Obtained inside a reduce task to read combined records from the mappers. + */ +private[spark] trait ShuffleReader[K, C] { + /** Read the combined key-values for this reduce task */ + def read(): Iterator[Product2[K, C]] + + /** Close this reader */ + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala new file mode 100644 index 0000000000000..ead3ebd652ca5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.scheduler.MapStatus + +/** + * Obtained inside a map task to write out records to the shuffle system. + */ +private[spark] trait ShuffleWriter[K, V] { + /** Write a record to this task's output */ + def write(record: Product2[K, V]): Unit + + /** Close this writer, passing along whether the map completed */ + def stop(success: Boolean): Option[MapStatus] +} diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala rename to core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index a67392441ed29..b05b6ea345df3 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -24,17 +24,16 @@ import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator +import org.apache.spark._ -private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - - override def fetch[T]( +private[hash] object BlockStoreShuffleFetcher extends Logging { + def fetch[T]( shuffleId: Int, reduceId: Int, context: TaskContext, serializer: Serializer) : Iterator[T] = { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala new file mode 100644 index 0000000000000..5b0940ecce29d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark._ +import org.apache.spark.shuffle._ + +/** + * A ShuffleManager using hashing, that creates one output file per reduce partition on each + * mapper (possibly reusing these across waves of tasks). + */ +class HashShuffleManager(conf: SparkConf) extends ShuffleManager { + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + new HashShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) + : ShuffleWriter[K, V] = { + new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Unit = {} + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = {} +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala new file mode 100644 index 0000000000000..f6a790309a587 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.TaskContext + +class HashShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext) + extends ShuffleReader[K, C] +{ + require(endPartition == startPartition + 1, + "Hash shuffle currently only supports fetching one partition") + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, + Serializer.getSerializer(handle.dependency.serializer)) + } + + /** Close this reader */ + override def stop(): Unit = ??? +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala new file mode 100644 index 0000000000000..4c6749098c110 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} +import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.storage.{BlockObjectWriter} +import org.apache.spark.serializer.Serializer +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.scheduler.MapStatus + +class HashShuffleWriter[K, V]( + handle: BaseShuffleHandle[K, V, _], + mapId: Int, + context: TaskContext) + extends ShuffleWriter[K, V] with Logging { + + private val dep = handle.dependency + private val numOutputSplits = dep.partitioner.numPartitions + private val metrics = context.taskMetrics + private var stopping = false + + private val blockManager = SparkEnv.get.blockManager + private val shuffleBlockManager = blockManager.shuffleBlockManager + private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) + private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser) + + /** Write a record to this task's output */ + override def write(record: Product2[K, V]): Unit = { + val pair = record.asInstanceOf[Product2[Any, Any]] + val bucketId = dep.partitioner.getPartition(pair._1) + shuffle.writers(bucketId).write(pair) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + try { + return Some(commitWritesAndBuildStatus()) + } catch { + case e: Exception => + revertWrites() + throw e + } + } else { + revertWrites() + return None + } + } finally { + // Release the writers back to the shuffle block manager. + if (shuffle != null && shuffle.writers != null) { + try { + shuffle.releaseWriters(success) + } catch { + case e: Exception => logError("Failed to release shuffle writers", e) + } + } + } + } + + private def commitWritesAndBuildStatus(): MapStatus = { + // Commit the writes. Get the size of each bucket block (total block size). + var totalBytes = 0L + var totalTime = 0L + val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => + writer.commit() + writer.close() + val size = writer.fileSegment().length + totalBytes += size + totalTime += writer.timeWriting() + MapOutputTracker.compressSize(size) + } + + // Update shuffle metrics. + val shuffleMetrics = new ShuffleWriteMetrics + shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime + metrics.shuffleWriteMetrics = Some(shuffleMetrics) + + new MapStatus(blockManager.blockManagerId, compressedSizes) + } + + private def revertWrites(): Unit = { + if (shuffle != null && shuffle.writers != null) { + for (writer <- shuffle.writers) { + writer.revertPartialWrites() + writer.close() + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index dc2db66df60e0..13b415cccb647 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo def newPairRDD = newRDD.map(_ -> 1) def newShuffleRDD = newPairRDD.reduceByKey(_ + _) def newBroadcast = sc.broadcast(1 to 100) - def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { + def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => getAllDependencies(dep.rdd) @@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Get all the shuffle dependencies val shuffleDeps = getAllDependencies(rdd) - .filter(_.isInstanceOf[ShuffleDependency[_, _]]) - .map(_.asInstanceOf[ShuffleDependency[_, _]]) + .filter(_.isInstanceOf[ShuffleDependency[_, _, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _, _]]) (rdd, shuffleDeps) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 7b0607dd3ed2d..47112ce66d695 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 10) @@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) .setSerializer(new KryoSerializer(conf)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => @@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // NOTE: The default Java serializer should create zero-sized blocks val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => From e056320cc8e2af56983a198d26717485304695d7 Mon Sep 17 00:00:00 2001 From: Yadong Date: Wed, 11 Jun 2014 20:58:39 -0700 Subject: [PATCH 22/38] 'killFuture' is never used Author: Yadong Closes #1052 from watermen/bug-fix1 and squashes the following commits: 409d09a [Yadong] 'killFuture' is never used --- core/src/main/scala/org/apache/spark/deploy/Client.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index aeb159adc31d9..c371dc3a51c73 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -81,7 +81,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends case "kill" => val driverId = driverArgs.driverId - val killFuture = masterActor ! RequestKillDriver(driverId) + masterActor ! RequestKillDriver(driverId) } } From 4d8ae709fb8d986634c97d21036391ed4685db1a Mon Sep 17 00:00:00 2001 From: Henry Saputra Date: Wed, 11 Jun 2014 23:17:51 -0700 Subject: [PATCH 23/38] Cleanup on Connection and ConnectionManager Simple cleanup on Connection and ConnectionManager to make IDE happy while working of issue: 1. Replace var with var 2. Add parentheses to Queue#dequeu to be consistent with side-effects. 3. Remove return on final line of a method. Author: Henry Saputra Closes #1060 from hsaputra/cleanup_connection_classes and squashes the following commits: 245fd09 [Henry Saputra] Cleanup on Connection and ConnectionManager to make IDE happy while working of issue: 1. Replace var with var 2. Add parentheses to Queue#dequeu to be consistent with side-effects. 3. Remove return on final line of a method. --- .../org/apache/spark/network/Connection.scala | 4 ++-- .../spark/network/ConnectionManager.scala | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 3ffaaab23d0f5..3b6298a26d7c5 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -210,7 +210,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, var nextMessageToBeUsed = 0 def addMessage(message: Message) { - messages.synchronized{ + messages.synchronized { /* messages += message */ messages.enqueue(message) logDebug("Added [" + message + "] to outbox for sending to " + @@ -223,7 +223,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, while (!messages.isEmpty) { /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ /* val message = messages(nextMessageToBeUsed) */ - val message = messages.dequeue + val message = messages.dequeue() val chunk = message.getChunkForSending(defaultChunkSize) if (chunk.isDefined) { messages.enqueue(message) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 5dd5fd0047c0d..cf1c985c2fff9 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -250,14 +250,14 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, try { while(!selectorThread.isInterrupted) { while (! registerRequests.isEmpty) { - val conn: SendingConnection = registerRequests.dequeue + val conn: SendingConnection = registerRequests.dequeue() addListeners(conn) conn.connect() addConnection(conn) } while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue + val (key, ops) = keyInterestChangeRequests.dequeue() try { if (key.isValid) { @@ -532,9 +532,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } return } - var securityMsgResp = SecurityMessage.fromResponse(replyToken, + val securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId.toString()) - var message = securityMsgResp.toBufferMessage + val message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) } catch { @@ -568,9 +568,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, logDebug("Server sasl not completed: " + connection.connectionId) } if (replyToken != null) { - var securityMsgResp = SecurityMessage.fromResponse(replyToken, + val securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId) - var message = securityMsgResp.toBufferMessage + val message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security Message") sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) } @@ -618,7 +618,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, return true } } - return false + false } private def handleMessage( @@ -694,9 +694,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() - var securityMsg = SecurityMessage.fromResponse(firstResponse, + val securityMsg = SecurityMessage.fromResponse(firstResponse, conn.connectionId.toString()) - var message = securityMsg.toBufferMessage + val message = securityMsg.toBufferMessage if (message == null) throw new Exception("Error creating security message") connectionsAwaitingSasl += ((conn.connectionId, conn)) sendSecurityMessage(connManagerId, message) From 43d53d51c9ee2626d9de91faa3b192979b86821d Mon Sep 17 00:00:00 2001 From: Jeff Thompson Date: Thu, 12 Jun 2014 08:10:51 -0700 Subject: [PATCH 24/38] fixed typo in docstring for min() Hi, I found this typo while learning spark and thought I'd do a pull request. Author: Jeff Thompson Closes #1065 from jkthompson/docstring-typo-minmax and squashes the following commits: 29b6a26 [Jeff Thompson] fixed typo in docstring for min() --- python/pyspark/rdd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9c69c79236edc..8a215fc51130a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -695,7 +695,7 @@ def max(self): def min(self): """ - Find the maximum item in this RDD. + Find the minimum item in this RDD. >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min() 1.0 From ce92a9c18f033ac9fa2f12143fab00a90e0f4577 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 12 Jun 2014 08:14:25 -0700 Subject: [PATCH 25/38] SPARK-554. Add aggregateByKey. Author: Sandy Ryza Closes #705 from sryza/sandy-spark-554 and squashes the following commits: 2302b8f [Sandy Ryza] Add MIMA exclude f52e0ad [Sandy Ryza] Fix Python tests for real 2f3afa3 [Sandy Ryza] Fix Python test 0b735e9 [Sandy Ryza] Fix line lengths ae56746 [Sandy Ryza] Fix doc (replace T with V) c2be415 [Sandy Ryza] Java and Python aggregateByKey 23bf400 [Sandy Ryza] SPARK-554. Add aggregateByKey. --- .../apache/spark/api/java/JavaPairRDD.scala | 44 ++++++++++++++++ .../apache/spark/rdd/PairRDDFunctions.scala | 50 +++++++++++++++++++ .../java/org/apache/spark/JavaAPISuite.java | 31 ++++++++++++ .../spark/rdd/PairRDDFunctionsSuite.scala | 13 +++++ docs/programming-guide.md | 4 ++ project/MimaExcludes.scala | 5 +- python/pyspark/rdd.py | 19 ++++++- python/pyspark/tests.py | 15 ++++++ 8 files changed, 179 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 7dcfbf741c4f1..14fa9d8135afe 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -228,6 +228,50 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U](zeroValue: U, partitioner: Partitioner, seqFunc: JFunction2[U, V, U], + combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = { + implicit val ctag: ClassTag[U] = fakeClassTag + fromRDD(rdd.aggregateByKey(zeroValue, partitioner)(seqFunc, combFunc)) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U](zeroValue: U, numPartitions: Int, seqFunc: JFunction2[U, V, U], + combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = { + implicit val ctag: ClassTag[U] = fakeClassTag + fromRDD(rdd.aggregateByKey(zeroValue, numPartitions)(seqFunc, combFunc)) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's. + * The former operation is used for merging values within a partition, and the latter is used for + * merging values between partitions. To avoid memory allocation, both of these functions are + * allowed to modify and return their first argument instead of creating a new U. + */ + def aggregateByKey[U](zeroValue: U, seqFunc: JFunction2[U, V, U], combFunc: JFunction2[U, U, U]): + JavaPairRDD[K, U] = { + implicit val ctag: ClassTag[U] = fakeClassTag + fromRDD(rdd.aggregateByKey(zeroValue)(seqFunc, combFunc)) + } + /** * Merge the values for each key using an associative function and a neutral "zero value" which * may be added to the result an arbitrary number of times, and must not change the result diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 8909980957058..b6ad9b6c3e168 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -118,6 +118,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) } + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U, + combOp: (U, U) => U): RDD[(K, U)] = { + // Serialize the zero value to a byte array so that we can get a new clone of it on each key + val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) + val zeroArray = new Array[Byte](zeroBuffer.limit) + zeroBuffer.get(zeroArray) + + lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() + def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) + + combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U, + combOp: (U, U) => U): RDD[(K, U)] = { + aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U, + combOp: (U, U) => U): RDD[(K, U)] = { + aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp) + } + /** * Merge the values for each key using an associative function and a neutral "zero value" which * may be added to the result an arbitrary number of times, and must not change the result diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 50a62129116f1..ef41bfb88de9d 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -317,6 +317,37 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(33, sum); } + @Test + public void aggregateByKey() { + JavaPairRDD pairs = sc.parallelizePairs( + Arrays.asList( + new Tuple2(1, 1), + new Tuple2(1, 1), + new Tuple2(3, 2), + new Tuple2(5, 1), + new Tuple2(5, 3)), 2); + + Map> sets = pairs.aggregateByKey(new HashSet(), + new Function2, Integer, Set>() { + @Override + public Set call(Set a, Integer b) { + a.add(b); + return a; + } + }, + new Function2, Set, Set>() { + @Override + public Set call(Set a, Set b) { + a.addAll(b); + return a; + } + }).collectAsMap(); + Assert.assertEquals(3, sets.size()); + Assert.assertEquals(new HashSet(Arrays.asList(1)), sets.get(1)); + Assert.assertEquals(new HashSet(Arrays.asList(2)), sets.get(3)); + Assert.assertEquals(new HashSet(Arrays.asList(1, 3)), sets.get(5)); + } + @SuppressWarnings("unchecked") @Test public void foldByKey() { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 9ddafc451878d..0b9004448a63e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -30,6 +30,19 @@ import org.apache.spark.SparkContext._ import org.apache.spark.{Partitioner, SharedSparkContext} class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { + test("aggregateByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2) + + val sets = pairs.aggregateByKey(new HashSet[Int]())(_ += _, _ ++= _).collect() + assert(sets.size === 3) + val valuesFor1 = sets.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1)) + val valuesFor3 = sets.find(_._1 == 3).get._2 + assert(valuesFor3.toList.sorted === List(2)) + val valuesFor5 = sets.find(_._1 == 5).get._2 + assert(valuesFor5.toList.sorted === List(1, 3)) + } + test("groupByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) val groups = pairs.groupByKey().collect() diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 7989e02dfb732..79784682bfd1b 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -890,6 +890,10 @@ for details. reduceByKey(func, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. + + aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. + sortByKey([ascending], [numTasks]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c80ab9a9f8e60..ee629794f60ad 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -52,7 +52,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1") + "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$" + + "createZero$1") ) ++ Seq( // Ignore some private methods in ALS. ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8a215fc51130a..735389c69831c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1178,6 +1178,20 @@ def _mergeCombiners(iterator): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) + + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): + """ + Aggregate the values of each key, using given combine functions and a neutral "zero value". + This function can return a different result type, U, than the type of the values in this RDD, + V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + The former operation is used for merging values within a partition, and the latter is used + for merging values between partitions. To avoid memory allocation, both of these functions are + allowed to modify and return their first argument instead of creating a new U. + """ + def createZero(): + return copy.deepcopy(zeroValue) + + return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) def foldByKey(self, zeroValue, func, numPartitions=None): """ @@ -1190,7 +1204,10 @@ def foldByKey(self, zeroValue, func, numPartitions=None): >>> rdd.foldByKey(0, add).collect() [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions) + def createZero(): + return copy.deepcopy(zeroValue) + + return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) # TODO: support variant with custom partitioner diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 184ee810b861b..c15bb457759ed 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -188,6 +188,21 @@ def test_deleting_input_files(self): os.unlink(tempFile.name) self.assertRaises(Exception, lambda: filtered_data.count()) + def testAggregateByKey(self): + data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) + def seqOp(x, y): + x.add(y) + return x + + def combOp(x, y): + x |= y + return x + + sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) + self.assertEqual(3, len(sets)) + self.assertEqual(set([1]), sets[1]) + self.assertEqual(set([2]), sets[3]) + self.assertEqual(set([1, 3]), sets[5]) class TestIO(PySparkTestCase): From 83c226d454722d5dea186d48070fb98652d0dafb Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 12:53:07 -0700 Subject: [PATCH 26/38] [SPARK-2088] fix NPE in toString After deserialization, the transient field creationSiteInfo does not get backfilled with the default value, but the toString method, which is invoked by the serializer, expects the field to always be non-null. An NPE is thrown when toString is called by the serializer when creationSiteInfo is null. Author: Doris Xin Closes #1028 from dorx/toStringNPE and squashes the following commits: f20021e [Doris Xin] unit test for toString after desrialization 6f0a586 [Doris Xin] Merge branch 'master' into toStringNPE f47fecf [Doris Xin] Merge branch 'master' into toStringNPE 76199c6 [Doris Xin] [SPARK-2088] fix NPE in toString --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 54bdc3e7cbc7a..b6fc4b13ad4d7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1180,7 +1180,7 @@ abstract class RDD[T: ClassTag]( /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ @transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo - private[spark] def getCreationSite: String = creationSiteInfo.toString + private[spark] def getCreationSite: String = Option(creationSiteInfo).getOrElse("").toString private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 55af1666df662..2e2ccc5a1859e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.SparkContext._ -import org.apache.spark.rdd._ +import org.apache.spark.util.Utils class RDDSuite extends FunSuite with SharedSparkContext { @@ -66,6 +66,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("serialization") { + val empty = new EmptyRDD[Int](sc) + val serial = Utils.serialize(empty) + val deserial: EmptyRDD[Int] = Utils.deserialize(serial) + assert(!deserial.toString().isEmpty()) + } + test("countApproxDistinct") { def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble From ecde5b837534b11d365fcab78089820990b815cf Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 12 Jun 2014 16:19:36 -0500 Subject: [PATCH 27/38] [SPARK-2080] Yarn: report HS URL in client mode, correct user in cluster mode. Yarn client mode was not setting the app's tracking URL to the History Server's URL when configured by the user. Now client mode behaves the same as cluster mode. In SparkContext.scala, the "user.name" system property had precedence over the SPARK_USER environment variable. This means that SPARK_USER was never used, since "user.name" is always set by the JVM. In Yarn cluster mode, this means the application always reported itself as being run by user "yarn" (or whatever user was running the Yarn NM). One could argue that the correct fix would be to use UGI.getCurrentUser() here, but at least for Yarn that will match what SPARK_USER is set to. Author: Marcelo Vanzin This patch had conflicts when merged, resolved by Committer: Thomas Graves Closes #1002 from vanzin/yarn-client-url and squashes the following commits: 4046e04 [Marcelo Vanzin] Set HS link in yarn-alpha also. 4c692d9 [Marcelo Vanzin] Yarn: report HS URL in client mode, correct user in cluster mode. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- .../org/apache/spark/deploy/yarn/ExecutorLauncher.scala | 1 + .../org/apache/spark/deploy/yarn/ExecutorLauncher.scala | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8bdaf0bf76e85..df151861958a2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -297,7 +297,7 @@ class SparkContext(config: SparkConf) extends Logging { // Set SPARK_USER for user who is running SparkContext. val sparkUser = Option { - Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER")) + Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name")) }.getOrElse { SparkContext.SPARK_UNKNOWN_USER } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index a3bd91590fc25..b6ecae1e652fe 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -271,6 +271,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp .asInstanceOf[FinishApplicationMasterRequest] finishReq.setAppAttemptId(appAttemptId) finishReq.setFinishApplicationStatus(status) + finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) resourceManager.finishApplicationMaster(finishReq) } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 4f8854a09a1e5..f71ad036ce0f2 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -115,7 +115,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val interval = math.min(timeoutInterval / 2, schedulerInterval) reporterThread = launchReporterThread(interval) - + // Wait for the reporter thread to Finish. reporterThread.join() @@ -134,12 +134,12 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) .orElse(Option(System.getenv("LOCAL_DIRS"))) - + localDirs match { case None => throw new Exception("Yarn Local dirs can't be empty") case Some(l) => l } - } + } private def registerApplicationMaster(): RegisterApplicationMasterResponse = { logInfo("Registering the ApplicationMaster") From 1c04652c8f18566baafb13dbae355f8ad2ad8d37 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 12 Jun 2014 15:43:32 -0700 Subject: [PATCH 28/38] SPARK-1843: Replace assemble-deps with env variable. (This change is actually small, I moved some logic into compute-classpath that was previously in spark-class). Assemble deps has existed for a while to allow developers to run local code with new changes quickly. When I'm developing I typically use a simpler approach which just prepends the Spark classes to the classpath before the assembly jar. This is well defined in the JVM and the Spark classes take precedence over those in the assembly. This approach is portable across both builds which is the main reason I'd like to switch to it. It's also a bit easier to toggle on and off quickly. The way you use this is the following: ``` $ ./bin/spark-shell # Use spark with the normal assembly $ export SPARK_PREPEND_CLASSES=true $ ./bin/spark-shell # Now it's using compiled classes $ unset SPARK_PREPEND_CLASSES $ ./bin/spark-shell # Back to normal ``` Author: Patrick Wendell Closes #877 from pwendell/assemble-deps and squashes the following commits: 8a11345 [Patrick Wendell] Merge remote-tracking branch 'apache/master' into assemble-deps faa3168 [Patrick Wendell] Adding a warning for compatibility 3f151a7 [Patrick Wendell] Small fix bbfb73c [Patrick Wendell] Review feedback 328e9f8 [Patrick Wendell] SPARK-1843: Replace assemble-deps with env variable. --- bin/compute-classpath.sh | 34 ++++++++++++++----- bin/spark-class | 17 ---------- .../scala/org/apache/spark/SparkContext.scala | 3 ++ project/SparkBuild.scala | 16 ++++++--- 4 files changed, 40 insertions(+), 30 deletions(-) diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 7df43a555d562..2cf4e381c1c88 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -38,8 +38,10 @@ else JAR_CMD="jar" fi -# First check if we have a dependencies jar. If so, include binary classes with the deps jar -if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then +# A developer option to prepend more recently compiled Spark classes +if [ -n "$SPARK_PREPEND_CLASSES" ]; then + echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ + "classes ahead of assembly." >&2 CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" @@ -51,17 +53,31 @@ if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" +fi - ASSEMBLY_JAR=$(ls "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar 2>/dev/null) +# Use spark-assembly jar from either RELEASE or assembly directory +if [ -f "$FWDIR/RELEASE" ]; then + assembly_folder="$FWDIR"/lib else - # Else use spark-assembly jar from either RELEASE or assembly directory - if [ -f "$FWDIR/RELEASE" ]; then - ASSEMBLY_JAR=$(ls "$FWDIR"/lib/spark-assembly*hadoop*.jar 2>/dev/null) - else - ASSEMBLY_JAR=$(ls "$ASSEMBLY_DIR"/spark-assembly*hadoop*.jar 2>/dev/null) - fi + assembly_folder="$ASSEMBLY_DIR" fi +num_jars=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l) +if [ "$num_jars" -eq "0" ]; then + echo "Failed to find Spark assembly in $assembly_folder" + echo "You need to build Spark before running this program." + exit 1 +fi +if [ "$num_jars" -gt "1" ]; then + jars_list=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*.jar") + echo "Found multiple Spark assembly jars in $assembly_folder:" + echo "$jars_list" + echo "Please remove all but one jar." + exit 1 +fi + +ASSEMBLY_JAR=$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null) + # Verify that versions of java used to build the jars and run Spark are compatible jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1) if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then diff --git a/bin/spark-class b/bin/spark-class index e884511010c6c..cfe363a71da31 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -108,23 +108,6 @@ fi export JAVA_OPTS # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! -if [ ! -f "$FWDIR/RELEASE" ]; then - # Exit if the user hasn't compiled Spark - num_jars=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar" | wc -l) - jars_list=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar") - if [ "$num_jars" -eq "0" ]; then - echo "Failed to find Spark assembly in $FWDIR/assembly/target/scala-$SCALA_VERSION/" >&2 - echo "You need to build Spark before running this program." >&2 - exit 1 - fi - if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $FWDIR/assembly/target/scala-$SCALA_VERSION:" >&2 - echo "$jars_list" - echo "Please remove all but one jar." - exit 1 - fi -fi - TOOLS_DIR="$FWDIR"/tools SPARK_TOOLS_JAR="" if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar ]; then diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index df151861958a2..8fbda2c667cf7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -290,6 +290,9 @@ class SparkContext(config: SparkConf) extends Logging { value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { executorEnvs(envKey) = value } + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + executorEnvs("SPARK_PREPEND_CLASSES") = v + } // The Mesos scheduler backend relies on this environment variable to set executor memory. // TODO: Set this only in the Mesos scheduler. executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ecd9d7068068d..8b4885d3bbbdb 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -90,7 +90,16 @@ object SparkBuild extends Build { lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings) .dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeHive: _*) dependsOn(maybeGanglia: _*) - lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects") + lazy val assembleDepsTask = TaskKey[Unit]("assemble-deps") + lazy val assembleDeps = assembleDepsTask := { + println() + println("**** NOTE ****") + println("'sbt/sbt assemble-deps' is no longer supported.") + println("Instead create a normal assembly and:") + println(" export SPARK_PREPEND_CLASSES=1 (toggle on)") + println(" unset SPARK_PREPEND_CLASSES (toggle off)") + println() + } // A configuration to set an alternative publishLocalConfiguration lazy val MavenCompile = config("m2r") extend(Compile) @@ -373,6 +382,7 @@ object SparkBuild extends Build { "net.sf.py4j" % "py4j" % "0.8.1" ), libraryDependencies ++= maybeAvro, + assembleDeps, previousArtifact := sparkPreviousArtifact("spark-core") ) @@ -584,9 +594,7 @@ object SparkBuild extends Build { def assemblyProjSettings = sharedSettings ++ Seq( name := "spark-assembly", - assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn, - jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }, - jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" } + jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" } ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq( From a6e0afdcf0174425e8a6ff20b2bc2e3a7a374f19 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Thu, 12 Jun 2014 17:37:06 -0700 Subject: [PATCH 29/38] SPARK-2085: [MLlib] Apply user-specific regularization instead of uniform regularization in ALS The current implementation of ALS takes a single regularization parameter and apply it on both of the user factors and the product factors. This kind of regularization can be less effective while user number is significantly larger than the number of products (and vice versa). For example, if we have 10M users and 1K product, regularization on user factors will dominate. Following the discussion in [this thread](http://apache-spark-user-list.1001560.n3.nabble.com/possible-bug-in-Spark-s-ALS-implementation-tt2567.html#a2704), the implementation in this PR will regularize each factor vector by #ratings * lambda. Author: Shuo Xiang Closes #1026 from coderxiang/als-reg and squashes the following commits: 93dfdb4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into als-reg b98f19c [Shuo Xiang] merge latest master 52c7b58 [Shuo Xiang] Apply user-specific regularization instead of uniform regularization in Alternating Least Squares (ALS) --- .../scala/org/apache/spark/mllib/recommendation/ALS.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index f1ae7b85b4a69..cc56fd6ef28d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -507,6 +507,9 @@ class ALS private ( val tempXtX = DoubleMatrix.zeros(triangleSize) val fullXtX = DoubleMatrix.zeros(rank, rank) + // Count the number of ratings each user gives to provide user-specific regularization + val numRatings = Array.fill(numUsers)(0) + // Compute the XtX and Xy values for each user by adding products it rated in each product // block for (productBlock <- 0 until numProductBlocks) { @@ -519,6 +522,7 @@ class ALS private ( if (implicitPrefs) { var i = 0 while (i < us.length) { + numRatings(us(i)) += 1 // Extension to the original paper to handle rs(i) < 0. confidence is a function // of |rs(i)| instead so that it is never negative: val confidence = 1 + alpha * abs(rs(i)) @@ -534,6 +538,7 @@ class ALS private ( } else { var i = 0 while (i < us.length) { + numRatings(us(i)) += 1 userXtX(us(i)).addi(tempXtX) SimpleBlas.axpy(rs(i), x, userXy(us(i))) i += 1 @@ -550,9 +555,10 @@ class ALS private ( // Compute the full XtX matrix from the lower-triangular part we got above fillFullMatrix(userXtX(index), fullXtX) // Add regularization + val regParam = numRatings(index) * lambda var i = 0 while (i < rank) { - fullXtX.data(i * rank + i) += lambda + fullXtX.data(i * rank + i) += regParam i += 1 } // Solve the resulting matrix, which is symmetric and positive-definite From 0154587ab71d1b864f97497dbb38bc52b87675be Mon Sep 17 00:00:00 2001 From: Ariel Rabkin Date: Thu, 12 Jun 2014 17:51:33 -0700 Subject: [PATCH 30/38] document laziness of parallelize Took me several hours to figure out this behavior. It would be good to highlight it in the documentation. Author: Ariel Rabkin Closes #1070 from asrabkin/master and squashes the following commits: 29a076e [Ariel Rabkin] doc fix --- .../main/scala/org/apache/spark/SparkContext.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8fbda2c667cf7..35970c2f50892 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -434,12 +434,21 @@ class SparkContext(config: SparkConf) extends Logging { // Methods for creating RDDs - /** Distribute a local Scala collection to form an RDD. */ + /** Distribute a local Scala collection to form an RDD. + * + * @note Parallelize acts lazily. If `seq` is a mutable collection and is + * altered after the call to parallelize and before the first action on the + * RDD, the resultant RDD will reflect the modified collection. Pass a copy of + * the argument to avoid this. + */ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } - /** Distribute a local Scala collection to form an RDD. */ + /** Distribute a local Scala collection to form an RDD. + * + * This method is identical to `parallelize`. + */ def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { parallelize(seq, numSlices) } From 1de1d703bf6b7ca14f7b40bbefe9bf6fd6c8ce47 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 19:44:27 -0700 Subject: [PATCH 31/38] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS Modified the takeSample method in RDD to use the ScaSRS sampling technique to improve performance. Added a private method that computes sampling rate > sample_size/total to ensure sufficient sample size with success rate >= 0.9999. Added a unit test for the private method to validate choice of sampling rate. Author: Doris Xin Author: dorx Author: Xiangrui Meng Closes #916 from dorx/takeSample and squashes the following commits: 5b061ae [Doris Xin] merge master 444e750 [Doris Xin] edge cases 3de882b [dorx] Merge pull request #2 from mengxr/SPARK-1939 82dde31 [Xiangrui Meng] update pyspark's takeSample 48d954d [Doris Xin] remove unused imports from RDDSuite fb1452f [Doris Xin] allowing num to be greater than count in all cases 1481b01 [Doris Xin] washing test tubes and making coffee dc699f3 [Doris Xin] give back imports removed by accident in rdd.py 64e445b [Doris Xin] logwarnning as soon as it enters the while loop 55518ed [Doris Xin] added TODO for logging in rdd.py eff89e2 [Doris Xin] addressed reviewer comments. ecab508 [Doris Xin] "fixed checkstyle violation 0a9b3e3 [Doris Xin] "reviewer comment addressed" f80f270 [Doris Xin] Merge branch 'master' into takeSample ae3ad04 [Doris Xin] fixed edge cases to prevent overflow 065ebcd [Doris Xin] Merge branch 'master' into takeSample 9bdd36e [Doris Xin] Check sample size and move computeFraction e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS --- core/pom.xml | 5 + .../main/scala/org/apache/spark/rdd/RDD.scala | 52 +++--- .../spark/util/random/RandomSampler.scala | 2 +- .../spark/util/random/SamplingUtils.scala | 55 ++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 35 ++-- .../util/random/SamplingUtilsSuite.scala | 46 +++++ project/SparkBuild.scala | 1 + python/pyspark/rdd.py | 167 +++++++++++------- 8 files changed, 263 insertions(+), 100 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala diff --git a/core/pom.xml b/core/pom.xml index c3d6b00a443f1..be56911b9e45a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -67,6 +67,11 @@ org.apache.commons commons-lang3 + + org.apache.commons + commons-math3 + test + com.google.code.findbugs jsr305 diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b6fc4b13ad4d7..446f369c9ea16 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -42,7 +42,7 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler} +import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -378,46 +378,56 @@ abstract class RDD[T: ClassTag]( }.toArray } - def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = - { - var fraction = 0.0 - var total = 0 - val multiplier = 3.0 - val initialCount = this.count() - var maxSelected = 0 + /** + * Return a fixed-size sampled subset of this RDD in an array + * + * @param withReplacement whether sampling is done with replacement + * @param num size of the returned sample + * @param seed seed for the random number generator + * @return sample of specified size in an array + */ + def takeSample(withReplacement: Boolean, + num: Int, + seed: Long = Utils.random.nextLong): Array[T] = { + val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") + } else if (num == 0) { + return new Array[T](0) } + val initialCount = this.count() if (initialCount == 0) { return new Array[T](0) } - if (initialCount > Integer.MAX_VALUE - 1) { - maxSelected = Integer.MAX_VALUE - 1 - } else { - maxSelected = initialCount.toInt + val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt + if (num > maxSampleSize) { + throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") } - if (num > initialCount && !withReplacement) { - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - } else { - fraction = multiplier * (num + 1) / initialCount - total = num + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + return Utils.randomizeInPlace(this.collect(), rand) } - val rand = new Random(seed) + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() // If the first sample didn't turn out large enough, keep trying to take samples; // this shouldn't happen often because we use a big multiplier for the initial size - while (samples.length < total) { + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 } - Utils.randomizeInPlace(samples, rand).take(total) + Utils.randomizeInPlace(samples, rand).take(num) } /** diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 4dc8ada00a3e8..247f10173f1e9 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } /** - * Return a sampler with is the complement of the range specified of the current sampler. + * Return a sampler that is the complement of the range specified of the current sampler. */ def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala new file mode 100644 index 0000000000000..a79e3ee756fc6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +private[spark] object SamplingUtils { + + /** + * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of + * the time. + * + * How the sampling rate is determined: + * Let p = num / total, where num is the sample size and total is the total number of + * datapoints in the RDD. We're trying to compute q > p such that + * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), + * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total), + * i.e. the failure rate of not having a sufficiently large sample < 0.0001. + * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for + * num > 12, but we need a slightly larger q (9 empirically determined). + * - when sampling without replacement, we're drawing each datapoint with prob_i + * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success + * rate, where success rate is defined the same as in sampling with replacement. + * + * @param sampleSizeLowerBound sample size + * @param total size of RDD + * @param withReplacement whether sampling with replacement + * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate + */ + def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, + withReplacement: Boolean): Double = { + val fraction = sampleSizeLowerBound.toDouble / total + if (withReplacement) { + val numStDev = if (sampleSizeLowerBound < 12) 9 else 5 + fraction + numStDev * math.sqrt(fraction / total) + } else { + val delta = 1e-4 + val gamma = - math.log(delta) / total + math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 2e2ccc5a1859e..e94a1e76d410c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -505,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("takeSample") { - val data = sc.parallelize(1 to 100, 2) + val n = 1000000 + val data = sc.parallelize(1 to n, 2) for (num <- List(5, 20, 100)) { val sample = data.takeSample(withReplacement=false, num=num) assert(sample.size === num) // Got exactly num elements assert(sample.toSet.size === num) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 200, seed) + val sample = data.takeSample(withReplacement=false, 100, seed) assert(sample.size === 100) // Got only 100 elements assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, 20, seed) assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement=true, num=20) assert(sample.size === 20) // Got exactly 100 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { - val sample = data.takeSample(withReplacement=true, num=100) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, num=n) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 100, seed) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, n, seed) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 200, seed) - assert(sample.size === 200) // Got exactly 200 elements + val sample = data.takeSample(withReplacement=true, 2 * n, seed) + assert(sample.size === 2 * n) // Got exactly 200 elements // Chance of getting all distinct elements is still quite low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } } diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala new file mode 100644 index 0000000000000..accfe2e9b7f2a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} +import org.scalatest.FunSuite + +class SamplingUtilsSuite extends FunSuite { + + test("computeFraction") { + // test that the computed fraction guarantees enough data points + // in the sample with a failure rate <= 0.0001 + val n = 100000 + + for (s <- 1 to 15) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(20, 100, 1000)) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(1, 10, 100, 1000)) { + val frac = SamplingUtils.computeFractionForSampleSize(s, n, false) + val binomial = new BinomialDistribution(n, frac) + assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") + } + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 8b4885d3bbbdb..2d60a44f04f6f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -349,6 +349,7 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "com.google.guava" % "guava" % "14.0.1", "org.apache.commons" % "commons-lang3" % "3.3.2", + "org.apache.commons" % "commons-math3" % "3.3" % "test", "com.google.code.findbugs" % "jsr305" % "1.3.9", "log4j" % "log4j" % "1.2.17", "org.slf4j" % "slf4j-api" % slf4jVersion, diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 735389c69831c..ddd22850a819c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -31,6 +31,7 @@ import warnings import heapq from random import Random +from math import sqrt, log from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ @@ -202,9 +203,9 @@ def cache(self): def persist(self, storageLevel): """ - Set this RDD's storage level to persist its values across operations after the first time - it is computed. This can only be used to assign a new storage level if the RDD does not - have a storage level set yet. + Set this RDD's storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. """ self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) @@ -213,7 +214,8 @@ def persist(self, storageLevel): def unpersist(self): """ - Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + Mark the RDD as non-persistent, and remove all blocks for it from + memory and disk. """ self.is_cached = False self._jrdd.unpersist() @@ -357,48 +359,87 @@ def sample(self, withReplacement, fraction, seed=None): # this is ported from scala/spark/RDD.scala def takeSample(self, withReplacement, num, seed=None): """ - Return a fixed-size sampled subset of this RDD (currently requires numpy). + Return a fixed-size sampled subset of this RDD (currently requires + numpy). - >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP - [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] + >>> rdd = sc.parallelize(range(0, 10)) + >>> len(rdd.takeSample(True, 20, 1)) + 20 + >>> len(rdd.takeSample(False, 5, 2)) + 5 + >>> len(rdd.takeSample(False, 15, 3)) + 10 """ + numStDev = 10.0 + + if num < 0: + raise ValueError("Sample size cannot be negative.") + elif num == 0: + return [] - fraction = 0.0 - total = 0 - multiplier = 3.0 initialCount = self.count() - maxSelected = 0 + if initialCount == 0: + return [] - if (num < 0): - raise ValueError + rand = Random(seed) - if (initialCount == 0): - return list() + if (not withReplacement) and num >= initialCount: + # shuffle current RDD and return + samples = self.collect() + rand.shuffle(samples) + return samples - if initialCount > sys.maxint - 1: - maxSelected = sys.maxint - 1 - else: - maxSelected = initialCount - - if num > initialCount and not withReplacement: - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - else: - fraction = multiplier * (num + 1) / initialCount - total = num + maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + if num > maxSampleSize: + raise ValueError("Sample size cannot be greater than %d." % maxSampleSize) + fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement) samples = self.sample(withReplacement, fraction, seed).collect() # If the first sample didn't turn out large enough, keep trying to take samples; # this shouldn't happen often because we use a big multiplier for their initial size. # See: scala/spark/RDD.scala - rand = Random(seed) - while len(samples) < total: - samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect() - - sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint)) - sampler.shuffle(samples) - return samples[0:total] + while len(samples) < num: + # TODO: add log warning for when more than one iteration was run + seed = rand.randint(0, sys.maxint) + samples = self.sample(withReplacement, fraction, seed).collect() + + rand.shuffle(samples) + + return samples[0:num] + + @staticmethod + def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement): + """ + Returns a sampling rate that guarantees a sample of + size >= sampleSizeLowerBound 99.99% of the time. + + How the sampling rate is determined: + Let p = num / total, where num is the sample size and total is the + total number of data points in the RDD. We're trying to compute + q > p such that + - when sampling with replacement, we're drawing each data point + with prob_i ~ Pois(q), where we want to guarantee + Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to + total), i.e. the failure rate of not having a sufficiently large + sample < 0.0001. Setting q = p + 5 * sqrt(p/total) is sufficient + to guarantee 0.9999 success rate for num > 12, but we need a + slightly larger q (9 empirically determined). + - when sampling without replacement, we're drawing each data point + with prob_i ~ Binomial(total, fraction) and our choice of q + guarantees 1-delta, or 0.9999 success rate, where success rate is + defined the same as in sampling with replacement. + """ + fraction = float(sampleSizeLowerBound) / total + if withReplacement: + numStDev = 5 + if (sampleSizeLowerBound < 12): + numStDev = 9 + return fraction + numStDev * sqrt(fraction / total) + else: + delta = 0.00005 + gamma = - log(delta) / total + return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * fraction)) def union(self, other): """ @@ -422,8 +463,8 @@ def union(self, other): def intersection(self, other): """ - Return the intersection of this RDD and another one. The output will not - contain any duplicate elements, even if the input RDDs did. + Return the intersection of this RDD and another one. The output will + not contain any duplicate elements, even if the input RDDs did. Note that this method performs a shuffle internally. @@ -665,8 +706,8 @@ def aggregate(self, zeroValue, seqOp, combOp): modify C{t2}. The first function (seqOp) can return a different result type, U, than - the type of this RDD. Thus, we need one operation for merging a T into an U - and one operation for merging two U + the type of this RDD. Thus, we need one operation for merging a T into + an U and one operation for merging two U >>> seqOp = (lambda x, y: (x[0] + y, x[1] + 1)) >>> combOp = (lambda x, y: (x[0] + y[0], x[1] + y[1])) @@ -759,8 +800,9 @@ def stdev(self): def sampleStdev(self): """ - Compute the sample standard deviation of this RDD's elements (which corrects for bias in - estimating the standard deviation by dividing by N-1 instead of N). + Compute the sample standard deviation of this RDD's elements (which + corrects for bias in estimating the standard deviation by dividing by + N-1 instead of N). >>> sc.parallelize([1, 2, 3]).sampleStdev() 1.0 @@ -769,8 +811,8 @@ def sampleStdev(self): def sampleVariance(self): """ - Compute the sample variance of this RDD's elements (which corrects for bias in - estimating the variance by dividing by N-1 instead of N). + Compute the sample variance of this RDD's elements (which corrects + for bias in estimating the variance by dividing by N-1 instead of N). >>> sc.parallelize([1, 2, 3]).sampleVariance() 1.0 @@ -822,8 +864,8 @@ def merge(a, b): def takeOrdered(self, num, key=None): """ - Get the N elements from a RDD ordered in ascending order or as specified - by the optional key function. + Get the N elements from a RDD ordered in ascending order or as + specified by the optional key function. >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) [1, 2, 3, 4, 5, 6] @@ -912,8 +954,9 @@ def first(self): def saveAsPickleFile(self, path, batchSize=10): """ - Save this RDD as a SequenceFile of serialized objects. The serializer used is - L{pyspark.serializers.PickleSerializer}, default batch size is 10. + Save this RDD as a SequenceFile of serialized objects. The serializer + used is L{pyspark.serializers.PickleSerializer}, default batch size + is 10. >>> tmpFile = NamedTemporaryFile(delete=True) >>> tmpFile.close() @@ -1195,9 +1238,10 @@ def createZero(): def foldByKey(self, zeroValue, func, numPartitions=None): """ - Merge the values for each key using an associative function "func" and a neutral "zeroValue" - which may be added to the result an arbitrary number of times, and must not change - the result (e.g., 0 for addition, or 1 for multiplication.). + Merge the values for each key using an associative function "func" + and a neutral "zeroValue" which may be added to the result an + arbitrary number of times, and must not change the result + (e.g., 0 for addition, or 1 for multiplication.). >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> from operator import add @@ -1217,8 +1261,8 @@ def groupByKey(self, numPartitions=None): Hash-partitions the resulting RDD with into numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey will provide much better - performance. + sum or average) over each key, using reduceByKey will provide much + better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) @@ -1278,8 +1322,8 @@ def groupWith(self, other): def cogroup(self, other, numPartitions=None): """ For each key k in C{self} or C{other}, return a resulting RDD that - contains a tuple with the list of values for that key in C{self} as well - as C{other}. + contains a tuple with the list of values for that key in C{self} as + well as C{other}. >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) @@ -1290,8 +1334,8 @@ def cogroup(self, other, numPartitions=None): def subtractByKey(self, other, numPartitions=None): """ - Return each (key, value) pair in C{self} that has no pair with matching key - in C{other}. + Return each (key, value) pair in C{self} that has no pair with matching + key in C{other}. >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)]) >>> y = sc.parallelize([("a", 3), ("c", None)]) @@ -1329,10 +1373,10 @@ def repartition(self, numPartitions): """ Return a new RDD that has exactly numPartitions partitions. - Can increase or decrease the level of parallelism in this RDD. Internally, this uses - a shuffle to redistribute data. - If you are decreasing the number of partitions in this RDD, consider using `coalesce`, - which can avoid performing a shuffle. + Can increase or decrease the level of parallelism in this RDD. + Internally, this uses a shuffle to redistribute data. + If you are decreasing the number of partitions in this RDD, consider + using `coalesce`, which can avoid performing a shuffle. >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4) >>> sorted(rdd.glom().collect()) [[1], [2, 3], [4, 5], [6, 7]] @@ -1357,9 +1401,10 @@ def coalesce(self, numPartitions, shuffle=False): def zip(self, other): """ - Zips this RDD with another one, returning key-value pairs with the first element in each RDD - second element in each RDD, etc. Assumes that the two RDDs have the same number of - partitions and the same number of elements in each partition (e.g. one was made through + Zips this RDD with another one, returning key-value pairs with the + first element in each RDD second element in each RDD, etc. Assumes + that the two RDDs have the same number of partitions and the same + number of elements in each partition (e.g. one was made through a map on the other). >>> x = sc.parallelize(range(0,5)) From 44daec5abd4c271ea0003ecdabab92cc958dea13 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 12 Jun 2014 20:40:58 -0700 Subject: [PATCH 32/38] [Minor] Fix style, formatting and naming in BlockManager etc. This is a precursor to a bigger change. I wanted to separate out the relatively insignificant changes so the ultimate PR is not inflated. (Warning: this PR is full of unimportant nitpicks) Author: Andrew Or Closes #1058 from andrewor14/bm-minor and squashes the following commits: 8e12eaf [Andrew Or] SparkException -> BlockException c36fd53 [Andrew Or] Make parts of BlockManager more readable 0a5f378 [Andrew Or] Entry -> MemoryEntry e9762a5 [Andrew Or] Tone down string interpolation (minor reverts) c4de9ac [Andrew Or] Merge branch 'master' of github.com:apache/spark into bm-minor b3470f1 [Andrew Or] More string interpolation (minor) 7f9dcab [Andrew Or] Use string interpolation (minor) 94a425b [Andrew Or] Refactor against duplicate code + minor changes 8a6a7dc [Andrew Or] Exception -> SparkException 97c410f [Andrew Or] Deal with MIMA excludes 2480f1d [Andrew Or] Fixes in StorgeLevel.scala abb0163 [Andrew Or] Style, formatting and naming fixes --- .../scala/org/apache/spark/CacheManager.scala | 20 +- .../org/apache/spark/storage/BlockInfo.scala | 22 +- .../apache/spark/storage/BlockManager.scala | 479 +++++++++--------- .../spark/storage/DiskBlockManager.scala | 21 +- .../org/apache/spark/storage/DiskStore.scala | 24 +- .../apache/spark/storage/MemoryStore.scala | 32 +- .../apache/spark/storage/StorageLevel.scala | 88 ++-- .../apache/spark/storage/TachyonStore.scala | 39 +- project/MimaExcludes.scala | 7 +- 9 files changed, 362 insertions(+), 370 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 811610c657b62..315ed91f81df3 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -32,10 +32,14 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { private val loading = new HashSet[RDDBlockId]() /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ - def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, + def getOrCompute[T]( + rdd: RDD[T], + split: Partition, + context: TaskContext, storageLevel: StorageLevel): Iterator[T] = { + val key = RDDBlockId(rdd.id, split.index) - logDebug("Looking for partition " + key) + logDebug(s"Looking for partition $key") blockManager.get(key) match { case Some(values) => // Partition is already materialized, so just return its values @@ -45,7 +49,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // Mark the split as loading (unless someone else marks it first) loading.synchronized { if (loading.contains(key)) { - logInfo("Another thread is loading %s, waiting for it to finish...".format(key)) + logInfo(s"Another thread is loading $key, waiting for it to finish...") while (loading.contains(key)) { try { loading.wait() @@ -54,7 +58,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { logWarning(s"Got an exception while waiting for another thread to load $key", e) } } - logInfo("Finished waiting for %s".format(key)) + logInfo(s"Finished waiting for $key") /* See whether someone else has successfully loaded it. The main way this would fail * is for the RDD-level cache eviction policy if someone else has loaded the same RDD * partition but we didn't want to make space for it. However, that case is unlikely @@ -64,7 +68,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Some(values) => return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => - logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key)) + logInfo(s"Whoever was loading $key failed; we'll try it ourselves") loading.add(key) } } else { @@ -73,7 +77,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { } try { // If we got here, we have to load the split - logInfo("Partition %s not found, computing it".format(key)) + logInfo(s"Partition $key not found, computing it") val computedValues = rdd.computeOrReadCheckpoint(split, context) // Persist the result, so long as the task is not running locally @@ -97,8 +101,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Some(values) => values.asInstanceOf[Iterator[T]] case None => - logInfo("Failure to store %s".format(key)) - throw new Exception("Block manager failed to return persisted valued") + logInfo(s"Failure to store $key") + throw new SparkException("Block manager failed to return persisted value") } } else { // In this case the RDD is cached to an array buffer. This will save the results diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala index c8f397609a0b4..22fdf73e9d1f4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala @@ -29,9 +29,9 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea setInitThread() private def setInitThread() { - // Set current thread as init thread - waitForReady will not block this thread - // (in case there is non trivial initialization which ends up calling waitForReady as part of - // initialization itself) + /* Set current thread as init thread - waitForReady will not block this thread + * (in case there is non trivial initialization which ends up calling waitForReady + * as part of initialization itself) */ BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread()) } @@ -42,7 +42,9 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea def waitForReady(): Boolean = { if (pending && initThread != Thread.currentThread()) { synchronized { - while (pending) this.wait() + while (pending) { + this.wait() + } } } !failed @@ -50,8 +52,8 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea /** Mark this BlockInfo as ready (i.e. block is finished writing) */ def markReady(sizeInBytes: Long) { - require (sizeInBytes >= 0, "sizeInBytes was negative: " + sizeInBytes) - assert (pending) + require(sizeInBytes >= 0, s"sizeInBytes was negative: $sizeInBytes") + assert(pending) size = sizeInBytes BlockInfo.blockInfoInitThreads.remove(this) synchronized { @@ -61,7 +63,7 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea /** Mark this BlockInfo as ready but failed */ def markFailure() { - assert (pending) + assert(pending) size = BlockInfo.BLOCK_FAILED BlockInfo.blockInfoInitThreads.remove(this) synchronized { @@ -71,9 +73,9 @@ private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolea } private object BlockInfo { - // initThread is logically a BlockInfo field, but we store it here because - // it's only needed while this block is in the 'pending' state and we want - // to minimize BlockInfo's memory footprint. + /* initThread is logically a BlockInfo field, but we store it here because + * it's only needed while this block is in the 'pending' state and we want + * to minimize BlockInfo's memory footprint. */ private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread] private val BLOCK_PENDING: Long = -1L diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 9cd79d262ea53..f52bc7075104b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -28,46 +28,48 @@ import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, MapOutputTracker, SecurityManager, SparkConf, SparkEnv, SparkException} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer import org.apache.spark.util._ -private[spark] sealed trait Values - -private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends Values -private[spark] case class IteratorValues(iterator: Iterator[Any]) extends Values -private[spark] case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends Values +private[spark] sealed trait BlockValues +private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues +private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues +private[spark] case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends BlockValues private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, - val defaultSerializer: Serializer, + defaultSerializer: Serializer, maxMemory: Long, - val _conf: SparkConf, + val conf: SparkConf, securityManager: SecurityManager, mapOutputTracker: MapOutputTracker) extends Logging { - def conf = _conf val shuffleBlockManager = new ShuffleBlockManager(this) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, - conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + val connectionManager = new ConnectionManager(0, conf, securityManager) + + implicit val futureExecContext = connectionManager.futureExecContext private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + // Actual storage of where blocks are kept + private var tachyonInitialized = false private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) - var tachyonInitialized = false private[storage] lazy val tachyonStore: TachyonStore = { val storeDir = conf.get("spark.tachyonStore.baseDir", "/tmp_spark_tachyon") val appFolderName = conf.get("spark.tachyonStore.folderName") - val tachyonStorePath = s"${storeDir}/${appFolderName}/${this.executorId}" + val tachyonStorePath = s"$storeDir/$appFolderName/${this.executorId}" val tachyonMaster = conf.get("spark.tachyonStore.url", "tachyon://localhost:19998") - val tachyonBlockManager = new TachyonBlockManager( - shuffleBlockManager, tachyonStorePath, tachyonMaster) + val tachyonBlockManager = + new TachyonBlockManager(shuffleBlockManager, tachyonStorePath, tachyonMaster) tachyonInitialized = true new TachyonStore(this, tachyonBlockManager) } @@ -79,43 +81,39 @@ private[spark] class BlockManager( if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } - val connectionManager = new ConnectionManager(0, conf, securityManager) - implicit val futureExecContext = connectionManager.futureExecContext - val blockManagerId = BlockManagerId( executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) - val maxBytesInFlight = - conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024 + val maxBytesInFlight = conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024 // Whether to compress broadcast variables that are stored - val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) + private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) // Whether to compress shuffle output that are stored - val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) + private val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) // Whether to compress RDD partitions that are stored serialized - val compressRdds = conf.getBoolean("spark.rdd.compress", false) + private val compressRdds = conf.getBoolean("spark.rdd.compress", false) // Whether to compress shuffle output temporarily spilled to disk - val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) + private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - - val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)), + private val slaveActor = actorSystem.actorOf( + Props(new BlockManagerSlaveActor(this, mapOutputTracker)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) - // Pending re-registration action being executed asynchronously or null if none - // is pending. Accesses should synchronize on asyncReregisterLock. - var asyncReregisterTask: Future[Unit] = null - val asyncReregisterLock = new Object + // Pending re-registration action being executed asynchronously or null if none is pending. + // Accesses should synchronize on asyncReregisterLock. + private var asyncReregisterTask: Future[Unit] = null + private val asyncReregisterLock = new Object - private def heartBeat() { + private def heartBeat(): Unit = { if (!master.sendHeartBeat(blockManagerId)) { reregister() } } - var heartBeatTask: Cancellable = null + private val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) + private var heartBeatTask: Cancellable = null private val metadataCleaner = new MetadataCleaner( MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) @@ -124,11 +122,11 @@ private[spark] class BlockManager( initialize() - // The compression codec to use. Note that the "lazy" val is necessary because we want to delay - // the initialization of the compression codec until it is first used. The reason is that a Spark - // program could be using a user-defined codec in a third party jar, which is loaded in - // Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been - // loaded yet. + /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay + * the initialization of the compression codec until it is first used. The reason is that a Spark + * program could be using a user-defined codec in a third party jar, which is loaded in + * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been + * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) /** @@ -150,7 +148,7 @@ private[spark] class BlockManager( * Initialize the BlockManager. Register to the BlockManagerMaster, and start the * BlockManagerWorker actor. */ - private def initialize() { + private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { @@ -170,12 +168,12 @@ private[spark] class BlockManager( * heart beat attempt or new block registration and another try to re-register all blocks * will be made then. */ - private def reportAllBlocks() { - logInfo("Reporting " + blockInfo.size + " blocks to the master.") + private def reportAllBlocks(): Unit = { + logInfo(s"Reporting ${blockInfo.size} blocks to the master.") for ((blockId, info) <- blockInfo) { val status = getCurrentBlockStatus(blockId, info) if (!tryToReportBlockStatus(blockId, info, status)) { - logError("Failed to report " + blockId + " to master; giving up.") + logError(s"Failed to report $blockId to master; giving up.") return } } @@ -187,7 +185,7 @@ private[spark] class BlockManager( * * Note that this method must be called without any BlockInfo locks held. */ - def reregister() { + private def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") master.registerBlockManager(blockManagerId, maxMemory, slaveActor) @@ -197,7 +195,7 @@ private[spark] class BlockManager( /** * Re-register with the master sometime soon. */ - def asyncReregister() { + private def asyncReregister(): Unit = { asyncReregisterLock.synchronized { if (asyncReregisterTask == null) { asyncReregisterTask = Future[Unit] { @@ -213,7 +211,7 @@ private[spark] class BlockManager( /** * For testing. Wait for any pending asynchronous re-registration; otherwise, do nothing. */ - def waitForAsyncReregister() { + def waitForAsyncReregister(): Unit = { val task = asyncReregisterTask if (task != null) { Await.ready(task, Duration.Inf) @@ -251,18 +249,18 @@ private[spark] class BlockManager( * it is still valid). This ensures that update in master will compensate for the increase in * memory on slave. */ - def reportBlockStatus( + private def reportBlockStatus( blockId: BlockId, info: BlockInfo, status: BlockStatus, - droppedMemorySize: Long = 0L) { + droppedMemorySize: Long = 0L): Unit = { val needReregister = !tryToReportBlockStatus(blockId, info, status, droppedMemorySize) if (needReregister) { - logInfo("Got told to re-register updating block " + blockId) + logInfo(s"Got told to re-register updating block $blockId") // Re-registering will report our new block for free. asyncReregister() } - logDebug("Told master about block " + blockId) + logDebug(s"Told master about block $blockId") } /** @@ -293,10 +291,10 @@ private[spark] class BlockManager( * and the updated in-memory and on-disk sizes. */ private def getCurrentBlockStatus(blockId: BlockId, info: BlockInfo): BlockStatus = { - val (newLevel, inMemSize, onDiskSize, inTachyonSize) = info.synchronized { + info.synchronized { info.level match { case null => - (StorageLevel.NONE, 0L, 0L, 0L) + BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) case level => val inMem = level.useMemory && memoryStore.contains(blockId) val inTachyon = level.useOffHeap && tachyonStore.contains(blockId) @@ -307,19 +305,18 @@ private[spark] class BlockManager( val memSize = if (inMem) memoryStore.getSize(blockId) else 0L val tachyonSize = if (inTachyon) tachyonStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L - (storageLevel, memSize, diskSize, tachyonSize) + BlockStatus(storageLevel, memSize, diskSize, tachyonSize) } } - BlockStatus(newLevel, inMemSize, onDiskSize, inTachyonSize) } /** * Get locations of an array of blocks. */ - def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = { + private def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = { val startTimeMs = System.currentTimeMillis val locations = master.getLocations(blockIds).toArray - logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got multiple block location in %s".format(Utils.getUsedTimeMs(startTimeMs))) locations } @@ -329,15 +326,16 @@ private[spark] class BlockManager( * never deletes (recent) items. */ def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - diskStore.getValues(blockId, serializer).orElse( - sys.error("Block " + blockId + " not found on disk, though it should be")) + diskStore.getValues(blockId, serializer).orElse { + throw new BlockException(blockId, s"Block $blockId not found on disk, though it should be") + } } /** * Get block from local block manager. */ def getLocal(blockId: BlockId): Option[Iterator[Any]] = { - logDebug("Getting local block " + blockId) + logDebug(s"Getting local block $blockId") doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] } @@ -345,7 +343,7 @@ private[spark] class BlockManager( * Get block from the local block manager as serialized bytes. */ def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = { - logDebug("Getting local block " + blockId + " as bytes") + logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { @@ -353,7 +351,8 @@ private[spark] class BlockManager( case Some(bytes) => Some(bytes) case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") + throw new BlockException( + blockId, s"Block $blockId not found on disk, though it should be") } } else { doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] @@ -368,16 +367,16 @@ private[spark] class BlockManager( // If another thread is writing the block, wait for it to become ready. if (!info.waitForReady()) { // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure.") + logWarning(s"Block $blockId was marked as failure.") return None } val level = info.level - logDebug("Level for block " + blockId + " is " + level) + logDebug(s"Level for block $blockId is $level") // Look for the block in memory if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") + logDebug(s"Getting block $blockId from memory") val result = if (asValues) { memoryStore.getValues(blockId) } else { @@ -387,51 +386,51 @@ private[spark] class BlockManager( case Some(values) => return Some(values) case None => - logDebug("Block " + blockId + " not found in memory") + logDebug(s"Block $blockId not found in memory") } } // Look for the block in Tachyon if (level.useOffHeap) { - logDebug("Getting block " + blockId + " from tachyon") + logDebug(s"Getting block $blockId from tachyon") if (tachyonStore.contains(blockId)) { tachyonStore.getBytes(blockId) match { - case Some(bytes) => { + case Some(bytes) => if (!asValues) { return Some(bytes) } else { return Some(dataDeserialize(blockId, bytes)) } - } case None => - logDebug("Block " + blockId + " not found in tachyon") + logDebug(s"Block $blockId not found in tachyon") } } } - // Look for block on disk, potentially storing it back into memory if required: + // Look for block on disk, potentially storing it back in memory if required if (level.useDisk) { - logDebug("Getting block " + blockId + " from disk") + logDebug(s"Getting block $blockId from disk") val bytes: ByteBuffer = diskStore.getBytes(blockId) match { - case Some(bytes) => bytes + case Some(b) => b case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") + throw new BlockException( + blockId, s"Block $blockId not found on disk, though it should be") } - assert (0 == bytes.position()) + assert(0 == bytes.position()) if (!level.useMemory) { - // If the block shouldn't be stored in memory, we can just return it: + // If the block shouldn't be stored in memory, we can just return it if (asValues) { return Some(dataDeserialize(blockId, bytes)) } else { return Some(bytes) } } else { - // Otherwise, we also have to store something in the memory store: + // Otherwise, we also have to store something in the memory store if (!level.deserialized || !asValues) { - // We'll store the bytes in memory if the block's storage level includes - // "memory serialized", or if it should be cached as objects in memory - // but we only requested its serialized bytes: + /* We'll store the bytes in memory if the block's storage level includes + * "memory serialized", or if it should be cached as objects in memory + * but we only requested its serialized bytes. */ val copyForMemory = ByteBuffer.allocate(bytes.limit) copyForMemory.put(bytes) memoryStore.putBytes(blockId, copyForMemory, level) @@ -442,16 +441,17 @@ private[spark] class BlockManager( } else { val values = dataDeserialize(blockId, bytes) if (level.deserialized) { - // Cache the values before returning them: + // Cache the values before returning them // TODO: Consider creating a putValues that also takes in a iterator? val valuesBuffer = new ArrayBuffer[Any] valuesBuffer ++= values - memoryStore.putValues(blockId, valuesBuffer, level, true).data match { - case Left(values2) => - return Some(values2) - case _ => - throw new Exception("Memory store did not return back an iterator") - } + memoryStore.putValues(blockId, valuesBuffer, level, returnValues = true).data + match { + case Left(values2) => + return Some(values2) + case _ => + throw new SparkException("Memory store did not return an iterator") + } } else { return Some(values) } @@ -460,7 +460,7 @@ private[spark] class BlockManager( } } } else { - logDebug("Block " + blockId + " not registered locally") + logDebug(s"Block $blockId not registered locally") } None } @@ -469,7 +469,7 @@ private[spark] class BlockManager( * Get block from remote block managers. */ def getRemote(blockId: BlockId): Option[Iterator[Any]] = { - logDebug("Getting remote block " + blockId) + logDebug(s"Getting remote block $blockId") doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] } @@ -477,7 +477,7 @@ private[spark] class BlockManager( * Get block from remote block managers as serialized bytes. */ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { - logDebug("Getting remote block " + blockId + " as bytes") + logDebug(s"Getting remote block $blockId as bytes") doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] } @@ -485,7 +485,7 @@ private[spark] class BlockManager( require(blockId != null, "BlockId is null") val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { - logDebug("Getting remote block " + blockId + " from " + loc) + logDebug(s"Getting remote block $blockId from $loc") val data = BlockManagerWorker.syncGetBlock( GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) if (data != null) { @@ -495,9 +495,9 @@ private[spark] class BlockManager( return Some(data) } } - logDebug("The value of block " + blockId + " is null") + logDebug(s"The value of block $blockId is null") } - logDebug("Block " + blockId + " not found") + logDebug(s"Block $blockId not found") None } @@ -507,12 +507,12 @@ private[spark] class BlockManager( def get(blockId: BlockId): Option[Iterator[Any]] = { val local = getLocal(blockId) if (local.isDefined) { - logInfo("Found block %s locally".format(blockId)) + logInfo(s"Found block $blockId locally") return local } val remote = getRemote(blockId) if (remote.isDefined) { - logInfo("Found block %s remotely".format(blockId)) + logInfo(s"Found block $blockId remotely") return remote } None @@ -533,7 +533,6 @@ private[spark] class BlockManager( } else { new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) } - iter.initialize() iter } @@ -543,6 +542,7 @@ private[spark] class BlockManager( values: Iterator[Any], level: StorageLevel, tellMaster: Boolean): Seq[(BlockId, BlockStatus)] = { + require(values != null, "Values is null") doPut(blockId, IteratorValues(values), level, tellMaster) } @@ -562,8 +562,8 @@ private[spark] class BlockManager( } /** - * Put a new block of values to the block manager. Return a list of blocks updated as a - * result of this put. + * Put a new block of values to the block manager. + * Return a list of blocks updated as a result of this put. */ def put( blockId: BlockId, @@ -575,8 +575,8 @@ private[spark] class BlockManager( } /** - * Put a new block of serialized bytes to the block manager. Return a list of blocks updated - * as a result of this put. + * Put a new block of serialized bytes to the block manager. + * Return a list of blocks updated as a result of this put. */ def putBytes( blockId: BlockId, @@ -589,7 +589,7 @@ private[spark] class BlockManager( private def doPut( blockId: BlockId, - data: Values, + data: BlockValues, level: StorageLevel, tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = { @@ -599,20 +599,18 @@ private[spark] class BlockManager( // Return value val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - // Remember the block's storage level so that we can correctly drop it to disk if it needs - // to be dropped right after it got put into memory. Note, however, that other threads will - // not be able to get() this block until we call markReady on its BlockInfo. + /* Remember the block's storage level so that we can correctly drop it to disk if it needs + * to be dropped right after it got put into memory. Note, however, that other threads will + * not be able to get() this block until we call markReady on its BlockInfo. */ val putBlockInfo = { val tinfo = new BlockInfo(level, tellMaster) // Do atomically ! val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - if (oldBlockOpt.isDefined) { if (oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + logWarning(s"Block $blockId already exists on this machine; not re-adding it") return updatedBlocks } - // TODO: So the block info exists - but previous attempt to load it (?) failed. // What do we do now ? Retry on it ? oldBlockOpt.get @@ -623,10 +621,10 @@ private[spark] class BlockManager( val startTimeMs = System.currentTimeMillis - // If we're storing values and we need to replicate the data, we'll want access to the values, - // but because our put will read the whole iterator, there will be no values left. For the - // case where the put serializes data, we'll remember the bytes, above; but for the case where - // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator. + /* If we're storing values and we need to replicate the data, we'll want access to the values, + * but because our put will read the whole iterator, there will be no values left. For the + * case where the put serializes data, we'll remember the bytes, above; but for the case where + * it doesn't, such as deserialized storage, let's rely on the put returning an Iterator. */ var valuesAfterPut: Iterator[Any] = null // Ditto for the bytes after the put @@ -637,78 +635,62 @@ private[spark] class BlockManager( // If we're storing bytes, then initiate the replication before storing them locally. // This is faster as data is already serialized and ready to send. - val replicationFuture = if (data.isInstanceOf[ByteBufferValues] && level.replication > 1) { - // Duplicate doesn't copy the bytes, just creates a wrapper - val bufferView = data.asInstanceOf[ByteBufferValues].buffer.duplicate() - Future { - replicate(blockId, bufferView, level) - } - } else { - null + val replicationFuture = data match { + case b: ByteBufferValues if level.replication > 1 => + // Duplicate doesn't copy the bytes, but just creates a wrapper + val bufferView = b.buffer.duplicate() + Future { replicate(blockId, bufferView, level) } + case _ => null } putBlockInfo.synchronized { - logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") + logTrace("Put for block %s took %s to get into synchronized block" + .format(blockId, Utils.getUsedTimeMs(startTimeMs))) var marked = false try { - if (level.useMemory) { - // Save it just to memory first, even if it also has useDisk set to true; we will - // drop it to disk later if the memory store can't hold it. - val res = data match { - case IteratorValues(iterator) => - memoryStore.putValues(blockId, iterator, level, true) - case ArrayBufferValues(array) => - memoryStore.putValues(blockId, array, level, true) - case ByteBufferValues(bytes) => - bytes.rewind() - memoryStore.putBytes(blockId, bytes, level) - } - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case Left(newIterator) => valuesAfterPut = newIterator - } - // Keep track of which blocks are dropped from memory - res.droppedBlocks.foreach { block => updatedBlocks += block } - } else if (level.useOffHeap) { - // Save to Tachyon. - val res = data match { - case IteratorValues(iterator) => - tachyonStore.putValues(blockId, iterator, level, false) - case ArrayBufferValues(array) => - tachyonStore.putValues(blockId, array, level, false) - case ByteBufferValues(bytes) => - bytes.rewind() - tachyonStore.putBytes(blockId, bytes, level) - } - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => - } - } else { - // Save directly to disk. - // Don't get back the bytes unless we replicate them. - val askForBytes = level.replication > 1 - - val res = data match { - case IteratorValues(iterator) => - diskStore.putValues(blockId, iterator, level, askForBytes) - case ArrayBufferValues(array) => - diskStore.putValues(blockId, array, level, askForBytes) - case ByteBufferValues(bytes) => - bytes.rewind() - diskStore.putBytes(blockId, bytes, level) - } - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => + // returnValues - Whether to return the values put + // blockStore - The type of storage to put these values into + val (returnValues, blockStore: BlockStore) = { + if (level.useMemory) { + // Put it in memory first, even if it also has useDisk set to true; + // We will drop it to disk later if the memory store can't hold it. + (true, memoryStore) + } else if (level.useOffHeap) { + // Use tachyon for off-heap storage + (false, tachyonStore) + } else if (level.useDisk) { + // Don't get back the bytes from put unless we replicate them + (level.replication > 1, diskStore) + } else { + assert(level == StorageLevel.NONE) + throw new BlockException( + blockId, s"Attempted to put block $blockId without specifying storage level!") } } + // Actually put the values + val result = data match { + case IteratorValues(iterator) => + blockStore.putValues(blockId, iterator, level, returnValues) + case ArrayBufferValues(array) => + blockStore.putValues(blockId, array, level, returnValues) + case ByteBufferValues(bytes) => + bytes.rewind() + blockStore.putBytes(blockId, bytes, level) + } + size = result.size + result.data match { + case Left (newIterator) if level.useMemory => valuesAfterPut = newIterator + case Right (newBytes) => bytesAfterPut = newBytes + case _ => + } + + // Keep track of which blocks are dropped from memory + if (level.useMemory) { + result.droppedBlocks.foreach { updatedBlocks += _ } + } + val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) if (putBlockStatus.storageLevel != StorageLevel.NONE) { // Now that the block is in either the memory, tachyon, or disk store, @@ -728,18 +710,21 @@ private[spark] class BlockManager( // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) putBlockInfo.markFailure() - logWarning("Putting block " + blockId + " failed") + logWarning(s"Putting block $blockId failed") } } } - logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) // Either we're storing bytes and we asynchronously started replication, or we're storing // values and need to serialize and replicate them now: if (level.replication > 1) { data match { - case ByteBufferValues(bytes) => Await.ready(replicationFuture, Duration.Inf) - case _ => { + case ByteBufferValues(bytes) => + if (replicationFuture != null) { + Await.ready(replicationFuture, Duration.Inf) + } + case _ => val remoteStartTime = System.currentTimeMillis // Serialize the block if not already done if (bytesAfterPut == null) { @@ -750,20 +735,19 @@ private[spark] class BlockManager( bytesAfterPut = dataSerialize(blockId, valuesAfterPut) } replicate(blockId, bytesAfterPut, level) - logDebug("Put block " + blockId + " remotely took " + - Utils.getUsedTimeMs(remoteStartTime)) - } + logDebug("Put block %s remotely took %s" + .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) } } BlockManager.dispose(bytesAfterPut) if (level.replication > 1) { - logDebug("Put for block " + blockId + " with replication took " + - Utils.getUsedTimeMs(startTimeMs)) + logDebug("Putting block %s with replication took %s" + .format(blockId, Utils.getUsedTimeMs(startTimeMs))) } else { - logDebug("Put for block " + blockId + " without replication took " + - Utils.getUsedTimeMs(startTimeMs)) + logDebug("Putting block %s without replication took %s" + .format(blockId, Utils.getUsedTimeMs(startTimeMs))) } updatedBlocks @@ -773,7 +757,7 @@ private[spark] class BlockManager( * Replicate block to another node. */ @volatile var cachedPeers: Seq[BlockManagerId] = null - private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) { + private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel): Unit = { val tLevel = StorageLevel( level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1) if (cachedPeers == null) { @@ -782,15 +766,16 @@ private[spark] class BlockManager( for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime data.rewind() - logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " - + data.limit() + " Bytes. To node: " + peer) - if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), - new ConnectionManagerId(peer.host, peer.port))) { - logError("Failed to call syncPutBlock to " + peer) + logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " + + s"To node: $peer") + val putBlock = PutBlock(blockId, data, tLevel) + val cmId = new ConnectionManagerId(peer.host, peer.port) + val syncPutBlockSuccess = BlockManagerWorker.syncPutBlock(putBlock, cmId) + if (!syncPutBlockSuccess) { + logError(s"Failed to call syncPutBlock to $peer") } - logDebug("Replicated BlockId " + blockId + " once used " + - (System.nanoTime - start) / 1e6 + " s; The size of the data is " + - data.limit() + " bytes.") + logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes." + .format(blockId, (System.nanoTime - start) / 1e6, data.limit())) } } @@ -822,17 +807,17 @@ private[spark] class BlockManager( blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]): Option[BlockStatus] = { - logInfo("Dropping block " + blockId + " from memory") + logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId).orNull // If the block has not already been dropped - if (info != null) { + if (info != null) { info.synchronized { // required ? As of now, this will be invoked only for blocks which are ready // But in case this changes in future, adding for consistency sake. if (!info.waitForReady()) { // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure. Nothing to drop") + logWarning(s"Block $blockId was marked as failure. Nothing to drop") return None } @@ -841,10 +826,10 @@ private[spark] class BlockManager( // Drop to disk, if storage level requires if (level.useDisk && !diskStore.contains(blockId)) { - logInfo("Writing block " + blockId + " to disk") + logInfo(s"Writing block $blockId to disk") data match { case Left(elements) => - diskStore.putValues(blockId, elements, level, false) + diskStore.putValues(blockId, elements, level, returnValues = false) case Right(bytes) => diskStore.putBytes(blockId, bytes, level) } @@ -858,7 +843,7 @@ private[spark] class BlockManager( if (blockIsRemoved) { blockIsUpdated = true } else { - logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") + logWarning(s"Block $blockId could not be dropped from memory as it does not exist") } val status = getCurrentBlockStatus(blockId, info) @@ -883,7 +868,7 @@ private[spark] class BlockManager( */ def removeRdd(rddId: Int): Int = { // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. - logInfo("Removing RDD " + rddId) + logInfo(s"Removing RDD $rddId") val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size @@ -893,7 +878,7 @@ private[spark] class BlockManager( * Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { - logInfo("Removing broadcast " + broadcastId) + logInfo(s"Removing broadcast $broadcastId") val blocksToRemove = blockInfo.keys.collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } @@ -904,40 +889,42 @@ private[spark] class BlockManager( /** * Remove a block from both memory and disk. */ - def removeBlock(blockId: BlockId, tellMaster: Boolean = true) { - logInfo("Removing block " + blockId) + def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { + logInfo(s"Removing block $blockId") val info = blockInfo.get(blockId).orNull - if (info != null) info.synchronized { - // Removals are idempotent in disk store and memory store. At worst, we get a warning. - val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) - val removedFromTachyon = if (tachyonInitialized) tachyonStore.remove(blockId) else false - if (!removedFromMemory && !removedFromDisk && !removedFromTachyon) { - logWarning("Block " + blockId + " could not be removed as it was not found in either " + - "the disk, memory, or tachyon store") - } - blockInfo.remove(blockId) - if (tellMaster && info.tellMaster) { - val status = getCurrentBlockStatus(blockId, info) - reportBlockStatus(blockId, info, status) + if (info != null) { + info.synchronized { + // Removals are idempotent in disk store and memory store. At worst, we get a warning. + val removedFromMemory = memoryStore.remove(blockId) + val removedFromDisk = diskStore.remove(blockId) + val removedFromTachyon = if (tachyonInitialized) tachyonStore.remove(blockId) else false + if (!removedFromMemory && !removedFromDisk && !removedFromTachyon) { + logWarning(s"Block $blockId could not be removed as it was not found in either " + + "the disk, memory, or tachyon store") + } + blockInfo.remove(blockId) + if (tellMaster && info.tellMaster) { + val status = getCurrentBlockStatus(blockId, info) + reportBlockStatus(blockId, info, status) + } } } else { // The block has already been removed; do nothing. - logWarning("Asked to remove block " + blockId + ", which does not exist") + logWarning(s"Asked to remove block $blockId, which does not exist") } } - private def dropOldNonBroadcastBlocks(cleanupTime: Long) { - logInfo("Dropping non broadcast blocks older than " + cleanupTime) + private def dropOldNonBroadcastBlocks(cleanupTime: Long): Unit = { + logInfo(s"Dropping non broadcast blocks older than $cleanupTime") dropOldBlocks(cleanupTime, !_.isBroadcast) } - private def dropOldBroadcastBlocks(cleanupTime: Long) { - logInfo("Dropping broadcast blocks older than " + cleanupTime) + private def dropOldBroadcastBlocks(cleanupTime: Long): Unit = { + logInfo(s"Dropping broadcast blocks older than $cleanupTime") dropOldBlocks(cleanupTime, _.isBroadcast) } - private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { + private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)): Unit = { val iterator = blockInfo.getEntrySet.iterator while (iterator.hasNext) { val entry = iterator.next() @@ -945,17 +932,11 @@ private[spark] class BlockManager( if (time < cleanupTime && shouldDrop(id)) { info.synchronized { val level = info.level - if (level.useMemory) { - memoryStore.remove(id) - } - if (level.useDisk) { - diskStore.remove(id) - } - if (level.useOffHeap) { - tachyonStore.remove(id) - } + if (level.useMemory) { memoryStore.remove(id) } + if (level.useDisk) { diskStore.remove(id) } + if (level.useOffHeap) { tachyonStore.remove(id) } iterator.remove() - logInfo("Dropped block " + id) + logInfo(s"Dropped block $id") } val status = getCurrentBlockStatus(id, info) reportBlockStatus(id, info, status) @@ -963,12 +944,14 @@ private[spark] class BlockManager( } } - def shouldCompress(blockId: BlockId): Boolean = blockId match { - case ShuffleBlockId(_, _, _) => compressShuffle - case BroadcastBlockId(_, _) => compressBroadcast - case RDDBlockId(_, _) => compressRdds - case TempBlockId(_) => compressShuffleSpill - case _ => false + private def shouldCompress(blockId: BlockId): Boolean = { + blockId match { + case _: ShuffleBlockId => compressShuffle + case _: BroadcastBlockId => compressBroadcast + case _: RDDBlockId => compressRdds + case _: TempBlockId => compressShuffleSpill + case _ => false + } } /** @@ -990,7 +973,7 @@ private[spark] class BlockManager( blockId: BlockId, outputStream: OutputStream, values: Iterator[Any], - serializer: Serializer = defaultSerializer) { + serializer: Serializer = defaultSerializer): Unit = { val byteStream = new BufferedOutputStream(outputStream) val ser = serializer.newInstance() ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() @@ -1016,16 +999,16 @@ private[spark] class BlockManager( serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - def getIterator = { + def getIterator: Iterator[Any] = { val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) serializer.newInstance().deserializeStream(stream).asIterator } if (blockId.isShuffle) { - // Reducer may need to read many local shuffle blocks and will wrap them into Iterators - // at the beginning. The wrapping will cost some memory (compression instance - // initialization, etc.). Reducer reads shuffle blocks one by one so we could do the - // wrapping lazily to save memory. + /* Reducer may need to read many local shuffle blocks and will wrap them into Iterators + * at the beginning. The wrapping will cost some memory (compression instance + * initialization, etc.). Reducer reads shuffle blocks one by one so we could do the + * wrapping lazily to save memory. */ class LazyProxyIterator(f: => Iterator[Any]) extends Iterator[Any] { lazy val proxy = f override def hasNext: Boolean = proxy.hasNext @@ -1037,7 +1020,7 @@ private[spark] class BlockManager( } } - def stop() { + def stop(): Unit = { if (heartBeatTask != null) { heartBeatTask.cancel() } @@ -1059,9 +1042,9 @@ private[spark] class BlockManager( private[spark] object BlockManager extends Logging { - val ID_GENERATOR = new IdGenerator + private val ID_GENERATOR = new IdGenerator - def getMaxMemory(conf: SparkConf): Long = { + private def getMaxMemory(conf: SparkConf): Long = { val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) (Runtime.getRuntime.maxMemory * memoryFraction).toLong } @@ -1078,9 +1061,9 @@ private[spark] object BlockManager extends Logging { * waiting for the GC to find it because that could lead to huge numbers of open files. There's * unfortunately no standard API to do this. */ - def dispose(buffer: ByteBuffer) { + def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace("Unmapping " + buffer) + logTrace(s"Unmapping $buffer") if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { buffer.asInstanceOf[DirectBuffer].cleaner().clean() } @@ -1093,7 +1076,7 @@ private[spark] object BlockManager extends Logging { blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[BlockManagerId]] = { // blockManagerMaster != null is used in tests - assert (env != null || blockManagerMaster != null) + assert(env != null || blockManagerMaster != null) val blockLocations: Seq[Seq[BlockManagerId]] = if (blockManagerMaster == null) { env.blockManager.getLocationBlockIds(blockIds) } else { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3a7243a1ba19c..2ec46d416f37d 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -40,9 +40,9 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64) - // Create one local directory for each path mentioned in spark.local.dir; then, inside this - // directory, create multiple subdirectories that we will hash files into, in order to avoid - // having really large inodes at the top level. + /* Create one local directory for each path mentioned in spark.local.dir; then, inside this + * directory, create multiple subdirectories that we will hash files into, in order to avoid + * having really large inodes at the top level. */ private val localDirs: Array[File] = createLocalDirs() private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private var shuffleSender : ShuffleSender = null @@ -114,7 +114,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD } private def createLocalDirs(): Array[File] = { - logDebug("Creating local directories at root dirs '" + rootDirs + "'") + logDebug(s"Creating local directories at root dirs '$rootDirs'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") rootDirs.split(",").map { rootDir => var foundLocalDir = false @@ -126,21 +126,20 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD tries += 1 try { localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - localDir = new File(rootDir, "spark-local-" + localDirId) + localDir = new File(rootDir, s"spark-local-$localDirId") if (!localDir.exists) { foundLocalDir = localDir.mkdirs() } } catch { case e: Exception => - logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) + logWarning(s"Attempt $tries to create local dir $localDir failed", e) } } if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + - " attempts to create local dir in " + rootDir) + logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create local dir in $rootDir") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } - logInfo("Created local directory at " + localDir) + logInfo(s"Created local directory at $localDir") localDir } } @@ -163,7 +162,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) } catch { case e: Exception => - logError("Exception while deleting local spark dir: " + localDir, e) + logError(s"Exception while deleting local spark dir: $localDir", e) } } } @@ -175,7 +174,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD private[storage] def startShuffleBlockSender(port: Int): Int = { shuffleSender = new ShuffleSender(port, this) - logInfo("Created ShuffleSender binding to port : " + shuffleSender.port) + logInfo(s"Created ShuffleSender binding to port: ${shuffleSender.port}") shuffleSender.port } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 0ab9fad422717..ebff0cb5ba153 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -39,41 +39,39 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage diskManager.getBlockLocation(blockId).length } - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = { // So that we do not modify the input offsets ! // duplicate does not copy buffer, so inexpensive val bytes = _bytes.duplicate() - logDebug("Attempting to put block " + blockId) + logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) - val channel = new FileOutputStream(file).getChannel() + val channel = new FileOutputStream(file).getChannel while (bytes.remaining > 0) { channel.write(bytes) } channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( - file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime))) - return PutResult(bytes.limit(), Right(bytes.duplicate())) + file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) + PutResult(bytes.limit(), Right(bytes.duplicate())) } override def putValues( blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, - returnValues: Boolean) - : PutResult = { - return putValues(blockId, values.toIterator, level, returnValues) + returnValues: Boolean): PutResult = { + putValues(blockId, values.toIterator, level, returnValues) } override def putValues( blockId: BlockId, values: Iterator[Any], level: StorageLevel, - returnValues: Boolean) - : PutResult = { + returnValues: Boolean): PutResult = { - logDebug("Attempting to write values for block " + blockId) + logDebug(s"Attempting to write values for block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) @@ -95,7 +93,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val segment = diskManager.getBlockLocation(blockId) - val channel = new RandomAccessFile(segment.file, "r").getChannel() + val channel = new RandomAccessFile(segment.file, "r").getChannel try { // For small files, directly read rather than memory map @@ -131,7 +129,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage file.delete() } else { if (fileSegment.length < file.length()) { - logWarning("Could not delete block associated with only a part of a file: " + blockId) + logWarning(s"Could not delete block associated with only a part of a file: $blockId") } false } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 488f1ea9628f5..084a566c48560 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -24,6 +24,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.util.{SizeEstimator, Utils} +private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean) + /** * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as * serialized ByteBuffers. @@ -31,15 +33,13 @@ import org.apache.spark.util.{SizeEstimator, Utils} private class MemoryStore(blockManager: BlockManager, maxMemory: Long) extends BlockStore(blockManager) { - case class Entry(value: Any, size: Long, deserialized: Boolean) - - private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true) + private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true) @volatile private var currentMemory = 0L // Object used to ensure that only one thread is putting blocks and if necessary, dropping // blocks from the memory store. private val putLock = new Object() - logInfo("MemoryStore started with capacity %s.".format(Utils.bytesToString(maxMemory))) + logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) def freeMemory: Long = maxMemory - currentMemory @@ -101,7 +101,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else if (entry.deserialized) { Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)) } else { - Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data + Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data } } @@ -124,8 +124,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) val entry = entries.remove(blockId) if (entry != null) { currentMemory -= entry.size - logInfo("Block %s of size %d dropped from memory (free %d)".format( - blockId, entry.size, freeMemory)) + logInfo(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)") true } else { false @@ -181,18 +180,14 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlocks ++= freeSpaceResult.droppedBlocks if (enoughFreeSpace) { - val entry = new Entry(value, size, deserialized) + val entry = new MemoryEntry(value, size, deserialized) entries.synchronized { entries.put(blockId, entry) currentMemory += size } - if (deserialized) { - logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - } else { - logInfo("Block %s stored as bytes to memory (size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - } + val valuesOrBytes = if (deserialized) "values" else "bytes" + logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( + blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) putSuccess = true } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to @@ -221,13 +216,12 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Return whether there is enough free space, along with the blocks dropped in the process. */ private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): ResultWithDroppedBlocks = { - logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( - space, currentMemory, maxMemory)) + logInfo(s"ensureFreeSpace($space) called with curMem=$currentMemory, maxMem=$maxMemory") val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] if (space > maxMemory) { - logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit") + logInfo(s"Will not store $blockIdToAdd as it is larger than our memory limit") return ResultWithDroppedBlocks(success = false, droppedBlocks) } @@ -252,7 +246,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } if (maxMemory - (currentMemory - selectedMemory) >= space) { - logInfo(selectedBlocks.size + " blocks selected for dropping") + logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } // This should never be null as only one thread should be dropping diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 2d8ff1194a5dc..1e35abaab5353 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -34,11 +34,11 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi class StorageLevel private( - private var useDisk_ : Boolean, - private var useMemory_ : Boolean, - private var useOffHeap_ : Boolean, - private var deserialized_ : Boolean, - private var replication_ : Int = 1) + private var _useDisk: Boolean, + private var _useMemory: Boolean, + private var _useOffHeap: Boolean, + private var _deserialized: Boolean, + private var _replication: Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. @@ -48,13 +48,13 @@ class StorageLevel private( def this() = this(false, true, false, false) // For deserialization - def useDisk = useDisk_ - def useMemory = useMemory_ - def useOffHeap = useOffHeap_ - def deserialized = deserialized_ - def replication = replication_ + def useDisk = _useDisk + def useMemory = _useMemory + def useOffHeap = _useOffHeap + def deserialized = _deserialized + def replication = _replication - assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes") if (useOffHeap) { require(!useDisk, "Off-heap storage level does not support using disk") @@ -63,8 +63,9 @@ class StorageLevel private( require(replication == 1, "Off-heap storage level does not support multiple replication") } - override def clone(): StorageLevel = new StorageLevel( - this.useDisk, this.useMemory, this.useOffHeap, this.deserialized, this.replication) + override def clone(): StorageLevel = { + new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication) + } override def equals(other: Any): Boolean = other match { case s: StorageLevel => @@ -77,20 +78,20 @@ class StorageLevel private( false } - def isValid = ((useMemory || useDisk || useOffHeap) && (replication > 0)) + def isValid = (useMemory || useDisk || useOffHeap) && (replication > 0) def toInt: Int = { var ret = 0 - if (useDisk_) { + if (_useDisk) { ret |= 8 } - if (useMemory_) { + if (_useMemory) { ret |= 4 } - if (useOffHeap_) { + if (_useOffHeap) { ret |= 2 } - if (deserialized_) { + if (_deserialized) { ret |= 1 } ret @@ -98,32 +99,34 @@ class StorageLevel private( override def writeExternal(out: ObjectOutput) { out.writeByte(toInt) - out.writeByte(replication_) + out.writeByte(_replication) } override def readExternal(in: ObjectInput) { val flags = in.readByte() - useDisk_ = (flags & 8) != 0 - useMemory_ = (flags & 4) != 0 - useOffHeap_ = (flags & 2) != 0 - deserialized_ = (flags & 1) != 0 - replication_ = in.readByte() + _useDisk = (flags & 8) != 0 + _useMemory = (flags & 4) != 0 + _useOffHeap = (flags & 2) != 0 + _deserialized = (flags & 1) != 0 + _replication = in.readByte() } @throws(classOf[IOException]) private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) - override def toString: String = "StorageLevel(%b, %b, %b, %b, %d)".format( - useDisk, useMemory, useOffHeap, deserialized, replication) + override def toString: String = { + s"StorageLevel($useDisk, $useMemory, $useOffHeap, $deserialized, $replication)" + } override def hashCode(): Int = toInt * 41 + replication - def description : String = { + + def description: String = { var result = "" result += (if (useDisk) "Disk " else "") result += (if (useMemory) "Memory " else "") result += (if (useOffHeap) "Tachyon " else "") result += (if (deserialized) "Deserialized " else "Serialized ") - result += "%sx Replicated".format(replication) + result += s"${replication}x Replicated" result } } @@ -165,7 +168,7 @@ object StorageLevel { case "MEMORY_AND_DISK_SER" => MEMORY_AND_DISK_SER case "MEMORY_AND_DISK_SER_2" => MEMORY_AND_DISK_SER_2 case "OFF_HEAP" => OFF_HEAP - case _ => throw new IllegalArgumentException("Invalid StorageLevel: " + s) + case _ => throw new IllegalArgumentException(s"Invalid StorageLevel: $s") } /** @@ -173,26 +176,37 @@ object StorageLevel { * Create a new StorageLevel object without setting useOffHeap. */ @DeveloperApi - def apply(useDisk: Boolean, useMemory: Boolean, useOffHeap: Boolean, - deserialized: Boolean, replication: Int) = getCachedStorageLevel( + def apply( + useDisk: Boolean, + useMemory: Boolean, + useOffHeap: Boolean, + deserialized: Boolean, + replication: Int) = { + getCachedStorageLevel( new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication)) + } /** * :: DeveloperApi :: * Create a new StorageLevel object. */ @DeveloperApi - def apply(useDisk: Boolean, useMemory: Boolean, - deserialized: Boolean, replication: Int = 1) = getCachedStorageLevel( - new StorageLevel(useDisk, useMemory, false, deserialized, replication)) + def apply( + useDisk: Boolean, + useMemory: Boolean, + deserialized: Boolean, + replication: Int = 1) = { + getCachedStorageLevel(new StorageLevel(useDisk, useMemory, false, deserialized, replication)) + } /** * :: DeveloperApi :: * Create a new StorageLevel object from its integer representation. */ @DeveloperApi - def apply(flags: Int, replication: Int): StorageLevel = + def apply(flags: Int, replication: Int): StorageLevel = { getCachedStorageLevel(new StorageLevel(flags, replication)) + } /** * :: DeveloperApi :: @@ -205,8 +219,8 @@ object StorageLevel { getCachedStorageLevel(obj) } - private[spark] - val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + private[spark] val storageLevelCache = + new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { storageLevelCache.putIfAbsent(level, level) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala index c37e76f893605..d8ff4ff6bd42c 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala @@ -22,15 +22,10 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import tachyon.client.{WriteType, ReadType} +import tachyon.client.{ReadType, WriteType} import org.apache.spark.Logging import org.apache.spark.util.Utils -import org.apache.spark.serializer.Serializer - - -private class Entry(val size: Long) - /** * Stores BlockManager blocks on Tachyon. @@ -46,8 +41,8 @@ private class TachyonStore( tachyonManager.getFile(blockId.name).length } - override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { - putToTachyonStore(blockId, bytes, true) + override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { + putIntoTachyonStore(blockId, bytes, returnValues = true) } override def putValues( @@ -55,7 +50,7 @@ private class TachyonStore( values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - return putValues(blockId, values.toIterator, level, returnValues) + putValues(blockId, values.toIterator, level, returnValues) } override def putValues( @@ -63,12 +58,12 @@ private class TachyonStore( values: Iterator[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - logDebug("Attempting to write values for block " + blockId) - val _bytes = blockManager.dataSerialize(blockId, values) - putToTachyonStore(blockId, _bytes, returnValues) + logDebug(s"Attempting to write values for block $blockId") + val bytes = blockManager.dataSerialize(blockId, values) + putIntoTachyonStore(blockId, bytes, returnValues) } - private def putToTachyonStore( + private def putIntoTachyonStore( blockId: BlockId, bytes: ByteBuffer, returnValues: Boolean): PutResult = { @@ -76,7 +71,7 @@ private class TachyonStore( // duplicate does not copy buffer, so inexpensive val byteBuffer = bytes.duplicate() byteBuffer.rewind() - logDebug("Attempting to put block " + blockId + " into Tachyon") + logDebug(s"Attempting to put block $blockId into Tachyon") val startTime = System.currentTimeMillis val file = tachyonManager.getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) @@ -84,7 +79,7 @@ private class TachyonStore( os.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file in Tachyon in %d ms".format( - blockId, Utils.bytesToString(byteBuffer.limit), (finishTime - startTime))) + blockId, Utils.bytesToString(byteBuffer.limit), finishTime - startTime)) if (returnValues) { PutResult(bytes.limit(), Right(bytes.duplicate())) @@ -106,10 +101,9 @@ private class TachyonStore( getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) } - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val file = tachyonManager.getFile(blockId) - if (file == null || file.getLocationHosts().size == 0) { + if (file == null || file.getLocationHosts.size == 0) { return None } val is = file.getInStream(ReadType.CACHE) @@ -121,16 +115,15 @@ private class TachyonStore( val fetchSize = is.read(bs, 0, size.asInstanceOf[Int]) buffer = ByteBuffer.wrap(bs) if (fetchSize != size) { - logWarning("Failed to fetch the block " + blockId + " from Tachyon : Size " + size + - " is not equal to fetched size " + fetchSize) + logWarning(s"Failed to fetch the block $blockId from Tachyon: Size $size " + + s"is not equal to fetched size $fetchSize") return None } } } catch { - case ioe: IOException => { - logWarning("Failed to fetch the block " + blockId + " from Tachyon", ioe) - return None - } + case ioe: IOException => + logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) + return None } Some(buffer) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ee629794f60ad..042fdfcc47261 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -53,6 +53,8 @@ object MimaExcludes { "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.Entry"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$" + "createZero$1") @@ -67,7 +69,10 @@ object MimaExcludes { ) ++ MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ - MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") + MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ + MimaBuild.excludeSparkClass("storage.Values") ++ + MimaBuild.excludeSparkClass("storage.Entry") ++ + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") case v if v.startsWith("1.0") => Seq( MimaBuild.excludeSparkPackage("api.java"), From f95ac686bcba4e677254120735b0eb7a29f20d63 Mon Sep 17 00:00:00 2001 From: John Zhao Date: Thu, 12 Jun 2014 21:39:00 -0700 Subject: [PATCH 33/38] [SPARK-1516]Throw exception in yarn client instead of run system.exit directly. All the changes is in the package of "org.apache.spark.deploy.yarn": 1) Throw exception in ClinetArguments and ClientBase instead of exit directly. 2) in Client's main method, if exception is caught, it will exit with code 1, otherwise exit with code 0. After the fix, if user integrate the spark yarn client into their applications, when the argument is wrong or the running is finished, the application won't be terminated. Author: John Zhao Closes #490 from codeboyyong/jira_1516_systemexit_inyarnclient and squashes the following commits: 138cb48 [John Zhao] [SPARK-1516]Throw exception in yarn clinet instead of run system.exit directly. All the changes is in the package of "org.apache.spark.deploy.yarn": 1) Add a ClientException with an exitCode 2) Throws exception in ClinetArguments and ClientBase instead of exit directly 3) in Client's main method, catch exception and exit with the exitCode. --- .../org/apache/spark/deploy/yarn/Client.scala | 14 +++++++--- .../spark/deploy/yarn/ClientArguments.scala | 16 +++++------- .../apache/spark/deploy/yarn/ClientBase.scala | 26 ++++++++++++------- .../org/apache/spark/deploy/yarn/Client.scala | 14 +++++++--- 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 8226207de42b8..4ccddc214c8ad 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -85,7 +85,6 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def run() { val appId = runApp() monitorApplication(appId) - System.exit(0) } def logClusterResourceDetails() { @@ -179,8 +178,17 @@ object Client { System.setProperty("SPARK_YARN_MODE", "true") val sparkConf = new SparkConf - val args = new ClientArguments(argStrings, sparkConf) - new Client(args, sparkConf).run + try { + val args = new ClientArguments(argStrings, sparkConf) + new Client(args, sparkConf).run() + } catch { + case e: Exception => { + Console.err.println(e.getMessage) + System.exit(1) + } + } + + System.exit(0) } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index b2c413b6d267c..fd3ef9e1fa2de 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -125,11 +125,11 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { case Nil => if (userClass == null) { - printUsageAndExit(1) + throw new IllegalArgumentException(getUsageMessage()) } case _ => - printUsageAndExit(1, args) + throw new IllegalArgumentException(getUsageMessage(args)) } } @@ -138,11 +138,10 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { } - def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { - if (unknownParam != null) { - System.err.println("Unknown/unsupported param " + unknownParam) - } - System.err.println( + def getUsageMessage(unknownParam: Any = null): String = { + val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" + + message + "Usage: org.apache.spark.deploy.yarn.Client [options] \n" + "Options:\n" + " --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" + @@ -158,8 +157,5 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" + " --files files Comma separated list of files to be distributed with the job.\n" + " --archives archives Comma separated list of archives to be distributed with the job." - ) - System.exit(exitCode) } - } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 29a35680c0e72..6861b503000ca 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.{SparkException, Logging, SparkConf, SparkContext} /** * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The @@ -79,7 +79,7 @@ trait ClientBase extends Logging { ).foreach { case(cond, errStr) => if (cond) { logError(errStr) - args.printUsageAndExit(1) + throw new IllegalArgumentException(args.getUsageMessage()) } } } @@ -94,15 +94,20 @@ trait ClientBase extends Logging { // If we have requested more then the clusters max for a single resource then exit. if (args.executorMemory > maxMem) { - logError("Required executor memory (%d MB), is above the max threshold (%d MB) of this cluster.". - format(args.executorMemory, maxMem)) - System.exit(1) + val errorMessage = + "Required executor memory (%d MB), is above the max threshold (%d MB) of this cluster." + .format(args.executorMemory, maxMem) + + logError(errorMessage) + throw new IllegalArgumentException(errorMessage) } val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD if (amMem > maxMem) { - logError("Required AM memory (%d) is above the max threshold (%d) of this cluster". - format(args.amMemory, maxMem)) - System.exit(1) + + val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster." + .format(args.amMemory, maxMem) + logError(errorMessage) + throw new IllegalArgumentException(errorMessage) } // We could add checks to make sure the entire cluster has enough resources but that involves @@ -186,8 +191,9 @@ trait ClientBase extends Logging { val delegTokenRenewer = Master.getMasterPrincipal(conf) if (UserGroupInformation.isSecurityEnabled()) { if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - logError("Can't get Master Kerberos principal for use as renewer") - System.exit(1) + val errorMessage = "Can't get Master Kerberos principal for use as renewer" + logError(errorMessage) + throw new SparkException(errorMessage) } } val dst = new Path(fs.getHomeDirectory(), appStagingDir) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 24027618c1f35..80a8bceb17269 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -95,7 +95,6 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def run() { val appId = runApp() monitorApplication(appId) - System.exit(0) } def logClusterResourceDetails() { @@ -186,9 +185,18 @@ object Client { // see Client#setupLaunchEnv(). System.setProperty("SPARK_YARN_MODE", "true") val sparkConf = new SparkConf() - val args = new ClientArguments(argStrings, sparkConf) - new Client(args, sparkConf).run() + try { + val args = new ClientArguments(argStrings, sparkConf) + new Client(args, sparkConf).run() + } catch { + case e: Exception => { + Console.err.println(e.getMessage) + System.exit(1) + } + } + + System.exit(0) } } From 13f8cfdc04589b986554310965e83fe658085683 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 12 Jun 2014 23:09:41 -0700 Subject: [PATCH 34/38] [SPARK-2135][SQL] Use planner for in-memory scans Author: Michael Armbrust Closes #1072 from marmbrus/cachedStars and squashes the following commits: 8757c8e [Michael Armbrust] Use planner for in-memory scans. --- .../org/apache/spark/sql/SQLContext.scala | 14 +++---- .../columnar/InMemoryColumnarTableScan.scala | 39 +++++++++++++++---- .../spark/sql/execution/SparkPlan.scala | 2 - .../spark/sql/execution/SparkStrategies.scala | 13 +++++++ .../apache/spark/sql/CachedTableSuite.scala | 15 ++++--- .../columnar/InMemoryColumnarQuerySuite.scala | 6 +-- .../apache/spark/sql/hive/HiveContext.scala | 1 + .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 ++-- .../spark/sql/hive/HiveStrategies.scala | 7 ++-- .../spark/sql/hive/CachedTableSuite.scala | 6 +-- 10 files changed, 75 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 264192ed1aa26..b6a2f1b9d1278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies @@ -166,10 +166,9 @@ class SQLContext(@transient val sparkContext: SparkContext) val useCompression = sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false) val asInMemoryRelation = - InMemoryColumnarTableScan( - currentTable.output, executePlan(currentTable).executedPlan, useCompression) + InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) - catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation)) + catalog.registerTable(None, tableName, asInMemoryRelation) } /** Removes the specified table from the in-memory cache. */ @@ -177,11 +176,11 @@ class SQLContext(@transient val sparkContext: SparkContext) EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match { // This is kind of a hack to make sure that if this was just an RDD registered as a table, // we reregister the RDD as a table. - case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd, _)) => + case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) catalog.registerTable(None, tableName, SparkLogicalPlan(e)) - case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) => + case inMem: InMemoryRelation => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan") @@ -192,7 +191,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def isCached(tableName: String): Boolean = { val relation = catalog.lookupRelation(None, tableName) EliminateAnalysisOperators(relation) match { - case SparkLogicalPlan(_: InMemoryColumnarTableScan) => true + case _: InMemoryRelation => true case _ => false } } @@ -208,6 +207,7 @@ class SQLContext(@transient val sparkContext: SparkContext) PartialAggregation :: LeftSemiJoin :: HashJoin :: + InMemoryScans :: ParquetOperations :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index fdf28e1bb1261..e1e4f24c6c66c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -17,18 +17,29 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, LeafNode} import org.apache.spark.sql.Row import org.apache.spark.SparkConf -private[sql] case class InMemoryColumnarTableScan( - attributes: Seq[Attribute], - child: SparkPlan, - useCompression: Boolean) - extends LeafNode { +object InMemoryRelation { + def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, child) +} - override def output: Seq[Attribute] = attributes +private[sql] case class InMemoryRelation( + output: Seq[Attribute], + useCompression: Boolean, + child: SparkPlan) + extends LogicalPlan with MultiInstanceRelation { + + override def children = Seq.empty + override def references = Set.empty + + override def newInstance() = + new InMemoryRelation(output.map(_.newInstance), useCompression, child).asInstanceOf[this.type] lazy val cachedColumnBuffers = { val output = child.output @@ -55,14 +66,26 @@ private[sql] case class InMemoryColumnarTableScan( cached.count() cached } +} + +private[sql] case class InMemoryColumnarTableScan( + attributes: Seq[Attribute], + relation: InMemoryRelation) + extends LeafNode { + + override def output: Seq[Attribute] = attributes override def execute() = { - cachedColumnBuffers.mapPartitions { iterator => + relation.cachedColumnBuffers.mapPartitions { iterator => val columnBuffers = iterator.next() assert(!iterator.hasNext) new Iterator[Row] { - val columnAccessors = columnBuffers.map(ColumnAccessor(_)) + // Find the ordinals of the requested columns. If none are requested, use the first. + val requestedColumns = + if (attributes.isEmpty) Seq(0) else attributes.map(relation.output.indexOf(_)) + + val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) val nextRow = new GenericMutableRow(columnAccessors.length) override def next() = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 4613df103943d..07967fe75e882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -77,8 +77,6 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan) SparkLogicalPlan( alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) - case scan @ InMemoryColumnarTableScan(output, _, _) => - scan.copy(attributes = output.map(_.newInstance)) case _ => sys.error("Multiple instance of the same relation detected.") }).asInstanceOf[this.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f2f95dfe27e69..a657911afe538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.parquet._ +import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -191,6 +192,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object InMemoryScans extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => + pruneFilterProject( + projectList, + filters, + identity[Seq[Expression]], // No filters are pushed down. + InMemoryColumnarTableScan(_, mem)) :: Nil + case _ => Nil + } + } + // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions = self.numPartitions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index ebca3adc2ff01..c794da4da4069 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan -import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.test.TestSQLContext class CachedTableSuite extends QueryTest { @@ -34,7 +33,7 @@ class CachedTableSuite extends QueryTest { ) TestSQLContext.table("testData").queryExecution.analyzed match { - case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case _ : InMemoryRelation => // Found evidence of caching case noCache => fail(s"No cache node found in plan $noCache") } @@ -46,7 +45,7 @@ class CachedTableSuite extends QueryTest { ) TestSQLContext.table("testData").queryExecution.analyzed match { - case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + case cachePlan: InMemoryRelation => fail(s"Table still cached after uncache: $cachePlan") case noCache => // Table uncached successfully } @@ -61,13 +60,17 @@ class CachedTableSuite extends QueryTest { test("SELECT Star Cached Table") { TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar") TestSQLContext.cacheTable("selectStar") - TestSQLContext.sql("SELECT * FROM selectStar") + TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect() TestSQLContext.uncacheTable("selectStar") } test("Self-join cached") { + val unCachedAnswer = + TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() TestSQLContext.cacheTable("testData") - TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key") + checkAnswer( + TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), + unCachedAnswer.toSeq) TestSQLContext.uncacheTable("testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 31c5dfba92954..86727b93f3659 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -28,14 +28,14 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true)) + val scan = InMemoryRelation(useCompression = true, plan) checkAnswer(scan, testData.collect().toSeq) } test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true)) + val scan = InMemoryRelation(useCompression = true, plan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -44,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true)) + val scan = InMemoryRelation(useCompression = true, plan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 64978215542ec..9cd13f6ae0d57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -230,6 +230,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { CommandStrategy(self), TakeOrdered, ParquetOperations, + InMemoryScans, HiveTableScans, DataSinks, Scripts, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a91b520765349..e9e6497f7e5fa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkLogicalPlan import org.apache.spark.sql.hive.execution.{HiveTableScan, InsertIntoHiveTable} -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -130,8 +130,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => castChildOutput(p, table, child) - case p @ logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan( - _, HiveTableScan(_, table, _), _)), _, child, _) => + case p @ logical.InsertIntoTable( + InMemoryRelation(_, _, + HiveTableScan(_, table, _)), _, child, _) => castChildOutput(p, table, child) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8b51957162e04..ed6cd5a11dba4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.columnar.InMemoryRelation private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. @@ -44,8 +44,9 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil - case logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan( - _, HiveTableScan(_, table, _), _)), partition, child, overwrite) => + case logical.InsertIntoTable( + InMemoryRelation(_, _, + HiveTableScan(_, table, _)), partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case _ => Nil } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 91ac03ca30cd7..3132d0112c708 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.execution.SparkLogicalPlan -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.hive.execution.HiveComparisonTest import org.apache.spark.sql.hive.test.TestHive @@ -34,7 +34,7 @@ class CachedTableSuite extends HiveComparisonTest { test("check that table is cached and uncache") { TestHive.table("src").queryExecution.analyzed match { - case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case _ : InMemoryRelation => // Found evidence of caching case noCache => fail(s"No cache node found in plan $noCache") } TestHive.uncacheTable("src") @@ -45,7 +45,7 @@ class CachedTableSuite extends HiveComparisonTest { test("make sure table is uncached") { TestHive.table("src").queryExecution.analyzed match { - case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + case cachePlan: InMemoryRelation => fail(s"Table still cached after uncache: $cachePlan") case noCache => // Table uncached successfully } From b3736e3d2ff9ccb83a18eefec661739105a38df5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 13 Jun 2014 02:59:38 -0700 Subject: [PATCH 35/38] [HOTFIX] add math3 version to pom Passed `mvn package`. Author: Xiangrui Meng Closes #1075 from mengxr/takeSample-fix and squashes the following commits: 45b4590 [Xiangrui Meng] add math3 version to pom --- core/pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/core/pom.xml b/core/pom.xml index be56911b9e45a..bd6767e03bb9d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -70,6 +70,7 @@ org.apache.commons commons-math3 + 3.3 test From 70c8116c0aecba293234edc44a7f8e58e5008649 Mon Sep 17 00:00:00 2001 From: nravi Date: Fri, 13 Jun 2014 10:52:21 -0700 Subject: [PATCH 36/38] Workaround in Spark for ConcurrentModification issue (JIRA Hadoop-10456, Spark-1097) This fix has gone into Hadoop 2.4.1. For developers using < 2.4.1, it would be good to have a workaround in Spark as well. Fix has been tested for performance as well, no regressions found. Author: nravi Closes #1000 from nishkamravi2/master and squashes the following commits: eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- .../main/scala/org/apache/spark/rdd/HadoopRDD.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 6547755764dcf..2aa111d600e9b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -139,10 +139,13 @@ class HadoopRDD[K, V]( // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - val newJobConf = new JobConf(broadcastedConf.value.value) - initLocalJobConfFuncOpt.map(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - newJobConf + // synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456) + broadcastedConf.synchronized { + val newJobConf = new JobConf(broadcastedConf.value.value) + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } } } From 1c2fd015b05b65abc83c4874ada825deac578af8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 13 Jun 2014 12:55:15 -0700 Subject: [PATCH 37/38] [SPARK-1964][SQL] Add timestamp to HiveMetastoreTypes.toMetastoreType Author: Michael Armbrust Closes #1061 from marmbrus/timestamp and squashes the following commits: 79c3903 [Michael Armbrust] Add timestamp to HiveMetastoreTypes.toMetastoreType() --- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index e9e6497f7e5fa..68284344afd55 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -237,6 +237,7 @@ object HiveMetastoreTypes extends RegexParsers { case BinaryType => "binary" case BooleanType => "boolean" case DecimalType => "decimal" + case TimestampType => "timestamp" } } From ac96d9657c9a9f89a455a1b671c059d896012d41 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 13 Jun 2014 12:59:48 -0700 Subject: [PATCH 38/38] [SPARK-2094][SQL] "Exactly once" semantics for DDL and command statements ## Related JIRA issues - Main issue: - [SPARK-2094](https://issues.apache.org/jira/browse/SPARK-2094): Ensure exactly once semantics for DDL/Commands - Issues resolved as dependencies: - [SPARK-2081](https://issues.apache.org/jira/browse/SPARK-2081): Undefine output() from the abstract class Command and implement it in concrete subclasses - [SPARK-2128](https://issues.apache.org/jira/browse/SPARK-2128): No plan for DESCRIBE - [SPARK-1852](https://issues.apache.org/jira/browse/SPARK-1852): SparkSQL Queries with Sorts run before the user asks them to - Other related issue: - [SPARK-2129](https://issues.apache.org/jira/browse/SPARK-2129): NPE thrown while lookup a view Two test cases, `join_view` and `mergejoin_mixed`, within the `HiveCompatibilitySuite` are removed from the whitelist to workaround this issue. ## PR Overview This PR defines physical plans for DDL statements and commands and wraps their side effects in a lazy field `PhysicalCommand.sideEffectResult`, so that they are executed eagerly and exactly once. Also, as a positive side effect, now DDL statements and commands can be turned into proper `SchemaRDD`s and let user query the execution results. This PR defines schemas for the following DDL/commands: - EXPLAIN command - `plan`: String, the plan explanation - SET command - `key`: String, the key(s) of the propert(y/ies) being set or queried - `value`: String, the value(s) of the propert(y/ies) being queried - Other Hive native command - `result`: String, execution result returned by Hive **NOTE**: We should refine schemas for different native commands by defining physical plans for them in the future. ## Examples ### EXPLAIN command Take the "EXPLAIN" command as an example, we first execute the command and obtain a `SchemaRDD` at the same time, then query the `plan` field with the schema DSL: ``` scala> loadTestTable("src") ... scala> val q0 = hql("EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key") ... q0: org.apache.spark.sql.SchemaRDD = SchemaRDD[0] at RDD at SchemaRDD.scala:98 == Query Plan == ExplainCommandPhysical [plan#11:0] Aggregate false, [key#4], [key#4,SUM(PartialCount#6L) AS c_1#2L] Exchange (HashPartitioning [key#4:0], 200) Exchange (HashPartitioning [key#4:0], 200) Aggregate true, [key#4], [key#4,COUNT(1) AS PartialCount#6L] HiveTableScan [key#4], (MetastoreRelation default, src, None), None scala> q0.select('plan).collect() ... [ExplainCommandPhysical [plan#24:0] Aggregate false, [key#17], [key#17,SUM(PartialCount#19L) AS c_1#2L] Exchange (HashPartitioning [key#17:0], 200) Exchange (HashPartitioning [key#17:0], 200) Aggregate true, [key#17], [key#17,COUNT(1) AS PartialCount#19L] HiveTableScan [key#17], (MetastoreRelation default, src, None), None] scala> ``` ### SET command In this example we query all the properties set in `SQLConf`, register the result as a table, and then query the table with HiveQL: ``` scala> val q1 = hql("SET") ... q1: org.apache.spark.sql.SchemaRDD = SchemaRDD[7] at RDD at SchemaRDD.scala:98 == Query Plan == scala> q1.registerAsTable("properties") scala> hql("SELECT key, value FROM properties ORDER BY key LIMIT 10").foreach(println) ... == Query Plan == TakeOrdered 10, [key#51:0 ASC] Project [key#51:0,value#52:1] SetCommandPhysical None, None, [key#55:0,value#56:1]), which has no missing parents 14/06/12 12:19:27 INFO scheduler.DAGScheduler: Submitting 1 missing tasks from Stage 5 (SchemaRDD[21] at RDD at SchemaRDD.scala:98 == Query Plan == TakeOrdered 10, [key#51:0 ASC] Project [key#51:0,value#52:1] SetCommandPhysical None, None, [key#55:0,value#56:1]) ... [datanucleus.autoCreateSchema,true] [datanucleus.autoStartMechanismMode,checked] [datanucleus.cache.level2,false] [datanucleus.cache.level2.type,none] [datanucleus.connectionPoolingType,BONECP] [datanucleus.fixedDatastore,false] [datanucleus.identifierFactory,datanucleus1] [datanucleus.plugin.pluginRegistryBundleCheck,LOG] [datanucleus.rdbms.useLegacyNativeValueStrategy,true] [datanucleus.storeManagerType,rdbms] scala> ``` ### "Exactly once" semantics At last, an example of the "exactly once" semantics: ``` scala> val q2 = hql("CREATE TABLE t1(key INT, value STRING)") ... q2: org.apache.spark.sql.SchemaRDD = SchemaRDD[28] at RDD at SchemaRDD.scala:98 == Query Plan == scala> table("t1") ... res9: org.apache.spark.sql.SchemaRDD = SchemaRDD[32] at RDD at SchemaRDD.scala:98 == Query Plan == HiveTableScan [key#58,value#59], (MetastoreRelation default, t1, None), None scala> q2.collect() ... res10: Array[org.apache.spark.sql.Row] = Array([]) scala> ``` As we can see, the "CREATE TABLE" command is executed eagerly right after the `SchemaRDD` is created, and referencing the `SchemaRDD` again won't trigger a duplicated execution. Author: Cheng Lian Closes #1071 from liancheng/exactlyOnceCommand and squashes the following commits: d005b03 [Cheng Lian] Made "SET key=value" returns the newly set key value pair f6c7715 [Cheng Lian] Added test cases for DDL/command statement RDDs 1d00937 [Cheng Lian] Makes SchemaRDD DSLs work for DDL/command statement RDDs 5c7e680 [Cheng Lian] Bug fix: wrong type used in pattern matching 48aa2e5 [Cheng Lian] Refined SQLContext.emptyResult as an empty RDD[Row] cc64f32 [Cheng Lian] Renamed physical plan classes for DDL/commands 74789c1 [Cheng Lian] Fixed failing test cases 0ad343a [Cheng Lian] Added physical plan for DDL and commands to ensure the "exactly once" semantics --- .../sql/catalyst/plans/logical/commands.scala | 18 +-- .../optimizer/FilterPushdownSuite.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 39 +----- .../org/apache/spark/sql/SchemaRDD.scala | 2 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 15 ++- .../spark/sql/api/java/JavaSchemaRDD.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 11 +- .../apache/spark/sql/execution/commands.scala | 81 ++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 16 +-- .../apache/spark/sql/hive/HiveContext.scala | 55 +++----- .../spark/sql/hive/HiveStrategies.scala | 8 ++ .../sql/hive/execution/hiveOperators.scala | 32 ++++- .../hive/execution/HiveComparisonTest.scala | 3 +- .../execution/HiveCompatibilitySuite.scala | 9 +- .../sql/hive/execution/HiveQuerySuite.scala | 125 +++++++++++++----- 15 files changed, 251 insertions(+), 167 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index d05c9652753e0..3299e86b85941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference} import org.apache.spark.sql.catalyst.types.StringType /** @@ -26,23 +26,25 @@ import org.apache.spark.sql.catalyst.types.StringType */ abstract class Command extends LeafNode { self: Product => - def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this + def output: Seq[Attribute] = Seq.empty } /** * Returned for commands supported by a given parser, but not catalyst. In general these are DDL * commands that are passed directly to another system. */ -case class NativeCommand(cmd: String) extends Command +case class NativeCommand(cmd: String) extends Command { + override def output = + Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)())) +} /** * Commands of the form "SET (key) (= value)". */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - AttributeReference("key", StringType, nullable = false)(), - AttributeReference("value", StringType, nullable = false)() - ) + BoundReference(0, AttributeReference("key", StringType, nullable = false)()), + BoundReference(1, AttributeReference("value", StringType, nullable = false)())) } /** @@ -50,11 +52,11 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman * actually performing the execution. */ case class ExplainCommand(plan: LogicalPlan) extends Command { - override def output = Seq(AttributeReference("plan", StringType, nullable = false)()) + override def output = + Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)())) } /** * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. */ case class CacheCommand(tableName: String, doCache: Boolean) extends Command - diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0cada785b6630..1f67c80e54906 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -161,7 +161,7 @@ class FilterPushdownSuite extends OptimizerTest { comparePlans(optimized, correctAnswer) } - + test("joins: push down left outer join #1") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b6a2f1b9d1278..378ff54531118 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.{ScalaReflection, dsl} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.columnar.InMemoryRelation @@ -147,14 +147,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def sql(sqlText: String): SchemaRDD = { - val result = new SchemaRDD(this, parseSql(sqlText)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but do not perform any execution. - result.queryExecution.toRdd - result - } + def sql(sqlText: String): SchemaRDD = new SchemaRDD(this, parseSql(sqlText)) /** Returns the specified table as a SchemaRDD */ def table(tableName: String): SchemaRDD = @@ -259,8 +252,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] val planner = new SparkPlanner @transient - protected[sql] lazy val emptyResult = - sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) /** * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and @@ -280,22 +272,6 @@ class SQLContext(@transient val sparkContext: SparkContext) protected abstract class QueryExecution { def logical: LogicalPlan - def eagerlyProcess(plan: LogicalPlan): RDD[Row] = plan match { - case SetCommand(key, value) => - // Only this case needs to be executed eagerly. The other cases will - // be taken care of when the actual results are being extracted. - // In the case of HiveContext, sqlConf is overridden to also pass the - // pair into its HiveConf. - if (key.isDefined && value.isDefined) { - set(key.get, value.get) - } - // It doesn't matter what we return here, since this is only used - // to force the evaluation to happen eagerly. To query the results, - // one must use SchemaRDD operations to extract them. - emptyResult - case _ => executedPlan.execute() - } - lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... @@ -303,12 +279,7 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[Row] = { - logical match { - case s: SetCommand => eagerlyProcess(s) - case _ => executedPlan.execute() - } - } + lazy val toRdd: RDD[Row] = executedPlan.execute() protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } @@ -330,7 +301,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * TODO: We only support primitive types, add support for nested types. */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { - val schema = rdd.first.map { case (fieldName, obj) => + val schema = rdd.first().map { case (fieldName, obj) => val dataType = obj.getClass match { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 7ad8edf5a5a6e..821ac850ac3f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -97,7 +97,7 @@ import java.util.{Map => JMap} @AlphaComponent class SchemaRDD( @transient val sqlContext: SQLContext, - @transient protected[spark] val logicalPlan: LogicalPlan) + @transient val baseLogicalPlan: LogicalPlan) extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike { def baseSchemaRDD = this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 3a895e15a4508..656be965a8fd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.SparkLogicalPlan /** * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) */ private[sql] trait SchemaRDDLike { @transient val sqlContext: SQLContext - @transient protected[spark] val logicalPlan: LogicalPlan + @transient val baseLogicalPlan: LogicalPlan private[sql] def baseSchemaRDD: SchemaRDD @@ -48,7 +49,17 @@ private[sql] trait SchemaRDDLike { */ @transient @DeveloperApi - lazy val queryExecution = sqlContext.executePlan(logicalPlan) + lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) + + @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile => + queryExecution.toRdd + SparkLogicalPlan(queryExecution.executedPlan) + case _ => + baseLogicalPlan + } override def toString = s"""${super.toString} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 22f57b758dd02..aff6ffe9f3478 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -37,7 +37,7 @@ import org.apache.spark.storage.StorageLevel */ class JavaSchemaRDD( @transient val sqlContext: SQLContext, - @transient protected[spark] val logicalPlan: LogicalPlan) + @transient val baseLogicalPlan: LogicalPlan) extends JavaRDDLike[Row, JavaRDD[Row]] with SchemaRDDLike { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a657911afe538..2233216a6ec52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLConf, SQLContext, execution} +import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -157,7 +157,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil - case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => { + case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { (filters: Seq[Expression]) => { @@ -186,7 +186,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { filters, prunePushedDownFilters, ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil - } case _ => Nil } @@ -250,12 +249,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.SetCommand(key, value) => - Seq(execution.SetCommandPhysical(key, value, plan.output)(context)) + Seq(execution.SetCommand(key, value, plan.output)(context)) case logical.ExplainCommand(child) => val executedPlan = context.executePlan(child).executedPlan - Seq(execution.ExplainCommandPhysical(executedPlan, plan.output)(context)) + Seq(execution.ExplainCommand(executedPlan, plan.output)(context)) case logical.CacheCommand(tableName, cache) => - Seq(execution.CacheCommandPhysical(tableName, cache)(context)) + Seq(execution.CacheCommand(tableName, cache)(context)) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index be26d19e66862..0377290af5926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -22,45 +22,69 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} +trait Command { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffectResult` + * so that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[Any] = Seq.empty[Any] +} + /** * :: DeveloperApi :: */ @DeveloperApi -case class SetCommandPhysical(key: Option[String], value: Option[String], output: Seq[Attribute]) - (@transient context: SQLContext) extends LeafNode { - def execute(): RDD[Row] = (key, value) match { - // Set value for key k; the action itself would - // have been performed in QueryExecution eagerly. - case (Some(k), Some(v)) => context.emptyResult +case class SetCommand( + key: Option[String], value: Option[String], output: Seq[Attribute])( + @transient context: SQLContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match { + // Set value for key k. + case (Some(k), Some(v)) => + context.set(k, v) + Array(k -> v) + // Query the value bound to key k. - case (Some(k), None) => - val resultString = context.getOption(k) match { - case Some(v) => s"$k=$v" - case None => s"$k is undefined" - } - context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](resultString))), 1) + case (Some(k), _) => + Array(k -> context.getOption(k).getOrElse("")) + // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - val pairs = context.getAll - val rows = pairs.map { case (k, v) => - new GenericRow(Array[Any](s"$k=$v")) - }.toSeq - // Assume config parameters can fit into one split (machine) ;) - context.sparkContext.parallelize(rows, 1) - // The only other case is invalid semantics and is impossible. - case _ => context.emptyResult + context.getAll + + case _ => + throw new IllegalArgumentException() } + + def execute(): RDD[Row] = { + val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) } + context.sparkContext.parallelize(rows, 1) + } + + override def otherCopyArgs = context :: Nil } /** * :: DeveloperApi :: */ @DeveloperApi -case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) - (@transient context: SQLContext) extends UnaryNode { +case class ExplainCommand( + child: SparkPlan, output: Seq[Attribute])( + @transient context: SQLContext) + extends UnaryNode with Command { + + // Actually "EXPLAIN" command doesn't cause any side effect. + override protected[sql] lazy val sideEffectResult: Seq[String] = this.toString.split("\n") + def execute(): RDD[Row] = { - val planString = new GenericRow(Array[Any](child.toString)) - context.sparkContext.parallelize(Seq(planString)) + val explanation = sideEffectResult.mkString("\n") + context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](explanation))), 1) } override def otherCopyArgs = context :: Nil @@ -70,19 +94,20 @@ case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) * :: DeveloperApi :: */ @DeveloperApi -case class CacheCommandPhysical(tableName: String, doCache: Boolean)(@transient context: SQLContext) - extends LeafNode { +case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext) + extends LeafNode with Command { - lazy val commandSideEffect = { + override protected[sql] lazy val sideEffectResult = { if (doCache) { context.cacheTable(tableName) } else { context.uncacheTable(tableName) } + Seq.empty[Any] } override def execute(): RDD[Row] = { - commandSideEffect + sideEffectResult context.emptyResult } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c1fc99f077431..e9360b0fc7910 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -141,7 +141,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), Seq((2147483645.0,1),(2.0,2))) } - + test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), @@ -332,7 +332,7 @@ class SQLQuerySuite extends QueryTest { (3, "C"), (4, "D"))) } - + test("system function upper()") { checkAnswer( sql("SELECT n,UPPER(l) FROM lowerCaseData"), @@ -349,7 +349,7 @@ class SQLQuerySuite extends QueryTest { (2, "ABC"), (3, null))) } - + test("system function lower()") { checkAnswer( sql("SELECT N,LOWER(L) FROM upperCaseData"), @@ -382,25 +382,25 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(s"$testKey=$testVal")) + Seq(Seq(testKey, testVal)) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(s"$testKey=$testVal"), - Seq(s"${testKey + testKey}=${testVal + testVal}")) + Seq(testKey, testVal), + Seq(testKey + testKey, testVal + testVal)) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(s"$testKey=$testVal")) + Seq(Seq(testKey, testVal)) ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(s"$nonexistentKey is undefined")) + Seq(Seq(nonexistentKey, "")) ) clear() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9cd13f6ae0d57..96e0ec5136331 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -15,8 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql -package hive +package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.util.{ArrayList => JArrayList} @@ -32,12 +31,13 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.execution.{Command => PhysicalCommand} /** * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is @@ -71,14 +71,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD. */ - def hiveql(hqlQuery: String): SchemaRDD = { - val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but does not perform any execution. - result.queryExecution.toRdd - result - } + def hiveql(hqlQuery: String): SchemaRDD = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) /** An alias for `hiveql`. */ def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) @@ -164,7 +157,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * Runs the specified SQL query using Hive. */ - protected def runSqlHive(sql: String): Seq[String] = { + protected[sql] def runSqlHive(sql: String): Seq[String] = { val maxResults = 100000 val results = runHive(sql, 100000) // It is very confusing when you only get back some of the results... @@ -228,6 +221,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override val strategies: Seq[Strategy] = Seq( CommandStrategy(self), + HiveCommandStrategy(self), TakeOrdered, ParquetOperations, InMemoryScans, @@ -252,25 +246,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override lazy val optimizedPlan = optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) - override lazy val toRdd: RDD[Row] = { - def processCmd(cmd: String): RDD[Row] = { - val output = runSqlHive(cmd) - if (output.size == 0) { - emptyResult - } else { - val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) - sparkContext.parallelize(asRows, 1) - } - } - - logical match { - case s: SetCommand => eagerlyProcess(s) - case _ => analyzed match { - case NativeCommand(cmd) => processCmd(cmd) - case _ => executedPlan.execute().map(_.copy()) - } - } - } + override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, @@ -298,7 +274,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ))=> + case (seq: Seq[_], ArrayType(typ)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_,_], MapType(kType, vType)) => map.map { @@ -314,10 +290,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Returns the result as a hive compatible sequence of strings. For native commands, the * execution is simply passed back to Hive. */ - def stringResult(): Seq[String] = analyzed match { - case NativeCommand(cmd) => runSqlHive(cmd) - case ExplainCommand(plan) => executePlan(plan).toString.split("\n") - case query => + def stringResult(): Seq[String] = executedPlan match { + case command: PhysicalCommand => + command.sideEffectResult.map(_.toString) + + case other => val result: Seq[Seq[Any]] = toRdd.collect().toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) @@ -328,8 +305,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def simpleString: String = logical match { - case _: NativeCommand => "" - case _: SetCommand => "" + case _: NativeCommand => "" + case _: SetCommand => "" case _ => executedPlan.toString } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ed6cd5a11dba4..0ac0ee9071f36 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -76,4 +76,12 @@ private[hive] trait HiveStrategies { Nil } } + + case class HiveCommandStrategy(context: HiveContext) extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.NativeCommand(sql) => + NativeCommand(sql, plan.output)(context) :: Nil + case _ => Nil + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala index 29b4b9b006e45..a839231449161 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala @@ -32,14 +32,15 @@ import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ -import org.apache.spark.{TaskContext, SparkException} import org.apache.spark.util.MutablePair +import org.apache.spark.{TaskContext, SparkException} /* Implicits */ import scala.collection.JavaConversions._ @@ -57,7 +58,7 @@ case class HiveTableScan( attributes: Seq[Attribute], relation: MetastoreRelation, partitionPruningPred: Option[Expression])( - @transient val sc: HiveContext) + @transient val context: HiveContext) extends LeafNode with HiveInspectors { @@ -75,7 +76,7 @@ case class HiveTableScan( } @transient - val hadoopReader = new HadoopTableReader(relation.tableDesc, sc) + val hadoopReader = new HadoopTableReader(relation.tableDesc, context) /** * The hive object inspector for this table, which can be used to extract values from the @@ -156,7 +157,7 @@ case class HiveTableScan( hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames) } - addColumnMetadataToConf(sc.hiveconf) + addColumnMetadataToConf(context.hiveconf) @transient def inputRdd = if (!relation.hiveQlTable.isPartitioned) { @@ -428,3 +429,26 @@ case class InsertIntoHiveTable( sc.sparkContext.makeRDD(Nil, 1) } } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class NativeCommand( + sql: String, output: Seq[Attribute])( + @transient context: HiveContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql) + + override def execute(): RDD[spark.sql.Row] = { + if (sideEffectResult.size == 0) { + context.emptyResult + } else { + val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r))) + context.sparkContext.parallelize(rows, 1) + } + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 357c7e654bd20..24c929ff7430d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -24,6 +24,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive.test.TestHive @@ -141,7 +142,7 @@ abstract class HiveComparisonTest // Hack: Hive simply prints the result of a SET command to screen, // and does not return it as a query answer. case _: SetCommand => Seq("0") - case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") + case _: LogicalNativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") case _: ExplainCommand => answer case plan => if (isSorted(plan)) answer else answer.sorted } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 3581617c269a6..ee194dbcb77b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -172,7 +172,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "case_sensitivity", // Flaky test, Hive sometimes returns different set of 10 rows. - "lateral_view_outer" + "lateral_view_outer", + + // After stop taking the `stringOrError` route, exceptions are thrown from these cases. + // See SPARK-2129 for details. + "join_view", + "mergejoins_mixed" ) /** @@ -476,7 +481,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_reorder3", "join_reorder4", "join_star", - "join_view", "lateral_view", "lateral_view_cp", "lateral_view_ppd", @@ -507,7 +511,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "merge1", "merge2", "mergejoins", - "mergejoins_mixed", "multigroupby_singlemr", "multi_insert_gby", "multi_insert_gby3", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6c239b02ed09a..0d656c556965d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.Row -import org.apache.spark.sql.hive.test.TestHive._ +import scala.util.Try + import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{SchemaRDD, execution, Row} /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. @@ -162,16 +164,60 @@ class HiveQuerySuite extends HiveComparisonTest { hql("SELECT * FROM src").toString } + private val explainCommandClassName = + classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$") + + def isExplanation(result: SchemaRDD) = { + val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } + explanation.size == 1 && explanation.head.startsWith(explainCommandClassName) + } + test("SPARK-1704: Explain commands as a SchemaRDD") { hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + val rdd = hql("explain select key, count(value) from src group by key") - assert(rdd.collect().size == 1) - assert(rdd.toString.contains("ExplainCommand")) - assert(rdd.filter(row => row.toString.contains("ExplainCommand")).collect().size == 0, - "actual contents of the result should be the plans of the query to be explained") + assert(isExplanation(rdd)) + TestHive.reset() } + test("Query Hive native command execution result") { + val tableName = "test_native_commands" + + val q0 = hql(s"DROP TABLE IF EXISTS $tableName") + assert(q0.count() == 0) + + val q1 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") + assert(q1.count() == 0) + + val q2 = hql("SHOW TABLES") + val tables = q2.select('result).collect().map { case Row(table: String) => table } + assert(tables.contains(tableName)) + + val q3 = hql(s"DESCRIBE $tableName") + assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { + q3.select('result).collect().map { case Row(fieldDesc: String) => + fieldDesc.split("\t").map(_.trim) + } + } + + val q4 = hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key") + assert(isExplanation(q4)) + + TestHive.reset() + } + + test("Exactly once semantics for DDL and command statements") { + val tableName = "test_exactly_once" + val q0 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") + + // If the table was not created, the following assertion would fail + assert(Try(table(tableName)).isSuccess) + + // If the CREATE TABLE command got executed again, the following assertion would fail + assert(Try(q0.count()).isSuccess) + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" @@ -195,52 +241,69 @@ class HiveQuerySuite extends HiveComparisonTest { test("SET commands semantics for a HiveContext") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" - var testVal = "test.val.0" + val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def fromRows(row: Array[Row]): Array[String] = row.map(_.getString(0)) + def rowsToPairs(rows: Array[Row]) = rows.map { case Row(key: String, value: String) => + key -> value + } clear() // "set" itself returns all config variables currently specified in SQLConf. - assert(hql("set").collect().size == 0) + assert(hql("SET").collect().size == 0) + + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql(s"SET $testKey=$testVal").collect()) + } - // "set key=val" - hql(s"SET $testKey=$testVal") - assert(fromRows(hql("SET").collect()) sameElements Array(s"$testKey=$testVal")) assert(hiveconf.get(testKey, "") == testVal) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql("SET").collect()) + } hql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(fromRows(hql("SET").collect()) sameElements - Array( - s"$testKey=$testVal", - s"${testKey + testKey}=${testVal + testVal}")) assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + rowsToPairs(hql("SET").collect()) + } // "set key" - assert(fromRows(hql(s"SET $testKey").collect()) sameElements - Array(s"$testKey=$testVal")) - assert(fromRows(hql(s"SET $nonexistentKey").collect()) sameElements - Array(s"$nonexistentKey is undefined")) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(hql(s"SET $testKey").collect()) + } + + assertResult(Array(nonexistentKey -> "")) { + rowsToPairs(hql(s"SET $nonexistentKey").collect()) + } // Assert that sql() should have the same effects as hql() by repeating the above using sql(). clear() - assert(sql("set").collect().size == 0) + assert(sql("SET").collect().size == 0) + + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql(s"SET $testKey=$testVal").collect()) + } - sql(s"SET $testKey=$testVal") - assert(fromRows(sql("SET").collect()) sameElements Array(s"$testKey=$testVal")) assert(hiveconf.get(testKey, "") == testVal) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql("SET").collect()) + } sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(fromRows(sql("SET").collect()) sameElements - Array( - s"$testKey=$testVal", - s"${testKey + testKey}=${testVal + testVal}")) assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + rowsToPairs(sql("SET").collect()) + } - assert(fromRows(sql(s"SET $testKey").collect()) sameElements - Array(s"$testKey=$testVal")) - assert(fromRows(sql(s"SET $nonexistentKey").collect()) sameElements - Array(s"$nonexistentKey is undefined")) + assertResult(Array(testKey -> testVal)) { + rowsToPairs(sql(s"SET $testKey").collect()) + } + + assertResult(Array(nonexistentKey -> "")) { + rowsToPairs(sql(s"SET $nonexistentKey").collect()) + } + + clear() } // Put tests that depend on specific Hive settings before these last two test,