From 05bf4e4aff0d052a53d3e64c43688f07e27fec50 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 4 Aug 2014 20:39:18 -0700 Subject: [PATCH 01/13] [SPARK-2323] Exception in accumulator update should not crash DAGScheduler & SparkContext Author: Reynold Xin Closes #1772 from rxin/accumulator-dagscheduler and squashes the following commits: 6a58520 [Reynold Xin] [SPARK-2323] Exception in accumulator update should not crash DAGScheduler & SparkContext. --- .../org/apache/spark/scheduler/DAGScheduler.scala | 9 +++++++-- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 11 +++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d87c3048985fc..9fa3a4e9c71ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -904,8 +904,13 @@ class DAGScheduler( event.reason match { case Success => if (event.accumUpdates != null) { - // TODO: fail the stage if the accumulator update fails... - Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted + try { + Accumulators.add(event.accumUpdates) + } catch { + // If we see an exception during accumulator update, just log the error and move on. + case e: Exception => + logError(s"Failed to update accumulators for $task", e) + } } stage.pendingTasks -= task task match { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 36e238b4c9434..8c1b0fed11f72 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -622,8 +622,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } - // TODO: Fix this and un-ignore the test. - ignore("misbehaved accumulator should not crash DAGScheduler and SparkContext") { + test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { val acc = new Accumulator[Int](0, new AccumulatorParam[Int] { override def addAccumulator(t1: Int, t2: Int): Int = t1 + t2 override def zero(initialValue: Int): Int = 0 @@ -633,14 +632,10 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F }) // Run this on executors - intercept[SparkDriverExecutionException] { - sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } - } + sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } // Run this within a local thread - intercept[SparkDriverExecutionException] { - sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) - } + sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) // Make sure we can still run local commands as well as cluster commands. assert(sc.parallelize(1 to 10, 2).count() === 10) From 066765d60d21b6b9943862b788e4a4bd07396e6c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Aug 2014 23:27:53 -0700 Subject: [PATCH 02/13] SPARK-2685. Update ExternalAppendOnlyMap to avoid buffer.remove() Replaces this with an O(1) operation that does not have to shift over the whole tail of the array into the gap produced by the element removed. Author: Matei Zaharia Closes #1773 from mateiz/SPARK-2685 and squashes the following commits: 1ea028a [Matei Zaharia] Update comments in StreamBuffer and EAOM, and reuse ArrayBuffers eb1abfd [Matei Zaharia] Update ExternalAppendOnlyMap to avoid buffer.remove() --- .../collection/ExternalAppendOnlyMap.scala | 50 +++++++++++++------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 5d10a1f84493c..1f7d2dc838ebc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -286,30 +286,32 @@ class ExternalAppendOnlyMap[K, V, C]( private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => - val kcPairs = getMorePairs(it) + val kcPairs = new ArrayBuffer[(K, C)] + readNextHashCode(it, kcPairs) if (kcPairs.length > 0) { mergeHeap.enqueue(new StreamBuffer(it, kcPairs)) } } /** - * Fetch from the given iterator until a key of different hash is retrieved. + * Fill a buffer with the next set of keys with the same hash code from a given iterator. We + * read streams one hash code at a time to ensure we don't miss elements when they are merged. + * + * Assumes the given iterator is in sorted order of hash code. * - * In the event of key hash collisions, this ensures no pairs are hidden from being merged. - * Assume the given iterator is in sorted order. + * @param it iterator to read from + * @param buf buffer to write the results into */ - private def getMorePairs(it: BufferedIterator[(K, C)]): ArrayBuffer[(K, C)] = { - val kcPairs = new ArrayBuffer[(K, C)] + private def readNextHashCode(it: BufferedIterator[(K, C)], buf: ArrayBuffer[(K, C)]): Unit = { if (it.hasNext) { var kc = it.next() - kcPairs += kc + buf += kc val minHash = hashKey(kc) while (it.hasNext && it.head._1.hashCode() == minHash) { kc = it.next() - kcPairs += kc + buf += kc } } - kcPairs } /** @@ -321,7 +323,9 @@ class ExternalAppendOnlyMap[K, V, C]( while (i < buffer.pairs.length) { val pair = buffer.pairs(i) if (pair._1 == key) { - buffer.pairs.remove(i) + // Note that there's at most one pair in the buffer with a given key, since we always + // merge stuff in a map before spilling, so it's safe to return after the first we find + removeFromBuffer(buffer.pairs, i) return mergeCombiners(baseCombiner, pair._2) } i += 1 @@ -329,6 +333,19 @@ class ExternalAppendOnlyMap[K, V, C]( baseCombiner } + /** + * Remove the index'th element from an ArrayBuffer in constant time, swapping another element + * into its place. This is more efficient than the ArrayBuffer.remove method because it does + * not have to shift all the elements in the array over. It works for our array buffers because + * we don't care about the order of elements inside, we just want to search them for a key. + */ + private def removeFromBuffer[T](buffer: ArrayBuffer[T], index: Int): T = { + val elem = buffer(index) + buffer(index) = buffer(buffer.size - 1) // This also works if index == buffer.size - 1 + buffer.reduceToSize(buffer.size - 1) + elem + } + /** * Return true if there exists an input stream that still has unvisited pairs. */ @@ -346,7 +363,7 @@ class ExternalAppendOnlyMap[K, V, C]( val minBuffer = mergeHeap.dequeue() val minPairs = minBuffer.pairs val minHash = minBuffer.minKeyHash - val minPair = minPairs.remove(0) + val minPair = removeFromBuffer(minPairs, 0) val minKey = minPair._1 var minCombiner = minPair._2 assert(hashKey(minPair) == minHash) @@ -363,7 +380,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Repopulate each visited stream buffer and add it back to the queue if it is non-empty mergedBuffers.foreach { buffer => if (buffer.isEmpty) { - buffer.pairs ++= getMorePairs(buffer.iterator) + readNextHashCode(buffer.iterator, buffer.pairs) } if (!buffer.isEmpty) { mergeHeap.enqueue(buffer) @@ -375,10 +392,13 @@ class ExternalAppendOnlyMap[K, V, C]( /** * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash. - * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to - * hash collisions, it is possible for multiple keys to be "tied" for being the lowest. + * Each buffer maintains all of the key-value pairs with what is currently the lowest hash + * code among keys in the stream. There may be multiple keys if there are hash collisions. + * Note that because when we spill data out, we only spill one value for each key, there is + * at most one element for each key. * - * StreamBuffers are ordered by the minimum key hash found across all of their own pairs. + * StreamBuffers are ordered by the minimum key hash currently available in their stream so + * that we can put them into a heap and sort that. */ private class StreamBuffer( val iterator: BufferedIterator[(K, C)], From 4fde28c2063f673ec7f51d514ba62a73321960a1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Aug 2014 23:41:03 -0700 Subject: [PATCH 03/13] SPARK-2711. Create a ShuffleMemoryManager to track memory for all spilling collections This tracks memory properly if there are multiple spilling collections in the same task (which was a problem before), and also implements an algorithm that lets each thread grow up to 1 / 2N of the memory pool (where N is the number of threads) before spilling, which avoids an inefficiency with small spills we had before (some threads would spill many times at 0-1 MB because the pool was allocated elsewhere). Author: Matei Zaharia Closes #1707 from mateiz/spark-2711 and squashes the following commits: debf75b [Matei Zaharia] Review comments 24f28f3 [Matei Zaharia] Small rename c8f3a8b [Matei Zaharia] Update ShuffleMemoryManager to be able to partially grant requests 315e3a5 [Matei Zaharia] Some review comments b810120 [Matei Zaharia] Create central manager to track memory for all spilling collections --- .../scala/org/apache/spark/SparkEnv.scala | 10 +- .../org/apache/spark/executor/Executor.scala | 5 +- .../spark/shuffle/ShuffleMemoryManager.scala | 125 ++++++++ .../collection/ExternalAppendOnlyMap.scala | 48 +-- .../util/collection/ExternalSorter.scala | 49 +-- .../shuffle/ShuffleMemoryManagerSuite.scala | 294 ++++++++++++++++++ 6 files changed, 450 insertions(+), 81 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0bce531aaba3e..dd8e4ac66dc66 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -35,7 +35,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.ConnectionManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -66,12 +66,9 @@ class SparkEnv ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, + val shuffleMemoryManager: ShuffleMemoryManager, val conf: SparkConf) extends Logging { - // A mapping of thread ID to amount of memory, in bytes, used for shuffle aggregations - // All accesses should be manually synchronized - val shuffleMemoryMap = mutable.HashMap[Long, Long]() - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation @@ -252,6 +249,8 @@ object SparkEnv extends Logging { val shuffleManager = instantiateClass[ShuffleManager]( "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + val shuffleMemoryManager = new ShuffleMemoryManager(conf) + // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -273,6 +272,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, + shuffleMemoryManager, conf) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1bb1b4aae91bb..c2b9c660ddaec 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -276,10 +276,7 @@ private[spark] class Executor( } } finally { // Release memory used by this thread for shuffles - val shuffleMemoryMap = env.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap.remove(Thread.currentThread().getId) - } + env.shuffleMemoryManager.releaseMemoryForThisThread() // Release memory used by this thread for unrolling blocks env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() runningTasks.remove(taskId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala new file mode 100644 index 0000000000000..ee91a368b76ea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import scala.collection.mutable + +import org.apache.spark.{Logging, SparkException, SparkConf} + +/** + * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory + * from this pool and release it as it spills data out. When a task ends, all its memory will be + * released by the Executor. + * + * This class tries to ensure that each thread gets a reasonable share of memory, instead of some + * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * this set changes. This is all done by synchronizing access on "this" to mutate state and using + * wait() and notifyAll() to signal changes. + */ +private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { + private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + + def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + + /** + * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * obtained, or 0 if none can be allocated. This call may block until there is enough free memory + * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active threads) before it is forced to spill. This can + * happen if the number of threads increases but an older thread had a lot of memory already. + */ + def tryToAcquire(numBytes: Long): Long = synchronized { + val threadId = Thread.currentThread().getId + assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) + + // Add this thread to the threadMemory map just so we can keep an accurate count of the number + // of active threads, to let other threads ramp down their memory in calls to tryToAcquire + if (!threadMemory.contains(threadId)) { + threadMemory(threadId) = 0L + notifyAll() // Will later cause waiting threads to wake up and check numThreads again + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // thread would have more than 1 / numActiveThreads of the memory) or we have enough free + // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + while (true) { + val numActiveThreads = threadMemory.keys.size + val curMem = threadMemory(threadId) + val freeMemory = maxMemory - threadMemory.values.sum + + // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads + val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem) + + if (curMem < maxMemory / (2 * numActiveThreads)) { + // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; + // if we can't give it this much now, wait for other threads to free up memory + // (this happens if older threads allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } else { + logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + wait() + } + } else { + // Only give it as much memory as is free, which might be none if it reached 1 / numThreads + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } + } + 0L // Never reached + } + + /** Release numBytes bytes for the current thread. */ + def release(numBytes: Long): Unit = synchronized { + val threadId = Thread.currentThread().getId + val curMem = threadMemory.getOrElse(threadId, 0L) + if (curMem < numBytes) { + throw new SparkException( + s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + } + threadMemory(threadId) -= numBytes + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } + + /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisThread(): Unit = synchronized { + val threadId = Thread.currentThread().getId + threadMemory.remove(threadId) + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } +} + +private object ShuffleMemoryManager { + /** + * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction + * of the memory pool and a safety factor since collections can sometimes grow bigger than + * the size we target before we estimate their sizes again. + */ + def getMaxMemory(conf: SparkConf): Long = { + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1f7d2dc838ebc..cc0423856cefb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -71,13 +71,7 @@ class ExternalAppendOnlyMap[K, V, C]( private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - - // Collective memory threshold shared across all running tasks - private val maxMemoryThreshold = { - val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } + private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Number of pairs inserted since last spill; note that we count them even if a value is merged // with a previous key in case we're doing something like groupBy where the result grows @@ -140,28 +134,15 @@ class ExternalAppendOnlyMap[K, V, C]( if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && currentMap.estimateSize() >= myMemoryThreshold) { - val currentSize = currentMap.estimateSize() - var shouldSpill = false - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - - // Atomically check whether there is sufficient memory in the global pool for - // this map to grow and, if possible, allocate the required amount - shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId - val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) - val availableMemory = maxMemoryThreshold - - (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) - - // Try to allocate at least 2x more memory, otherwise spill - shouldSpill = availableMemory < currentSize * 2 - if (!shouldSpill) { - shuffleMemoryMap(threadId) = currentSize * 2 - myMemoryThreshold = currentSize * 2 - } - } - // Do not synchronize spills - if (shouldSpill) { - spill(currentSize) + // Claim up to double our current memory from the shuffle memory pool + val currentMemory = currentMap.estimateSize() + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + spill(currentMemory) // Will also release memory back to ShuffleMemoryManager } } currentMap.changeValue(curEntry._1, update) @@ -245,12 +226,9 @@ class ExternalAppendOnlyMap[K, V, C]( currentMap = new SizeTrackingAppendOnlyMap[K, C] spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - // Reset the amount of shuffle memory used by this map in the global pool - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap(Thread.currentThread().getId) = 0 - } - myMemoryThreshold = 0 + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0L elementsRead = 0 _memoryBytesSpilled += mapSize diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b04c50bd3e196..101c83b264f63 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -78,6 +78,7 @@ private[spark] class ExternalSorter[K, V, C]( private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager + private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() @@ -116,13 +117,6 @@ private[spark] class ExternalSorter[K, V, C]( private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L - // Collective memory threshold shared across all running tasks - private val maxMemoryThreshold = { - val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } - // How much of the shared memory pool this collection has claimed private var myMemoryThreshold = 0L @@ -218,31 +212,15 @@ private[spark] class ExternalSorter[K, V, C]( if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && collection.estimateSize() >= myMemoryThreshold) { - // TODO: This logic doesn't work if there are two external collections being used in the same - // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711] - - val currentSize = collection.estimateSize() - var shouldSpill = false - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - - // Atomically check whether there is sufficient memory in the global pool for - // us to double our threshold - shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId - val previouslyClaimedMemory = shuffleMemoryMap.get(threadId) - val availableMemory = maxMemoryThreshold - - (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L)) - - // Try to allocate at least 2x more memory, otherwise spill - shouldSpill = availableMemory < currentSize * 2 - if (!shouldSpill) { - shuffleMemoryMap(threadId) = currentSize * 2 - myMemoryThreshold = currentSize * 2 - } - } - // Do not hold lock during spills - if (shouldSpill) { - spill(currentSize, usingMap) + // Claim up to double our current memory from the shuffle memory pool + val currentMemory = collection.estimateSize() + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager } } } @@ -327,11 +305,8 @@ private[spark] class ExternalSorter[K, V, C]( buffer = new SizeTrackingPairBuffer[(Int, K), C] } - // Reset the amount of shuffle memory used by this map in the global pool - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap(Thread.currentThread().getId) = 0 - } + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) myMemoryThreshold = 0 spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala new file mode 100644 index 0000000000000..d31bc22ee74f7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.CountDownLatch + +class ShuffleMemoryManagerSuite extends FunSuite with Timeouts { + /** Launch a thread with the given body block and return it. */ + private def startThread(name: String)(body: => Unit): Thread = { + val thread = new Thread("ShuffleMemorySuite " + name) { + override def run() { + body + } + } + thread.start() + thread + } + + test("single thread requesting memory") { + val manager = new ShuffleMemoryManager(1000L) + + assert(manager.tryToAcquire(100L) === 100L) + assert(manager.tryToAcquire(400L) === 400L) + assert(manager.tryToAcquire(400L) === 400L) + assert(manager.tryToAcquire(200L) === 100L) + assert(manager.tryToAcquire(100L) === 0L) + assert(manager.tryToAcquire(100L) === 0L) + + manager.release(500L) + assert(manager.tryToAcquire(300L) === 300L) + assert(manager.tryToAcquire(300L) === 200L) + + manager.releaseMemoryForThisThread() + assert(manager.tryToAcquire(1000L) === 1000L) + assert(manager.tryToAcquire(100L) === 0L) + } + + test("two threads requesting full memory") { + // Two threads request 500 bytes first, wait for each other to get it, and then request + // 500 more; we should immediately return 0 as both are now at 1 / N + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Result1 = -1L + var t2Result1 = -1L + var t1Result2 = -1L + var t2Result2 = -1L + } + val state = new State + + val t1 = startThread("t1") { + val r1 = manager.tryToAcquire(500L) + state.synchronized { + state.t1Result1 = r1 + state.notifyAll() + while (state.t2Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t1Result2 = r2 } + } + + val t2 = startThread("t2") { + val r1 = manager.tryToAcquire(500L) + state.synchronized { + state.t2Result1 = r1 + state.notifyAll() + while (state.t1Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t2Result2 = r2 } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + assert(state.t1Result1 === 500L) + assert(state.t2Result1 === 500L) + assert(state.t1Result2 === 0L) + assert(state.t2Result2 === 0L) + } + + + test("threads cannot grow past 1 / N") { + // Two threads request 250 bytes first, wait for each other to get it, and then request + // 500 more; we should only grant 250 bytes to each of them on this second request + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Result1 = -1L + var t2Result1 = -1L + var t1Result2 = -1L + var t2Result2 = -1L + } + val state = new State + + val t1 = startThread("t1") { + val r1 = manager.tryToAcquire(250L) + state.synchronized { + state.t1Result1 = r1 + state.notifyAll() + while (state.t2Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t1Result2 = r2 } + } + + val t2 = startThread("t2") { + val r1 = manager.tryToAcquire(250L) + state.synchronized { + state.t2Result1 = r1 + state.notifyAll() + while (state.t1Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t2Result2 = r2 } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + assert(state.t1Result1 === 250L) + assert(state.t2Result1 === 250L) + assert(state.t1Result2 === 250L) + assert(state.t2Result2 === 250L) + } + + test("threads can block to get at least 1 / 2N memory") { + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps + // for a bit and releases 250 bytes, which should then be greanted to t2. Further requests + // by t2 will return false right away because it now has 1 / 2N of the memory. + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Requested = false + var t2Requested = false + var t1Result = -1L + var t2Result = -1L + var t2Result2 = -1L + var t2WaitTime = 0L + } + val state = new State + + val t1 = startThread("t1") { + state.synchronized { + state.t1Result = manager.tryToAcquire(1000L) + state.t1Requested = true + state.notifyAll() + while (!state.t2Requested) { + state.wait() + } + } + // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make + // sure the other thread blocks for some time otherwise + Thread.sleep(300) + manager.release(250L) + } + + val t2 = startThread("t2") { + state.synchronized { + while (!state.t1Requested) { + state.wait() + } + state.t2Requested = true + state.notifyAll() + } + val startTime = System.currentTimeMillis() + val result = manager.tryToAcquire(250L) + val endTime = System.currentTimeMillis() + state.synchronized { + state.t2Result = result + // A second call should return 0 because we're now already at 1 / 2N + state.t2Result2 = manager.tryToAcquire(100L) + state.t2WaitTime = endTime - startTime + } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + // Both threads should've been able to acquire their memory; the second one will have waited + // until the first one acquired 1000 bytes and then released 250 + state.synchronized { + assert(state.t1Result === 1000L, "t1 could not allocate memory") + assert(state.t2Result === 250L, "t2 could not allocate memory") + assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") + assert(state.t2Result2 === 0L, "t1 got extra memory the second time") + } + } + + test("releaseMemoryForThisThread") { + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps + // for a bit and releases all its memory. t2 should now be able to grab all the memory. + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Requested = false + var t2Requested = false + var t1Result = -1L + var t2Result1 = -1L + var t2Result2 = -1L + var t2Result3 = -1L + var t2WaitTime = 0L + } + val state = new State + + val t1 = startThread("t1") { + state.synchronized { + state.t1Result = manager.tryToAcquire(1000L) + state.t1Requested = true + state.notifyAll() + while (!state.t2Requested) { + state.wait() + } + } + // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make + // sure the other thread blocks for some time otherwise + Thread.sleep(300) + manager.releaseMemoryForThisThread() + } + + val t2 = startThread("t2") { + state.synchronized { + while (!state.t1Requested) { + state.wait() + } + state.t2Requested = true + state.notifyAll() + } + val startTime = System.currentTimeMillis() + val r1 = manager.tryToAcquire(500L) + val endTime = System.currentTimeMillis() + val r2 = manager.tryToAcquire(500L) + val r3 = manager.tryToAcquire(500L) + state.synchronized { + state.t2Result1 = r1 + state.t2Result2 = r2 + state.t2Result3 = r3 + state.t2WaitTime = endTime - startTime + } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + // Both threads should've been able to acquire their memory; the second one will have waited + // until the first one acquired 1000 bytes and then released all of it + state.synchronized { + assert(state.t1Result === 1000L, "t1 could not allocate memory") + assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time") + assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time") + assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})") + assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") + } + } +} From a646a365e3beb8d0cd7e492e625ce68ee9439a07 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 5 Aug 2014 00:39:07 -0700 Subject: [PATCH 04/13] [SPARK-2857] Correct properties to set Master / Worker ports `master.ui.port` and `worker.ui.port` were never picked up by SparkConf, simply because they are not prefixed with "spark." Unfortunately, this is also currently the documented way of setting these values. Author: Andrew Or Closes #1779 from andrewor14/master-worker-port and squashes the following commits: 8475e95 [Andrew Or] Update docs to reflect changes in configs 4db3d5d [Andrew Or] Stop using configs that don't actually work --- .../org/apache/spark/deploy/master/MasterArguments.scala | 4 ++-- .../scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala | 2 +- docs/spark-standalone.md | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index a87781fb93850..4b0dbbe543d3f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -38,8 +38,8 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt } - if (conf.contains("master.ui.port")) { - webUiPort = conf.get("master.ui.port").toInt + if (conf.contains("spark.master.ui.port")) { + webUiPort = conf.get("spark.master.ui.port").toInt } parse(args.toList) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 0ad2edba2227f..a9f531e9e4cae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -58,6 +58,6 @@ private[spark] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = { - requestedPort.getOrElse(conf.getInt("worker.ui.port", WorkerWebUI.DEFAULT_PORT)) + requestedPort.getOrElse(conf.getInt("spark.worker.ui.port", WorkerWebUI.DEFAULT_PORT)) } } diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 2fb30765f35e8..293a7ac9bc9aa 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -314,7 +314,7 @@ configure those ports. Standalone Cluster Master 8080 Web UI - master.ui.port + spark.master.ui.port Jetty-based @@ -338,7 +338,7 @@ configure those ports. Worker 8081 Web UI - worker.ui.port + spark.worker.ui.port Jetty-based From 9862c614c06507aa7624208f1d7ed5bc027ca52e Mon Sep 17 00:00:00 2001 From: wangfei Date: Tue, 5 Aug 2014 00:51:07 -0700 Subject: [PATCH 05/13] [SPARK-1779] Throw an exception if memory fractions are not between 0 and 1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: wangfei Author: wangfei Closes #714 from scwf/memoryFraction and squashes the following commits: 6e385b9 [wangfei] Update SparkConf.scala da6ee59 [wangfei] add configs 829a195 [wangfei] add indent 717c0ca [wangfei] updated to make more concise fc45476 [wangfei] validate memoryfraction in sparkconf 2e79b3d [wangfei] && => || 43621bd [wangfei] && => || cf38bcf [wangfei] throw IllegalArgumentException 14d18ac [wangfei] throw IllegalArgumentException dff1f0f [wangfei] Update BlockManager.scala 764965f [wangfei] Update ExternalAppendOnlyMap.scala a59d76b [wangfei] Throw exception when memoryFracton is out of range 7b899c2 [wangfei] 【SPARK-1779】 --- .../main/scala/org/apache/spark/SparkConf.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 38700847c80f4..cce7a23d3b9fc 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -238,6 +238,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } + // Validate memory fractions + val memoryKeys = Seq( + "spark.storage.memoryFraction", + "spark.shuffle.memoryFraction", + "spark.shuffle.safetyFraction", + "spark.storage.unrollFraction", + "spark.storage.safetyFraction") + for (key <- memoryKeys) { + val value = getDouble(key, 0.5) + if (value > 1 || value < 0) { + throw new IllegalArgumentException("$key should be between 0 and 1 (was '$value').") + } + } + // Check for legacy configs sys.env.get("SPARK_JAVA_OPTS").foreach { value => val warning = From 184048f80b6fa160c89d5bb47b937a0a89534a95 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 5 Aug 2014 01:30:46 -0700 Subject: [PATCH 06/13] [SPARK-2856] Decrease initial buffer size for Kryo to 64KB. Author: Reynold Xin Closes #1780 from rxin/kryo-init-size and squashes the following commits: 551b935 [Reynold Xin] [SPARK-2856] Decrease initial buffer size for Kryo to 64KB. --- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 4 +++- docs/configuration.md | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index e60b802a86a14..407cb9db6ee9a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -47,7 +47,9 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 + private val bufferSize = + (conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) * 1024 * 1024).toInt + private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) diff --git a/docs/configuration.md b/docs/configuration.md index 870343f1c0bd2..b3dee3f131411 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -412,7 +412,7 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer.mb - 2 + 0.064 Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer per core on each worker. This buffer will grow up to From e87075df977a539e4a1684045a7bd66c36285174 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 5 Aug 2014 10:40:28 -0700 Subject: [PATCH 07/13] [SPARK-1022][Streaming] Add Kafka real unit test This PR is a updated version of (https://github.com/apache/spark/pull/557) to actually test sending and receiving data through Kafka, and fix previous flaky issues. @tdas, would you mind reviewing this PR? Thanks a lot. Author: jerryshao Closes #1751 from jerryshao/kafka-unit-test and squashes the following commits: b6a505f [jerryshao] code refactor according to comments 5222330 [jerryshao] Change JavaKafkaStreamSuite to better test it 5525f10 [jerryshao] Fix flaky issue of Kafka real unit test 4559310 [jerryshao] Minor changes for Kafka unit test 860f649 [jerryshao] Minor style changes, and tests ignored due to flakiness 796d4ca [jerryshao] Add real Kafka streaming test --- external/kafka/pom.xml | 6 + .../streaming/kafka/JavaKafkaStreamSuite.java | 125 +++++++++-- .../streaming/kafka/KafkaStreamSuite.scala | 197 ++++++++++++++++-- 3 files changed, 293 insertions(+), 35 deletions(-) diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index daf03360bc5f5..2aee99949223a 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -70,6 +70,12 @@ + + net.sf.jopt-simple + jopt-simple + 3.2 + test + org.scalatest scalatest_${scala.binary.version} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 9f8046bf00f8f..0571454c01dae 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -17,31 +17,118 @@ package org.apache.spark.streaming.kafka; +import java.io.Serializable; import java.util.HashMap; +import java.util.List; + +import scala.Predef; +import scala.Tuple2; +import scala.collection.JavaConverters; + +import junit.framework.Assert; -import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream; -import org.junit.Test; -import com.google.common.collect.Maps; import kafka.serializer.StringDecoder; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +import org.junit.Test; +import org.junit.After; +import org.junit.Before; + +public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable { + private transient KafkaStreamSuite testSuite = new KafkaStreamSuite(); + + @Before + @Override + public void setUp() { + testSuite.beforeFunction(); + System.clearProperty("spark.driver.port"); + //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + } + + @After + @Override + public void tearDown() { + ssc.stop(); + ssc = null; + System.clearProperty("spark.driver.port"); + testSuite.afterFunction(); + } -public class JavaKafkaStreamSuite extends LocalJavaStreamingContext { @Test - public void testKafkaStream() { - HashMap topics = Maps.newHashMap(); - - // tests the API, does not actually test data receiving - JavaPairReceiverInputDStream test1 = - KafkaUtils.createStream(ssc, "localhost:12345", "group", topics); - JavaPairReceiverInputDStream test2 = KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, - StorageLevel.MEMORY_AND_DISK_SER_2()); - - HashMap kafkaParams = Maps.newHashMap(); - kafkaParams.put("zookeeper.connect", "localhost:12345"); - kafkaParams.put("group.id","consumer-group"); - JavaPairReceiverInputDStream test3 = KafkaUtils.createStream(ssc, - String.class, String.class, StringDecoder.class, StringDecoder.class, - kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2()); + public void testKafkaStream() throws InterruptedException { + String topic = "topic1"; + HashMap topics = new HashMap(); + topics.put(topic, 1); + + HashMap sent = new HashMap(); + sent.put("a", 5); + sent.put("b", 3); + sent.put("c", 10); + + testSuite.createTopic(topic); + HashMap tmp = new HashMap(sent); + testSuite.produceAndSendMessage(topic, + JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( + Predef.>conforms())); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("zookeeper.connect", testSuite.zkConnect()); + kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); + kafkaParams.put("auto.offset.reset", "smallest"); + + JavaPairDStream stream = KafkaUtils.createStream(ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topics, + StorageLevel.MEMORY_ONLY_SER()); + + final HashMap result = new HashMap(); + + JavaDStream words = stream.map( + new Function, String>() { + @Override + public String call(Tuple2 tuple2) throws Exception { + return tuple2._2(); + } + } + ); + + words.countByValue().foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws Exception { + List> ret = rdd.collect(); + for (Tuple2 r : ret) { + if (result.containsKey(r._1())) { + result.put(r._1(), result.get(r._1()) + r._2()); + } else { + result.put(r._1(), r._2()); + } + } + + return null; + } + } + ); + + ssc.start(); + ssc.awaitTermination(3000); + + Assert.assertEquals(sent.size(), result.size()); + for (String k : sent.keySet()) { + Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index e6f2c4a5cf5d1..c0b55e9340253 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -17,28 +17,193 @@ package org.apache.spark.streaming.kafka -import kafka.serializer.StringDecoder +import java.io.File +import java.net.InetSocketAddress +import java.util.{Properties, Random} + +import scala.collection.mutable + +import kafka.admin.CreateTopicCommand +import kafka.common.TopicAndPartition +import kafka.producer.{KeyedMessage, ProducerConfig, Producer} +import kafka.utils.ZKStringSerializer +import kafka.serializer.{StringDecoder, StringEncoder} +import kafka.server.{KafkaConfig, KafkaServer} + +import org.I0Itec.zkclient.ZkClient + +import org.apache.zookeeper.server.ZooKeeperServer +import org.apache.zookeeper.server.NIOServerCnxnFactory + import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.util.Utils class KafkaStreamSuite extends TestSuiteBase { + import KafkaTestUtils._ + + val zkConnect = "localhost:2181" + val zkConnectionTimeout = 6000 + val zkSessionTimeout = 6000 + + val brokerPort = 9092 + val brokerProps = getBrokerConfig(brokerPort, zkConnect) + val brokerConf = new KafkaConfig(brokerProps) + + protected var zookeeper: EmbeddedZookeeper = _ + protected var zkClient: ZkClient = _ + protected var server: KafkaServer = _ + protected var producer: Producer[String, String] = _ + + override def useManualClock = false + + override def beforeFunction() { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(zkConnect) + logInfo("==================== 0 ====================") + zkClient = new ZkClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) + logInfo("==================== 1 ====================") - test("kafka input stream") { + // Kafka broker startup + server = new KafkaServer(brokerConf) + logInfo("==================== 2 ====================") + server.startup() + logInfo("==================== 3 ====================") + Thread.sleep(2000) + logInfo("==================== 4 ====================") + super.beforeFunction() + } + + override def afterFunction() { + producer.close() + server.shutdown() + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + zkClient.close() + zookeeper.shutdown() + + super.afterFunction() + } + + test("Kafka input stream") { val ssc = new StreamingContext(master, framework, batchDuration) - val topics = Map("my-topic" -> 1) - - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream(ssc, "localhost:1234", "group", topics) - val test2: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK_SER_2) - val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group") - val test3: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( - ssc, kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2) - - // TODO: Actually test receiving data + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + produceAndSendMessage(topic, sent) + + val kafkaParams = Map("zookeeper.connect" -> zkConnect, + "group.id" -> s"test-consumer-${random.nextInt(10000)}", + "auto.offset.reset" -> "smallest") + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, + kafkaParams, + Map(topic -> 1), + StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map { case (k, v) => v } + .countByValue() + .foreachRDD { r => + val ret = r.collect() + ret.toMap.foreach { kv => + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } + } + ssc.start() + ssc.awaitTermination(3000) + + assert(sent.size === result.size) + sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } + ssc.stop() } + + private def createTestMessage(topic: String, sent: Map[String, Int]) + : Seq[KeyedMessage[String, String]] = { + val messages = for ((s, freq) <- sent; i <- 0 until freq) yield { + new KeyedMessage[String, String](topic, s) + } + messages.toSeq + } + + def createTopic(topic: String) { + CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") + logInfo("==================== 5 ====================") + // wait until metadata is propagated + waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000) + } + + def produceAndSendMessage(topic: String, sent: Map[String, Int]) { + val brokerAddr = brokerConf.hostName + ":" + brokerConf.port + producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr))) + producer.send(createTestMessage(topic, sent): _*) + logInfo("==================== 6 ====================") + } +} + +object KafkaTestUtils { + val random = new Random() + + def getBrokerConfig(port: Int, zkConnect: String): Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", port.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkConnect) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + def getProducerConfig(brokerList: String): Properties = { + val props = new Properties() + props.put("metadata.broker.list", brokerList) + props.put("serializer.class", classOf[StringEncoder].getName) + props + } + + def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = { + val startTime = System.currentTimeMillis() + while (true) { + if (condition()) + return true + if (System.currentTimeMillis() > startTime + waitTime) + return false + Thread.sleep(waitTime.min(100L)) + } + // Should never go to here + throw new RuntimeException("unexpected error") + } + + def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int, + timeout: Long) { + assert(waitUntilTrue(() => + servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains( + TopicAndPartition(topic, partition))), timeout), + s"Partition [$topic, $partition] metadata not propagated after timeout") + } + + class EmbeddedZookeeper(val zkConnect: String) { + val random = new Random() + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } } From 2c0f705e26ca3dfc43a1e9a0722c0e57f67c970a Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 5 Aug 2014 12:48:26 -0500 Subject: [PATCH 08/13] SPARK-1528 - spark on yarn, add support for accessing remote HDFS Add a config (spark.yarn.access.namenodes) to allow applications running on yarn to access other secure HDFS cluster. User just specifies the namenodes of the other clusters and we get Tokens for those and ship them with the spark application. Author: Thomas Graves Closes #1159 from tgravescs/spark-1528 and squashes the following commits: ddbcd16 [Thomas Graves] review comments 0ac8501 [Thomas Graves] SPARK-1528 - add support for accessing remote HDFS --- docs/running-on-yarn.md | 7 +++ .../apache/spark/deploy/yarn/ClientBase.scala | 56 +++++++++++++------ .../spark/deploy/yarn/ClientBaseSuite.scala | 55 +++++++++++++++++- 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 0362f5a223319..573930dbf4e54 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -106,6 +106,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes set this configuration to "hdfs:///some/path". + + spark.yarn.access.namenodes + (none) + + A list of secure HDFS namenodes your Spark application is going to access. For example, `spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032`. The Spark application must have acess to the namenodes listed and Kerberos must be properly configured to be able to access them (either in the same realm or in a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. + + # Launching Spark on YARN diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index b7e8636e02eb2..ed8f56ab8b75e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -191,23 +191,11 @@ trait ClientBase extends Logging { // Upload Spark and the application JAR to the remote file system if necessary. Add them as // local resources to the application master. val fs = FileSystem.get(conf) - - val delegTokenRenewer = Master.getMasterPrincipal(conf) - if (UserGroupInformation.isSecurityEnabled()) { - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - } val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort - - if (UserGroupInformation.isSecurityEnabled()) { - val dstFs = dst.getFileSystem(conf) - dstFs.addDelegationTokens(delegTokenRenewer, credentials) - } + val nns = ClientBase.getNameNodesToAccess(sparkConf) + dst + ClientBase.obtainTokensForNamenodes(nns, conf, credentials) + val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) @@ -614,4 +602,40 @@ object ClientBase extends Logging { YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, File.pathSeparator) + /** + * Get the list of namenodes the user may access. + */ + private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { + sparkConf.get("spark.yarn.access.namenodes", "").split(",").map(_.trim()).filter(!_.isEmpty) + .map(new Path(_)).toSet + } + + private[yarn] def getTokenRenewer(conf: Configuration): String = { + val delegTokenRenewer = Master.getMasterPrincipal(conf) + logDebug("delegation token renewer is: " + delegTokenRenewer) + if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer" + logError(errorMessage) + throw new SparkException(errorMessage) + } + delegTokenRenewer + } + + /** + * Obtains tokens for the namenodes passed in and adds them to the credentials. + */ + private[yarn] def obtainTokensForNamenodes(paths: Set[Path], conf: Configuration, + creds: Credentials) { + if (UserGroupInformation.isSecurityEnabled()) { + val delegTokenRenewer = getTokenRenewer(conf) + + paths.foreach { + dst => + val dstFs = dst.getFileSystem(conf) + logDebug("getting token for namenode: " + dst) + dstFs.addDelegationTokens(delegTokenRenewer, creds) + } + } + } + } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 686714dc36488..68cc2890f3a22 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -31,6 +31,8 @@ import org.apache.hadoop.yarn.api.records.ContainerLaunchContext import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ + + import org.scalatest.FunSuite import org.scalatest.Matchers @@ -38,7 +40,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ HashMap => MutableHashMap } import scala.util.Try -import org.apache.spark.SparkConf +import org.apache.spark.{SparkException, SparkConf} import org.apache.spark.util.Utils class ClientBaseSuite extends FunSuite with Matchers { @@ -138,6 +140,57 @@ class ClientBaseSuite extends FunSuite with Matchers { } } + test("check access nns empty") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns unset") { + val sparkConf = new SparkConf() + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access nns space") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access two nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) + } + + test("check token renewer") { + val hadoopConf = new Configuration() + hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") + hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") + val renewer = ClientBase.getTokenRenewer(hadoopConf) + renewer should be ("yarn/myrm:8032@SPARKTEST.COM") + } + + test("check token renewer default") { + val hadoopConf = new Configuration() + val caught = + intercept[SparkException] { + ClientBase.getTokenRenewer(hadoopConf) + } + assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = From 1c5555a23d3aa40423d658cfbf2c956ad415a6b1 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 5 Aug 2014 12:52:52 -0500 Subject: [PATCH 09/13] SPARK-1890 and SPARK-1891- add admin and modify acls It was easier to combine these 2 jira since they touch many of the same places. This pr adds the following: - adds modify acls - adds admin acls (list of admins/users that get added to both view and modify acls) - modify Kill button on UI to take modify acls into account - changes config name of spark.ui.acls.enable to spark.acls.enable since I choose poorly in original name. We keep backwards compatibility so people can still use spark.ui.acls.enable. The acls should apply to any web ui as well as any CLI interfaces. - send view and modify acls information on to YARN so that YARN interfaces can use (yarn cli for killing applications for example). Author: Thomas Graves Closes #1196 from tgravescs/SPARK-1890 and squashes the following commits: 8292eb1 [Thomas Graves] review comments b92ec89 [Thomas Graves] remove unneeded variable from applistener 4c765f4 [Thomas Graves] Add in admin acls 72eb0ac [Thomas Graves] Add modify acls --- .../org/apache/spark/SecurityManager.scala | 107 +++++++++++++++--- .../deploy/history/FsHistoryProvider.scala | 4 +- .../scheduler/ApplicationEventListener.scala | 4 +- .../apache/spark/ui/jobs/JobProgressTab.scala | 2 +- .../apache/spark/SecurityManagerSuite.scala | 83 ++++++++++++-- docs/configuration.md | 27 ++++- docs/security.md | 7 +- .../apache/spark/deploy/yarn/ClientBase.scala | 9 +- 8 files changed, 206 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 74aa441619bd2..25c2c9fc6af7c 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -41,10 +41,19 @@ import org.apache.spark.deploy.SparkHadoopUtil * secure the UI if it has data that other users should not be allowed to see. The javax * servlet filter specified by the user can authenticate the user and then once the user * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * authorized to view the UI. The configs 'spark.acls.enable' and 'spark.ui.view.acls' * control the behavior of the acls. Note that the person who started the application * always has view access to the UI. * + * Spark has a set of modify acls (`spark.modify.acls`) that controls which users have permission + * to modify a single application. This would include things like killing the application. By + * default the person who started the application has modify access. For modify access through + * the UI, you must have a filter that does authentication in place for the modify acls to work + * properly. + * + * Spark also has a set of admin acls (`spark.admin.acls`) which is a set of users/administrators + * who always have permission to view or modify the Spark application. + * * Spark does not currently support encryption after authentication. * * At this point spark has multiple communication protocols that need to be secured and @@ -137,18 +146,32 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { private val sparkSecretLookupKey = "sparkCookie" private val authOn = sparkConf.getBoolean("spark.authenticate", false) - private var uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false) + // keep spark.ui.acls.enable for backwards compatibility with 1.0 + private var aclsOn = sparkConf.getOption("spark.acls.enable").getOrElse( + sparkConf.get("spark.ui.acls.enable", "false")).toBoolean + + // admin acls should be set before view or modify acls + private var adminAcls: Set[String] = + stringToSet(sparkConf.get("spark.admin.acls", "")) private var viewAcls: Set[String] = _ + + // list of users who have permission to modify the application. This should + // apply to both UI and CLI for things like killing the application. + private var modifyAcls: Set[String] = _ + // always add the current user and SPARK_USER to the viewAcls - private val defaultAclUsers = Seq[String](System.getProperty("user.name", ""), + private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), Option(System.getenv("SPARK_USER")).getOrElse("")) + setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) + setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) private val secretKey = generateSecretKey() logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + - "; ui acls " + (if (uiAclsOn) "enabled" else "disabled") + - "; users with view permissions: " + viewAcls.toString()) + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + + "; users with view permissions: " + viewAcls.toString() + + "; users with modify permissions: " + modifyAcls.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. // This is needed by the HTTP client fetching from the HttpServer. Put here so its @@ -169,18 +192,51 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { ) } - private[spark] def setViewAcls(defaultUsers: Seq[String], allowedUsers: String) { - viewAcls = (defaultUsers ++ allowedUsers.split(',')).map(_.trim()).filter(!_.isEmpty).toSet + /** + * Split a comma separated String, filter out any empty items, and return a Set of strings + */ + private def stringToSet(list: String): Set[String] = { + list.split(',').map(_.trim).filter(!_.isEmpty).toSet + } + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setViewAcls(defaultUsers: Set[String], allowedUsers: String) { + viewAcls = (adminAcls ++ defaultUsers ++ stringToSet(allowedUsers)) logInfo("Changing view acls to: " + viewAcls.mkString(",")) } - private[spark] def setViewAcls(defaultUser: String, allowedUsers: String) { - setViewAcls(Seq[String](defaultUser), allowedUsers) + def setViewAcls(defaultUser: String, allowedUsers: String) { + setViewAcls(Set[String](defaultUser), allowedUsers) + } + + def getViewAcls: String = viewAcls.mkString(",") + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setModifyAcls(defaultUsers: Set[String], allowedUsers: String) { + modifyAcls = (adminAcls ++ defaultUsers ++ stringToSet(allowedUsers)) + logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) + } + + def getModifyAcls: String = modifyAcls.mkString(",") + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setAdminAcls(adminUsers: String) { + adminAcls = stringToSet(adminUsers) + logInfo("Changing admin acls to: " + adminAcls.mkString(",")) } - private[spark] def setUIAcls(aclSetting: Boolean) { - uiAclsOn = aclSetting - logInfo("Changing acls enabled to: " + uiAclsOn) + def setAcls(aclSetting: Boolean) { + aclsOn = aclSetting + logInfo("Changing acls enabled to: " + aclsOn) } /** @@ -224,22 +280,39 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * Check to see if Acls for the UI are enabled * @return true if UI authentication is enabled, otherwise false */ - def uiAclsEnabled(): Boolean = uiAclsOn + def aclsEnabled(): Boolean = aclsOn /** * Checks the given user against the view acl list to see if they have - * authorization to view the UI. If the UI acls must are disabled - * via spark.ui.acls.enable, all users have view access. + * authorization to view the UI. If the UI acls are disabled + * via spark.acls.enable, all users have view access. If the user is null + * it is assumed authentication is off and all users have access. * * @param user to see if is authorized * @return true is the user has permission, otherwise false */ def checkUIViewPermissions(user: String): Boolean = { - logDebug("user=" + user + " uiAclsEnabled=" + uiAclsEnabled() + " viewAcls=" + + logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true + if (aclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true } + /** + * Checks the given user against the modify acl list to see if they have + * authorization to modify the application. If the UI acls are disabled + * via spark.acls.enable, all users have modify access. If the user is null + * it is assumed authentication isn't turned on and all users have access. + * + * @param user to see if is authorized + * @return true is the user has permission, otherwise false + */ + def checkModifyPermissions(user: String): Boolean = { + logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + + modifyAcls.mkString(",")) + if (aclsEnabled() && (user != null) && (!modifyAcls.contains(user))) false else true + } + + /** * Check to see if authentication for the Spark communication protocols is enabled * @return true if authentication is enabled, otherwise false diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 6d2d4cef1ee46..cc06540ee0647 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -189,7 +189,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (ui != null) { val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setUIAcls(uiAclsEnabled) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls) ui.getSecurityManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) } (appInfo, ui) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index cd5d44ad4a7e6..162158babc35b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -29,7 +29,7 @@ private[spark] class ApplicationEventListener extends SparkListener { var startTime = -1L var endTime = -1L var viewAcls = "" - var enableViewAcls = false + var adminAcls = "" def applicationStarted = startTime != -1 @@ -55,7 +55,7 @@ private[spark] class ApplicationEventListener extends SparkListener { val environmentDetails = environmentUpdate.environmentDetails val allProperties = environmentDetails("Spark Properties").toMap viewAcls = allProperties.getOrElse("spark.ui.view.acls", "") - enableViewAcls = allProperties.getOrElse("spark.ui.acls.enable", "false").toBoolean + adminAcls = allProperties.getOrElse("spark.admin.acls", "") } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala index 3308c8c8a3d37..8a01ec80c9dd6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -41,7 +41,7 @@ private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stag def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) def handleKillRequest(request: HttpServletRequest) = { - if (killEnabled) { + if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) { val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e39093e24d68a..fcca0867b8072 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -31,7 +31,7 @@ class SecurityManagerSuite extends FunSuite { conf.set("spark.ui.view.acls", "user1,user2") val securityManager = new SecurityManager(conf); assert(securityManager.isAuthenticationEnabled() === true) - assert(securityManager.uiAclsEnabled() === true) + assert(securityManager.aclsEnabled() === true) assert(securityManager.checkUIViewPermissions("user1") === true) assert(securityManager.checkUIViewPermissions("user2") === true) assert(securityManager.checkUIViewPermissions("user3") === false) @@ -41,17 +41,17 @@ class SecurityManagerSuite extends FunSuite { val conf = new SparkConf conf.set("spark.ui.view.acls", "user1,user2") val securityManager = new SecurityManager(conf); - securityManager.setUIAcls(true) - assert(securityManager.uiAclsEnabled() === true) - securityManager.setUIAcls(false) - assert(securityManager.uiAclsEnabled() === false) + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setAcls(false) + assert(securityManager.aclsEnabled() === false) // acls are off so doesn't matter what view acls set to assert(securityManager.checkUIViewPermissions("user4") === true) - securityManager.setUIAcls(true) - assert(securityManager.uiAclsEnabled() === true) - securityManager.setViewAcls(ArrayBuffer[String]("user5"), "user6,user7") + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setViewAcls(Set[String]("user5"), "user6,user7") assert(securityManager.checkUIViewPermissions("user1") === false) assert(securityManager.checkUIViewPermissions("user5") === true) assert(securityManager.checkUIViewPermissions("user6") === true) @@ -59,5 +59,72 @@ class SecurityManagerSuite extends FunSuite { assert(securityManager.checkUIViewPermissions("user8") === false) assert(securityManager.checkUIViewPermissions(null) === true) } + + test("set security modify acls") { + val conf = new SparkConf + conf.set("spark.modify.acls", "user1,user2") + + val securityManager = new SecurityManager(conf); + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setAcls(false) + assert(securityManager.aclsEnabled() === false) + + // acls are off so doesn't matter what view acls set to + assert(securityManager.checkModifyPermissions("user4") === true) + + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setModifyAcls(Set("user5"), "user6,user7") + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user5") === true) + assert(securityManager.checkModifyPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === false) + assert(securityManager.checkModifyPermissions(null) === true) + } + + test("set security admin acls") { + val conf = new SparkConf + conf.set("spark.admin.acls", "user1,user2") + conf.set("spark.ui.view.acls", "user3") + conf.set("spark.modify.acls", "user4") + + val securityManager = new SecurityManager(conf); + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + + assert(securityManager.checkModifyPermissions("user1") === true) + assert(securityManager.checkModifyPermissions("user2") === true) + assert(securityManager.checkModifyPermissions("user4") === true) + assert(securityManager.checkModifyPermissions("user3") === false) + assert(securityManager.checkModifyPermissions("user5") === false) + assert(securityManager.checkModifyPermissions(null) === true) + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user2") === true) + assert(securityManager.checkUIViewPermissions("user3") === true) + assert(securityManager.checkUIViewPermissions("user4") === false) + assert(securityManager.checkUIViewPermissions("user5") === false) + assert(securityManager.checkUIViewPermissions(null) === true) + + securityManager.setAdminAcls("user6") + securityManager.setViewAcls(Set[String]("user8"), "user9") + securityManager.setModifyAcls(Set("user11"), "user9") + assert(securityManager.checkModifyPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user11") === true) + assert(securityManager.checkModifyPermissions("user9") === true) + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user4") === false) + assert(securityManager.checkModifyPermissions(null) === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkUIViewPermissions("user8") === true) + assert(securityManager.checkUIViewPermissions("user9") === true) + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user3") === false) + assert(securityManager.checkUIViewPermissions(null) === true) + + } + + } diff --git a/docs/configuration.md b/docs/configuration.md index b3dee3f131411..25adea210cba0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -815,13 +815,13 @@ Apart from these, the following properties are also available, and may be useful - spark.ui.acls.enable + spark.acls.enable false - Whether Spark web ui acls should are enabled. If enabled, this checks to see if the user has - access permissions to view the web ui. See spark.ui.view.acls for more details. - Also note this requires the user to be known, if the user comes across as null no checks - are done. Filters can be used to authenticate and set the user. + Whether Spark acls should are enabled. If enabled, this checks to see if the user has + access permissions to view or modify the job. Note this requires the user to be known, + so if the user comes across as null no checks are done. Filters can be used with the UI + to authenticate and set the user. @@ -832,6 +832,23 @@ Apart from these, the following properties are also available, and may be useful user that started the Spark job has view access. + + spark.modify.acls + Empty + + Comma separated list of users that have modify access to the Spark job. By default only the + user that started the Spark job has access to modify it (kill it for example). + + + + spark.admin.acls + Empty + + Comma separated list of users/administrators that have view and modify access to all Spark jobs. + This can be used if you run on a shared cluster and have a set of administrators or devs who + help debug when things work. + + #### Spark Streaming diff --git a/docs/security.md b/docs/security.md index 90ba678033b19..8312f8d017e1f 100644 --- a/docs/security.md +++ b/docs/security.md @@ -8,8 +8,11 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -The Spark UI can also be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.ui.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. -On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +The Spark UI can also be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. + +Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable` and `spark.modify.acls`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. + +Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the config `spark.admin.acls`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index ed8f56ab8b75e..44e025b8f60ba 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records -import org.apache.spark.{SparkException, Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} /** * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The @@ -405,6 +405,13 @@ trait ClientBase extends Logging { amContainer.setCommands(printableCommands) setupSecurityToken(amContainer) + + // send the acl settings into YARN to control who has access via YARN interfaces + val securityManager = new SecurityManager(sparkConf) + val acls = Map[ApplicationAccessType, String] ( + ApplicationAccessType.VIEW_APP -> securityManager.getViewAcls, + ApplicationAccessType.MODIFY_APP -> securityManager.getModifyAcls) + amContainer.setApplicationACLs(acls) amContainer } } From 6e821e3d1ae1ed23459bc7f1098510b968130152 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 5 Aug 2014 11:17:50 -0700 Subject: [PATCH 10/13] [SPARK-2860][SQL] Fix coercion of CASE WHEN. Author: Michael Armbrust Closes #1785 from marmbrus/caseNull and squashes the following commits: 126006d [Michael Armbrust] better error message 2fe357f [Michael Armbrust] Fix coercion of CASE WHEN. --- .../catalyst/analysis/HiveTypeCoercion.scala | 56 +++++++++++-------- ...ll case-0-581cdfe70091e546414b202da2cebdcb | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 3 + 3 files changed, 36 insertions(+), 24 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e94f2a3bea63e..15eb5982a4a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -49,10 +49,21 @@ trait HiveTypeCoercion { BooleanCasts :: StringToIntegralCasts :: FunctionArgumentConversion :: - CastNulls :: + CaseWhenCoercion :: Division :: Nil + trait TypeWidening { + def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = + HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) + + // If found return the widest common type, otherwise None + applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + } + } + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -133,16 +144,7 @@ trait HiveTypeCoercion { * - LongType to FloatType * - LongType to DoubleType */ - object WidenTypes extends Rule[LogicalPlan] { - - def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) - } + object WidenTypes extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -336,28 +338,34 @@ trait HiveTypeCoercion { } /** - * Ensures that NullType gets casted to some other types under certain circumstances. + * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CastNulls extends Rule[LogicalPlan] { + object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw @ CaseWhen(branches) => + case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) => val valueTypes = branches.sliding(2, 2).map { - case Seq(_, value) if value.resolved => Some(value.dataType) - case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType) - case _ => None + case Seq(_, value) => value.dataType + case Seq(elseVal) => elseVal.dataType }.toSeq - if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) { - val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get + + logDebug(s"Input values for null casting ${valueTypes.mkString(",")}") + + if (valueTypes.distinct.size > 1) { + val commonType = valueTypes.reduce { (v1, v2) => + findTightestCommonType(v1, v2) + .getOrElse(sys.error( + s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) + } val transformedBranches = branches.sliding(2, 2).map { - case Seq(cond, value) if value.resolved && value.dataType == NullType => - Seq(cond, Cast(value, otherType)) - case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType => - Seq(Cast(elseVal, otherType)) + case Seq(cond, value) if value.dataType != commonType => + Seq(cond, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) case s => s }.reduce(_ ++ _) CaseWhen(transformedBranches) } else { - // It is possible to have more types due to the possibility of short-circuiting. + // Types match up. Hopefully some other rule fixes whatever is wrong with resolution. cw } } diff --git a/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index aa810a291231a..2f0be49b6a6d7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -32,6 +32,9 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("null case", + "SELECT case when(true) then 1 else null end FROM src LIMIT 1") + createQueryTest("single case", """SELECT case when true then 1 else 2 end FROM src LIMIT 1""") From ac3440f4f3c4b79070ffec7db0b08ad062b4df90 Mon Sep 17 00:00:00 2001 From: "Guancheng (G.C.) Chen" Date: Tue, 5 Aug 2014 11:50:08 -0700 Subject: [PATCH 11/13] [SPARK-2859] Update url of Kryo project in related docs JIRA Issue: https://issues.apache.org/jira/browse/SPARK-2859 Kryo project has been migrated from googlecode to github, hence we need to update its URL in related docs such as tuning.md. Author: Guancheng (G.C.) Chen Closes #1782 from gchen/kryo-docs and squashes the following commits: b62543c [Guancheng (G.C.) Chen] update url of Kryo project --- docs/tuning.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tuning.md b/docs/tuning.md index 4917c11bc1147..8fb2a0433b1a8 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -32,7 +32,7 @@ in your operations) and performance. It provides two serialization libraries: [`java.io.Externalizable`](http://docs.oracle.com/javase/6/docs/api/java/io/Externalizable.html). Java serialization is flexible but often quite slow, and leads to large serialized formats for many classes. -* [Kryo serialization](http://code.google.com/p/kryo/): Spark can also use +* [Kryo serialization](https://github.com/EsotericSoftware/kryo): Spark can also use the Kryo library (version 2) to serialize objects more quickly. Kryo is significantly faster and more compact than Java serialization (often as much as 10x), but does not support all `Serializable` types and requires you to *register* the classes you'll use in the program in advance @@ -68,7 +68,7 @@ conf.set("spark.kryo.registrator", "mypackage.MyRegistrator") val sc = new SparkContext(conf) {% endhighlight %} -The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced +The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` From 74f82c71b03d265a7d0c98ce196ca8c44de002e8 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 5 Aug 2014 13:08:23 -0700 Subject: [PATCH 12/13] SPARK-2380: Support displaying accumulator values in the web UI This patch adds support for giving accumulators user-visible names and displaying accumulator values in the web UI. This allows users to create custom counters that can display in the UI. The current approach displays both the accumulator deltas caused by each task and a "current" value of the accumulator totals for each stage, which gets update as tasks finish. Currently in Spark developers have been extending the `TaskMetrics` functionality to provide custom instrumentation for RDD's. This provides a potentially nicer alternative of going through the existing accumulator framework (actually `TaskMetrics` and accumulators are on an awkward collision course as we add more features to the former). The current patch demo's how we can use the feature to provide instrumentation for RDD input sizes. The nice thing about going through accumulators is that users can read the current value of the data being tracked in their programs. This could be useful to e.g. decide to short-circuit a Spark stage depending on how things are going. ![counters](https://cloud.githubusercontent.com/assets/320616/3488815/6ee7bc34-0505-11e4-84ce-e36d9886e2cf.png) Author: Patrick Wendell Closes #1309 from pwendell/metrics and squashes the following commits: 8815308 [Patrick Wendell] Merge remote-tracking branch 'apache/master' into HEAD 93fbe0f [Patrick Wendell] Other minor fixes cc43f68 [Patrick Wendell] Updating unit tests c991b1b [Patrick Wendell] Moving some code into the Accumulators class 9a9ba3c [Patrick Wendell] More merge fixes c5ace9e [Patrick Wendell] More merge conflicts 1da15e3 [Patrick Wendell] Merge remote-tracking branch 'apache/master' into metrics 9860c55 [Patrick Wendell] Potential solution to posting listener events 0bb0e33 [Patrick Wendell] Remove "display" variable and assume display = name.isDefined 0ec4ac7 [Patrick Wendell] Java API's e95bf69 [Patrick Wendell] Stash be97261 [Patrick Wendell] Style fix 8407308 [Patrick Wendell] Removing examples in Hadoop and RDD class 64d405f [Patrick Wendell] Adding missing file 5d8b156 [Patrick Wendell] Changes based on Kay's review. 9f18bad [Patrick Wendell] Minor style changes and tests 7a63abc [Patrick Wendell] Adding Json serialization and responding to Reynold's feedback ad85076 [Patrick Wendell] Example of using named accumulators for custom RDD metrics. 0b72660 [Patrick Wendell] Initial WIP example of supporing globally named accumulators. --- .../scala/org/apache/spark/Accumulators.scala | 19 ++++-- .../scala/org/apache/spark/SparkContext.scala | 19 ++++++ .../spark/api/java/JavaSparkContext.scala | 59 ++++++++++++++++++ .../spark/scheduler/AccumulableInfo.scala | 46 ++++++++++++++ .../apache/spark/scheduler/DAGScheduler.scala | 24 ++++++- .../apache/spark/scheduler/StageInfo.scala | 4 ++ .../org/apache/spark/scheduler/TaskInfo.scala | 9 +++ .../spark/ui/jobs/JobProgressListener.scala | 10 ++- .../org/apache/spark/ui/jobs/StagePage.scala | 21 ++++++- .../org/apache/spark/ui/jobs/UIData.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 39 +++++++++++- .../apache/spark/util/JsonProtocolSuite.scala | 62 +++++++++++++++---- docs/programming-guide.md | 6 +- 13 files changed, 294 insertions(+), 27 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 9c55bfbb47626..12f2fe031cb1d 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -36,15 +36,21 @@ import org.apache.spark.serializer.JavaSerializer * * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` + * @param name human-readable name for use in Spark's web UI * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ class Accumulable[R, T] ( @transient initialValue: R, - param: AccumulableParam[R, T]) + param: AccumulableParam[R, T], + val name: Option[String]) extends Serializable { - val id = Accumulators.newId + def this(@transient initialValue: R, param: AccumulableParam[R, T]) = + this(initialValue, param, None) + + val id: Long = Accumulators.newId + @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers private var deserialized = false @@ -219,8 +225,10 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T]) - extends Accumulable[T,T](initialValue, param) +class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) + extends Accumulable[T,T](initialValue, param, name) { + def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) +} /** * A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add @@ -281,4 +289,7 @@ private object Accumulators { } } } + + def stringifyPartialValue(partialValue: Any) = "%s".format(partialValue) + def stringifyValue(value: Any) = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ba21cfcde01a..e132955f0f850 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -760,6 +760,15 @@ class SparkContext(config: SparkConf) extends Logging { def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display + * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the + * driver can access the accumulator's `value`. + */ + def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { + new Accumulator(initialValue, param, Some(name)) + } + /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. @@ -769,6 +778,16 @@ class SparkContext(config: SparkConf) extends Logging { def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) + /** + * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the + * Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can + * access the accumuable's `value`. + * @tparam T accumulator type + * @tparam R type that can be added to the accumulator + */ + def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) = + new Accumulable(initialValue, param, Some(name)) + /** * Create an accumulator from a "mutable collection" type. * diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index d9d1c5955ca99..e0a4815940db3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -429,6 +429,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def intAccumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue, name)(IntAccumulatorParam) + .asInstanceOf[Accumulator[java.lang.Integer]] + /** * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. @@ -436,12 +446,31 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def doubleAccumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = + sc.accumulator(initialValue, name)(DoubleAccumulatorParam) + .asInstanceOf[Accumulator[java.lang.Double]] + /** * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = + intAccumulator(initialValue, name) + /** * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. @@ -449,6 +478,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulator(initialValue: Double): Accumulator[java.lang.Double] = doubleAccumulator(initialValue) + + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = + doubleAccumulator(initialValue, name) + /** * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `add` method. Only the master can access the accumulator's `value`. @@ -456,6 +495,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" + * values to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator[T](initialValue: T, name: String, accumulatorParam: AccumulatorParam[T]) + : Accumulator[T] = + sc.accumulator(initialValue, name)(accumulatorParam) + /** * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks * can "add" values with `add`. Only the master can access the accumuable's `value`. @@ -463,6 +512,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = sc.accumulable(initialValue)(param) + /** + * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks + * can "add" values with `add`. Only the master can access the accumuable's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulable[T, R](initialValue: T, name: String, param: AccumulableParam[T, R]) + : Accumulable[T, R] = + sc.accumulable(initialValue, name)(param) + /** * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala new file mode 100644 index 0000000000000..fa83372bb4d11 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. + */ +@DeveloperApi +class AccumulableInfo ( + val id: Long, + val name: String, + val update: Option[String], // represents a partial update within a task + val value: String) { + + override def equals(other: Any): Boolean = other match { + case acc: AccumulableInfo => + this.id == acc.id && this.name == acc.name && + this.update == acc.update && this.value == acc.value + case _ => false + } +} + +object AccumulableInfo { + def apply(id: Long, name: String, update: Option[String], value: String) = + new AccumulableInfo(id, name, update, value) + + def apply(id: Long, name: String, value: String) = new AccumulableInfo(id, name, None, value) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 9fa3a4e9c71ae..430e45ada5808 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -883,8 +883,14 @@ class DAGScheduler( val task = event.task val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, - event.taskMetrics)) + + // The success case is dealt with separately below, since we need to compute accumulator + // updates before posting. + if (event.reason != Success) { + listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, + event.taskMetrics)) + } + if (!stageIdToStage.contains(task.stageId)) { // Skip all the actions if the stage has been cancelled. return @@ -906,12 +912,26 @@ class DAGScheduler( if (event.accumUpdates != null) { try { Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => + val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + val name = acc.name.get + val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) + val stringValue = Accumulators.stringifyValue(acc.value) + stage.info.accumulables(id) = AccumulableInfo(id, name, stringValue) + event.taskInfo.accumulables += + AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + } + } } catch { // If we see an exception during accumulator update, just log the error and move on. case e: Exception => logError(s"Failed to update accumulators for $task", e) } } + listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, + event.taskMetrics)) stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 480891550eb60..2a407e47a05bd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashMap + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.RDDInfo @@ -37,6 +39,8 @@ class StageInfo( var completionTime: Option[Long] = None /** If the stage failed, the reason why. */ var failureReason: Option[String] = None + /** Terminal values of accumulables updated during this stage. */ + val accumulables = HashMap[Long, AccumulableInfo]() def stageFailed(reason: String) { failureReason = Some(reason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index ca0595f35143e..6fa1f2c880f7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.ListBuffer + import org.apache.spark.annotation.DeveloperApi /** @@ -41,6 +43,13 @@ class TaskInfo( */ var gettingResultTime: Long = 0 + /** + * Intermediate updates to accumulables during this task. Note that it is valid for the same + * accumulable to be updated multiple times in a single task or for two accumulables with the + * same name but different IDs to exist in a task. + */ + val accumulables = ListBuffer[AccumulableInfo]() + /** * The time when the task has completed successfully (including the time to remotely fetch * results, if necessary). diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index da2f5d3172fe2..a57a354620163 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.{HashMap, ListBuffer, Map} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -65,6 +65,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) + for ((id, info) <- stageCompleted.stageInfo.accumulables) { + stageData.accumulables(id) = info + } + poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId)) activeStages.remove(stageId) if (stage.failureReason.isEmpty) { @@ -130,6 +134,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) + for (accumulableInfo <- info.accumulables) { + stageData.accumulables(accumulableInfo.id) = accumulableInfo + } + val execSummaryMap = stageData.executorSummary val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index cab26b9e2f7d3..8bc1ba758cf77 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -20,11 +20,12 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.scheduler.AccumulableInfo /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { @@ -51,6 +52,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) val numCompleted = tasks.count(_.taskInfo.finished) + val accumulables = listener.stageIdToData(stageId).accumulables val hasInput = stageData.inputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 @@ -95,10 +97,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { // scalastyle:on + val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") + def accumulableRow(acc: AccumulableInfo) = {acc.name}{acc.value} + val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, + accumulables.values.toSeq) + val taskHeaders: Seq[String] = Seq( "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", - "Launch Time", "Duration", "GC Time") ++ + "Launch Time", "Duration", "GC Time", "Accumulators") ++ {if (hasInput) Seq("Input") else Nil} ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ @@ -208,11 +215,16 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) } val executorTable = new ExecutorTable(stageId, parent) + + val maybeAccumulableTable: Seq[Node] = + if (accumulables.size > 0) {

Accumulators

++ accumulableTable } else Seq() + val content = summary ++

Summary Metrics for {numCompleted} Completed Tasks

++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++

Aggregated Metrics by Executor

++ executorTable.toNodeSeq ++ + maybeAccumulableTable ++

Tasks

++ taskTable UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId), @@ -279,6 +291,11 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} + + {Unparsed( + info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString("
") + )} +