Skip to content

Commit

Permalink
Fix tests + simplify sc.getRDDStorageInfo
Browse files Browse the repository at this point in the history
This adds more beef to the test that was failing.
  • Loading branch information
andrewor14 committed Jul 31, 2014
1 parent da8e322 commit b12fcd7
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 23 deletions.
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.storage._
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}

Expand Down Expand Up @@ -840,9 +840,9 @@ class SparkContext(config: SparkConf) extends Logging {
*/
@DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
val rddInfos = StorageUtils.makeRddInfo(this)
val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
rddInfos.toArray
rddInfos.filter(_.isCached)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,9 @@ case class BlockStatus(
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long,
tachyonSize: Long)
tachyonSize: Long) {
def isCached: Boolean = memSize + diskSize + tachyonSize > 0
}

private[spark] class BlockManagerInfo(
val blockManagerId: BlockManagerId,
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/RDDInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class RDDInfo(
var diskSize = 0L
var tachyonSize = 0L

def isCached: Boolean = (memSize + diskSize + tachyonSize > 0) && numCachedPartitions > 0

override def toString = {
import Utils.bytesToString
("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " +
Expand Down
13 changes: 1 addition & 12 deletions core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,6 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
/** Helper methods for storage-related objects. */
private[spark] object StorageUtils {

/** Return a list of RDDInfo based on the RDDs cached in the given SparkContext. */
def makeRddInfo(sc: SparkContext): Seq[RDDInfo] = {
sc.persistentRdds.values.toSeq.map { rdd =>
val name = Option(rdd.name).getOrElse(rdd.id.toString)
val numPartitions = rdd.partitions.size
val storageLevel = rdd.getStorageLevel
val rddInfo = new RDDInfo(rdd.id, name, numPartitions, storageLevel)
rddInfo
}
}

/**
* Update the given list of RDDInfo with the given list of storage statuses.
* This method overwrites the old values stored in the RDDInfo's.
Expand All @@ -142,7 +131,7 @@ private[spark] object StorageUtils {
.flatMap(_.rddBlocks(rddId))
.filter { case (bid, _) => !newBlockIds.contains(bid) } // avoid duplicates
val blocks = (oldBlocks ++ newBlocks).map { case (_, bstatus) => bstatus }
val persistedBlocks = blocks.filter { s => s.memSize + s.diskSize + s.tachyonSize > 0 }
val persistedBlocks = blocks.filter(_.isCached)

// Assume all blocks belonging to the same RDD have the same storage level
val storageLevel = blocks.headOption.map(_.storageLevel).getOrElse(StorageLevel.NONE)
Expand Down
22 changes: 15 additions & 7 deletions core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import org.scalatest.{Assertions, FunSuite}
import org.apache.spark.storage.StorageLevel

class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
test("getPersistentRDDs only returns RDDs that are marked as cached") {
Expand All @@ -35,26 +36,33 @@ class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
test("getPersistentRDDs returns an immutable map") {
sc = new SparkContext("local", "test")
val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()

val myRdds = sc.getPersistentRDDs
assert(myRdds.size === 1)
assert(myRdds.values.head === rdd1)
assert(myRdds(0) === rdd1)
assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)

// myRdds2 should have 2 RDDs, but myRdds should not change
val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()

// getPersistentRDDs should have 2 RDDs, but myRdds should not change
assert(sc.getPersistentRDDs.size === 2)
val myRdds2 = sc.getPersistentRDDs
assert(myRdds2.size === 2)
assert(myRdds2(0) === rdd1)
assert(myRdds2(1) === rdd2)
assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY)
assert(myRdds.size === 1)
assert(myRdds(0) === rdd1)
assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
}

test("getRDDStorageInfo only reports on RDDs that actually persist data") {
sc = new SparkContext("local", "test")
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()

assert(sc.getRDDStorageInfo.size === 0)

rdd.collect()
assert(sc.getRDDStorageInfo.size === 1)
assert(sc.getRDDStorageInfo.head.isCached)
assert(sc.getRDDStorageInfo.head.memSize > 0)
assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)
}

test("call sites report correct locations") {
Expand Down

0 comments on commit b12fcd7

Please sign in to comment.