From e95479cd63b3259beddea278befd0bdee89bb17e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 27 Mar 2014 14:37:51 -0700 Subject: [PATCH] Add tests for unpersisting broadcast There is not currently a way to query the blocks on the executors, an operation that is deceptively simple to accomplish. This commit adds this mechanism in order to verify that blocks are in fact persisted/unpersisted on the executors in the tests. --- .../apache/spark/broadcast/Broadcast.scala | 16 +- .../spark/broadcast/HttpBroadcast.scala | 13 +- .../spark/broadcast/TorrentBroadcast.scala | 13 +- .../apache/spark/storage/BlockManager.scala | 20 +- .../spark/storage/BlockManagerMaster.scala | 18 ++ .../storage/BlockManagerMasterActor.scala | 24 +- .../spark/storage/BlockManagerMessages.scala | 7 + .../storage/BlockManagerSlaveActor.scala | 7 +- .../org/apache/spark/BroadcastSuite.scala | 254 +++++++++++++++--- 9 files changed, 309 insertions(+), 63 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d75b9acfb7aa0..3a2fef05861e6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -48,16 +48,26 @@ import java.io.Serializable * @tparam T Type of the data contained in the broadcast variable. */ abstract class Broadcast[T](val id: Long) extends Serializable { + + /** + * Whether this Broadcast is actually usable. This should be false once persisted state is + * removed from the driver. + */ + protected var isValid: Boolean = true + def value: T /** - * Remove all persisted state associated with this broadcast. + * Remove all persisted state associated with this broadcast. Overriding implementations + * should set isValid to false if persisted state is also removed from the driver. + * * @param removeFromDriver Whether to remove state from the driver. + * If true, the resulting broadcast should no longer be valid. */ def unpersist(removeFromDriver: Boolean) - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. + // We cannot define abstract readObject and writeObject here due to some weird issues + // with these methods having to be 'private' in sub-classes. override def toString = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4985d4202ed6b..d5e3d60a5b2b7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -17,8 +17,8 @@ package org.apache.spark.broadcast -import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.{URL, URLConnection, URI} +import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} +import java.net.{URI, URL, URLConnection} import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream} @@ -49,10 +49,17 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { + isValid = !removeFromDriver HttpBroadcast.unpersist(id, removeFromDriver) } - // Called by JVM when deserializing an object + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 51f1592cef752..ace71575f5390 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,12 +17,12 @@ package org.apache.spark.broadcast -import java.io._ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} import scala.math import scala.util.Random -import org.apache.spark._ +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils @@ -76,10 +76,17 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { + isValid = !removeFromDriver TorrentBroadcast.unpersist(id, removeFromDriver) } - // Called by JVM when deserializing an object + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3c0941e195724..78dc32b4b1525 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException, MapOutputTracker} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -58,7 +58,7 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. @@ -210,9 +210,9 @@ private[spark] class BlockManager( } /** - * Get storage level of local block. If no info exists for the block, then returns null. + * Get storage level of local block. If no info exists for the block, return None. */ - def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + def getLevel(blockId: BlockId): Option[StorageLevel] = blockInfo.get(blockId).map(_.level) /** * Tell the master about the current storage status of a block. This will send a block update @@ -496,9 +496,8 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. - * This is currently used for writing shuffle files out. Callers should handle error - * cases. + * The Block will be appended to the File specified by filename. This is currently used for + * writing shuffle files out. Callers should handle error cases. */ def getDiskWriter( blockId: BlockId, @@ -816,8 +815,7 @@ private[spark] class BlockManager( * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { - // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps - // from RDD.id to blocks. + // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } @@ -827,13 +825,13 @@ private[spark] class BlockManager( /** * Remove all blocks belonging to the given broadcast. */ - def removeBroadcast(broadcastId: Long) { + def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { logInfo("Removing broadcast " + broadcastId) val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid } - blocksToRemove.foreach { blockId => removeBlock(blockId) } + blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 4579c0d959553..674322e3034c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -147,6 +147,24 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply[Array[StorageStatus]](GetStorageStatus) } + /** + * Mainly for testing. Ask the driver to query all executors for their storage levels + * regarding this block. This provides an avenue for the driver to learn the storage + * levels of blocks it has not been informed of. + * + * WARNING: This could lead to deadlocks if there are any outstanding messages the + * executors are already expecting from the driver. In this case, while the driver is + * waiting for the executors to respond to its GetStorageLevel query, the executors + * are also waiting for a response from the driver to a prior message. + * + * The interim solution is to wait for a brief window of time to pass before asking. + * This should suffice, since this mechanism is largely introduced for testing only. + */ + def askForStorageLevels(blockId: BlockId, waitTimeMs: Long = 1000) = { + Thread.sleep(waitTimeMs) + askDriverWithReply[Map[BlockManagerId, StorageLevel]](AskForStorageLevels(blockId)) + } + /** Stop the driver actor, called only on the Spark driver node */ def stop() { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 4cc4227fd87e2..f83c26dafe2e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future +import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import akka.actor.{Actor, ActorRef, Cancellable} @@ -126,6 +126,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case HeartBeat(blockManagerId) => sender ! heartBeat(blockManagerId) + case AskForStorageLevels(blockId) => + sender ! askForStorageLevels(blockId) + case other => logWarning("Got unknown message: " + other) } @@ -158,6 +161,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } } + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { // TODO(aor): Consolidate usages of val removeMsg = RemoveBroadcast(broadcastId) @@ -246,6 +254,19 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toArray } + // For testing. Ask all block managers for the given block's local storage level, if any. + private def askForStorageLevels(blockId: BlockId): Map[BlockManagerId, StorageLevel] = { + val getStorageLevel = GetStorageLevel(blockId) + blockManagerInfo.values.flatMap { info => + val future = info.slaveActor.ask(getStorageLevel)(akkaTimeout) + val result = Await.result(future, akkaTimeout) + if (result != null) { + // If the block does not exist on the slave, the slave replies None + result.asInstanceOf[Option[StorageLevel]].map { reply => (info.blockManagerId, reply) } + } else None + }.toMap + } + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -329,6 +350,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Note that this logic will select the same node multiple times if there aren't enough peers Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 3ea710ebc786e..1d3e94c4b6533 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -43,6 +43,9 @@ private[storage] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave + // For testing. Ask the slave for the block's storage level. + case class GetStorageLevel(blockId: BlockId) extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. @@ -116,4 +119,8 @@ private[storage] object BlockManagerMessages { case object ExpireDeadHosts extends ToBlockManagerMaster case object GetStorageStatus extends ToBlockManagerMaster + + // For testing. Have the master ask all slaves for the given block's storage level. + case class AskForStorageLevels(blockId: BlockId) extends ToBlockManagerMaster + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 8c2ccbe6a7e66..85b8ec40c0ea3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -47,7 +47,10 @@ class BlockManagerSlaveActor( mapOutputTracker.unregisterShuffle(shuffleId) } - case RemoveBroadcast(broadcastId, _) => - blockManager.removeBroadcast(broadcastId) + case RemoveBroadcast(broadcastId, removeFromDriver) => + blockManager.removeBroadcast(broadcastId, removeFromDriver) + + case GetStorageLevel(blockId) => + sender ! blockManager.getLevel(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..a462654197ea0 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -19,67 +19,241 @@ package org.apache.spark import org.scalatest.FunSuite +import org.apache.spark.storage._ +import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId} + class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { - super.afterEach() - System.clearProperty("spark.broadcast.factory") - } + private val httpConf = broadcastConf("HttpBroadcastFactory") + private val torrentConf = broadcastConf("TorrentBroadcastFactory") test("Using HttpBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing HttpBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing HttpBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } test("Using TorrentBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing TorrentBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing TorrentBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + + test("Unpersisting HttpBroadcast on executors only") { + testUnpersistHttpBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting HttpBroadcast on executors and driver") { + testUnpersistHttpBroadcast(2, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only") { + testUnpersistTorrentBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver") { + testUnpersistTorrentBroadcast(2, removeFromDriver = true) + } + + /** + * Verify the persistence of state associated with an HttpBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks and the broadcast file + * are present only on the expected nodes. + */ + private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = Seq[BlockId](BroadcastBlockId(id)) + + // Verify that the broadcast file is created, and blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] + val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + assert(levels.size === 1) + levels.head match { case (bm, level) => + assert(bm.executorId === "") + assert(level === StorageLevel.MEMORY_AND_DISK) + } + assert(HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. In the latter case, also verify that the broadcast file is deleted on the driver. + def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] + val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + assert(levels.size === (if (removeFromDriver) 0 else 1)) + assert(removeFromDriver === !HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + } + + testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks are present only on the + * expected nodes. + */ + private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = { + val broadcastBlockId = BroadcastBlockId(id) + val metaBlockId = BroadcastHelperBlockId(broadcastBlockId, "meta") + // Assume broadcast value is small enough to fit into 1 piece + val pieceBlockId = BroadcastHelperBlockId(broadcastBlockId, "piece0") + Seq[BlockId](broadcastBlockId, metaBlockId, pieceBlockId) + } + + // Verify that blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + assert(levels.size === 1) + levels.head match { case (bm, level) => + assert(bm.executorId === "") + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + blockId match { + case BroadcastHelperBlockId(_, "meta") => + // Meta data is only on the driver + assert(levels.size === 1) + levels.head match { case (bm, _) => assert(bm.executorId === "") } + case _ => + // Other blocks are on both the executors and the driver + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. + def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + var waitTimeMs = 1000L + blockIds.foreach { blockId => + // Allow a second for the messages triggered by unpersist to propagate to prevent deadlocks + val levels = bmm.askForStorageLevels(blockId, waitTimeMs) + assert(levels.size === expectedNumBlocks) + waitTimeMs = 0L + } + } + + testUnpersistBroadcast(numSlaves, torrentConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * This test runs in 4 steps: + * + * 1) Create broadcast variable, and verify that all state is persisted on the driver. + * 2) Use the broadcast variable on all executors, and verify that all state is persisted + * on both the driver and the executors. + * 3) Unpersist the broadcast, and verify that all state is removed where they should be. + * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable. + */ + private def testUnpersistBroadcast( + numSlaves: Int, + broadcastConf: SparkConf, + getBlockIds: Long => Seq[BlockId], + afterCreation: (Seq[BlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BlockId], BlockManagerMaster) => Unit, + removeFromDriver: Boolean) { + + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + val blockManagerMaster = sc.env.blockManager.master + val list = List[Int](1, 2, 3, 4) + + // Create broadcast variable + val broadcast = sc.broadcast(list) + val blocks = getBlockIds(broadcast.id) + afterCreation(blocks, blockManagerMaster) + + // Use broadcast variable on all executors + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + afterUsingBroadcast(blocks, blockManagerMaster) + + // Unpersist broadcast + broadcast.unpersist(removeFromDriver) + afterUnpersist(blocks, blockManagerMaster) + + if (!removeFromDriver) { + // The broadcast variable is not completely destroyed (i.e. state still exists on driver) + // Using the variable again should yield the same answer as before. + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } + + /** Helper method to create a SparkConf that uses the given broadcast factory. */ + private def broadcastConf(factoryName: String): SparkConf = { + val conf = new SparkConf + conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) + conf } }