Skip to content

Commit

Permalink
skeletal framework
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 12738c1 commit cd53eae
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.classification

class ClassificationTree {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.regression

class RegressionTree {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.{Split, Bin, DecisionTreeModel}


class DecisionTree(val strategy : Strategy) {

def train(input : RDD[LabeledPoint]) : DecisionTreeModel = {

//Cache input RDD for speedup during multiple passes
input.cache()

//TODO: Find all splits and bins using quantiles including support for categorical features, single-pass
val (splits, bins) = DecisionTree.find_splits_bins(input, strategy)

//TODO: Level-wise training of tree and obtain Decision Tree model


return new DecisionTreeModel()
}

}

object DecisionTree {
def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {
val numSplits = strategy.numSplits
//TODO: Justify this calculation
val requiredSamples : Long = numSplits*numSplits
val count : Long = input.count()
val numSamples : Long = if (requiredSamples < count) requiredSamples else count
val numFeatures = input.take(1)(0).features.length
(Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits))
}

}
15 changes: 15 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
This package contains the default implementation of the decision tree algorithm.

The decision tree algorithm supports:
+ information loss calculation with entropy and gini for classification and variance for regression
+ node model pruning
+ printing to dot files
+ unit tests

#Performance testing

#Future Extensions

+ Random forests
+ Boosting
+ Extremely randomized trees
28 changes: 28 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree

import org.apache.spark.mllib.tree.impurity.Impurity

class Strategy (
val kind : String,
val impurity : Impurity,
val maxDepth : Int,
val numSplits : Int,
val quantileCalculationStrategy : String = "sampleAndSort") {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impurity

object Entropy extends Impurity {

def log2(x: Double) = scala.math.log(x) / scala.math.log(2)

def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
-(f0 * log2(f0)) - (f1 * log2(f1))
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impurity

object Gini extends Impurity {

def calculate(c0 : Double, c1 : Double): Double = {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
1 - f0*f0 - f1*f1
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impurity

trait Impurity {

def calculate(c0 : Double, c1 : Double): Double

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException

object Variance extends Impurity {
def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate")
}
21 changes: 21 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model

case class Bin(kind : String, lowSplit : Split, highSplit : Split) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model

class DecisionTreeModel {

}
29 changes: 29 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model

case class Split(
val feature: Int,
val threshold : Double,
val kind : String) {

}

class dummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind)

class dummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind)

0 comments on commit cd53eae

Please sign in to comment.