Skip to content

Commit

Permalink
working version of multi-level split calculation
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 4798aae commit 80e8c66
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 28 deletions.
75 changes: 53 additions & 22 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.mllib.tree.model.Split
import org.apache.spark.mllib.tree.impurity.Gini


class DecisionTree(val strategy : Strategy) {
class DecisionTree(val strategy : Strategy) extends Logging {

def train(input : RDD[LabeledPoint]) : DecisionTreeModel = {

Expand All @@ -42,20 +42,43 @@ class DecisionTree(val strategy : Strategy) {

val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1
val filters = new Array[List[Filter]](maxNumNodes)
filters(0) = List()
val parentImpurities = new Array[Double](maxNumNodes)
//Dummy value for top node (calculate from scratch during first split calculation)
parentImpurities(0) = Double.MinValue

for (level <- 0 until maxDepth){

println("#####################################")
println("level = " + level)
println("#####################################")

//Find best split for all nodes at a level
val numNodes= scala.math.pow(2,level).toInt
//TODO: Change the input parent impurities values
val splits_stats_for_level = DecisionTree.findBestSplits(input, Array(2.0), strategy, level, filters,splits,bins)
for (tmp <- splits_stats_for_level){
println("final best split = " + tmp._1)
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins)
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){
for (i <- 0 to 1){
val nodeIndex = (scala.math.pow(2,level+1)).toInt - 1 + 2*index + i
if(level < maxDepth - 1){
val impurity = if (i == 0) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity
println("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
parentImpurities(nodeIndex) = impurity
println("updating nodeIndex = " + nodeIndex)
filters(nodeIndex) = new Filter(nodeSplitStats._1, if(i == 0) - 1 else 1) :: filters((nodeIndex-1)/2)
for (filter <- filters(nodeIndex)){
println(filter)
}
}
}
println("final best split = " + nodeSplitStats._1)
}
//TODO: update filters and decision tree model
require(scala.math.pow(2,level)==splits_stats_for_level.length)
require(scala.math.pow(2,level)==splitsStatsForLevel.length)


}

//TODO: Extract decision tree model

return new DecisionTreeModel()
}

Expand Down Expand Up @@ -99,7 +122,7 @@ object DecisionTree extends Serializable {
if (level == 0) {
List[Filter]()
} else {
val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
//val parentFilterIndex = nodeFilterIndex / 2
//TODO: Check left or right filter
filters(nodeFilterIndex)
Expand Down Expand Up @@ -155,11 +178,11 @@ object DecisionTree extends Serializable {
// calculating bin index and label per feature per node
val arr = new Array[Double](1+(numFeatures * numNodes))
arr(0) = labeledPoint.label
for (nodeIndex <- 0 until numNodes) {
val parentFilters = findParentFilters(nodeIndex)
for (index <- 0 until numNodes) {
val parentFilters = findParentFilters(index)
//Find out whether the sample qualifies for the particular node
val sampleValid = isSampleValid(parentFilters, labeledPoint)
val shift = 1 + numFeatures * nodeIndex
val shift = 1 + numFeatures * index
if (!sampleValid) {
//Add to invalid bin index -1
for (featureIndex <- 0 until numFeatures) {
Expand Down Expand Up @@ -251,22 +274,26 @@ object DecisionTree extends Serializable {
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
val rightCount = right0Count + right1Count

val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)

if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong)
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0)

//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)


//println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)

val leftWeight = leftCount.toDouble / (leftCount + rightCount)
val rightWeight = rightCount.toDouble / (leftCount + rightCount)

val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
val gain = {
if (level > 0) {
impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
} else {
impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
}
}

new InformationGainStats(gain,topImpurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)

}

Expand Down Expand Up @@ -339,7 +366,7 @@ object DecisionTree extends Serializable {
var bestFeatureIndex = 0
var bestSplitIndex = 0
//Initialization with infeasible values
var bestGainStats = new InformationGainStats(-1.0,-1.0,-1.0,0,-1.0,0)
var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0)
// var maxGain = Double.MinValue
// var leftSamples = Long.MinValue
// var rightSamples = Long.MinValue
Expand All @@ -351,8 +378,8 @@ object DecisionTree extends Serializable {
bestGainStats = gainStats
bestFeatureIndex = featureIndex
bestSplitIndex = splitIndex
println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex
+ ", gain stats = " + bestGainStats)
//println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex)
//println( "gain stats = " + bestGainStats)
}
}
}
Expand All @@ -365,9 +392,12 @@ object DecisionTree extends Serializable {
//Calculate best splits for all nodes at a given level
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
for (node <- 0 until numNodes){
val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node
val shift = 2*node*numSplits*numFeatures
val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures)
val parentNodeImpurity = parentImpurities(node/2)
println("nodeImpurityIndex = " + nodeImpurityIndex)
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
println("node impurity = " + parentNodeImpurity)
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
}

Expand Down Expand Up @@ -456,8 +486,9 @@ object DecisionTree extends Serializable {

val sc = new SparkContext(args(0), "DecisionTree")
val data = loadLabeledData(sc, args(1))
val maxDepth = args(2).toInt

val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = 2, numSplits = 569)
val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, numSplits = 569)
val model = new DecisionTree(strategy).train(data)

sc.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ package org.apache.spark.mllib.tree.impurity

object Gini extends Impurity {

def calculate(c0 : Double, c1 : Double): Double = {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
1 - f0*f0 - f1*f1
}
def calculate(c0 : Double, c1 : Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
1 - f0*f0 - f1*f1
}
}

}

0 comments on commit 80e8c66

Please sign in to comment.