Skip to content

Commit

Permalink
fixed model call so that uses type argument
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire committed Mar 5, 2015
1 parent ea09b28 commit 900b586
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -310,21 +310,21 @@ object NaiveBayes {
*
* The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]])
* or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle
* discrete count data and can be called by setting the model type to "Multinomial".
* discrete count data and can be called by setting the model type to "multinomial".
* For example, it can be used with word counts or TF_IDF vectors of documents.
* The Bernoulli model fits presence or absence (0-1) counts. By making every vector a
* 0-1 vector and setting the model type to "Bernoulli", the fits and predicts as
* 0-1 vector and setting the model type to "bernoulli", the fits and predicts as
* Bernoulli NB.
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
* @param lambda The smoothing parameter
*
* @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
* Multinomial or Bernoulli
* multinomial or bernoulli
*/
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
new NaiveBayes(lambda, Multinomial).run(input)
new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()

val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
val model = NaiveBayes.train(testRDD, 1.0, "multinomial")
validateModelFit(pi, theta, model)

val validationData = NaiveBayesSuite.generateNaiveBayesInput(
Expand Down Expand Up @@ -161,7 +161,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()

val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
val model = NaiveBayes.train(testRDD, 1.0, "bernoulli")
validateModelFit(pi, theta, model)

val validationData = NaiveBayesSuite.generateNaiveBayesInput(
Expand Down

0 comments on commit 900b586

Please sign in to comment.