Skip to content

Commit

Permalink
added command line parsing
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 98ec8d5 commit 02c595c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 29 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ object DecisionTree extends Serializable with Logging {

/*Combines the aggregates from partitions
@param agg1 Array containing aggregates from one or more partitions
@param agg2 Array contianing aggregates from one or more partitions
@param agg2 Array containing aggregates from one or more partitions
@return Combined aggregate from agg1 and agg2
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,69 @@
package org.apache.spark.mllib.tree

import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.mllib.tree.impurity.Gini
import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.DecisionTreeModel

object DecisionTreeRunner extends Logging {

val usage = """
Usage: DecisionTreeRunner <master>[slices] --kind <Classification,Regression> --trainDataDir path --testDataDir path [--maxDepth num] [--impurity <Gini,Entropy,Variance>] [--maxBins num]
"""


def main(args: Array[String]) {

if (args.length < 2) {
System.err.println(usage)
System.exit(1)
}

val sc = new SparkContext(args(0), "DecisionTree")
val data = loadLabeledData(sc, args(1))
val maxDepth = args(2).toInt
val maxBins = args(3).toInt

val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, maxBins = maxBins)
val model = new DecisionTree(strategy).train(data)

val accuracy = accuracyScore(model, data)
val arglist = args.toList.drop(1)
type OptionMap = Map[Symbol, Any]

def nextOption(map : OptionMap, list: List[String]) : OptionMap = {
def isSwitch(s : String) = (s(0) == '-')
list match {
case Nil => map
case "--kind" :: string :: tail => nextOption(map ++ Map('kind -> string), tail)
case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail)
case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail)
case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail)
case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail)
case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail)
case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail)
case option :: tail => println("Unknown option "+option)
exit(1)
}
}
val options = nextOption(Map(),arglist)
logDebug(options.toString())

val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString)

val typeStr = options.get('type).toString
//TODO: Create enum
val impurityStr = options.getOrElse('impurity,if (typeStr == "classification") "Gini" else "Variance").toString
val impurity = {
impurityStr match {
case "Gini" => Gini
case "Entropy" => Entropy
case "Variance" => Variance
}
}
val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt
val maxBins = options.getOrElse('maxBins,"100").toString.toInt

val strategy = new Strategy(kind = typeStr, impurity = Gini, maxDepth = maxDepth, maxBins = maxBins)
val model = new DecisionTree(strategy).train(trainData)

val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)
val accuracy = accuracyScore(model, testData)
logDebug("accuracy = " + accuracy)

sc.stop()
Expand Down

0 comments on commit 02c595c

Please sign in to comment.