Skip to content

Commit

Permalink
decison stump functionality working
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 03f534c commit dad0afc
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 43 deletions.
123 changes: 88 additions & 35 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ package org.apache.spark.mllib.tree
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.tree.model._
import org.apache.spark.Logging
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.Split
import org.apache.spark.mllib.tree.impurity.Gini


class DecisionTree(val strategy : Strategy) {
Expand All @@ -46,8 +47,13 @@ class DecisionTree(val strategy : Strategy) {
//Find best split for all nodes at a level
val numNodes= scala.math.pow(2,level).toInt
//TODO: Change the input parent impurities values
val bestSplits = DecisionTree.findBestSplits(input, Array(0.0), strategy, level, filters,splits,bins)
val splits_stats_for_level = DecisionTree.findBestSplits(input, Array(2.0), strategy, level, filters,splits,bins)
for (tmp <- splits_stats_for_level){
println("final best split = " + tmp._1)
}
//TODO: update filters and decision tree model
require(scala.math.pow(2,level)==splits_stats_for_level.length)

}

return new DecisionTreeModel()
Expand Down Expand Up @@ -77,7 +83,7 @@ object DecisionTree extends Serializable {
level: Int,
filters : Array[List[Filter]],
splits : Array[Array[Split]],
bins : Array[Array[Bin]]) : Array[Split] = {
bins : Array[Array[Bin]]) : Array[(Split, Double, Long, Long)] = {

//Common calculations for multiple nested methods
val numNodes = scala.math.pow(2, level).toInt
Expand All @@ -94,8 +100,9 @@ object DecisionTree extends Serializable {
List[Filter]()
} else {
val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex
val parentFilterIndex = nodeFilterIndex / 2
filters(parentFilterIndex)
//val parentFilterIndex = nodeFilterIndex / 2
//TODO: Check left or right filter
filters(nodeFilterIndex)
}
}

Expand Down Expand Up @@ -230,30 +237,34 @@ object DecisionTree extends Serializable {
//binAggregates.foreach(x => println(x))


def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = {
def calculateGainForSplit(leftNodeAgg: Array[Array[Double]],
featureIndex: Int,
index: Int,
rightNodeAgg: Array[Array[Double]],
topImpurity: Double) : (Double, Long, Long) = {

val left0Count = leftNodeAgg(featureIndex)(2 * index)
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1)
val leftCount = left0Count + left1Count

if (leftCount == 0) return 0

//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)

val right0Count = rightNodeAgg(featureIndex)(2 * index)
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
val rightCount = right0Count + right1Count

if (rightCount == 0) return 0
if (leftCount == 0) return (0, leftCount.toLong, rightCount.toLong)

//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)

if (rightCount == 0) return (0, leftCount.toLong, rightCount.toLong)

//println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)

val leftWeight = leftCount.toDouble / (leftCount + rightCount)
val rightWeight = rightCount.toDouble / (leftCount + rightCount)

topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
(topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong)

}

Expand Down Expand Up @@ -295,9 +306,10 @@ object DecisionTree extends Serializable {
(leftNodeAgg, rightNodeAgg)
}

def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double): Array[Array[Double]] = {
def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double)
: Array[Array[(Double,Long,Long)]] = {

val gains = Array.ofDim[Double](numFeatures, numSplits - 1)
val gains = Array.ofDim[(Double,Long,Long)](numFeatures, numSplits - 1)

for (featureIndex <- 0 until numFeatures) {
for (index <- 0 until numSplits -1) {
Expand All @@ -313,40 +325,44 @@ object DecisionTree extends Serializable {
@param binData Array[Double] of size 2*numSplits*numFeatures
*/
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : Split = {
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, Double, Long, Long) = {
println("node impurity = " + nodeImpurity)
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)

//println("gains.size = " + gains.size)
//println("gains(0).size = " + gains(0).size)

val (bestFeatureIndex,bestSplitIndex) = {
val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount) = {
var bestFeatureIndex = 0
var bestSplitIndex = 0
var maxGain = Double.MinValue
var leftSamples = Long.MinValue
var rightSamples = Long.MinValue
for (featureIndex <- 0 until numFeatures) {
for (splitIndex <- 0 until numSplits - 1){
val gain = gains(featureIndex)(splitIndex)
//println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
if(gain > maxGain) {
maxGain = gain
if(gain._1 > maxGain) {
maxGain = gain._1
leftSamples = gain._2
rightSamples = gain._3
bestFeatureIndex = featureIndex
bestSplitIndex = splitIndex
println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex + ", maxGain = " + maxGain)
println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex
+ ", maxGain = " + maxGain + ", leftSamples = " + leftSamples + ",rightSamples = " + rightSamples)
}
}
}
(bestFeatureIndex,bestSplitIndex)
(bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples)
}

splits(bestFeatureIndex)(bestSplitIndex)

//TODo: Return array of node stats with split and impurity information
(splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount)
//TODO: Return array of node stats with split and impurity information
}

//Calculate best splits for all nodes at a given level
val bestSplits = new Array[Split](numNodes)
val bestSplits = new Array[(Split, Double, Long, Long)](numNodes)
for (node <- 0 until numNodes){
val shift = 2*node*numSplits*numFeatures
val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures)
Expand Down Expand Up @@ -381,9 +397,6 @@ object DecisionTree extends Serializable {
val sampledInput = input.sample(false, fraction, 42).collect()
val numSamples = sampledInput.length

//TODO: Remove this requirement
require(numSamples > numSplits, "length of input samples should be greater than numSplits")

//Find the number of features by looking at the first sample
val numFeatures = input.take(1)(0).features.length

Expand All @@ -395,14 +408,22 @@ object DecisionTree extends Serializable {
//Find all splits
for (featureIndex <- 0 until numFeatures){
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
val stride : Double = numSamples.toDouble/numSplits

println("stride = " + stride)

for (index <- 0 until numSplits-1) {
val sampleIndex = (index+1)*stride.toInt
val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous")
splits(featureIndex)(index) = split
if (numSamples < numSplits) {
//TODO: Test this
println("numSamples = " + numSamples + ", less than numSplits = " + numSplits)
for (index <- 0 until numSplits-1) {
val split = new Split(featureIndex,featureSamples(index),"continuous")
splits(featureIndex)(index) = split
}
} else {
val stride : Double = numSamples.toDouble/numSplits
println("stride = " + stride)
for (index <- 0 until numSplits-1) {
val sampleIndex = (index+1)*stride.toInt
val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous")
splits(featureIndex)(index) = split
}
}
}

Expand Down Expand Up @@ -430,4 +451,36 @@ object DecisionTree extends Serializable {
}
}

def main(args: Array[String]) {

val sc = new SparkContext(args(0), "DecisionTree")
val data = loadLabeledData(sc, args(1))

val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = 2, numSplits = 569)
val model = new DecisionTree(strategy).train(data)

sc.stop()
}

/**
* Load labeled data from a file. The data format used here is
* <L>, <f1> <f2> ...
* where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
*
* @param sc SparkContext
* @param dir Directory to the input data files.
* @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
* the label, and the second element represents the feature values (an array of Double).
*/
def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
sc.textFile(dir).map { line =>
val parts = line.trim().split(",")
val label = parts(0).toDouble
val features = parts.slice(1,parts.length).map(_.toDouble)
LabeledPoint(label, features)
}
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
assert(bestSplits.length == 1)
println(bestSplits(0))
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
assert(0==bestSplits(0)._2)
assert(10==bestSplits(0)._3)
assert(990==bestSplits(0)._4)
}

test("stump with fixed label 1 for Gini"){
Expand All @@ -86,10 +89,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
assert(bestSplits.length == 1)
println(bestSplits(0))
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
assert(0==bestSplits(0)._2)
assert(10==bestSplits(0)._3)
assert(990==bestSplits(0)._4)
}


Expand All @@ -105,10 +111,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
assert(bestSplits.length == 1)
println(bestSplits(0))
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
assert(0==bestSplits(0)._2)
assert(10==bestSplits(0)._3)
assert(990==bestSplits(0)._4)
}

test("stump with fixed label 1 for Entropy"){
Expand All @@ -123,10 +132,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
assert(bestSplits.length == 1)
println(bestSplits(0))
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
assert(0==bestSplits(0)._2)
assert(10==bestSplits(0)._3)
assert(990==bestSplits(0)._4)
}


Expand Down

0 comments on commit dad0afc

Please sign in to comment.