# Naive Bayes with PySpark

This notebook creates and measures a Naive Bayes classifier with PySpark

## Imports

In [None]:
# Set SPARK_HOME
# environ["SPARK_HOME"] = "/home/students/spark-2.2.0"

import findspark
findspark.init()

from pyspark import SparkContext
from pyspark.sql import SQLContext

from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

## Get Some Context

In [None]:
# Create a SparkContext and a SQLContext context to use
sc = SparkContext(appName="Naive Bayes Classification with Spark")
sqlContext = SQLContext(sc)

## Load and Prepare the Data

In [None]:
DATA_FILE = "/Users/robert.dempsey/Dev/daamlobd/data/mllib/sample_libsvm_data.txt"

In [None]:
# Load the training data
data = sqlContext.read.format("libsvm").load(DATA_FILE)
data.show(5)

In [None]:
# View a single row
data.take(1)

## Fit a Naive Bayes Model

In [None]:
# Split the data into train and test sets
splits = data.randomSplit([0.6, 0.4], 1234)
train = splits[0]
test = splits[1]

In [None]:
# Create an instance of a NaiveBayes model
nb = NaiveBayes(smoothing=1.0, modelType="multinomial")

In [None]:
# Train the model
nb_model = nb.fit(train)

In [None]:
nb_model.pi

## Create Predictions

In [None]:
# Create predictions from the test set
predictions = nb_model.transform(test)
predictions.show(5)

## Model Evaluation

### MulticlassClassificationEvaluator

The MulticlassClassificationEvaluator expects two input columns: prediction and label.

Available metrics:
* f1: a measure of a test's accuracy considering both [precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall). Best value is 1.
  * precision: the fraction of retrieved documents that are relevant to the query
  * recall: the fraction of the relevant documents that are successfully retrieved
* weightedPrecision
* weightedRecall
* accuracy: either the fraction (default) or the count (normalize=False) of correct predictions.

In [None]:
# Use the MulticlassClassificationEvaluator to compute accuracy on the test set
metrics = ['f1','weightedPrecision','weightedRecall','accuracy']
measurements = dict()

for metric in metrics:
    metric_eval = MulticlassClassificationEvaluator(labelCol="label",
                                                    predictionCol="prediction",
                                                    metricName=metric).evaluate(predictions)
    measurements[metric] = metric_eval

for key, value in measurements.items():
    print("{}: {}".format(key, value))

## Shut it Down

In [None]:
sc.stop()