Skip to content

Commit

Permalink
additional code for creating intermediate RDD
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 92cedce commit 8bca1e2
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 26 deletions.
124 changes: 100 additions & 24 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Expand Up @@ -37,7 +37,6 @@ class DecisionTree(val strategy : Strategy) {
val (splits, bins) = DecisionTree.find_splits_bins(input, strategy)

//TODO: Level-wise training of tree and obtain Decision Tree model

val maxDepth = strategy.maxDepth

val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1
Expand All @@ -55,8 +54,20 @@ class DecisionTree(val strategy : Strategy) {

}

object DecisionTree extends Logging {
object DecisionTree extends Serializable {

/*
Returns an Array[Split] of optimal splits for all nodes at a given level
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
@param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
@param level Level of the tree
@param filters Filter for all nodes at a given level
@param splits possible splits for all features
@param bins possible bins for all features
@return Array[Split] instance for best splits for all nodes at a given level.
*/
def findBestSplits(
input : RDD[LabeledPoint],
strategy: Strategy,
Expand All @@ -65,6 +76,16 @@ object DecisionTree extends Logging {
splits : Array[Array[Split]],
bins : Array[Array[Bin]]) : Array[Split] = {

//TODO: Move these calculations outside
val numNodes = scala.math.pow(2, level).toInt
println("numNodes = " + numNodes)
//Find the number of features by looking at the first sample
val numFeatures = input.take(1)(0).features.length
println("numFeatures = " + numFeatures)
val numSplits = strategy.numSplits
println("numSplits = " + numSplits)

/*Find the filters used before reaching the current code*/
def findParentFilters(nodeIndex: Int): List[Filter] = {
if (level == 0) {
List[Filter]()
Expand All @@ -75,6 +96,10 @@ object DecisionTree extends Logging {
}
}

/*Find whether the sample is valid input for the current node.
In other words, does it pass through all the filters for the current node.
*/
def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {

for (filter <- parentFilters) {
Expand All @@ -91,79 +116,130 @@ object DecisionTree extends Logging {
true
}

/*Finds the right bin for the given feature*/
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {

//TODO: Do binary search
for (binIndex <- 0 until strategy.numSplits) {
val bin = bins(featureIndex)(binIndex)
//TODO: Remove this requirement post basic functional testing
require(bin.lowSplit.feature == featureIndex)
require(bin.highSplit.feature == featureIndex)
//TODO: Remove this requirement post basic functional
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
val features = labeledPoint.features
if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) {
if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
return binIndex
}
}
throw new UnknownError("no bin was found.")

}
def findBinsForLevel: Array[Double] = {

val numNodes = scala.math.pow(2, level).toInt
//Find the number of features by looking at the first sample
val numFeatures = input.take(1)(0).features.length
/*Finds bins for all nodes (and all features) at a given level
k features, l nodes
Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk
Denotes invalid sample for tree by noting bin for feature 1 as -1
*/
def findBinsForLevel(labeledPoint : LabeledPoint) : Array[Double] = {


//TODO: Bit pack more by removing redundant label storage
// calculating bin index and label per feature per node
val arr = new Array[Double](2 * numFeatures * numNodes)
val arr = new Array[Double](1+(numFeatures * numNodes))
arr(0) = labeledPoint.label
for (nodeIndex <- 0 until numNodes) {
val parentFilters = findParentFilters(nodeIndex)
//Find out whether the sample qualifies for the particular node
val sampleValid = isSampleValid(parentFilters, labeledPoint)
val shift = 2 * numFeatures * nodeIndex
if (sampleValid) {
val shift = 1 + numFeatures * nodeIndex
if (!sampleValid) {
//Add to invalid bin index -1
for (featureIndex <- shift until (shift + numFeatures) by 2) {
arr(featureIndex + 1) = -1
arr(featureIndex + 2) = labeledPoint.label
for (featureIndex <- 0 until numFeatures) {
arr(shift+featureIndex) = -1
//TODO: Break since marking one bin is sufficient
}
} else {
for (featureIndex <- 0 until numFeatures) {
arr(shift + (featureIndex * 2) + 1) = findBin(featureIndex, labeledPoint)
arr(shift + (featureIndex * 2) + 2) = labeledPoint.label
//println("shift+featureIndex =" + (shift+featureIndex))
arr(shift + featureIndex) = findBin(featureIndex, labeledPoint)
}
}

}
arr
}

val binMappedRDD = input.map(labeledPoint => findBinsForLevel)
/*
Performs a sequential aggreation over a partition
@param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
and 3*numSplits*numFeatures*numNodes for regression
@param arr Array[Double] of size 1+(numFeatures*numNodes)
@return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
and 3*numSplits*numFeatures*numNodes for regression
*/
def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = {
for (node <- 0 until numNodes) {
val validSignalIndex = 1+numFeatures*node
val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false
if(isSampleValidForNode) {
for (feature <- 0 until numFeatures){
val arrShift = 1 + numFeatures*node
val aggShift = numSplits*numFeatures*node
val arrIndex = arrShift + feature
val aggIndex = aggShift + feature*numSplits + arr(arrIndex).toInt
agg(aggIndex) = agg(aggIndex) + 1
}
}
}
agg
}

def binCombOp(par1 : Array[Double], par2: Array[Double]) : Array[Double] = {
par1
}

println("input = " + input.count)
val binMappedRDD = input.map(x => findBinsForLevel(x))
println("binMappedRDD.count = " + binMappedRDD.count)
//calculate bin aggregates

val binAggregates = binMappedRDD.aggregate(Array.fill[Double](numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp)

//find best split
println("binAggregates.length = " + binAggregates.length)


Array[Split]()
val bestSplits = new Array[Split](numNodes)
for (node <- 0 until numNodes){
val binsForNode = binAggregates.slice(node,numSplits*node)
}

bestSplits
}

/*
Returns split and bins for decision tree calculation.
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
@param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
@return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an
Array[Array[Bin]] of size (numFeatures,numSplits1)
*/
def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {

val numSplits = strategy.numSplits
logDebug("numSplits = " + numSplits)
println("numSplits = " + numSplits)

//Calculate the number of sample for approximate quantile calculation
//TODO: Justify this calculation
val requiredSamples = numSplits*numSplits
val count = input.count()
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
logDebug("fraction of data used for calculating quantiles = " + fraction)
println("fraction of data used for calculating quantiles = " + fraction)

//sampled input for RDD calculation
val sampledInput = input.sample(false, fraction, 42).collect()
val numSamples = sampledInput.length

//TODO: Remove this requirement
require(numSamples > numSplits, "length of input samples should be greater than numSplits")

//Find the number of features by looking at the first sample
Expand Down
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree

import org.apache.spark.mllib.tree.impurity.Impurity

class Strategy (
case class Strategy (
val kind : String,
val impurity : Impurity,
val maxDepth : Int,
Expand Down
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.mllib.tree.impurity

trait Impurity {
trait Impurity extends Serializable {

def calculate(c0 : Double, c1 : Double): Double

Expand Down
Expand Up @@ -28,6 +28,7 @@ import org.jblas._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.Gini
import org.apache.spark.mllib.tree.model.Filter

class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {

Expand All @@ -54,6 +55,23 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)
println(splits(1)(98))
}

test("stump"){
val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Gini,3,100,"sort")
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
assert(bins.length==2)
assert(bins(0).length==100)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
DecisionTree.findBestSplits(rdd,strategy,0,Array[List[Filter]](),splits,bins)
}

}

object DecisionTreeSuite {
Expand Down

0 comments on commit 8bca1e2

Please sign in to comment.