Skip to content

Commit

Permalink
CR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Apr 9, 2014
1 parent b78804e commit e8741a7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
18 changes: 7 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Expand Up @@ -22,7 +22,7 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import scala.reflect._
import scala.reflect.ClassTag

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -181,21 +181,17 @@ object MLUtils {
dataStr.saveAsTextFile(dir)
}

def meanSquaredError(a: Double, b: Double): Double = {
(a-b)*(a-b)
}

/**
* Return a k element list of pairs of RDDs with the first element of each pair
* containing a unique 1/Kth of the data and the second element contain the composite of that.
* containing a unique 1/Kth of the data and the second element contain the compliment of that.
*/
def kFoldRdds[T : ClassTag](rdd: RDD[T], folds: Int, seed: Int): List[Pair[RDD[T], RDD[T]]] = {
def kFold[T : ClassTag](rdd: RDD[T], folds: Int, seed: Int): List[Pair[RDD[T], RDD[T]]] = {
val foldsF = folds.toFloat
1.to(folds).map(fold => ((
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, false),
seed),
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, true),
seed)
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF,
complement = false), seed),
new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF,
complement = true), seed)
))).toList
}

Expand Down
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.mllib.util

import java.io.File
import scala.math
import scala.util.Random

import org.scalatest.FunSuite

Expand Down Expand Up @@ -136,19 +138,30 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
new LinearRegressionModel(Array(1.0), 0)
}

test("kfoldRdd") {
test("kFold") {
val data = sc.parallelize(1 to 100, 2)
val collectedData = data.collect().sorted
val twoFoldedRdd = MLUtils.kFoldRdds(data, 2, 1)
val twoFoldedRdd = MLUtils.kFold(data, 2, 1)
assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted)
assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted)
for (folds <- 2 to 10) {
for (seed <- 1 to 5) {
val foldedRdds = MLUtils.kFoldRdds(data, folds, seed)
val foldedRdds = MLUtils.kFold(data, folds, seed)
assert(foldedRdds.size === folds)
foldedRdds.map{case (test, train) =>
val result = test.union(train).collect().sorted
assert(test.collect().size > 0, "Non empty test data")
val testSize = test.collect().size.toFloat
assert(testSize > 0, "Non empty test data")
val p = 1 / folds.toFloat
// Within 3 standard deviations of the mean
val range = 3 * math.sqrt(100 * p * (1-p))
val expected = 100 * p
val lowerBound = expected - range
val upperBound = expected + range
assert(testSize > lowerBound,
"Test data (" + testSize + ") smaller than expected (" + lowerBound +")" )
assert(testSize < upperBound,
"Test data (" + testSize + ") larger than expected (" + upperBound +")" )
assert(train.collect().size > 0, "Non empty training data")
assert(result === collectedData,
"Each training+test set combined contains all of the data")
Expand Down

0 comments on commit e8741a7

Please sign in to comment.