diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d75b9acfb7aa0..3a2fef05861e6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -48,16 +48,26 @@ import java.io.Serializable * @tparam T Type of the data contained in the broadcast variable. */ abstract class Broadcast[T](val id: Long) extends Serializable { + + /** + * Whether this Broadcast is actually usable. This should be false once persisted state is + * removed from the driver. + */ + protected var isValid: Boolean = true + def value: T /** - * Remove all persisted state associated with this broadcast. + * Remove all persisted state associated with this broadcast. Overriding implementations + * should set isValid to false if persisted state is also removed from the driver. + * * @param removeFromDriver Whether to remove state from the driver. + * If true, the resulting broadcast should no longer be valid. */ def unpersist(removeFromDriver: Boolean) - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. + // We cannot define abstract readObject and writeObject here due to some weird issues + // with these methods having to be 'private' in sub-classes. override def toString = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4985d4202ed6b..d5e3d60a5b2b7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -17,8 +17,8 @@ package org.apache.spark.broadcast -import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.{URL, URLConnection, URI} +import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} +import java.net.{URI, URL, URLConnection} import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream} @@ -49,10 +49,17 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { + isValid = !removeFromDriver HttpBroadcast.unpersist(id, removeFromDriver) } - // Called by JVM when deserializing an object + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 51f1592cef752..ace71575f5390 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,12 +17,12 @@ package org.apache.spark.broadcast -import java.io._ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} import scala.math import scala.util.Random -import org.apache.spark._ +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils @@ -76,10 +76,17 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { + isValid = !removeFromDriver TorrentBroadcast.unpersist(id, removeFromDriver) } - // Called by JVM when deserializing an object + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3c0941e195724..78dc32b4b1525 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException, MapOutputTracker} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -58,7 +58,7 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. @@ -210,9 +210,9 @@ private[spark] class BlockManager( } /** - * Get storage level of local block. If no info exists for the block, then returns null. + * Get storage level of local block. If no info exists for the block, return None. */ - def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + def getLevel(blockId: BlockId): Option[StorageLevel] = blockInfo.get(blockId).map(_.level) /** * Tell the master about the current storage status of a block. This will send a block update @@ -496,9 +496,8 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. - * This is currently used for writing shuffle files out. Callers should handle error - * cases. + * The Block will be appended to the File specified by filename. This is currently used for + * writing shuffle files out. Callers should handle error cases. */ def getDiskWriter( blockId: BlockId, @@ -816,8 +815,7 @@ private[spark] class BlockManager( * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { - // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps - // from RDD.id to blocks. + // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } @@ -827,13 +825,13 @@ private[spark] class BlockManager( /** * Remove all blocks belonging to the given broadcast. */ - def removeBroadcast(broadcastId: Long) { + def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { logInfo("Removing broadcast " + broadcastId) val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid } - blocksToRemove.foreach { blockId => removeBlock(blockId) } + blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 4579c0d959553..674322e3034c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -147,6 +147,24 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply[Array[StorageStatus]](GetStorageStatus) } + /** + * Mainly for testing. Ask the driver to query all executors for their storage levels + * regarding this block. This provides an avenue for the driver to learn the storage + * levels of blocks it has not been informed of. + * + * WARNING: This could lead to deadlocks if there are any outstanding messages the + * executors are already expecting from the driver. In this case, while the driver is + * waiting for the executors to respond to its GetStorageLevel query, the executors + * are also waiting for a response from the driver to a prior message. + * + * The interim solution is to wait for a brief window of time to pass before asking. + * This should suffice, since this mechanism is largely introduced for testing only. + */ + def askForStorageLevels(blockId: BlockId, waitTimeMs: Long = 1000) = { + Thread.sleep(waitTimeMs) + askDriverWithReply[Map[BlockManagerId, StorageLevel]](AskForStorageLevels(blockId)) + } + /** Stop the driver actor, called only on the Spark driver node */ def stop() { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 4cc4227fd87e2..f83c26dafe2e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future +import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import akka.actor.{Actor, ActorRef, Cancellable} @@ -126,6 +126,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case HeartBeat(blockManagerId) => sender ! heartBeat(blockManagerId) + case AskForStorageLevels(blockId) => + sender ! askForStorageLevels(blockId) + case other => logWarning("Got unknown message: " + other) } @@ -158,6 +161,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } } + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { // TODO(aor): Consolidate usages of val removeMsg = RemoveBroadcast(broadcastId) @@ -246,6 +254,19 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toArray } + // For testing. Ask all block managers for the given block's local storage level, if any. + private def askForStorageLevels(blockId: BlockId): Map[BlockManagerId, StorageLevel] = { + val getStorageLevel = GetStorageLevel(blockId) + blockManagerInfo.values.flatMap { info => + val future = info.slaveActor.ask(getStorageLevel)(akkaTimeout) + val result = Await.result(future, akkaTimeout) + if (result != null) { + // If the block does not exist on the slave, the slave replies None + result.asInstanceOf[Option[StorageLevel]].map { reply => (info.blockManagerId, reply) } + } else None + }.toMap + } + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -329,6 +350,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Note that this logic will select the same node multiple times if there aren't enough peers Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 3ea710ebc786e..1d3e94c4b6533 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -43,6 +43,9 @@ private[storage] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave + // For testing. Ask the slave for the block's storage level. + case class GetStorageLevel(blockId: BlockId) extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. @@ -116,4 +119,8 @@ private[storage] object BlockManagerMessages { case object ExpireDeadHosts extends ToBlockManagerMaster case object GetStorageStatus extends ToBlockManagerMaster + + // For testing. Have the master ask all slaves for the given block's storage level. + case class AskForStorageLevels(blockId: BlockId) extends ToBlockManagerMaster + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 8c2ccbe6a7e66..85b8ec40c0ea3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -47,7 +47,10 @@ class BlockManagerSlaveActor( mapOutputTracker.unregisterShuffle(shuffleId) } - case RemoveBroadcast(broadcastId, _) => - blockManager.removeBroadcast(broadcastId) + case RemoveBroadcast(broadcastId, removeFromDriver) => + blockManager.removeBroadcast(broadcastId, removeFromDriver) + + case GetStorageLevel(blockId) => + sender ! blockManager.getLevel(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..a462654197ea0 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -19,67 +19,241 @@ package org.apache.spark import org.scalatest.FunSuite +import org.apache.spark.storage._ +import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId} + class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { - super.afterEach() - System.clearProperty("spark.broadcast.factory") - } + private val httpConf = broadcastConf("HttpBroadcastFactory") + private val torrentConf = broadcastConf("TorrentBroadcastFactory") test("Using HttpBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing HttpBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing HttpBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } test("Using TorrentBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing TorrentBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing TorrentBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + + test("Unpersisting HttpBroadcast on executors only") { + testUnpersistHttpBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting HttpBroadcast on executors and driver") { + testUnpersistHttpBroadcast(2, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only") { + testUnpersistTorrentBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver") { + testUnpersistTorrentBroadcast(2, removeFromDriver = true) + } + + /** + * Verify the persistence of state associated with an HttpBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks and the broadcast file + * are present only on the expected nodes. + */ + private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = Seq[BlockId](BroadcastBlockId(id)) + + // Verify that the broadcast file is created, and blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] + val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + assert(levels.size === 1) + levels.head match { case (bm, level) => + assert(bm.executorId === "") + assert(level === StorageLevel.MEMORY_AND_DISK) + } + assert(HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. In the latter case, also verify that the broadcast file is deleted on the driver. + def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] + val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + assert(levels.size === (if (removeFromDriver) 0 else 1)) + assert(removeFromDriver === !HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + } + + testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks are present only on the + * expected nodes. + */ + private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = { + val broadcastBlockId = BroadcastBlockId(id) + val metaBlockId = BroadcastHelperBlockId(broadcastBlockId, "meta") + // Assume broadcast value is small enough to fit into 1 piece + val pieceBlockId = BroadcastHelperBlockId(broadcastBlockId, "piece0") + Seq[BlockId](broadcastBlockId, metaBlockId, pieceBlockId) + } + + // Verify that blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + assert(levels.size === 1) + levels.head match { case (bm, level) => + assert(bm.executorId === "") + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + blockId match { + case BroadcastHelperBlockId(_, "meta") => + // Meta data is only on the driver + assert(levels.size === 1) + levels.head match { case (bm, _) => assert(bm.executorId === "") } + case _ => + // Other blocks are on both the executors and the driver + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. + def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + var waitTimeMs = 1000L + blockIds.foreach { blockId => + // Allow a second for the messages triggered by unpersist to propagate to prevent deadlocks + val levels = bmm.askForStorageLevels(blockId, waitTimeMs) + assert(levels.size === expectedNumBlocks) + waitTimeMs = 0L + } + } + + testUnpersistBroadcast(numSlaves, torrentConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * This test runs in 4 steps: + * + * 1) Create broadcast variable, and verify that all state is persisted on the driver. + * 2) Use the broadcast variable on all executors, and verify that all state is persisted + * on both the driver and the executors. + * 3) Unpersist the broadcast, and verify that all state is removed where they should be. + * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable. + */ + private def testUnpersistBroadcast( + numSlaves: Int, + broadcastConf: SparkConf, + getBlockIds: Long => Seq[BlockId], + afterCreation: (Seq[BlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BlockId], BlockManagerMaster) => Unit, + removeFromDriver: Boolean) { + + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + val blockManagerMaster = sc.env.blockManager.master + val list = List[Int](1, 2, 3, 4) + + // Create broadcast variable + val broadcast = sc.broadcast(list) + val blocks = getBlockIds(broadcast.id) + afterCreation(blocks, blockManagerMaster) + + // Use broadcast variable on all executors + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + afterUsingBroadcast(blocks, blockManagerMaster) + + // Unpersist broadcast + broadcast.unpersist(removeFromDriver) + afterUnpersist(blocks, blockManagerMaster) + + if (!removeFromDriver) { + // The broadcast variable is not completely destroyed (i.e. state still exists on driver) + // Using the variable again should yield the same answer as before. + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } + + /** Helper method to create a SparkConf that uses the given broadcast factory. */ + private def broadcastConf(factoryName: String): SparkConf = { + val conf = new SparkConf + conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) + conf } }