Skip to content

Commit

Permalink
some more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 0012a77 commit 03f534c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 23 deletions.
43 changes: 27 additions & 16 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Expand Up @@ -121,7 +121,7 @@ object DecisionTree extends Serializable {

/*Finds the right bin for the given feature*/
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {
println("finding bin for labeled point " + labeledPoint.features(featureIndex))
//println("finding bin for labeled point " + labeledPoint.features(featureIndex))
//TODO: Do binary search
for (binIndex <- 0 until strategy.numSplits) {
val bin = bins(featureIndex)(binIndex)
Expand Down Expand Up @@ -227,21 +227,27 @@ object DecisionTree extends Serializable {

val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp)
println("binAggregates.length = " + binAggregates.length)
binAggregates.foreach(x => println(x))
//binAggregates.foreach(x => println(x))


def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = {

val left0Count = leftNodeAgg(featureIndex)(2 * index)
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1)
val leftCount = left0Count + left1Count
println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)

if (leftCount == 0) return 0

//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)

val right0Count = rightNodeAgg(featureIndex)(2 * index)
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
val rightCount = right0Count + right1Count
println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)

if (rightCount == 0) return 0

//println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)

val leftWeight = leftCount.toDouble / (leftCount + rightCount)
Expand All @@ -261,21 +267,21 @@ object DecisionTree extends Serializable {
def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
println("binData.length = " + binData.length)
println("binData.sum = " + binData.sum)
//println("binData.length = " + binData.length)
//println("binData.sum = " + binData.sum)
for (featureIndex <- 0 until numFeatures) {
println("featureIndex = " + featureIndex)
//println("featureIndex = " + featureIndex)
val shift = 2*featureIndex*numSplits
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
println("binData(shift + 0) = " + binData(shift + 0))
//println("binData(shift + 0) = " + binData(shift + 0))
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
println("binData(shift + 1) = " + binData(shift + 1))
//println("binData(shift + 1) = " + binData(shift + 1))
rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1)))
println(binData(shift + (2 * (numSplits - 1))))
//println(binData(shift + (2 * (numSplits - 1))))
rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1)
println(binData(shift + (2 * (numSplits - 1)) + 1))
//println(binData(shift + (2 * (numSplits - 1)) + 1))
for (splitIndex <- 1 until numSplits - 1) {
println("splitIndex = " + splitIndex)
//println("splitIndex = " + splitIndex)
leftNodeAgg(featureIndex)(2 * splitIndex)
= binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2)
leftNodeAgg(featureIndex)(2 * splitIndex + 1)
Expand All @@ -295,7 +301,7 @@ object DecisionTree extends Serializable {

for (featureIndex <- 0 until numFeatures) {
for (index <- 0 until numSplits -1) {
println("splitIndex = " + index)
//println("splitIndex = " + index)
gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
}
}
Expand All @@ -312,8 +318,8 @@ object DecisionTree extends Serializable {
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)

println("gains.size = " + gains.size)
println("gains(0).size = " + gains(0).size)
//println("gains.size = " + gains.size)
//println("gains(0).size = " + gains(0).size)

val (bestFeatureIndex,bestSplitIndex) = {
var bestFeatureIndex = 0
Expand All @@ -322,7 +328,7 @@ object DecisionTree extends Serializable {
for (featureIndex <- 0 until numFeatures) {
for (splitIndex <- 0 until numSplits - 1){
val gain = gains(featureIndex)(splitIndex)
println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
//println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
if(gain > maxGain) {
maxGain = gain
bestFeatureIndex = featureIndex
Expand All @@ -335,6 +341,8 @@ object DecisionTree extends Serializable {
}

splits(bestFeatureIndex)(bestSplitIndex)

//TODo: Return array of node stats with split and impurity information
}

//Calculate best splits for all nodes at a given level
Expand Down Expand Up @@ -388,6 +396,9 @@ object DecisionTree extends Serializable {
for (featureIndex <- 0 until numFeatures){
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
val stride : Double = numSamples.toDouble/numSplits

println("stride = " + stride)

for (index <- 0 until numSplits-1) {
val sampleIndex = (index+1)*stride.toInt
val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous")
Expand Down
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext._
import org.jblas._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.Gini
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini}
import org.apache.spark.mllib.tree.model.Filter

class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
Expand All @@ -44,7 +44,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
}

test("split and bin calculation"){
val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints()
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Gini,3,100,"sort")
Expand All @@ -56,8 +56,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
println(splits(1)(98))
}

test("stump"){
val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints()
test("stump with fixed label 0 for Gini"){
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Gini,3,100,"sort")
Expand All @@ -69,17 +69,85 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
assert(bestSplits.length == 1)
println(bestSplits(0))
}

test("stump with fixed label 1 for Gini"){
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Gini,3,100,"sort")
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
assert(bins.length==2)
assert(bins(0).length==100)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
assert(bestSplits.length == 1)
println(bestSplits(0))
}


test("stump with fixed label 0 for Entropy"){
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Entropy,3,100,"sort")
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
assert(bins.length==2)
assert(bins(0).length==100)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
assert(bestSplits.length == 1)
println(bestSplits(0))
}

test("stump with fixed label 1 for Entropy"){
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Entropy,3,100,"sort")
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
assert(bins.length==2)
assert(bins(0).length==100)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
assert(bestSplits.length == 1)
println(bestSplits(0))
}


}

object DecisionTreeSuite {

def generateReverseOrderedLabeledPoints() : Array[LabeledPoint] = {
def generateOrderedLabeledPointsWithLabel0() : Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i))
arr(i) = lp
}
arr
}


def generateOrderedLabeledPointsWithLabel1() : Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
val lp = new LabeledPoint(1.0,Array(i.toDouble,1000.0-i))
val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i))
arr(i) = lp
}
arr
Expand Down

0 comments on commit 03f534c

Please sign in to comment.