Skip to content

Commit

Permalink
removing dummy bin calculation for categorical variables
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Mar 12, 2014
1 parent 6068356 commit 2116360
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ object DecisionTree extends Serializable with Logging {
}
throw new UnknownError("no bin was found for continuous variable.")
} else {

for (binIndex <- 0 until strategy.numBins) {
val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
for (binIndex <- 0 until numCategoricalBins) {
val bin = bins(featureIndex)(binIndex)
val category = bin.category
val features = labeledPoint.features
Expand Down Expand Up @@ -917,13 +917,6 @@ object DecisionTree extends Serializable with Logging {
bins(featureIndex)(numBins-1)
= new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex,
Continuous), Continuous, Double.MinValue)
} else {
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
for (i <- maxFeatureValue until numBins){
bins(featureIndex)(i)
= new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
new DummyCategoricalSplit(featureIndex, Categorical), Categorical, Double.MaxValue)
}
}
}
(splits,bins)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2,
1-> 2))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
assert(splits.length==2)
assert(bins.length==2)
Expand Down Expand Up @@ -120,7 +121,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0)(1).highSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.contains(0.0))

assert(bins(0)(2).category == Double.MaxValue)
assert(bins(0)(2) == null)

assert(bins(1)(0).category == 0.0)
assert(bins(1)(0).lowSplit.categories.length == 0)
Expand All @@ -134,15 +135,16 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(1)(1).highSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.contains(1.0))

assert(bins(1)(2).category == Double.MaxValue)
assert(bins(1)(2) == null)

}

test("split and bin calculations for categorical variables with no sample for one category"){
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3,
1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)

//Checking splits
Expand Down Expand Up @@ -217,7 +219,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0)(2).highSplit.categories.contains(0.0))
assert(bins(0)(2).highSplit.categories.contains(2.0))

assert(bins(0)(3).category == Double.MaxValue)
assert(bins(0)(3) == null)

assert(bins(1)(0).category == 0.0)
assert(bins(1)(0).lowSplit.categories.length == 0)
Expand All @@ -240,7 +242,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(1)(2).highSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.contains(2.0))

assert(bins(1)(3).category == Double.MaxValue)
assert(bins(1)(3) == null)


}
Expand All @@ -249,10 +251,12 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3,
1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)

val split = bestSplits(0)._1
assert(split.categories.length == 1)
Expand All @@ -272,10 +276,12 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3,
1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)

val split = bestSplits(0)._1
assert(split.categories.length == 1)
Expand Down Expand Up @@ -305,7 +311,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)

strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
assert(bestSplits.length == 1)
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
Expand All @@ -329,7 +336,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)

strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
assert(bestSplits.length == 1)
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
Expand All @@ -355,7 +363,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)

strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
assert(bestSplits.length == 1)
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
Expand All @@ -379,7 +388,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)

strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
assert(bestSplits.length == 1)
assert(0==bestSplits(0)._1.feature)
assert(10==bestSplits(0)._1.threshold)
Expand Down

0 comments on commit 2116360

Please sign in to comment.