Skip to content

Commit

Permalink
[SPARK-6083] Make Python API example consistent in NaiveBayes
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Mar 1, 2015
1 parent e0e64ba commit 65bbbe9
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions docs/mllib-naive-bayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,31 @@ used for evaluation and prediction.

Note that the Python API does not yet support model save/load but will in the future.

<!-- TODO: Make Python's example consistent with Scala's and Java's. -->
{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.classification import NaiveBayes
from pyspark.mllib.linalg import Vectors
from pyspark.mllib.regression import LabeledPoint

data = sc.textFile("data/mllib/sample_naive_bayes_data.txt")

# Preprocessing
splitData = data.map(lambda line: line.split(','))
parsedData = splitData.map(
lambda parts: LabeledPoint(
float(parts[0]),
Vectors.dense(map(lambda x: float(x), parts[1].split(' ')))
)
)

# an RDD of LabeledPoint
data = sc.parallelize([
LabeledPoint(0.0, [0.0, 0.0])
... # more labeled points
])
# Split data into training (60%) and test (40%)
training, test = parsedData.randomSplit([0.6, 0.4], seed = 0)

# Train a naive Bayes model.
model = NaiveBayes.train(data, 1.0)
model = NaiveBayes.train(training, 1.0)

# Make prediction.
prediction = model.predict([0.0, 0.0])
# Make prediction and test accuracy.
predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label))
accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count()
{% endhighlight %}

</div>
Expand Down

0 comments on commit 65bbbe9

Please sign in to comment.